squarem-JAXopt
This page contains the API documentation for squarem-JAXopt.
SquaremAcceleration
- class squarem_jaxopt.SquaremAcceleration(fixed_point_fun, maxiter=100, tol=1e-05, has_aux=False, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit=True, unroll='auto')[source]
SQUAREM accelerator method for solving fixed-points.
- fixed_point_fun
a function
fixed_point_fun(x, *args, **kwargs)returning a pytree with the same structure and type as x The function should fulfill the Banach fixed-point theorem’s assumptions.
- maxiter
maximum number of iterations.
- tol
tolerance (stopping criterion)
- has_aux
wether fixed_point_fun returns additional data. (default: False) if True, the fixed is computed only with respect to first element of the sequence returned. Other elements are carried during computation.
- verbose
whether to print information on every iteration or not.
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- implicit_diff_solve
the linear system solver to use.
- jit
whether to JIT-compile the optimization loop (default: True).
- unroll
whether to unroll the optimization loop (default: “auto”)
References