A comprehensive collection of utility functions to work with PyTorch sparse tensors, ensuring memory efficiency and supporting various sparsity-preserving tensor operations with automatic differentiation. This package addresses fundamental gaps in PyTorch's sparse tensor ecosystem, providing essential operations that preserve sparsity in gradients during backpropagation.
Memory-Efficient Sparse Matrix Multiplication
sparse_mm
: Memory-efficient sparse matrix multiplication with batch support- Preserves sparsity in gradients during backpropagation
- Workaround for PyTorch issue #41128
- Supports both COO and CSR formats with optional batching
Sparse Linear System Solvers
-
sparse_triangular_solve
: Sparse triangular solver with batch support- Discussion reference: PyTorch issue #87358
-
sparse_generic_solve
: Generic sparse linear solver with pluggable backends- Tested and benchmarked with CG, BICGSTAB, LSMR and MINRES solvers
-
sparse_solve_c4t
: Wrappers around cupy sparse solvers- Discussion reference: Pytorch issue #69538
- Tested and benchmarked with: CG, CGS, MINRES, GMRES, spsolve and spsolve_triangular CuPy solvers
-
tsgujax.sparse_solve_j4t
: Wrappers around jax sparse solvers- Tested with: CG and BICGSTAB JAX solvers
-
sparse_generic_lstsq
: Generic sparse linear least-squares solver
Pure PyTorch Implementations
- BICGSTAB: Biconjugate Gradient Stabilized method (ported from pykrylov)
- CG: Conjugate Gradient method (ported from cornellius-gp/linear_operator)
- LSMR: Least Squares Minimal Residual method (ported from pytorch-minimize)
- MINRES: Minimal Residual method (ported from cornellius-gp/linear_operator)
- SparseMultivariateNormal: Structured Gaussian Distribution
- Implements reparameterised sampling (rsample)
- Supports leading batch dimension
- Supports COO and CSR sparse tensors
- Covariance or precision matrices with LL^T or LDL^T parameterisations.
- LDL^T parameterization offers numerical stability without SPD constraints
- SparseMultivariateNormalNative:
- Implements reparameterised sampling (rsample)
- Uses native
torch.sparse.mm
only - Only supports ubatched CSR tensors
- Covariance LL^T parameterization
Pairwise Encoder
- Encode local neighborhood relationships in nD spatial volumes
- Multi-channel/class support
- Configurable neighborhood radius and sparsity patterns
- Outputs sparse unbatched/batched COO or CSR matrices for downstream processing
- Optimised for medical imaging and volumetric data applications
Indexed Matrix Multiplication
segment_mm
: Segmented matrix multiplication compatible with DGL/PyGgather_mm
: Gather-based matrix multiplication for graph operations- Pure PyTorch implementations as alternatives to
dgl.ops.segment_mm
,pyg_lib.ops.segment_matmul
, anddgl.ops.gather_mm
- Supports PyTorch >= 2.4 with nested tensor operations
The package can be installed using pip:
pip install torchsparsegradutils
For the latest features and development work:
pip install git+https://github.com/cai4cai/torchsparsegradutils
For full functionality, install optional dependencies:
# For CuPy sparse solver support (GPU acceleration)
pip install cupy-cuda12x # Replace with your CUDA version
# For JAX sparse solver support
pip install "jax[cpu]" # CPU version
pip install "jax[cuda12]" # GPU version (replace with your CUDA version)
# For benchmarking and testing
pip install scipy matplotlib pandas tqdm pytest
- Python: β₯ 3.10
- PyTorch: β₯ 2.5 (β₯ 2.4 for indexed operations)
- Operating Systems: Linux, macOS, Windows
- Hardware: CPU and CUDA GPU support
Our comprehensive benchmark suite demonstrates significant performance improvements across various sparse operations. All benchmarks were conducted on an NVIDIA GeForce RTX 4090 with PyTorch 2.8.0+cu128. Benchmarks are performed using Rothberg/cfd2 matrix from SuiteSparse Matrix Collection
import torch
from torchsparsegradutils import sparse_mm
# Create sparse matrix in COO format
indices = torch.tensor([[0, 1, 1], [2, 0, 2]], dtype=torch.int64)
values = torch.tensor([3., 4., 5.], requires_grad=True)
A = torch.sparse_coo_tensor(indices, values, (2, 3))
# Dense matrix
B = torch.randn(3, 4, requires_grad=True)
# Memory-efficient sparse matrix multiplication with gradient support
C = sparse_mm(A, B)
loss = C.sum()
loss.backward() # Gradients preserved in sparse format
print(f"A.grad: {A.grad}") # Sparse gradient
print(f"B.grad: {B.grad}") # Dense gradient
import torch
from torchsparsegradutils import sparse_triangular_solve, sparse_generic_solve
from torchsparsegradutils.utils import linear_cg
# Create sparse triangular matrix
A = create_sparse_triangular_matrix() # Your sparse CSR matrix
b = torch.randn(A.shape[0], requires_grad=True)
# Triangular solve (fast for triangular systems)
x1 = sparse_triangular_solve(A, b, upper=False)
# Generic solve with different backends
x2 = sparse_generic_solve(A, b, solve=linear_cg, tol=1e-6)
# Using CuPy backend (if available)
from torchsparsegradutils.cupy import sparse_solve_c4t
x3 = sparse_solve_c4t(A, b, solve="cg", tol=1e-6)
import torch
from torchsparsegradutils.distributions import SparseMultivariateNormal
from torchsparsegradutils.utils.random_sparse import rand_sparse_tri
# Create parameters
batch_size, event_size = 2, 1000
loc = torch.zeros(batch_size, event_size)
# Example 1: LDL^T parameterization (numerically stable for precision matrices)
# Create sparse lower triangular matrix (unit triangular, no diagonal)
scale_tril = rand_sparse_tri(
(batch_size, event_size, event_size),
nnz=5000, # 5000 non-zeros for 1M parameters (0.5% sparsity)
layout=torch.sparse_csr,
upper=False,
unit_triangular=True # Unit triangular for LDL^T
)
# Diagonal component for LDL^T parameterization
diagonal = torch.ones(batch_size, event_size) * 0.5
# Create distribution with LDL^T parameterization
dist_ldlt = SparseMultivariateNormal(
loc=loc,
diagonal=diagonal,
scale_tril=scale_tril # Unit lower triangular
)
# Example 2: LL^T parameterization (standard Cholesky)
scale_tril_chol = rand_sparse_tri(
(batch_size, event_size, event_size),
nnz=5000,
layout=torch.sparse_csr,
upper=False,
unit_triangular=False # Include diagonal for LL^T
)
# Create distribution with LL^T parameterization
dist_chol = SparseMultivariateNormal(
loc=loc,
scale_tril=scale_tril_chol # Lower triangular with diagonal
)
# Example 3: Precision matrix parameterization (more stable with LDL^T)
precision_tril = rand_sparse_tri(
(batch_size, event_size, event_size),
nnz=5000,
layout=torch.sparse_csr,
upper=False,
unit_triangular=True
)
precision_diagonal = torch.ones(batch_size, event_size) * 2.0
dist_precision = SparseMultivariateNormal(
loc=loc,
diagonal=precision_diagonal,
precision_tril=precision_tril # Unit triangular precision factor
)
# Sample with gradient support
samples = dist_ldlt.rsample((100,)) # 100 samples
# Gradient computation preserves sparsity
loss = samples.sum()
loss.backward()
print(f"Sparse gradient shape: {scale_tril.grad.shape}")
print(f"Sparse gradient nnz: {scale_tril.grad._nnz()}")
print(f"Using LDL^T parameterization: {dist_ldlt.is_ldlt_parameterization}")
import torch
from torchsparsegradutils.encoders import PairwiseEncoder
# Create 3D volume encoder (channels, height, depth, width)
volume_shape = (4, 64, 64, 64) # 4 channels, 64x64x64 spatial
encoder = PairwiseEncoder(
radius=2.0,
volume_shape=volume_shape,
layout=torch.sparse_csr
)
# Generate values for each spatial relationship offset
num_offsets = len(encoder.offsets)
values = torch.randn(num_offsets, *volume_shape)
# Generate sparse encoding matrix
sparse_matrix = encoder(values)
print(f"Encoded volume shape: {sparse_matrix.shape}")
print(f"Sparsity: {sparse_matrix._nnz() / sparse_matrix.numel():.3%}")
print(f"Number of spatial offsets: {num_offsets}")
# Use in sparse multivariate normal
flat_size = 4 * 64 * 64 * 64 # Total flattened size
dist = SparseMultivariateNormal(
loc=torch.zeros(flat_size),
scale_tril=sparse_matrix
)
The encoder creates sparse matrices that encode pairwise spatial relationships within a specified radius. Different channel relationship types affect how channels interact:
indep
: Independent channels (only spatial neighbors within same channel)intra
: Intra-channel relationships (spatial neighbors within same channel)inter
: Inter-channel relationships (spatial neighbors across all channels)
3D Spatial Grid (3Γ3Γ3Γ3) with Different Channel Relations:
Each color represents a different spatial offset (relative position) in the 3D neighborhood. The sparse matrix encodes these relationships efficiently, enabling:
- Local spatial modeling for volumetric data (medical imaging, 3D computer vision)
- Multi-channel feature interaction in convolutional architectures
- Sparse graph construction from regular grids
- Memory-efficient neighborhood encoding for large volumes
Key Parameters:
radius
: Spatial neighborhood radius (1.0 = immediate neighbors, 2.0 = extended neighborhood)volume_shape
:(channels, height, depth, width)
for 4D volumeschannel_voxel_relation
: Controls cross-channel connectivity patternslayout
: Output sparse format (torch.sparse_coo
ortorch.sparse_csr
)
import torch
from torchsparsegradutils import segment_mm, gather_mm
# Segment matrix multiplication (compatible with DGL/PyG)
a = torch.randn(15, 10, requires_grad=True) # Node features
b = torch.randn(3, 10, 5, requires_grad=True) # Edge type embeddings
seglen_a = torch.tensor([5, 6, 4]) # Segment lengths
# Performs: a[0:5] @ b[0], a[5:11] @ b[1], a[11:15] @ b[2]
result = segment_mm(a, b, seglen_a)
# Gather matrix multiplication
indices = torch.tensor([0, 0, 1, 1, 2])
a_gathered = torch.randn(5, 10, requires_grad=True)
result = gather_mm(a_gathered, b, indices)
import torch
from torch.distributions import MultivariateNormal
from torchsparsegradutils.utils import mean_hotelling_t2_test, cov_nagao_test
# Generate sample data from known distribution
torch.manual_seed(42)
true_mean = torch.tensor([[0.0, 0.0]])
true_cov = torch.eye(2).unsqueeze(0)
n = 1000
# Generate samples and compute statistics
dist = MultivariateNormal(true_mean.squeeze(0), true_cov.squeeze(0))
samples = dist.sample((n,)).unsqueeze(1)
sample_mean = samples.mean(0)
sample_cov = torch.cov(samples.squeeze(1).T).unsqueeze(0)
# Test if sample mean is consistent with hypothesized mean (should pass)
result, t2_stat, threshold = mean_hotelling_t2_test(
sample_mean, true_mean, sample_cov, n, confidence_level=0.95
)
print(f"Mean test passed: {result.item()}") # True
# Test if sample covariance is consistent with hypothesized covariance (should pass)
result, t_n_stat, threshold = cov_nagao_test(
sample_cov, true_cov, n, confidence_level=0.95
)
print(f"Covariance test passed: {result.item()}") # True
# Test against wrong parameters (should fail)
wrong_mean = true_mean + 1.0 # Significantly different mean
result, _, _ = mean_hotelling_t2_test(
sample_mean, wrong_mean, sample_cov, n, confidence_level=0.95
)
print(f"Wrong mean test passed: {result.item()}") # False
# Run all tests
python -m pytest
# Run specific test modules
python -m pytest torchsparsegradutils/tests/test_sparse_matmul.py
python -m pytest torchsparsegradutils/tests/test_distributions.py
# Run with coverage
python -m pytest --cov=torchsparsegradutils
The package includes comprehensive benchmarks for performance evaluation:
# Sparse matrix multiplication benchmarks
python -m torchsparsegradutils.benchmarks.sparse_mm_rand
python -m torchsparsegradutils.benchmarks.batched_sparse_mm_rand
# Triangular solver benchmarks
python -m torchsparsegradutils.benchmarks.sparse_triangular_solve_rand
# Generic solver benchmarks
python -m torchsparsegradutils.benchmarks.sparse_generic_solve_suite
# SuiteSparse matrix benchmarks
python -m torchsparsegradutils.benchmarks.sparse_mm_suite
Results are automatically saved to torchsparsegradutils/benchmarks/results/
as CSV files.
Sparse Random Matrix Generators
-
rand_sparse(size, nnz, layout=torch.sparse_coo, **kwargs)
: Generate random sparse matrices with specified layout and properties- Supports COO and CSR
- Supports batch dimension
-
rand_sparse_tri(size, nnz, layout=torch.sparse_coo, upper=True, strict=False, **kwargs)
: Generate random sparse triangular matrices- Supports COO and CSR
- Supports batch dimension
- Strict triangular (no diagonal) or non-strict (with diagonal values)
- Option to produce well conditioned matrices and regulate diagonal values
-
make_spd_sparse(n, layout, value_dtype, index_dtype, device, sparsity_ratio=0.5, nz=None)
: Generate sparse symmetric positive definite (SPD) matrices
Sparse Matrix Operations
sparse_block_diag(*sparse_tensors)
: Create block diagonal sparse matrix from multiple sparse tensorssparse_block_diag_split(sparse_block_diag_tensor, *shapes)
: Split block diagonal sparse matrix into original sparse tensorssparse_eye(size, layout=torch.sparse_coo, **kwargs)
: Create batched or unbatched sparse identity matricesstack_csr(tensors, dim=0)
: Stack CSR tensors along batch dimension (like torch.stack for CSR)
Sparse Format Conversion
convert_coo_to_csr_indices_values(coo_indices, num_rows, values=None)
: Convert COO indices and values to CSR format, with support for batch dimensionconvert_coo_to_csr(sparse_coo_tensor)
: Convert COO sparse tensor to CSR format with batch support
Statistical Distribution Validation
mean_hotelling_t2_test(sample_mean, true_mean, sample_cov, n, confidence_level=0.95)
: One-sample Hotelling TΒ² test for multivariate mean equality using confidence regions- Tests whether hypothesized mean vector lies within confidence region around sample mean
- Uses F-distribution for threshold calculation with proper degrees of freedom
- Higher confidence levels create larger (more permissive) acceptance regions
cov_nagao_test(emp_cov, ref_cov, n, confidence_level=0.95)
: Nagao's test for covariance matrix equality using confidence regions- Tests whether hypothesized covariance matrix is consistent with empirical covariance
- Uses ΟΒ² distribution with appropriate degrees of freedom
- Standardizes covariance matrices for improved numerical stability
We welcome contributions! Please see our contributing guidelines:
- Issues: Report bugs and request features via GitHub Issues
- Pull Requests: Submit improvements via GitHub PRs
- Testing: Ensure all tests pass and add tests for new functionality
- Documentation: Update docstrings and examples for new features
- Benchmarks: Include performance benchmarks for new operations
git clone https://github.com/cai4cai/torchsparsegradutils
cd torchsparsegradutils
pip install -e ".[dev]" # Install in development mode
pre-commit install # Install pre-commit hooks
For a consistent development environment with GPU support and all dependencies pre-installed, use VS Code Dev Containers:
Prerequisites:
- Docker with NVIDIA Container Toolkit (for GPU support)
- VS Code with the Dev Containers extension
Quick Start:
-
Clone the repository and open in VS Code:
git clone https://github.com/cai4cai/torchsparsegradutils cd torchsparsegradutils code .
-
When prompted, click "Reopen in Container" or use the Command Palette:
- Press
Ctrl+Shift+P
(orCmd+Shift+P
on macOS) - Type "Dev Containers: Reopen in Container"
- Press
Available Configurations:
.devcontainer/Dockerfile.stable
(default): Uses stable PyTorch with CUDA 12.8 support.devcontainer/Dockerfile.nightly
: Uses nightly PyTorch builds for latest features
To switch configurations, modify the dockerfile
field in .devcontainer/devcontainer.json
:
"build": {
"dockerfile": "./Dockerfile.nightly", // or "./Dockerfile.stable"
"context": "."
}
What's Included:
- CUDA 12.8: Full GPU development support with NVIDIA drivers
- Pre-installed Dependencies: PyTorch, CuPy, JAX, SciPy, and all development tools
- VS Code Extensions: Python, Pylance, Jupyter, GitHub Copilot, and code formatting tools
- Development Tools: pytest, black, flake8, pre-commit hooks
- Python Environment: Python 3.10+ with all optional dependencies
Benefits:
- β Consistent Environment: Same setup across different machines
- β GPU Support: Pre-configured CUDA environment
- β Zero Setup: All dependencies and tools pre-installed
- β Isolated: No conflicts with host system packages
- β VS Code Integration: Seamless debugging, IntelliSense, and testing
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
- PyTorch Team: For the foundational sparse tensor implementations
- SciPy/CuPy Teams: For high-performance sparse linear algebra routines
- JAX Team: For cross-platform sparse operations and XLA compilation
- Open Source Libraries: We port and adapt algorithms from:
- pykrylov (BICGSTAB)
- cornellius-gp/linear_operator (CG, MINRES)
- pytorch-minimize (LSMR)
If you use this package in your research, please cite:
@software{torchsparsegradutils,
title={torchsparsegradutils: Sparsity-preserving gradient utility tools for PyTorch},
author={Barfoot, Theodore and Glocker, Ben and Vercauteren, Tom},
url={https://github.com/cai4cai/torchsparsegradutils},
year={2024}
}
Issue: PyTorch automatically converts int32
indices to int64
when creating sparse COO tensors, but preserves int32
for sparse CSR tensors. This affects memory usage and performance for algorithms that benefit from int32
indices (such as sparse_mm
).
Impact:
- Memory:
int64
indices use 2Γ more memory thanint32
- Performance: Some sparse operations may run faster with
int32
indices - Cross-format consistency: Different behavior between COO and CSR formats
Example:
import torch
# Demonstrate the issue
indices_int32 = torch.tensor([[0, 1], [1, 0]], dtype=torch.int32)
values = torch.tensor([1.0, 2.0])
print(f"Original indices dtype: {indices_int32.dtype}") # torch.int32
# COO: int32 -> int64 conversion happens
coo_tensor = torch.sparse_coo_tensor(indices_int32, values, (2, 2)).coalesce()
print(f"COO indices dtype: {coo_tensor.indices().dtype}") # torch.int64 (converted!)
# CSR: int32 is preserved
crow_indices = torch.tensor([0, 1, 2], dtype=torch.int32)
col_indices = torch.tensor([1, 0], dtype=torch.int32)
csr_tensor = torch.sparse_csr_tensor(crow_indices, col_indices, values, (2, 2))
print(f"CSR crow_indices dtype: {csr_tensor.crow_indices().dtype}") # torch.int32 (preserved!)
print(f"CSR col_indices dtype: {csr_tensor.col_indices().dtype}") # torch.int32 (preserved!)
Workarounds:
- Use CSR format when
int32
indices are important for performance - Account for extra memory when using COO format with large sparse matrices
- Test performance with both dtypes to determine if the conversion impacts your use case
Status: This is a known PyTorch behavior. Our test suite documents and validates this behavior to catch any future changes in PyTorch's handling of sparse tensor index dtypes.
Issue: CSR sparse tensors generated by PairwiseEncoder
consume significantly more memory during backward passes compared to COO format, particularly in integration tests with SparseMultivariateNormal
.
Impact:
- Memory Consumption: CSR integration tests can use 2-3x more memory than equivalent COO tests during
.backward()
- Training Stability: May cause out-of-memory errors during training with large spatial volumes
- Development: Affects integration testing with large tensor configurations
Suspected Cause: The issue may be related to CSR permutation operations within PairwiseEncoder
that create additional intermediate tensors during gradient computation.
Current Status: Under investigation. The memory spike occurs specifically during backpropagation through the sparse matrix operations.
Workarounds:
- Use COO format for
PairwiseEncoder
when memory is constrained during training - Reduce batch sizes or spatial dimensions when using CSR format
- Monitor memory usage carefully when integrating
PairwiseEncoder
with gradient-based optimization
Example:
# More memory-efficient approach for large tensors
encoder = PairwiseEncoder(
radius=2.0,
volume_shape=(4, 64, 64, 64),
layout=torch.sparse_coo # Use COO instead of CSR for memory efficiency
)
Issue: Large gradient magnitudes can occur when using LL^T parameterization with precision matrices in SparseMultivariateNormal
, leading to training instability.
Impact:
- Gradient Explosion: Gradients can become extremely large (>1e6) during backpropagation
- Training Instability: May cause NaN values or divergent optimization
- Numerical Issues: Poor conditioning of the precision matrix can amplify gradient problems
Affected Configurations:
- LL^T parameterization (
scale_tril
parameter) combined with precision matrix formulation - Both 2D and 3D spatial configurations show this behavior
- More pronounced with larger spatial dimensions and higher sparsity
Root Cause: The LL^T precision parameterization can lead to poor numerical conditioning, especially when the triangular matrix has small diagonal values or high condition number.
Recommended Solution: Use LDL^T parameterization instead, which provides better numerical stability:
# Problematic: LL^T precision parameterization
dist_unstable = SparseMultivariateNormal(
loc=loc,
precision_tril=scale_tril # LL^T with precision - can cause large gradients
)
# Better: LDL^T parameterization with separate diagonal
dist_stable = SparseMultivariateNormal(
loc=loc,
diagonal=diagonal, # Separate diagonal component for stability
precision_tril=unit_triangular_matrix # Unit triangular (LDL^T)
)
Benefits of LDL^T Parameterization:
- Numerical Stability: Separates diagonal scaling from triangular structure
- Gradient Stability: More stable gradients during backpropagation
- No SPD Constraints: Doesn't require strict positive definiteness
- Better Conditioning: Diagonal component can be controlled independently
Status: This is a known limitation of the LL^T precision formulation. LDL^T parameterization is the recommended approach for precision matrices.