Source code for squarem_jaxopt.squarem_jaxopt

"""Implementation of the SQUAREM accelerator method in JAXopt.

References:
Du, Y., & Varadhan, R. (2020). SQUAREM: An R Package for Off-the-Shelf Acceleration of EM, MM and Other EM-Like Monotone Algorithms. Journal of Statistical Software, 92(7), 1–41. https://doi.org/10.18637/jss.v092.i07.
Blondel, M., Berthet, Q., Cuturi, M., Frostig, R., Hoyer, S., Llinares-López, F., Pedregosa, F., & Vert, J.-P. (2021). Efficient and Modular Implicit Differentiation. arXiv. https://arxiv.org/abs/2105.15183.
"""

from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import Optional
from typing import Union

from dataclasses import dataclass

import jax.numpy as jnp

from jaxopt._src import base
from jaxopt._src.tree_util import tree_l2_norm, tree_sub


class SquaremState(NamedTuple):
    """Named tuple containing state information.
    Attributes:
      iter_num: iteration number
      error: residuals of current estimate
      aux: auxiliary output of fixed_point_fun when has_aux=True
      num_fun_eval: number of function evaluations
    """

    iter_num: jnp.ndarray | int
    error: jnp.ndarray | float
    aux: Optional[Any] = None
    num_fun_eval: jnp.ndarray | int = jnp.asarray(0)


[docs] @dataclass(eq=False) class SquaremAcceleration(base.IterativeSolver): """SQUAREM accelerator method for solving fixed-points. Attributes: fixed_point_fun: a function ``fixed_point_fun(x, *args, **kwargs)`` returning a pytree with the same structure and type as x The function should fulfill the Banach fixed-point theorem's assumptions. maxiter: maximum number of iterations. tol: tolerance (stopping criterion) has_aux: wether fixed_point_fun returns additional data. (default: False) if True, the fixed is computed only with respect to first element of the sequence returned. Other elements are carried during computation. verbose: whether to print information on every iteration or not. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. jit: whether to JIT-compile the optimization loop (default: True). unroll: whether to unroll the optimization loop (default: "auto") References: https://doi.org/10.18637/jss.v092.i07 """ fixed_point_fun: Callable maxiter: int = 100 tol: float = 1e-5 has_aux: bool = False verbose: Union[bool, int] = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None jit: bool = True unroll: base.AutoOrBoolean = "auto" def init_state(self, init_params, *args, **kwargs) -> SquaremState: """Initialize the solver state. Args: init_params: initial guess of the fixed point, pytree *args: additional positional arguments to be passed to ``optimality_fun``. **kwargs: additional keyword arguments to be passed to ``optimality_fun``. Returns: state """ return SquaremState( iter_num=jnp.asarray(0), error=jnp.asarray(jnp.inf), aux=None, num_fun_eval=jnp.asarray(0, base.NUM_EVAL_DTYPE), ) def squarem_step(self, params: Any, *args, **kwargs) -> tuple[Any, Any]: """Update fixed-point by SQUAREM Args: params: pytree containing the parameters. *args: additional positional arguments to be passed to ``fixed_point_fun``. **kwargs: additional keyword arguments to be passed to ``fixed_point_fun``. Returns: (next_params, aux) """ params1 = self._fun(params, *args, **kwargs)[0] params2 = self._fun(params1, *args, **kwargs)[0] # Accelerated step r = params1 - params # change v = params2 - params1 - r # curvature alpha = -jnp.sqrt(jnp.sum(r**2) / jnp.sum(v**2)) params3 = jnp.where( jnp.isnan(alpha), params2, params - 2 * alpha * r + (alpha**2) * v ) return self._fun(params3, *args, **kwargs) def update(self, params: Any, state: SquaremState, *args, **kwargs) -> base.OptStep: """Performs one iteration of the SQUAREM accelerator method. Args: params: pytree containing the parameters. state: named tuple containing the solver state. *args: additional positional arguments to be passed to ``fixed_point_fun``. **kwargs: additional keyword arguments to be passed to ``fixed_point_fun``. Returns: (params, state) """ next_params, aux = self.squarem_step(params, *args, **kwargs) error = tree_l2_norm(tree_sub(next_params, params)) next_state = SquaremState( iter_num=state.iter_num + 1, error=error, aux=aux, num_fun_eval=state.num_fun_eval + 3, ) if self.verbose: self.log_info(next_state, error_name="Distance btw Iterates") return base.OptStep(params=next_params, state=next_state) def optimality_fun(self, params, *args, **kwargs): """Optimality function mapping compatible with ``@custom_root``.""" new_params, _ = self._fun(params, *args, **kwargs) return tree_sub(new_params, params) def __post_init__(self): super().__post_init__() if self.has_aux: self._fun = self.fixed_point_fun else: self._fun = lambda *a, **kw: (self.fixed_point_fun(*a, **kw), None) self.reference_signature = self.fixed_point_fun