Skip to content

VarProx

An implementation of the variable projection method in Python + JAX for neural network training.

Overview

VarProx provides implementations of variable projection (VarPro) losses for training neural networks. The variable projection method optimally solves for the final linear layer weights while training the feature extraction layers, leading to improved convergence and performance.

Warning

varprox is still in early development and the losses are quite inefficient at the moment. For regression, vploss_msereg should be fine. But for classification, the binary cross-entropy loss (vploss_bce) is still very early and uses a stupid BFGS to determine the optimal linear layer weights. As it stands, it should only be used for toy comparisons.

Installation

Using pip

pip install git+https://github.com/abhijit-c/varprox.git

Using uv

uv pip install git+https://github.com/abhijit-c/varprox.git

Development Installation

git clone https://github.com/abhijit-c/varprox.git
cd varprox
pip install -e .

Quick Start

Binary Classification Example

import equinox as eqx
import jax
import jax.numpy as jnp
from varprox.bce import vploss_bce

# Create a model (without final linear layer)
key = jax.random.PRNGKey(42)
model = eqx.nn.MLP(
    in_size=2,
    out_size=4,  # Feature dimension
    width_size=8,
    depth=2,
    activation=jax.nn.tanh,
    final_activation=lambda x: x,
    key=key,
)

# Generate some sample data
x = jax.random.normal(key, (2, 100))  # Shape: (input_dim, batch_size)
y = jax.random.choice(key, jnp.array([-1.0, 1.0]), (1, 100))  # Shape: (1, batch_size)

# Compute loss and optimal linear layer
alpha1, alpha2 = 1e-6, 1e-6
loss, linear_layer = vploss_bce(model, x, y, alpha1, alpha2, use_bias=True, max_iter=100)

# Use in training loop with JAX transformations
@eqx.filter_jit
def train_step(model, x, y):
    def loss_fn(model):
        loss, _ = vploss_bce(model, x, y, alpha1, alpha2, True, 100)
        return loss

    return eqx.filter_value_and_grad(loss_fn)(model)

Regression Example

from varprox import vploss_msereg

# Same model setup as above
loss, linear_layer = vploss_msereg(model, x, y_continuous, alpha1, alpha2)

Examples

Complete working examples are available in the examples/ directory:

  • examples/circle_classification.py: Binary classification of points inside/outside a unit circle
  • examples/peaks_approximation.py: Function approximation of the MATLAB peaks function

These are also avaliable as python notebooks.

Key Concepts

Variable Projection Method

The variable projection method separates neural network parameters into: 1. Nonlinear parameters: Feature extraction layers (θ) 2. Linear parameters: Final linear layer weights (w)

For a given θ, the optimal w is computed analytically, leading to: - Faster convergence - Better local minima - Reduced parameter space for optimization

Architecture Requirements

Models used with VarProx should: - Not include a final linear layer (VarProx computes this optimally) - Output features that will be linearly combined

Data Format

  • Input data x: Shape (input_dim, batch_size)
  • Output data y: Shape (output_dim, batch_size)
  • Binary labels: Use {-1, +1} encoding for vploss_bce

Citation

If you use VarProx in your research, please cite:

