squarem-JAXopt

JAX implementation of the SQUAREM accelerator for solving fixed-point equations, originally proposed by Du and Varadhan (2020).

The SQUAREM accelerator is implemented using JAXopt, enabling efficient automatic differentiation of the fixed-point equations via the implicit function theorem (see Blondel et al., 2022 for details).

Installation

pip install squarem-jaxopt

Quick Start

import jax
import jax.numpy as jnp
from jax import random

from jaxopt import FixedPointIteration, AndersonAcceleration
from squarem_jaxopt import SquaremAcceleration

# Increase precision to 64 bit
jax.config.update("jax_enable_x64", True)

N = 100_000

a = random.uniform(random.PRNGKey(111), (N, 1))


def fun(x: jax.Array) -> jax.Array:
   y = (1 - a) + a * jnp.cos(x)
   return y


initial_guess = jnp.zeros_like(a)

fxp_none = FixedPointIteration(fixed_point_fun=fun, verbose=False)
result_none = fxp_none.run(initial_guess)

fxp_anderson = AndersonAcceleration(fixed_point_fun=fun, verbose=False)
result_anderson = fxp_anderson.run(initial_guess)

fxp_squarem = SquaremAcceleration(fixed_point_fun=fun, verbose=False)
result_squarem = fxp_squarem.run(initial_guess)

print("\n" + "="*60)
print("ALGORITHM COMPARISON TABLE")
print("="*60)
print(f"{'Algorithm':<25} {'Iterations':<12} {'Func Evals':<12} {'Error':<12}")
print("-"*60)
print(f"{'FixedPointIteration':<25} {result_none.state.iter_num:<12} {result_none.state.num_fun_eval:<12} {result_none.state.error:<12.2e}")
print(f"{'AndersonAcceleration':<25} {result_anderson.state.iter_num:<12} {result_anderson.state.num_fun_eval:<12} {result_anderson.state.error:<12.2e}")
print(f"{'SquaremAcceleration':<25} {result_squarem.state.iter_num:<12} {result_squarem.state.num_fun_eval:<12} {result_squarem.state.error:<12.2e}")
print("="*60)