squarem-JAXopt

This page contains the API documentation for squarem-JAXopt.

SquaremAcceleration

class squarem_jaxopt.SquaremAcceleration(fixed_point_fun, maxiter=100, tol=1e-05, has_aux=False, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit=True, unroll='auto')[source]

SQUAREM accelerator method for solving fixed-points.

fixed_point_fun

a function fixed_point_fun(x, *args, **kwargs) returning a pytree with the same structure and type as x The function should fulfill the Banach fixed-point theorem’s assumptions.

maxiter

maximum number of iterations.

tol

tolerance (stopping criterion)

has_aux

wether fixed_point_fun returns additional data. (default: False) if True, the fixed is computed only with respect to first element of the sequence returned. Other elements are carried during computation.

verbose

whether to print information on every iteration or not.

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

implicit_diff_solve

the linear system solver to use.

jit

whether to JIT-compile the optimization loop (default: True).

unroll

whether to unroll the optimization loop (default: “auto”)

References

https://doi.org/10.18637/jss.v092.i07