VarProx
An implementation of the variable projection method in Python + JAX for neural network training.
Overview
VarProx provides implementations of variable projection (VarPro) losses for training neural networks. The variable projection method optimally solves for the final linear layer weights while training the feature extraction layers, leading to improved convergence and performance.
Warning
varprox
is still in early development and the losses are quite inefficient at
the moment. For regression, vploss_msereg
should be fine. But for
classification, the binary cross-entropy loss (vploss_bce
) is still very early
and uses a stupid BFGS to determine the optimal linear layer weights. As it
stands, it should only be used for toy comparisons.
Installation
Using pip
Using uv
Development Installation
Quick Start
Binary Classification Example
import equinox as eqx
import jax
import jax.numpy as jnp
from varprox.bce import vploss_bce
# Create a model (without final linear layer)
key = jax.random.PRNGKey(42)
model = eqx.nn.MLP(
in_size=2,
out_size=4, # Feature dimension
width_size=8,
depth=2,
activation=jax.nn.tanh,
final_activation=lambda x: x,
key=key,
)
# Generate some sample data
x = jax.random.normal(key, (2, 100)) # Shape: (input_dim, batch_size)
y = jax.random.choice(key, jnp.array([-1.0, 1.0]), (1, 100)) # Shape: (1, batch_size)
# Compute loss and optimal linear layer
alpha1, alpha2 = 1e-6, 1e-6
loss, linear_layer = vploss_bce(model, x, y, alpha1, alpha2, use_bias=True, max_iter=100)
# Use in training loop with JAX transformations
@eqx.filter_jit
def train_step(model, x, y):
def loss_fn(model):
loss, _ = vploss_bce(model, x, y, alpha1, alpha2, True, 100)
return loss
return eqx.filter_value_and_grad(loss_fn)(model)
Regression Example
from varprox import vploss_msereg
# Same model setup as above
loss, linear_layer = vploss_msereg(model, x, y_continuous, alpha1, alpha2)
Examples
Complete working examples are available in the examples/
directory:
examples/circle_classification.py
: Binary classification of points inside/outside a unit circleexamples/peaks_approximation.py
: Function approximation of the MATLAB peaks function
These are also avaliable as python notebooks.
Key Concepts
Variable Projection Method
The variable projection method separates neural network parameters into: 1. Nonlinear parameters: Feature extraction layers (θ) 2. Linear parameters: Final linear layer weights (w)
For a given θ, the optimal w is computed analytically, leading to: - Faster convergence - Better local minima - Reduced parameter space for optimization
Architecture Requirements
Models used with VarProx should: - Not include a final linear layer (VarProx computes this optimally) - Output features that will be linearly combined
Data Format
- Input data
x
: Shape(input_dim, batch_size)
- Output data
y
: Shape(output_dim, batch_size)
- Binary labels: Use {-1, +1} encoding for
vploss_bce
Citation
If you use VarProx in your research, please cite:
@software{varprox2025,
title={VarProx: Variable Projection Method for Neural Networks in JAX},
author={Abhijit Chowdhary},
url={https://github.com/abhijit-c/varprox},
year={2024}
}
API Reference
bce
bce_optimal_weights_bfgs(pred_model, y, alpha=DEFAULT_EPS, max_iter=256)
Solves the optimal weights for binary cross entropy loss using BFGS method.
Minimizes: \(\(\sum_i \log(1 + \exp(-y_i \cdot (W @ \text{pred\_model})_i)) + \alpha \|W\|_F^2\)\)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_model
|
Model predictions of shape (N_out, N_batch) |
required | |
y
|
Target data of shape (N_target, N_batch), values should be in {-1, +1} |
required | |
alpha
|
float
|
Regularization parameter |
DEFAULT_EPS
|
max_iter
|
int
|
Maximum number of BFGS iterations |
256
|
Returns:
Name | Type | Description |
---|---|---|
W_optimal |
The optimal weights matrix of shape (N_target, N_out) |
Source code in src/varprox/bce.py
vploss_bce(model, x, y, alpha1=DEFAULT_EPS, alpha2=DEFAULT_EPS, use_bias=True, max_iter=256)
Binary Cross Entropy with the final linear layer learned by Variable Projection.
Uses BFGS to fit the final readout layer and returns both the loss and an eqx.nn.Linear layer with the optimal weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
The model function to evaluate |
required | |
x
|
Input data of shape (N_in, N_batch) |
required | |
y
|
Target data of shape (N_target, N_batch), values should be in {-1, +1} |
required | |
alpha1
|
float
|
Regularization parameter for model weights |
DEFAULT_EPS
|
alpha2
|
float
|
Regularization parameter for optimal weights |
DEFAULT_EPS
|
use_bias
|
bool
|
Whether to include bias in the linear layer |
True
|
max_iter
|
int
|
Maximum number of BFGS iterations |
256
|
Returns:
Name | Type | Description |
---|---|---|
loss |
The computed loss value |
|
linear_layer |
eqx.nn.Linear layer with optimal weights and bias |
Source code in src/varprox/bce.py
msereg
msereg_optimal_weights(pred_model, y, alpha=DEFAULT_EPS)
Solves: \(\(W = \arg\min_W \|W @ \text{pred\_model} - y\|_2^2 + \alpha \|W\|_F^2\)\)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_model
|
Model predictions of shape (N_out, N_batch) |
required | |
y
|
Target data of shape (N_target, N_batch) |
required | |
alpha
|
float
|
Regularization parameter |
DEFAULT_EPS
|
Returns:
Name | Type | Description |
---|---|---|
W_optimal |
The optimal weights matrix of shape (N_target, N_out) |
Source code in src/varprox/msereg.py
vploss_msereg(model, x, y, alpha1=DEFAULT_EPS, alpha2=DEFAULT_EPS, use_bias=True)
MSEREG with the final linear layer learned by Variable Projection.
Uses least squares to fit the final readout layer and returns both the loss and an eqx.nn.Linear layer with the optimal weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
The model function to evaluate |
required | |
x
|
Input data of shape (N_in, N_batch) |
required | |
y
|
Target data of shape (N_target, N_batch) |
required | |
alpha1
|
float
|
Regularization parameter for model weights |
DEFAULT_EPS
|
alpha2
|
float
|
Regularization parameter for optimal weights |
DEFAULT_EPS
|
use_bias
|
bool
|
Whether to include bias in the linear layer |
True
|
Returns:
Name | Type | Description |
---|---|---|
loss |
The computed loss value |
|
linear_layer |
eqx.nn.Linear layer with optimal weights and bias |
Source code in src/varprox/msereg.py
utils
find_minmax(*arrays)
Find the global minimum and maximum across multiple arrays.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*arrays
|
Variable number of JAX arrays |
()
|
Returns:
Name | Type | Description |
---|---|---|
tuple |
(min_value, max_value) |
Source code in src/varprox/utils.py
vplinear
VPLinear
Bases: Module
Performs a linear transformation. This is a slightly modified eqx.nn.Linear that accepts initial weights and biases as arguments.
Source code in src/varprox/vplinear.py
__call__(x, *, key=None)
Arguments:
x
: The input. Should be a JAX array of shape(in_features,)
. (Or shape()
ifin_features="scalar"
.)key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (out_features,)
. (Or shape ()
if
out_features="scalar"
.)
Source code in src/varprox/vplinear.py
__init__(weight, bias=None, use_bias=True)
Arguments:
weight
: The weight matrix for the linear transformation.bias
: The bias vector for the linear transformation. Can be None if use_bias is False.use_bias
: Whether to add on a bias as well.
The input and output sizes are inferred from the weight matrix shape. For a
weight matrix of shape (out_features, in_features)
, the input should be a
vector of shape (in_features,)
and the output will be a vector of shape
(out_features,)
.
Note that in_features
also supports the string "scalar"
as a special value.
In this case the input to the layer should be of shape ()
.
Likewise out_features
can also be a string "scalar"
, in which case the
output from the layer will have shape ()
.