Skip to content

Differentiate a QP

qpax exposes a separate entry point, qpax.solve_qp_primal, that returns only the primal x and carries a custom VJP. That makes it a drop-in differentiable layer: wrap it in jax.grad (or jax.value_and_grad) and gradients flow back into every QP parameter \((Q, q, A, b, G, h)\).

The problem

The example sets up a small non-negative least-squares problem and casts it to standard QP form:

\[ \min_{x}\;\tfrac{1}{2}\,\|F x - g\|_2^{2} \quad \text{s.t.}\quad x \ge 0. \]

It then defines a tracking loss on the solution,

\[ L(x) \;=\; \|x - \mathbf{1}\|_2^{2}, \]

and computes \(\nabla_{Q,q,A,b,G,h}\, L\) with a single jax.value_and_grad call.

Code

"""Differentiate through a QP solve with qpax.

This example solves a non-negative least-squares problem

    minimize_x ||F x - g||^2
    subject to x >= 0

after converting it to standard QP form, then differentiates the loss

    L(x) = ||x - 1||^2

with respect to the QP data.
"""

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit

import qpax

N_VARS = 5
N_ROWS = 10


def loss(Q, q, A, b, G, h):
    x = qpax.solve_qp_primal(Q, q, A, b, G, h)
    x_target = jnp.ones_like(x)
    return jnp.sum((x - x_target) ** 2)


def nnls_to_qp(F, g):
    n_vars = F.shape[1]
    Q = F.T @ F
    q = -F.T @ g
    A = jnp.zeros((0, n_vars))
    b = jnp.zeros(0)
    G = -jnp.eye(n_vars)
    h = jnp.zeros(n_vars)
    return Q, q, A, b, G, h


# Create one random NNLS problem.
F = jnp.array(np.random.randn(N_ROWS, N_VARS))
g = jnp.array(np.random.randn(N_ROWS))

# Convert it to QP form.
Q, q, A, b, G, h = nnls_to_qp(F, g)

# Evaluate the loss and its gradients.
loss_and_grad = jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3, 4, 5)))
loss_value, derivs = loss_and_grad(Q, q, A, b, G, h)
dl_dQ, dl_dq, dl_dA, dl_db, dl_dG, dl_dh = derivs

print("loss:", float(loss_value))
print("dl_dQ.shape:", dl_dQ.shape)
print("dl_dq.shape:", dl_dq.shape)
print("dl_dG.shape:", dl_dG.shape)
print("dl_dh.shape:", dl_dh.shape)