Home
Differentiable, batched, single-precision quadratic programming in JAX
qpax solves and differentiates (batched) convex quadratic programs of the form
with decision variables \(x \in \mathbb{R}^{n}\) and data matrices \(Q \succeq 0\), \(q \in \mathbb{R}^{n}\), \(A \in \mathbb{R}^{m \times n}\), \(b \in \mathbb{R}^{m}\), \(G \in \mathbb{R}^{p \times n}\) and \(h \in \mathbb{R}^{p}\).
Features
-
Differentiable
Backpropagate through QPs and obtain smooth, informative subgradients even at active inequality constraints.
-
Single precision
Runs in
f32, enabling larger batch sizes and higher throughput on GPU. -
Batchable
Solves and differentiates many QPs in parallel with shared structure via
jax.vmap. -
Infeasibility avoidance
Avoids generating infeasible problems by solving an always-feasible elastic QP that returns informative gradients toward feasibility.