Skip to content

Batched solving

When you have many independent QPs of the same shape, you almost never want to loop over them in Python. Wrap the solver in jax.vmap and fuse the whole batch into a single accelerator call; add jax.jit to amortise the trace/compile cost over many batches.

The problem

The example draws \(N = 10{,}000\) random non-negative least-squares problems and solves all of them in parallel. Each one is

\[ \min_{x_i}\;\tfrac{1}{2}\,\|F_i x_i - g_i\|_2^{2} \quad \text{s.t.}\quad x_i \ge 0, \qquad i = 1, \dots, N, \]

cast to standard QP form. vmap is applied twice: once to convert the batch of \((F_i, g_i)\) to batched QP data, and once to solve the batched QPs.

Code

"""Solve a batch of non-negative least-squares problems with qpax.

Each problem is

    minimize_x ||F x - g||^2
    subject to x >= 0
"""

import jax.numpy as jnp
import numpy as np
from jax import jit, vmap

import qpax

N_VARS = 5
N_ROWS = 10
N_QPS = 10_000


@jit
def nnls_to_qp(F, g):
    """Convert one NNLS problem to standard QP form."""
    n_vars = F.shape[1]
    Q = F.T @ F
    q = -F.T @ g
    A = jnp.zeros((0, n_vars))
    b = jnp.zeros(0)
    G = -jnp.eye(n_vars)
    h = jnp.zeros(n_vars)
    return Q, q, A, b, G, h


def solve_one_qp(Q, q, A, b, G, h):
    return qpax.solve_qp(Q, q, A, b, G, h)


# Create a batch of random NNLS problems.
Fs = jnp.array(np.random.randn(N_QPS, N_ROWS, N_VARS))
gs = jnp.array(np.random.randn(N_QPS, N_ROWS))

# Convert the whole batch to QP form.
batch_nnls_to_qp = vmap(nnls_to_qp, in_axes=(0, 0))
Qs, qs, As, bs, Gs, hs = batch_nnls_to_qp(Fs, gs)

# Solve all QPs in parallel and keep the convergence diagnostics.
batch_solve_qp = jit(vmap(solve_one_qp, in_axes=(0, 0, 0, 0, 0, 0)))
xs, _, _, _, converged, iters = batch_solve_qp(Qs, qs, As, bs, Gs, hs)

converged = np.asarray(converged)
iters = np.asarray(iters)

print("xs.shape:", xs.shape)
print(f"converged: {int(converged.sum())}/{N_QPS}")
print(f"median iterations: {np.median(iters):.1f}")
print(f"max iterations: {iters.max()}")