@software{varprox2025,
  title={VarProx: Variable Projection Method for Neural Networks in JAX},
  author={Abhijit Chowdhary},
  url={https://github.com/abhijit-c/varprox},
  year={2024}
}

API Reference

bce

bce_optimal_weights_bfgs(pred_model, y, alpha=DEFAULT_EPS, max_iter=256)

Solves the optimal weights for binary cross entropy loss using BFGS method.

Minimizes: \(\(\sum_i \log(1 + \exp(-y_i \cdot (W @ \text{pred\_model})_i)) + \alpha \|W\|_F^2\)\)

Parameters:

Name Type Description Default
pred_model

Model predictions of shape (N_out, N_batch)

required
y

Target data of shape (N_target, N_batch), values should be in {-1, +1}

required
alpha float

Regularization parameter

DEFAULT_EPS
max_iter int

Maximum number of BFGS iterations

256

Returns:

Name Type Description
W_optimal

The optimal weights matrix of shape (N_target, N_out)

Source code in src/varprox/bce.py
@eqx.filter_jit
def bce_optimal_weights_bfgs(
    pred_model, y, alpha: float = DEFAULT_EPS, max_iter: int = 256
):
    r"""
    Solves the optimal weights for binary cross entropy loss using BFGS method.

    Minimizes: $$\sum_i \log(1 + \exp(-y_i \cdot (W @ \text{pred\_model})_i)) + \alpha \|W\|_F^2$$

    Args:
        pred_model: Model predictions of shape (N_out, N_batch)
        y: Target data of shape (N_target, N_batch), values should be in {-1, +1}
        alpha: Regularization parameter
        max_iter: Maximum number of BFGS iterations

    Returns:
        W_optimal: The optimal weights matrix of shape (N_target, N_out)
    """
    N_target, N_batch = y.shape
    N_out, _ = pred_model.shape

    def objective(W_flat):
        W = W_flat.reshape(N_target, N_out)
        logits = W @ pred_model  # (N_target, N_batch)

        # Binary cross entropy: log(1 + exp(-y * logits))
        # Use log_sigmoid for numerical stability: log_sigmoid(-y * logits) = log(1 / (1 + exp(y * logits)))
        # But we want log(1 + exp(-y * logits)), so we use: -log_sigmoid(y * logits)
        bce_loss = jnp.sum(-jax.nn.log_sigmoid(y * logits))

        # L2 regularization
        reg_loss = alpha * jnp.sum(W**2)

        return bce_loss + reg_loss

    # Initialize weights (using a simple initialization)
    W_init_flat = jnp.eye(N_target, N_out).flatten()
    # Use BFGS optimization
    result = minimize(
        fun=objective,
        x0=W_init_flat,
        method="BFGS",
        options={"maxiter": max_iter},
    )

    W_optimal = result.x.reshape(N_target, N_out)
    return W_optimal

vploss_bce(model, x, y, alpha1=DEFAULT_EPS, alpha2=DEFAULT_EPS, use_bias=True, max_iter=256)

Binary Cross Entropy with the final linear layer learned by Variable Projection.

Uses BFGS to fit the final readout layer and returns both the loss and an eqx.nn.Linear layer with the optimal weights.

Parameters:

Name Type Description Default
model

The model function to evaluate

required
x

Input data of shape (N_in, N_batch)

required
y

Target data of shape (N_target, N_batch), values should be in {-1, +1}

required
alpha1 float

Regularization parameter for model weights

DEFAULT_EPS
alpha2 float

Regularization parameter for optimal weights

DEFAULT_EPS
use_bias bool

Whether to include bias in the linear layer

True
max_iter int

Maximum number of BFGS iterations

256

Returns:

Name Type Description
loss

The computed loss value

linear_layer

eqx.nn.Linear layer with optimal weights and bias

Source code in src/varprox/bce.py
@eqx.filter_jit
def vploss_bce(
    model,
    x,
    y,
    alpha1: float = DEFAULT_EPS,
    alpha2: float = DEFAULT_EPS,
    use_bias: bool = True,
    max_iter: int = 256,
):
    """Binary Cross Entropy with the final linear layer learned by Variable Projection.

    Uses BFGS to fit the final readout layer and returns both the loss and an
    eqx.nn.Linear layer with the optimal weights.

    Args:
        model: The model function to evaluate
        x: Input data of shape (N_in, N_batch)
        y: Target data of shape (N_target, N_batch), values should be in {-1, +1}
        alpha1: Regularization parameter for model weights
        alpha2: Regularization parameter for optimal weights
        use_bias: Whether to include bias in the linear layer
        max_iter: Maximum number of BFGS iterations

    Returns:
        loss: The computed loss value
        linear_layer: eqx.nn.Linear layer with optimal weights and bias
    """
    N_target, N_batch = y.shape
    N_in, _ = x.shape
    pred_model = jax.vmap(model, in_axes=1, out_axes=1)(x)

    # Add bias column if requested
    if use_bias:
        ones = jnp.ones((1, N_batch))
        pred_model_with_bias = jnp.concatenate([pred_model, ones], axis=0)
    else:
        pred_model_with_bias = pred_model

    # Compute optimal weights using BFGS
    W_optimal = jax.lax.stop_gradient(
        bce_optimal_weights_bfgs(pred_model_with_bias, y, alpha2, max_iter)
    )

    # Compute loss with both regularization terms
    logits = W_optimal @ pred_model_with_bias

    # Binary cross entropy loss
    bce_loss = jnp.sum(-jax.nn.log_sigmoid(y * logits))

    # Regularization term for linear weights
    reg_vp = alpha2 * jnp.sum(W_optimal**2)

    # Regularization term for model weights
    model_weights = eqx.filter(model, eqx.is_array)
    reg_model = alpha1 * sum(jnp.sum(w**2) for w in jax.tree.leaves(model_weights))

    loss = bce_loss + reg_vp + reg_model

    # Create VPLinear layer with optimal weights
    if use_bias:
        weight = W_optimal[:, :-1]  # All columns except last
        bias = W_optimal[:, -1]  # Last column is bias
    else:
        weight = W_optimal
        bias = None
    linear_layer = VPLinear(weight=weight, bias=bias, use_bias=use_bias)

    return loss, linear_layer

msereg

msereg_optimal_weights(pred_model, y, alpha=DEFAULT_EPS)

Solves: \(\(W = \arg\min_W \|W @ \text{pred\_model} - y\|_2^2 + \alpha \|W\|_F^2\)\)

Parameters:

Name Type Description Default
pred_model

Model predictions of shape (N_out, N_batch)

required
y

Target data of shape (N_target, N_batch)

required
alpha float

Regularization parameter

DEFAULT_EPS

Returns:

Name Type Description
W_optimal

The optimal weights matrix of shape (N_target, N_out)

Source code in src/varprox/msereg.py
@eqx.filter_jit
def msereg_optimal_weights(pred_model, y, alpha: float = DEFAULT_EPS):
    r"""
    Solves: $$W = \arg\min_W \|W @ \text{pred\_model} - y\|_2^2 + \alpha \|W\|_F^2$$

    Args:
        pred_model: Model predictions of shape (N_out, N_batch)
        y: Target data of shape (N_target, N_batch)
        alpha: Regularization parameter

    Returns:
        W_optimal: The optimal weights matrix of shape (N_target, N_out)
    """
    # SVD decomposition: pred_model = U @ diag(s) @ V.T
    U, s, Vt = jnp.linalg.svd(pred_model, full_matrices=False)

    # Compute W = y @ V @ diag(s/(s^2 + alpha)) @ U.T
    s_reg = s / (s**2 + alpha)
    W_optimal = y @ Vt.T @ jnp.diag(s_reg) @ U.T

    return W_optimal

vploss_msereg(model, x, y, alpha1=DEFAULT_EPS, alpha2=DEFAULT_EPS, use_bias=True)

MSEREG with the final linear layer learned by Variable Projection.

Uses least squares to fit the final readout layer and returns both the loss and an eqx.nn.Linear layer with the optimal weights.

Parameters:

Name Type Description Default
model

The model function to evaluate

required
x

Input data of shape (N_in, N_batch)

required
y

Target data of shape (N_target, N_batch)

required
alpha1 float

Regularization parameter for model weights

DEFAULT_EPS
alpha2 float

Regularization parameter for optimal weights

DEFAULT_EPS
use_bias bool

Whether to include bias in the linear layer

True

Returns:

Name Type Description
loss

The computed loss value

linear_layer

eqx.nn.Linear layer with optimal weights and bias

Source code in src/varprox/msereg.py
@eqx.filter_jit
def vploss_msereg(
    model,
    x,
    y,
    alpha1: float = DEFAULT_EPS,
    alpha2: float = DEFAULT_EPS,
    use_bias: bool = True,
):
    """MSEREG with the final linear layer learned by Variable Projection.

    Uses least squares to fit the final readout layer and returns both the loss
    and an eqx.nn.Linear layer with the optimal weights.

    Args:
        model: The model function to evaluate
        x: Input data of shape (N_in, N_batch)
        y: Target data of shape (N_target, N_batch)
        alpha1: Regularization parameter for model weights
        alpha2: Regularization parameter for optimal weights
        use_bias: Whether to include bias in the linear layer

    Returns:
        loss: The computed loss value
        linear_layer: eqx.nn.Linear layer with optimal weights and bias
    """
    N_target, N_batch = y.shape
    N_in, _ = x.shape
    pred_model = jax.vmap(model, in_axes=1, out_axes=1)(x)

    # Add bias column if requested
    if use_bias:
        ones = jnp.ones((1, N_batch))
        pred_model_with_bias = jnp.concatenate([pred_model, ones], axis=0)
    else:
        pred_model_with_bias = pred_model

    # Compute optimal weights using least squares
    W_optimal = jax.lax.stop_gradient(
        msereg_optimal_weights(pred_model_with_bias, y, alpha2)
    )

    # Compute loss with both regularization terms
    residual = W_optimal @ pred_model_with_bias - y
    data_loss = jnp.sum(residual**2)

    # Regularization term for linear weights
    reg_vp = alpha2 * jnp.sum(W_optimal**2)

    # Regularization term for model weights
    model_weights = eqx.filter(model, eqx.is_array)
    reg_model = alpha1 * sum(jnp.sum(w**2) for w in jax.tree.leaves(model_weights))

    loss = data_loss + reg_vp + reg_model

    # Create VPLinear layer with optimal weights
    if use_bias:
        weight = W_optimal[:, :-1]  # All columns except last
        bias = W_optimal[:, -1]  # Last column is bias
    else:
        weight = W_optimal
        bias = None
    linear_layer = VPLinear(weight=weight, bias=bias, use_bias=use_bias)

    return loss, linear_layer

utils

find_minmax(*arrays)

Find the global minimum and maximum across multiple arrays.

Parameters:

Name Type Description Default
*arrays

Variable number of JAX arrays

()

Returns:

Name Type Description
tuple

(min_value, max_value)

Source code in src/varprox/utils.py
@jax.jit
def find_minmax(*arrays):
    """
    Find the global minimum and maximum across multiple arrays.

    Args:
        *arrays: Variable number of JAX arrays

    Returns:
        tuple: (min_value, max_value)
    """
    if len(arrays) == 0:
        raise ValueError("At least one array must be provided")

    # Concatenate all arrays into one flattened array
    all_values = jnp.concatenate([arr.flatten() for arr in arrays])

    # Use JAX's optimized min/max operations
    return all_values.min(), all_values.max()

vplinear

VPLinear

Bases: Module

Performs a linear transformation. This is a slightly modified eqx.nn.Linear that accepts initial weights and biases as arguments.

Source code in src/varprox/vplinear.py
class VPLinear(eqx.Module, strict=True):
    """Performs a linear transformation. This is a slightly modified eqx.nn.Linear
    that accepts initial weights and biases as arguments."""

    weight: Array
    bias: Array | None
    in_features: int | Literal["scalar"] = eqx.field(static=True)
    out_features: int | Literal["scalar"] = eqx.field(static=True)
    use_bias: bool = eqx.field(static=True)

    def __init__(
        self,
        weight: Array,
        bias: Array | None = None,
        use_bias: bool = True,
    ):
        """**Arguments:**

        - `weight`: The weight matrix for the linear transformation.
        - `bias`: The bias vector for the linear transformation. Can be None if use_bias is False.
        - `use_bias`: Whether to add on a bias as well.

        The input and output sizes are inferred from the weight matrix shape. For a
        weight matrix of shape `(out_features, in_features)`, the input should be a
        vector of shape `(in_features,)` and the output will be a vector of shape
        `(out_features,)`.

        Note that `in_features` also supports the string `"scalar"` as a special value.
        In this case the input to the layer should be of shape `()`.

        Likewise `out_features` can also be a string `"scalar"`, in which case the
        output from the layer will have shape `()`.
        """
        self.weight = weight
        self.bias = bias

        # Infer dimensions from weight matrix shape
        if weight.shape == (1, 1):
            self.in_features = "scalar"
            self.out_features = "scalar"
        elif weight.shape[1] == 1:
            self.in_features = "scalar"
            self.out_features = weight.shape[0]
        elif weight.shape[0] == 1:
            self.in_features = weight.shape[1]
            self.out_features = "scalar"
        else:
            self.out_features, self.in_features = weight.shape

        self.use_bias = use_bias

    def __call__(self, x: Array, *, key: PRNGKeyArray | None = None) -> Array:
        """**Arguments:**

        - `x`: The input. Should be a JAX array of shape `(in_features,)`. (Or shape
            `()` if `in_features="scalar"`.)
        - `key`: Ignored; provided for compatibility with the rest of the Equinox API.
            (Keyword only argument.)

        **Returns:**

        A JAX array of shape `(out_features,)`. (Or shape `()` if
        `out_features="scalar"`.)
        """

        if self.in_features == "scalar":
            if jnp.shape(x) != ():
                raise ValueError("x must have scalar shape")
            x = jnp.broadcast_to(x, (1,))
        x = self.weight @ x
        if self.bias is not None:
            x = x + self.bias
        if self.out_features == "scalar":
            assert jnp.shape(x) == (1,)
            x = jnp.squeeze(x)
        return x

__call__(x, *, key=None)

Arguments:

  • x: The input. Should be a JAX array of shape (in_features,). (Or shape () if in_features="scalar".)
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (out_features,). (Or shape () if out_features="scalar".)

Source code in src/varprox/vplinear.py
def __call__(self, x: Array, *, key: PRNGKeyArray | None = None) -> Array:
    """**Arguments:**

    - `x`: The input. Should be a JAX array of shape `(in_features,)`. (Or shape
        `()` if `in_features="scalar"`.)
    - `key`: Ignored; provided for compatibility with the rest of the Equinox API.
        (Keyword only argument.)

    **Returns:**

    A JAX array of shape `(out_features,)`. (Or shape `()` if
    `out_features="scalar"`.)
    """

    if self.in_features == "scalar":
        if jnp.shape(x) != ():
            raise ValueError("x must have scalar shape")
        x = jnp.broadcast_to(x, (1,))
    x = self.weight @ x
    if self.bias is not None:
        x = x + self.bias
    if self.out_features == "scalar":
        assert jnp.shape(x) == (1,)
        x = jnp.squeeze(x)
    return x

__init__(weight, bias=None, use_bias=True)

Arguments:

  • weight: The weight matrix for the linear transformation.
  • bias: The bias vector for the linear transformation. Can be None if use_bias is False.
  • use_bias: Whether to add on a bias as well.

The input and output sizes are inferred from the weight matrix shape. For a weight matrix of shape (out_features, in_features), the input should be a vector of shape (in_features,) and the output will be a vector of shape (out_features,).

Note that in_features also supports the string "scalar" as a special value. In this case the input to the layer should be of shape ().

Likewise out_features can also be a string "scalar", in which case the output from the layer will have shape ().

Source code in src/varprox/vplinear.py
def __init__(
    self,
    weight: Array,
    bias: Array | None = None,
    use_bias: bool = True,
):
    """**Arguments:**

    - `weight`: The weight matrix for the linear transformation.
    - `bias`: The bias vector for the linear transformation. Can be None if use_bias is False.
    - `use_bias`: Whether to add on a bias as well.

    The input and output sizes are inferred from the weight matrix shape. For a
    weight matrix of shape `(out_features, in_features)`, the input should be a
    vector of shape `(in_features,)` and the output will be a vector of shape
    `(out_features,)`.

    Note that `in_features` also supports the string `"scalar"` as a special value.
    In this case the input to the layer should be of shape `()`.

    Likewise `out_features` can also be a string `"scalar"`, in which case the
    output from the layer will have shape `()`.
    """
    self.weight = weight
    self.bias = bias

    # Infer dimensions from weight matrix shape
    if weight.shape == (1, 1):
        self.in_features = "scalar"
        self.out_features = "scalar"
    elif weight.shape[1] == 1:
        self.in_features = "scalar"
        self.out_features = weight.shape[0]
    elif weight.shape[0] == 1:
        self.in_features = weight.shape[1]
        self.out_features = "scalar"
    else:
        self.out_features, self.in_features = weight.shape

    self.use_bias = use_bias