Batched differentiating¶
solve_qp_primal is fully vmap-compatible, so the batched-solve pattern
extends to gradients with no extra plumbing: stack
jit(vmap(value_and_grad(...))) and you get a batch of losses and a
batch of gradients in one fused call.
The problem¶
The example draws \(N = 256\) random non-negative least-squares problems, cast to standard QP form, and defines a per-problem tracking loss:
\[
\min_{x_i}\;\tfrac{1}{2}\,\|F_i x_i - g_i\|_2^{2}
\quad \text{s.t.}\quad x_i \ge 0,
\qquad
L_i(x_i) \;=\; \|x_i - \mathbf{1}\|_2^{2}.
\]
A single batched call then evaluates every \(L_i\) and every \(\nabla_{Q_i, q_i, A_i, b_i, G_i, h_i}\, L_i\) in parallel.
Code¶
"""Differentiate through a batch of QP solves with qpax.
Each problem is a non-negative least-squares problem
minimize_x ||F x - g||^2
subject to x >= 0
after converting it to standard QP form. For each QP in the batch, this
example differentiates the loss
L(x) = ||x - 1||^2
with respect to the QP data.
"""
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, vmap
import qpax
N_VARS = 5
N_ROWS = 10
N_QPS = 256
def loss(Q, q, A, b, G, h):
x = qpax.solve_qp_primal(Q, q, A, b, G, h)
x_target = jnp.ones_like(x)
return jnp.sum((x - x_target) ** 2)
def nnls_to_qp(F, g):
n_vars = F.shape[1]
Q = F.T @ F
q = -F.T @ g
A = jnp.zeros((0, n_vars))
b = jnp.zeros(0)
G = -jnp.eye(n_vars)
h = jnp.zeros(n_vars)
return Q, q, A, b, G, h
# Create a batch of random NNLS problems.
Fs = jnp.array(np.random.randn(N_QPS, N_ROWS, N_VARS))
gs = jnp.array(np.random.randn(N_QPS, N_ROWS))
# Convert them to QP form.
batch_nnls_to_qp = vmap(nnls_to_qp, in_axes=(0, 0))
Qs, qs, As, bs, Gs, hs = batch_nnls_to_qp(Fs, gs)
# Evaluate the loss and its gradients for all QPs in parallel.
loss_and_grad = jax.value_and_grad(loss, argnums=(0, 1, 2, 3, 4, 5))
batch_loss_and_grad = jit(vmap(loss_and_grad, in_axes=(0, 0, 0, 0, 0, 0)))
losses, derivs = batch_loss_and_grad(Qs, qs, As, bs, Gs, hs)
dl_dQ, dl_dq, dl_dA, dl_db, dl_dG, dl_dh = derivs
print("losses.shape:", losses.shape)
print("mean loss:", float(jnp.mean(losses)))
print("dl_dQ.shape:", dl_dQ.shape)
print("dl_dq.shape:", dl_dq.shape)
print("dl_dG.shape:", dl_dG.shape)
print("dl_dh.shape:", dl_dh.shape)