"""
Sparse Non-negative Matrix Factorization (NMF) module with GPU acceleration.
This module provides GPU-accelerated NMF that works directly with sparse matrices,
avoiding the need to convert large sparse matrices to dense format (which can be
memory-prohibitive for very sparse data).
The implementation uses PyTorch sparse tensors and processes data in batches to
maintain memory efficiency while leveraging GPU acceleration.
Two modes are available:
1. **Standalone NMF**: Use SparseNMF for dimensionality reduction, then pass to autoencoder
2. **Joint Model**: Use SparseNMF_Autoencoder for end-to-end training (recommended)
Examples:
---------
Standalone NMF:
>>> from AoU.phenome.sparseNMF import SparseNMF
>>> from scipy.sparse import csr_matrix
>>>
>>> # Create a sparse matrix (e.g., from gene associations)
>>> X_sparse = csr_matrix((data, (rows, cols)), shape=(n_samples, n_features))
>>>
>>> # Fit NMF model
>>> nmf = SparseNMF(n_components=256, max_iter=500, device='cuda:0')
>>> X_reduced = nmf.fit_transform(X_sparse)
>>>
>>> # X_reduced is now (n_samples, 256) dense array ready for autoencoder
Joint Model (Recommended):
>>> from AoU.phenome.sparseNMF import train_joint_model
>>>
>>> # Train joint model end-to-end
>>> z, model = train_joint_model(
... X_sparse,
... n_samples=114171,
... n_features=31263,
... nmf_components=256,
... latent_dim=2,
... device='cuda:0',
... n_epochs=100
... )
>>>
>>> # z is now (n_samples, 2) 2D embeddings
"""
import numpy as np
import pandas as pd
from typing import Optional, Union, Tuple
from scipy.sparse import spmatrix, csr_matrix, coo_matrix
import torch
import torch.nn as nn
from tqdm import tqdm
def _compute_recon_values_chunked(
W_rows: torch.Tensor,
H: torch.Tensor,
col_idx: torch.Tensor,
chunk_size: int = 50000,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Compute reconstructed values (W_rows * H_cols).sum(dim=1) in chunks to avoid OOM.
Uses a memory-efficient approach that processes the sum in sub-chunks to avoid
creating large intermediate tensors.
Parameters
----------
W_rows : torch.Tensor
W matrix rows of shape (nnz, n_components).
H : torch.Tensor
H matrix of shape (n_components, n_features).
col_idx : torch.Tensor
Column indices of shape (nnz,).
chunk_size : int
Maximum number of entries to process at once.
device : torch.device, optional
Device to use for clearing cache. If None, uses W_rows.device.
Returns
-------
torch.Tensor
Reconstructed values of shape (nnz,).
"""
nnz = W_rows.shape[0]
n_components = H.shape[0]
if device is None:
device = W_rows.device
# For very large n_components, use much smaller chunks
# Estimate memory: chunk_size * n_components * 4 bytes (float32) * 2 (for intermediate)
# Try to keep chunks under ~50MB for large n_components
estimated_memory_per_entry = n_components * 4 * 2 # bytes (with intermediate)
max_chunk_memory = 50 * 1024 * 1024 # 50MB (more conservative)
adaptive_chunk_size = min(chunk_size, max_chunk_memory // estimated_memory_per_entry)
adaptive_chunk_size = max(500, adaptive_chunk_size) # At least 500 entries
# For very large n_components, use even smaller chunks
if n_components > 512:
adaptive_chunk_size = min(adaptive_chunk_size, 5000)
if nnz <= adaptive_chunk_size:
# Small enough to process all at once, but still use memory-efficient approach
# Use unique columns to reduce memory
unique_cols, inverse_idx = torch.unique(col_idx, return_inverse=True)
H_cols_unique = H[:, unique_cols].t() # (n_unique, n_components)
H_cols = H_cols_unique[inverse_idx] # (nnz, n_components)
# For large n_components, compute sum in sub-chunks to avoid large intermediate tensor
if n_components > 512:
# Process sum in sub-chunks across components dimension
result = torch.zeros(nnz, device=W_rows.device, dtype=W_rows.dtype)
component_chunk_size = 256
for comp_start in range(0, n_components, component_chunk_size):
comp_end = min(comp_start + component_chunk_size, n_components)
comp_idx = slice(comp_start, comp_end)
result += (W_rows[:, comp_idx] * H_cols[:, comp_idx]).sum(dim=1)
if device.type == 'cuda' and comp_start % (component_chunk_size * 4) == 0:
torch.cuda.empty_cache()
else:
result = (W_rows * H_cols).sum(dim=1)
del H_cols_unique, H_cols, unique_cols, inverse_idx
if device.type == 'cuda':
torch.cuda.empty_cache()
return result
# Process in chunks
X_recon_values = torch.zeros(nnz, device=W_rows.device, dtype=W_rows.dtype)
for chunk_start in range(0, nnz, adaptive_chunk_size):
chunk_end = min(chunk_start + adaptive_chunk_size, nnz)
chunk_idx = slice(chunk_start, chunk_end)
W_rows_chunk = W_rows[chunk_idx]
col_idx_chunk = col_idx[chunk_idx]
# Get unique columns to reduce memory
unique_cols, inverse_idx = torch.unique(col_idx_chunk, return_inverse=True)
H_cols_unique = H[:, unique_cols].t() # (n_unique, n_components)
H_cols_chunk = H_cols_unique[inverse_idx] # (chunk_size, n_components)
# For large n_components, compute sum in sub-chunks
if n_components > 512:
chunk_result = torch.zeros(chunk_end - chunk_start, device=W_rows.device, dtype=W_rows.dtype)
component_chunk_size = 256
for comp_start in range(0, n_components, component_chunk_size):
comp_end = min(comp_start + component_chunk_size, n_components)
comp_idx = slice(comp_start, comp_end)
chunk_result += (W_rows_chunk[:, comp_idx] * H_cols_chunk[:, comp_idx]).sum(dim=1)
X_recon_values[chunk_idx] = chunk_result
del chunk_result
else:
X_recon_values[chunk_idx] = (W_rows_chunk * H_cols_chunk).sum(dim=1)
# Clear intermediate tensors aggressively
del H_cols_unique, H_cols_chunk, W_rows_chunk, col_idx_chunk, unique_cols, inverse_idx
# Clear cache after every chunk when memory is tight
if device.type == 'cuda':
torch.cuda.empty_cache()
# Final cache clear
if device.type == 'cuda':
torch.cuda.empty_cache()
return X_recon_values
[docs]
class SparseNMF:
"""
GPU-accelerated Non-negative Matrix Factorization for sparse matrices.
This class implements NMF using multiplicative update rules while working
directly with sparse matrices, avoiding memory-intensive dense conversions.
When R² loss is enabled (r2_weight > 0), the optimization switches to
gradient-based Adam optimizer to handle the normalized R² loss function.
Parameters
----------
n_components : int, default 256
Number of components (latent dimensions) to extract.
max_iter : int, default 500
Maximum number of iterations for optimization.
device : str, default 'cuda'
Device to use for computation ('cuda', 'cuda:0', 'cpu', etc.).
If 'cuda' is specified but CUDA is not available, falls back to CPU.
batch_size : int, optional
Batch size for processing rows. If None, auto-determines based on
GPU memory and matrix size.
verbose : bool, default True
Whether to print progress information.
random_state : int, optional
Random seed for initialization. If None, uses random initialization.
tol : float, default 1e-4
Tolerance for convergence checking. Training stops if change in
reconstruction error is below this threshold.
patience : int, optional
Number of iterations to wait without improvement before early stopping.
If None, only uses tolerance-based convergence. If specified, stops training
if the error doesn't improve for `patience` consecutive iterations.
mse_weight : float, default 1.0
Weight for MSE (Mean Squared Error) loss component.
total_loss = mse_weight * MSE + r2_weight * (1 - R²)
r2_weight : float, default 0.0
Weight for R² (coefficient of determination) loss component.
When > 0, switches to gradient-based optimization.
R² loss is computed as (1 - R²) so that minimizing loss maximizes R².
learning_rate : float, default 0.01
Learning rate for gradient-based optimizer (used when r2_weight > 0).
nonzero_mse_weight : float, default 0.0
Controls whether MSE loss includes zeros or only non-zero values.
When > 0: MSE computed only on non-zero positions (ignores zeros).
When 0: MSE computed on all positions including zeros (learns sparsity patterns).
When > 0, forces gradient descent (multiplicative updates always include zeros).
nonzero_r2_weight : float, default 0.0
Controls whether R² loss includes zeros or only non-zero values.
When > 0: R² computed only on non-zero positions (ignores zeros).
When 0: R² computed on all positions including zeros (learns sparsity patterns).
Only affects training when r2_weight > 0.
For final reporting, both R² (all values) and R² (non-zero only) are always computed.
Attributes
----------
W : torch.Tensor
Basis matrix of shape (n_samples, n_components).
H : torch.Tensor
Coefficient matrix of shape (n_components, n_features).
reconstruction_error_ : float
Final reconstruction error (Frobenius norm).
r2_score_ : float
R² (coefficient of determination) score on all values (including zeros),
computed on a sample of the data. Values closer to 1.0 indicate better
reconstruction quality.
r2_score_nonzero_ : float
R² (coefficient of determination) score on non-zero values only,
computed on a sample of the data. More meaningful for highly sparse data.
n_iter_ : int
Number of iterations completed.
Examples
--------
>>> from AoU.phenome.sparseNMF import SparseNMF
>>> from scipy.sparse import random
>>>
>>> # Create a sparse matrix
>>> X = random(1000, 5000, density=0.01, format='csr')
>>>
>>> # Fit NMF with MSE only (default, uses multiplicative updates)
>>> nmf = SparseNMF(n_components=128, max_iter=200, device='cuda:0')
>>> X_reduced = nmf.fit_transform(X)
>>>
>>> # Fit NMF with weighted MSE + R² loss (uses gradient descent)
>>> nmf = SparseNMF(
... n_components=128, max_iter=200, device='cuda:0',
... mse_weight=0.5, r2_weight=0.5 # Equal weighting
... )
>>> X_reduced = nmf.fit_transform(X)
>>> print(f"R² score: {nmf.r2_score_:.4f}")
"""
def __init__(
self,
n_components: int = 256,
max_iter: int = 500,
device: str = 'cuda',
batch_size: Optional[int] = None,
verbose: bool = True,
random_state: Optional[int] = None,
tol: float = 1e-4,
mse_weight: float = 1.0,
r2_weight: float = 0.0,
learning_rate: float = 0.01,
nonzero_mse_weight: float = 0.0,
nonzero_r2_weight: float = 0.0,
patience: Optional[int] = None,
):
self.n_components = n_components
self.max_iter = max_iter
self.batch_size = batch_size
self.verbose = verbose
self.random_state = random_state
self.tol = tol
self.mse_weight = mse_weight
self.r2_weight = r2_weight
self.learning_rate = learning_rate
self.nonzero_mse_weight = nonzero_mse_weight
self.nonzero_r2_weight = nonzero_r2_weight
self.patience = patience
# Set device
if device.startswith('cuda') and not torch.cuda.is_available():
if verbose:
print("CUDA not available, falling back to CPU")
self.device = torch.device('cpu')
else:
self.device = torch.device(device)
# Model parameters (set during fit)
self.W = None
self.H = None
self.reconstruction_error_ = None
self.r2_score_ = None
self.r2_score_nonzero_ = None
self.n_iter_ = None
# Set random seed if provided
if random_state is not None:
torch.manual_seed(random_state)
np.random.seed(random_state)
def _sparse_to_torch(self, X_sparse: spmatrix) -> torch.Tensor:
"""
Convert scipy sparse matrix to PyTorch sparse tensor.
Parameters
----------
X_sparse : scipy.sparse matrix
Sparse matrix to convert.
Returns
-------
torch.Tensor
Sparse COO tensor on the specified device.
"""
# Convert to COO format for efficient conversion
if not isinstance(X_sparse, coo_matrix):
coo = X_sparse.tocoo()
else:
coo = X_sparse
indices = torch.from_numpy(np.vstack([coo.row, coo.col])).long()
values = torch.from_numpy(coo.data).float()
shape = torch.Size(coo.shape)
# Create sparse tensor on device
sparse_tensor = torch.sparse_coo_tensor(
indices, values, shape, device=self.device
)
return sparse_tensor
def _sparse_matmul(self, sparse_A: torch.Tensor, dense_B: torch.Tensor) -> torch.Tensor:
"""
Efficient sparse-dense matrix multiplication.
Parameters
----------
sparse_A : torch.Tensor (sparse)
Sparse matrix.
dense_B : torch.Tensor (dense)
Dense matrix.
Returns
-------
torch.Tensor
Result of sparse_A @ dense_B.
"""
return torch.sparse.mm(sparse_A, dense_B)
def _auto_batch_size(self, n_samples: int) -> int:
"""
Automatically determine batch size based on GPU memory and matrix size.
Parameters
----------
n_samples : int
Number of samples (rows) in the matrix.
Returns
-------
int
Recommended batch size.
"""
if self.device.type == 'cuda':
try:
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
# Conservative estimate: use ~20% of GPU memory for batches
# Rough estimate: each batch needs ~batch_size * n_components * 4 bytes
estimated_batch_memory_gb = (self.n_components * 4) / 1e9 # per sample
max_batch_size = int((gpu_memory_gb * 0.2) / estimated_batch_memory_gb)
batch_size = max(1000, min(10000, max_batch_size, n_samples))
except:
batch_size = 5000
else:
batch_size = 2000 # Smaller batches for CPU
return batch_size
def _compute_recon_values_chunked(
self,
W_rows: torch.Tensor,
H: torch.Tensor,
col_idx: torch.Tensor,
chunk_size: int = 50000,
) -> torch.Tensor:
"""
Compute reconstructed values (W_rows * H_cols).sum(dim=1) in chunks to avoid OOM.
Parameters
----------
W_rows : torch.Tensor
W matrix rows of shape (nnz, n_components).
H : torch.Tensor
H matrix of shape (n_components, n_features).
col_idx : torch.Tensor
Column indices of shape (nnz,).
chunk_size : int
Maximum number of entries to process at once.
Returns
-------
torch.Tensor
Reconstructed values of shape (nnz,).
"""
return _compute_recon_values_chunked(W_rows, H, col_idx, chunk_size, self.device)
def _compute_r2_loss(
self,
X_values: torch.Tensor,
X_recon_values: torch.Tensor,
X_mean: torch.Tensor,
) -> torch.Tensor:
"""
Compute R² loss (1 - R²) for optimization.
R² = 1 - (SS_res / SS_tot)
R² loss = 1 - R² = SS_res / SS_tot
Parameters
----------
X_values : torch.Tensor
Original values (from sparse matrix).
X_recon_values : torch.Tensor
Reconstructed values.
X_mean : torch.Tensor
Mean of original values.
Returns
-------
torch.Tensor
R² loss value (1 - R²).
"""
SS_res = torch.sum((X_values - X_recon_values) ** 2)
SS_tot = torch.sum((X_values - X_mean) ** 2)
r2 = 1 - (SS_res / (SS_tot + 1e-10))
return 1 - r2 # R² loss (to minimize)
def _fit_multiplicative_updates(
self,
X_sparse: spmatrix,
X_csr: spmatrix,
W: torch.Tensor,
H: torch.Tensor,
n_samples: int,
n_features: int,
) -> tuple:
"""
Fit NMF using multiplicative update rules (efficient for MSE-only loss).
"""
prev_error = None
best_error = None
patience_counter = 0
error_report_interval = max(1, min(10, self.max_iter // 20))
iterator = tqdm(range(self.max_iter), disable=not self.verbose, desc="NMF iterations (multiplicative)")
for iteration in iterator:
# Update H: H = H * (W^T @ X) / (W^T @ W @ H)
WTX = torch.zeros(n_features, self.n_components, device=self.device)
for i in range(0, n_samples, self.batch_size):
end = min(i + self.batch_size, n_samples)
X_batch_sparse = X_csr[i:end]
if X_batch_sparse.nnz > 0:
X_batch_torch = self._sparse_to_torch(X_batch_sparse)
WTX += self._sparse_matmul(X_batch_torch.t(), W[i:end])
WTW = torch.mm(W.t(), W)
WTWH = torch.mm(WTW, H)
H = H * (WTX.t() / (WTWH + 1e-10))
H = torch.clamp(H, min=1e-10)
# Update W: W = W * (X @ H^T) / (W @ H @ H^T)
for i in range(0, n_samples, self.batch_size):
end = min(i + self.batch_size, n_samples)
X_batch_sparse = X_csr[i:end]
if X_batch_sparse.nnz > 0:
X_batch_torch = self._sparse_to_torch(X_batch_sparse)
W_batch = W[i:end]
XHT_batch = self._sparse_matmul(X_batch_torch, H.t())
WHHT_batch = torch.mm(torch.mm(W_batch, H), H.t())
W[i:end] = W_batch * (XHT_batch / (WHHT_batch + 1e-10))
W = torch.clamp(W, min=1e-10)
# Monitor progress
if self.verbose and ((iteration + 1) % error_report_interval == 0 or iteration == 0):
error, prev_error, converged = self._check_convergence(
X_csr, W, H, n_samples, prev_error, iterator, iteration,
patience=self.patience, best_error=best_error, patience_counter=patience_counter
)
if converged:
break
# Check patience-based early stopping
if self.patience is not None:
if best_error is None or error < best_error:
best_error = error
patience_counter = 0
else:
patience_counter += error_report_interval
if patience_counter >= self.patience:
if self.verbose:
print(f"\nEarly stopping: no improvement for {self.patience} iterations (best error: {best_error:.6f})")
break
return W, H, iteration
def _fit_gradient_descent(
self,
X_sparse: spmatrix,
X_csr: spmatrix,
W: torch.Tensor,
H: torch.Tensor,
n_samples: int,
n_features: int,
) -> tuple:
"""
Fit NMF using gradient descent (supports weighted MSE + R² loss).
"""
import torch.nn.functional as F
# Clear CUDA cache aggressively to free up memory before training
if self.device.type == 'cuda':
torch.cuda.empty_cache()
# Force garbage collection and clear cache multiple times
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Make W and H require gradients
W = W.requires_grad_(True)
H = H.requires_grad_(True)
# Use Adam optimizer
optimizer = torch.optim.Adam([W, H], lr=self.learning_rate)
# Patience tracking for gradient descent
best_loss = None
patience_counter = 0
# Compute global mean for R² calculation
# If nonzero_r2_weight > 0, use mean of non-zero values only
# If nonzero_r2_weight = 0, use mean of all values (including zeros)
if self.r2_weight > 0:
if self.nonzero_r2_weight > 0:
X_mean = torch.tensor(X_sparse.data.mean(), device=self.device, dtype=torch.float32)
else:
# Mean of all values including zeros (sparsity-aware)
total_sum = X_sparse.data.sum()
total_elements = X_sparse.shape[0] * X_sparse.shape[1]
X_mean = torch.tensor(total_sum / total_elements, device=self.device, dtype=torch.float32)
else:
# Not used if r2_weight = 0, but initialize to avoid errors
X_mean = torch.tensor(X_sparse.data.mean(), device=self.device, dtype=torch.float32)
prev_loss = None
loss_report_interval = max(1, min(10, self.max_iter // 20))
iterator = tqdm(range(self.max_iter), disable=not self.verbose, desc="NMF iterations (gradient)")
# Determine effective MSE weight: use nonzero_mse_weight as the weight if mse_weight is 0
# This allows users to set mse_weight=0 and nonzero_mse_weight=1.0 to train with non-zero MSE only
effective_mse_weight = self.mse_weight if self.mse_weight > 0 else self.nonzero_mse_weight
# Similarly for R² weight
effective_r2_weight = self.r2_weight if self.r2_weight > 0 else self.nonzero_r2_weight
for iteration in iterator:
optimizer.zero_grad()
total_loss = torch.tensor(0.0, device=self.device, requires_grad=False)
total_mse = torch.tensor(0.0, device=self.device)
total_r2_loss = torch.tensor(0.0, device=self.device)
n_batches = 0
# Process in batches for memory efficiency
for i in range(0, n_samples, self.batch_size):
end = min(i + self.batch_size, n_samples)
X_batch_sparse = X_csr[i:end]
if X_batch_sparse.nnz == 0:
continue
# Convert batch to torch sparse
X_batch_torch = self._sparse_to_torch(X_batch_sparse)
coo = X_batch_torch.coalesce()
X_values = coo.values()
row_idx = coo.indices()[0]
col_idx = coo.indices()[1]
W_batch = W[i:end]
# Pre-compute X_recon_values on non-zero positions if needed (for MSE or R²)
X_recon_values_nonzero = None
if self.nonzero_mse_weight > 0 or self.nonzero_r2_weight > 0:
# Compute reconstructed values on non-zero positions (chunked to avoid OOM)
W_rows = W_batch[row_idx]
# Use much smaller chunk size when nonzero_mse_weight is enabled to be more conservative
# Especially important for large n_components
chunk_size = 5000 if self.nonzero_mse_weight > 0 else 20000
X_recon_values_nonzero = _compute_recon_values_chunked(W_rows, H, col_idx, chunk_size, self.device)
# MSE loss - controlled by nonzero_mse_weight
if effective_mse_weight > 0:
if self.nonzero_mse_weight > 0:
# Compute MSE on non-zero values only
if X_recon_values_nonzero is None:
W_rows = W_batch[row_idx]
chunk_size = 5000 # Very conservative chunk size for nonzero_mse_weight
X_recon_values = _compute_recon_values_chunked(W_rows, H, col_idx, chunk_size, self.device)
else:
X_recon_values = X_recon_values_nonzero
mse_loss = F.mse_loss(X_recon_values, X_values)
else:
# Compute MSE on all values (including zeros) - need full dense matrix
X_batch_dense = X_batch_torch.to_dense()
X_recon_batch = torch.mm(W_batch, H)
mse_loss = F.mse_loss(X_recon_batch, X_batch_dense)
total_loss = total_loss + effective_mse_weight * mse_loss
total_mse = total_mse + mse_loss.detach()
# R² loss - controlled by nonzero_r2_weight
if effective_r2_weight > 0:
if self.nonzero_r2_weight > 0:
# Compute R² on non-zero values only
if X_recon_values_nonzero is None:
W_rows = W_batch[row_idx]
chunk_size = 5000 # Very conservative chunk size
X_recon_values = _compute_recon_values_chunked(W_rows, H, col_idx, chunk_size, self.device)
else:
X_recon_values = X_recon_values_nonzero
r2_loss = self._compute_r2_loss(X_values, X_recon_values, X_mean)
else:
# Compute R² on all values (including zeros) - need full dense matrix
X_batch_dense = X_batch_torch.to_dense()
X_recon_batch = torch.mm(W_batch, H)
X_batch_flat = X_batch_dense.flatten()
X_recon_flat = X_recon_batch.flatten()
# Use mean of all values in this batch (including zeros) for R² calculation
X_mean_batch = X_batch_flat.mean()
r2_loss = self._compute_r2_loss(X_batch_flat, X_recon_flat, X_mean_batch)
total_loss = total_loss + effective_r2_weight * r2_loss
total_r2_loss = total_r2_loss + r2_loss.detach()
# Clear intermediate tensors and cache after each batch when using nonzero_mse_weight
# (memory-intensive mode) - helps prevent OOM
if self.nonzero_mse_weight > 0 or self.nonzero_r2_weight > 0:
# Clear cache periodically to prevent memory fragmentation
if self.device.type == 'cuda' and n_batches % 5 == 0:
torch.cuda.empty_cache()
n_batches += 1
if n_batches == 0:
continue
# Average loss across batches
total_loss = total_loss / n_batches
# Backward pass
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_([W, H], max_norm=1.0)
# Optimizer step
optimizer.step()
# Enforce non-negativity
with torch.no_grad():
W.clamp_(min=1e-10)
H.clamp_(min=1e-10)
# Monitor progress
if self.verbose and ((iteration + 1) % loss_report_interval == 0 or iteration == 0):
avg_mse = (total_mse / n_batches).item()
# Compute R² for display if either r2_weight > 0 or nonzero_r2_weight > 0
should_show_r2 = self.r2_weight > 0 or self.nonzero_r2_weight > 0
avg_r2_loss = (total_r2_loss / n_batches).item() if should_show_r2 else 0
current_r2 = 1 - avg_r2_loss # Convert loss back to R²
postfix = {'loss': f'{total_loss.item():.6f}', 'MSE': f'{avg_mse:.6f}'}
if should_show_r2:
# Label R² based on whether it's computed on non-zero values only
r2_label = 'R²(nonzero)' if self.nonzero_r2_weight > 0 else 'R²'
postfix[r2_label] = f'{current_r2:.4f}'
if prev_loss is not None:
loss_change = abs(prev_loss - total_loss.item()) / (prev_loss + 1e-10)
postfix['change'] = f'{loss_change:.2%}'
if loss_change < self.tol:
if self.verbose:
print(f"\nConverged at iteration {iteration + 1}")
break
# Check patience-based early stopping
if self.patience is not None:
current_loss = total_loss.item()
if best_loss is None or current_loss < best_loss:
best_loss = current_loss
patience_counter = 0
else:
patience_counter += loss_report_interval
# Add patience counter to progress bar
postfix['patience'] = f'{patience_counter}/{self.patience}'
if patience_counter >= self.patience:
if self.verbose:
print(f"\nEarly stopping: no improvement for {self.patience} iterations (best loss: {best_loss:.6f})")
break
iterator.set_postfix(postfix)
prev_loss = total_loss.item()
# Detach from computation graph
W = W.detach()
H = H.detach()
return W, H, iteration
def _check_convergence(
self,
X_csr: spmatrix,
W: torch.Tensor,
H: torch.Tensor,
n_samples: int,
prev_error: Optional[float],
iterator,
iteration: int,
patience: Optional[int] = None,
best_error: Optional[float] = None,
patience_counter: int = 0,
) -> tuple:
"""Check convergence and update progress bar."""
sample_size = min(1000, n_samples)
sample_indices = torch.randperm(n_samples, device=self.device)[:sample_size]
X_sample_sparse = X_csr[sample_indices.cpu().numpy()]
if X_sample_sparse.nnz > 0:
X_sample_torch = self._sparse_to_torch(X_sample_sparse)
W_sample = W[sample_indices]
X_recon_sample = torch.mm(W_sample, H)
X_sample_dense = X_sample_torch.to_dense()
error = torch.mean((X_sample_dense - X_recon_sample) ** 2).item()
if prev_error is not None:
error_change = abs(prev_error - error) / (prev_error + 1e-10)
postfix = {'error': f'{error:.6f}', 'change': f'{error_change:.2%}'}
if patience is not None:
postfix['patience'] = f'{patience_counter}/{patience}'
iterator.set_postfix(postfix)
if error_change < self.tol:
if self.verbose:
print(f"\nConverged at iteration {iteration + 1}")
return error, error, True
else:
postfix = {'error': f'{error:.6f}', 'change': 'N/A'}
if patience is not None:
postfix['patience'] = f'{patience_counter}/{patience}'
iterator.set_postfix(postfix)
return error, error, False
return prev_error, prev_error, False
def _compute_final_metrics(
self,
X_csr: spmatrix,
W: torch.Tensor,
H: torch.Tensor,
n_samples: int,
) -> None:
"""Compute final reconstruction error and R² score (both all values and non-zero only)."""
if self.verbose:
print("Computing final reconstruction error and R²...")
sample_size = min(5000, n_samples)
sample_indices = torch.randperm(n_samples, device=self.device)[:sample_size]
X_sample_sparse = X_csr[sample_indices.cpu().numpy()]
if X_sample_sparse.nnz > 0:
X_sample_torch = self._sparse_to_torch(X_sample_sparse)
W_sample = W[sample_indices]
X_recon_sample = torch.mm(W_sample, H)
X_sample_dense = X_sample_torch.to_dense()
self.reconstruction_error_ = torch.mean((X_sample_dense - X_recon_sample) ** 2).item()
# Compute R² on all values (including zeros)
X_sample_flat = X_sample_dense.flatten()
X_recon_flat = X_recon_sample.flatten()
X_mean_all = X_sample_flat.mean()
SS_res_all = torch.sum((X_sample_flat - X_recon_flat) ** 2).item()
SS_tot_all = torch.sum((X_sample_flat - X_mean_all) ** 2).item()
# R² can be negative when model performs worse than baseline (predicting mean)
# For highly sparse data, NMF with non-negative constraints cannot produce exact zeros.
# The model W @ H produces a dense matrix with small positive values everywhere,
# causing large errors on zero positions. This makes R² (all values) very negative
# even though R² (non-zero only) may be good.
if SS_tot_all > 0:
self.r2_score_ = 1 - (SS_res_all / SS_tot_all)
# Clamp extremely negative values to prevent numerical issues
# R² below -1000 is likely a numerical artifact or indicates severe model failure
if self.r2_score_ < -1000:
self.r2_score_ = -1000.0
else:
self.r2_score_ = 0.0
# Compute R² on non-zero values only (sparse-aware)
coo = X_sample_torch.coalesce()
X_values = coo.values() # Only non-zero values
row_idx = coo.indices()[0]
col_idx = coo.indices()[1]
# Get reconstructed values at non-zero positions only
W_rows = W_sample[row_idx]
X_recon_values = _compute_recon_values_chunked(W_rows, H, col_idx)
# R² on non-zero values only
X_mean_nonzero = X_values.mean()
SS_res_nonzero = torch.sum((X_values - X_recon_values) ** 2).item()
SS_tot_nonzero = torch.sum((X_values - X_mean_nonzero) ** 2).item()
self.r2_score_nonzero_ = 1 - (SS_res_nonzero / SS_tot_nonzero) if SS_tot_nonzero > 0 else 0.0
else:
self.reconstruction_error_ = 0.0
self.r2_score_ = 0.0
self.r2_score_nonzero_ = 0.0
if self.verbose:
print(f"Final reconstruction error (MSE): {self.reconstruction_error_:.6f}")
print(f"Final R² (all values): {self.r2_score_:.6f}")
if self.r2_score_ < -10:
print(f" ⚠️ Warning: Very negative R² (all values) is expected for highly sparse data.")
print(f" NMF with non-negative constraints cannot produce exact zeros - the model")
print(f" predicts small positive values everywhere, causing large errors on zero")
print(f" positions. R² (non-zero only) is the relevant metric for sparse data.")
print(f"Final R² (non-zero values only): {self.r2_score_nonzero_:.6f}")
print(f" (computed on {sample_size:,} sample rows out of {n_samples:,} total)")
[docs]
class SparseNMF_Autoencoder(nn.Module):
"""
Joint SparseNMF + Autoencoder model for end-to-end training.
This model combines sparse NMF with an autoencoder in a single trainable architecture.
The key advantage is that sparse operations are used throughout NMF, and only the
reduced W matrix (n_samples × n_components) is converted to dense for the autoencoder.
Architecture:
Sparse X (n_samples, n_features)
→ NMF: X ≈ W @ H (sparse operations)
→ Dense W (n_samples, n_components)
→ Autoencoder: W → z (n_samples, latent_dim)
Parameters
----------
n_samples : int
Number of samples (rows) in the input matrix.
Can be inferred from X_sparse.shape[0] in train_joint_model.
n_features : int
Number of features (columns) in the input matrix.
Can be inferred from X_sparse.shape[1] in train_joint_model.
nmf_components : int, default 256
Number of NMF components (intermediate dimension).
latent_dim : int, default 2
Final latent dimension (output of autoencoder).
hidden_dims : tuple of int, default (256, 128, 64, 16)
Hidden layer dimensions for autoencoder. Matches two-step approach default.
activation : str, default "relu"
Activation function for autoencoder.
dropout : float, default 0.0
Dropout rate for autoencoder.
use_vae : bool, default False
Whether to use Variational Autoencoder.
use_feature_attention : bool, default False
If True, learn attention weights for each NMF component based on reconstruction importance.
Disabled by default for joint training - use two-step approach for feature attention.
feature_attention_weight : float, default 1.0
Weight for mixing original input with attended input (0 = no attention, 1 = full attention).
feature_attention_temperature : float, default 1.0
Temperature for attention weights. Use higher values (1.0+) for joint training stability.
normalize_nmf_components : bool, default False
Whether to L2-normalize NMF components (W) before passing to autoencoder.
When True, matches `normalize_input=True` in two-step approach. This is crucial
for proper clustering and prevents radial patterns. When False, preserves the
original magnitude information in W.
device : str, default 'cuda'
Device to use for computation.
Examples
--------
>>> from AoU.phenome.sparseNMF import SparseNMF_Autoencoder
>>> from scipy.sparse import csr_matrix
>>>
>>> # Create and train model (dimensions inferred automatically)
>>> z, model = train_joint_model(
... X_sparse, # n_samples and n_features inferred from X_sparse.shape
... nmf_components=256,
... latent_dim=2,
... device='cuda:0',
... n_epochs=100
... )
"""
def __init__(
self,
n_samples: int,
n_features: int,
nmf_components: int = 256,
latent_dim: int = 2,
hidden_dims: tuple = (256, 128, 64, 16), # Match two-step approach default
activation: str = "relu",
dropout: float = 0.0,
use_vae: bool = False,
use_feature_attention: bool = False,
feature_attention_weight: float = 1.0,
feature_attention_temperature: float = 1.0, # Higher = more gradual (safer for joint training)
normalize_nmf_components: bool = False, # Whether to L2-normalize W before autoencoder
device: str = 'cuda',
random_state: Optional[int] = None,
):
super().__init__()
self.n_samples = n_samples
self.n_features = n_features
self.nmf_components = nmf_components
self.latent_dim = latent_dim
self.use_vae = use_vae
self.use_feature_attention = use_feature_attention
self.feature_attention_weight = feature_attention_weight
self.feature_attention_temperature = feature_attention_temperature
self.normalize_nmf_components = normalize_nmf_components
# Set device
if device.startswith('cuda') and not torch.cuda.is_available():
self.device = torch.device('cpu')
else:
self.device = torch.device(device)
# Set random seed
if random_state is not None:
torch.manual_seed(random_state)
np.random.seed(random_state)
# NMF parameters (learnable)
# W: (n_samples, nmf_components) - this is the only dense matrix
# Better initialization: use Xavier uniform for better gradient flow
W_init = torch.empty(n_samples, nmf_components, device=self.device)
nn.init.xavier_uniform_(W_init, gain=0.1)
self.W = nn.Parameter(W_init + 0.01) # Ensure positive
# H: (nmf_components, n_features)
H_init = torch.empty(nmf_components, n_features, device=self.device)
nn.init.xavier_uniform_(H_init, gain=0.1)
self.H = nn.Parameter(H_init + 0.01) # Ensure positive
# Feature attention mechanism (matches two-step autoencoder)
# Learns which NMF components are important for reconstruction
if use_feature_attention:
self.feature_attention_net = nn.Sequential(
nn.Linear(nmf_components, max(64, nmf_components // 4)),
nn.ReLU(),
nn.Linear(max(64, nmf_components // 4), nmf_components),
# No sigmoid here - we'll apply sigmoid with temperature in forward()
)
# Autoencoder
# Activation function
if activation == "relu":
self.activation = nn.LeakyReLU(0.1)
elif activation == "leaky_relu":
self.activation = nn.LeakyReLU(0.1)
elif activation == "gelu":
self.activation = nn.GELU()
elif activation == "silu" or activation == "swish":
self.activation = nn.SiLU()
elif activation == "tanh":
self.activation = nn.Tanh()
elif activation == "sigmoid":
self.activation = nn.Sigmoid()
else:
raise ValueError(f"Unsupported activation: {activation}")
# Build encoder
encoder_layers = []
prev_dim = nmf_components
for hidden_dim in hidden_dims:
encoder_layers.append(nn.Linear(prev_dim, hidden_dim))
encoder_layers.append(nn.BatchNorm1d(hidden_dim))
encoder_layers.append(self.activation)
if dropout > 0:
encoder_layers.append(nn.Dropout(dropout))
prev_dim = hidden_dim
self.encoder_base = nn.Sequential(*encoder_layers)
# Encoder output
if use_vae:
self.encoder_mu = nn.Linear(prev_dim, latent_dim)
self.encoder_logvar = nn.Linear(prev_dim, latent_dim)
else:
self.encoder_latent = nn.Linear(prev_dim, latent_dim)
# Build decoder
decoder_layers = []
prev_dim = latent_dim
for hidden_dim in reversed(hidden_dims):
decoder_layers.append(nn.Linear(prev_dim, hidden_dim))
decoder_layers.append(nn.BatchNorm1d(hidden_dim))
decoder_layers.append(self.activation)
if dropout > 0:
decoder_layers.append(nn.Dropout(dropout))
prev_dim = hidden_dim
decoder_layers.append(nn.Linear(prev_dim, nmf_components))
self.decoder = nn.Sequential(*decoder_layers)
[docs]
def reparameterize(self, mu, logvar):
"""Reparameterization trick for VAE."""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def _sparse_to_torch(self, X_sparse: spmatrix) -> torch.Tensor:
"""Convert scipy sparse matrix to PyTorch sparse tensor."""
if not isinstance(X_sparse, coo_matrix):
coo = X_sparse.tocoo()
else:
coo = X_sparse
indices = torch.from_numpy(np.vstack([coo.row, coo.col])).long()
values = torch.from_numpy(coo.data).float()
shape = torch.Size(coo.shape)
sparse_tensor = torch.sparse_coo_tensor(
indices, values, shape, device=self.device
)
return sparse_tensor
[docs]
def forward(self, X_sparse_torch: torch.Tensor):
"""
Forward pass through the joint model.
Parameters
----------
X_sparse_torch : torch.sparse.FloatTensor
Sparse input matrix (n_samples, n_features)
Returns
-------
z : torch.Tensor
Latent embeddings (n_samples, latent_dim)
W_recon : torch.Tensor
Reconstructed W from autoencoder (n_samples, nmf_components)
X_recon : torch.Tensor
NMF reconstruction (n_samples, n_features) - for loss computation
W : torch.Tensor
Current W matrix (n_samples, nmf_components)
"""
# Ensure non-negativity
W = torch.clamp(self.W, min=1e-10)
H = torch.clamp(self.H, min=1e-10)
# NMF reconstruction: X ≈ W @ H
# Don't compute full dense matrix - we'll compute only needed values in loss
# X_recon = torch.mm(W, H) # (n_samples, n_features) - TOO LARGE!
# Instead, we'll compute W @ H only for non-zero elements in loss function
X_recon = None # Will be computed on-demand in loss function
# Optionally normalize W before autoencoder (matches normalize_input=True in two-step)
# This is crucial for proper clustering and prevents radial patterns when enabled
if self.normalize_nmf_components:
W_for_encoder = torch.nn.functional.normalize(W, p=2, dim=1)
else:
W_for_encoder = W
# Apply feature attention if enabled (matches two-step autoencoder)
if self.use_feature_attention:
attention_logits = self.feature_attention_net(W_for_encoder)
attention_weights = torch.sigmoid(attention_logits / self.feature_attention_temperature)
attended_W = W_for_encoder * attention_weights
W_for_encoder = (1.0 - self.feature_attention_weight) * W_for_encoder + self.feature_attention_weight * attended_W
# Pass through autoencoder
encoded_base = self.encoder_base(W_for_encoder)
if self.use_vae:
mu = self.encoder_mu(encoded_base)
logvar = self.encoder_logvar(encoded_base)
z = self.reparameterize(mu, logvar)
W_recon = self.decoder(z)
return z, W_recon, None, W, H, mu, logvar # X_recon=None, pass H instead
else:
z = self.encoder_latent(encoded_base)
W_recon = self.decoder(z)
return z, W_recon, None, W, H # X_recon=None, pass H instead
[docs]
def encode(self, X_sparse_torch: Optional[torch.Tensor] = None):
"""
Encode to latent space.
Parameters
----------
X_sparse_torch : torch.sparse.FloatTensor, optional
If provided, uses current W. Otherwise uses stored W.
Returns
-------
z : torch.Tensor
Latent embeddings (n_samples, latent_dim)
"""
W = torch.clamp(self.W, min=1e-10)
# Optionally normalize W before encoding (matches normalize_input=True in two-step)
if self.normalize_nmf_components:
W_for_encoder = torch.nn.functional.normalize(W, p=2, dim=1)
else:
W_for_encoder = W
# Apply feature attention if enabled (matches two-step autoencoder)
if self.use_feature_attention:
attention_logits = self.feature_attention_net(W_for_encoder)
attention_weights = torch.sigmoid(attention_logits / self.feature_attention_temperature)
attended_W = W_for_encoder * attention_weights
W_for_encoder = (1.0 - self.feature_attention_weight) * W_for_encoder + self.feature_attention_weight * attended_W
encoded_base = self.encoder_base(W_for_encoder)
if self.use_vae:
mu = self.encoder_mu(encoded_base)
logvar = self.encoder_logvar(encoded_base)
if self.training:
z = self.reparameterize(mu, logvar)
else:
z = mu # Use mean at inference
return z
else:
return self.encoder_latent(encoded_base)
[docs]
def compute_joint_loss(
model: SparseNMF_Autoencoder,
X_sparse_torch: torch.Tensor,
z: torch.Tensor,
W_recon: torch.Tensor,
X_recon: Optional[torch.Tensor],
W: torch.Tensor,
H: torch.Tensor,
mu: Optional[torch.Tensor] = None,
logvar: Optional[torch.Tensor] = None,
nmf_weight: float = 1.0,
ae_weight: float = 1.0,
kl_weight: float = 0.01,
use_sparse_loss: bool = True,
use_contrastive: bool = True,
contrastive_weight: float = 0.25,
contrastive_temperature: float = 0.5,
use_cosine_loss: bool = True,
dimension_reg_weight: float = 0.1,
) -> tuple:
"""
Compute combined loss for joint model.
Parameters
----------
model : SparseNMF_Autoencoder
The model instance.
X_sparse_torch : torch.sparse.FloatTensor
Sparse input matrix.
z : torch.Tensor
Latent embeddings.
W_recon : torch.Tensor
Reconstructed W from autoencoder.
X_recon : torch.Tensor
NMF reconstruction.
W : torch.Tensor
Current W matrix.
mu : torch.Tensor, optional
VAE mean (if using VAE).
logvar : torch.Tensor, optional
VAE log variance (if using VAE).
nmf_weight : float, default 1.0
Weight for NMF reconstruction loss.
ae_weight : float, default 1.0
Weight for autoencoder reconstruction loss.
kl_weight : float, default 0.01
Weight for KL divergence (VAE only).
use_sparse_loss : bool, default True
If True, compute NMF loss only on non-zero elements (more efficient).
Returns
-------
total_loss : torch.Tensor
Total combined loss.
loss_dict : dict
Dictionary of individual loss components.
"""
import torch.nn.functional as F
losses = {}
# 1. NMF reconstruction loss
if use_sparse_loss:
# Only compute loss on non-zero elements (memory efficient)
coo = X_sparse_torch.coalesce()
X_values = coo.values()
row_idx = coo.indices()[0]
col_idx = coo.indices()[1]
# Compute W @ H only for non-zero positions (memory efficient)
# X_recon[row, col] = sum_k W[row, k] * H[k, col]
# We can compute this efficiently using batched operations
W_rows = W[row_idx] # (nnz, nmf_components)
X_recon_values = _compute_recon_values_chunked(W_rows, H, col_idx) # (nnz,)
nmf_loss = F.mse_loss(X_recon_values, X_values)
else:
# Compute full dense loss (memory intensive - not recommended for large matrices)
if X_recon is None:
# Compute on the fly
X_recon = torch.mm(W, H)
X_dense = X_sparse_torch.to_dense()
nmf_loss = F.mse_loss(X_recon, X_dense)
losses['nmf'] = nmf_loss
# 2. Autoencoder reconstruction loss
if use_cosine_loss:
# Use cosine loss (works better with normalized inputs)
# Normalize both W and W_recon for cosine similarity
W_norm = torch.nn.functional.normalize(W, p=2, dim=1)
W_recon_norm = torch.nn.functional.normalize(W_recon, p=2, dim=1)
# Cosine loss: 1 - cosine_similarity (mean over batch)
cosine_sim = (W_norm * W_recon_norm).sum(dim=1)
ae_loss = (1.0 - cosine_sim).mean()
else:
# Standard MSE loss
ae_loss = F.mse_loss(W_recon, W)
losses['ae'] = ae_loss
# 3. KL divergence (VAE only)
kl_loss = None
if model.use_vae and mu is not None and logvar is not None:
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
losses['kl'] = kl_loss
# 4. Contrastive loss (InfoNCE-style, matches autoencoder implementation)
# Memory-efficient: sample subset for large datasets to avoid O(n^2) memory
contrastive_loss = None
if use_contrastive:
n_samples = W.shape[0]
if n_samples < 2:
contrastive_loss = torch.tensor(0.0, device=W.device)
else:
# For large datasets, sample a subset to avoid O(n^2) memory
# Sample up to 1000 samples for contrastive loss computation
max_contrastive_samples = min(1000, n_samples)
if n_samples > max_contrastive_samples:
# Randomly sample indices
indices = torch.randperm(n_samples, device=W.device)[:max_contrastive_samples]
W_sample = W[indices]
z_sample = z[indices]
else:
W_sample = W
z_sample = z
batch_size = W_sample.shape[0]
# Normalize for cosine similarity (matches autoencoder)
W_norm = torch.nn.functional.normalize(W_sample, p=2, dim=1)
z_norm = torch.nn.functional.normalize(z_sample, p=2, dim=1)
# Compute pairwise similarities in W space (original)
W_sim = torch.mm(W_norm, W_norm.t()) # (batch, batch)
# Compute pairwise similarities in z space (latent)
z_sim = torch.mm(z_norm, z_norm.t()) / contrastive_temperature # (batch, batch)
# Convert W similarities to soft targets (probability distribution)
W_targets = torch.nn.functional.softmax(W_sim / contrastive_temperature, dim=1)
# Remove diagonal (self-similarity)
mask = ~torch.eye(batch_size, dtype=torch.bool, device=W.device)
# Compute cross-entropy: z similarities should match W similarities
z_log_softmax = torch.nn.functional.log_softmax(z_sim, dim=1)
# KL divergence: sum over pairs, mean over batch
contrastive_loss = -torch.sum(W_targets * z_log_softmax * mask.float()) / batch_size
losses['contrastive'] = contrastive_loss
# 5. Dimension regularization (prevents collapse to 1D line)
dim_reg_loss = None
if dimension_reg_weight > 0:
# Encourage both dimensions to be used
# Penalize if variance in one dimension is much smaller than the other
z_var = z.var(dim=0) # Variance per dimension
if z_var.min() > 0:
dim_ratio = z_var.max() / (z_var.min() + 1e-8)
# Penalize if ratio is too large (one dimension dominates)
dim_reg_loss = torch.clamp(dim_ratio - 5.0, min=0.0) # Only penalize if ratio > 5
else:
dim_reg_loss = torch.tensor(1.0, device=z.device) # Heavy penalty if one dim has zero variance
losses['dim_reg'] = dim_reg_loss
# Total loss
total_loss = nmf_weight * nmf_loss + ae_weight * ae_loss
if kl_loss is not None:
total_loss += kl_weight * kl_loss
if contrastive_loss is not None:
total_loss += contrastive_weight * contrastive_loss
if dim_reg_loss is not None:
total_loss += dimension_reg_weight * dim_reg_loss
return total_loss, losses
[docs]
def train_joint_model(
X_sparse: spmatrix,
model: Optional[SparseNMF_Autoencoder] = None,
n_samples: Optional[int] = None,
n_features: Optional[int] = None,
nmf_components: int = 256,
latent_dim: int = 2,
hidden_dims: tuple = (256, 128, 64, 16), # Match two-step approach default
activation: str = "relu",
dropout: float = 0.0,
use_vae: bool = True,
use_feature_attention: bool = False, # Disabled by default - problematic for joint training from scratch
feature_attention_weight: float = 1.0,
feature_attention_temperature: float = 1.0, # Higher temp for joint model (more gradual)
normalize_nmf_components: bool = False, # Whether to L2-normalize W before autoencoder
device: str = 'cuda',
n_epochs: int = 200,
learning_rate: float = 0.0005,
nmf_weight: float = 1.0,
ae_weight: float = 1.0,
kl_weight: float = 0.01,
use_contrastive: bool = False, # Matches working two-step config
contrastive_weight: float = 0.25,
contrastive_temperature: float = 0.5,
use_cosine_loss: bool = True,
dimension_reg_weight: float = 0.0, # Disabled - can cause instability
weight_decay: float = 1e-4,
batch_size: Optional[int] = None,
verbose: bool = True,
random_state: Optional[int] = None,
save_path: Optional[str] = None,
force: bool = False,
) -> tuple:
"""
Train the joint SparseNMF + Autoencoder model.
Parameters
----------
X_sparse : scipy.sparse matrix
Sparse input matrix (n_samples, n_features).
model : SparseNMF_Autoencoder, optional
Pre-initialized model. If None, creates a new model.
n_samples : int, optional
Number of samples. If None, automatically inferred from X_sparse.shape[0].
Only specify if you want to validate the shape matches.
n_features : int, optional
Number of features. If None, automatically inferred from X_sparse.shape[1].
Only specify if you want to validate the shape matches.
nmf_components : int, default 256
Number of NMF components.
latent_dim : int, default 2
Final latent dimension.
hidden_dims : tuple, default (256, 128, 64, 16)
Autoencoder hidden dimensions. Matches two-step approach default.
activation : str, default "relu"
Activation function.
dropout : float, default 0.0
Dropout rate.
use_vae : bool, default True
Whether to use VAE.
use_feature_attention : bool, default False
If True, learn attention weights for each NMF component. Disabled by default for
joint training because randomly initialized attention can destabilize training.
Use two-step approach if you need feature attention.
feature_attention_weight : float, default 1.0
Weight for mixing original input with attended input (0 = no attention, 1 = full).
feature_attention_temperature : float, default 1.0
Temperature for attention weights. Higher values for joint training stability.
normalize_nmf_components : bool, default False
Whether to L2-normalize NMF components (W) before passing to autoencoder.
When True, matches `normalize_input=True` in two-step approach. This is crucial
for proper clustering and prevents radial patterns. When False, preserves the
original magnitude information in W.
device : str, default 'cuda'
Device to use.
n_epochs : int, default 200
Number of training epochs.
learning_rate : float, default 0.0005
Learning rate.
nmf_weight : float, default 1.0
Weight for NMF loss.
ae_weight : float, default 1.0
Weight for AE loss.
kl_weight : float, default 0.01
Weight for KL loss (VAE only).
use_contrastive : bool, default False
Whether to use contrastive loss. Default False matches working two-step config.
contrastive_weight : float, default 0.25
Weight for contrastive loss term.
contrastive_temperature : float, default 0.5
Temperature for contrastive loss (lower = sharper distinctions).
use_cosine_loss : bool, default True
Whether to use cosine loss for autoencoder reconstruction.
Works better with normalized inputs (matches two-step approach).
dimension_reg_weight : float, default 0.1
Weight for dimension regularization (prevents collapse to 1D).
weight_decay : float, default 1e-4
L2 weight regularization for optimizer.
batch_size : int, optional
Batch size for autoencoder training. Default 256.
verbose : bool, default True
Whether to print progress.
random_state : int, optional
Random seed.
save_path : str, optional
Path to save model and embeddings.
force : bool, default False
If True, retrain even if save_path exists.
Returns
-------
z : np.ndarray
Final latent embeddings (n_samples, latent_dim).
model : SparseNMF_Autoencoder
Trained model.
"""
import os
import pickle
from pathlib import Path
import torch.nn.functional as F
# Set random seed
if random_state is not None:
torch.manual_seed(random_state)
np.random.seed(random_state)
# Infer dimensions from X_sparse (always, even if provided - for validation)
inferred_n_samples = X_sparse.shape[0]
inferred_n_features = X_sparse.shape[1]
if n_samples is None:
n_samples = inferred_n_samples
elif n_samples != inferred_n_samples:
raise ValueError(
f"n_samples={n_samples} doesn't match X_sparse.shape[0]={inferred_n_samples}"
)
if n_features is None:
n_features = inferred_n_features
elif n_features != inferred_n_features:
raise ValueError(
f"n_features={n_features} doesn't match X_sparse.shape[1]={inferred_n_features}"
)
# Check for saved model
if save_path is not None and not force:
save_path_obj = Path(save_path)
if save_path_obj.exists():
if verbose:
print(f"Loading saved model from {save_path}...")
try:
checkpoint = torch.load(save_path, map_location=device)
model = SparseNMF_Autoencoder(
n_samples=n_samples,
n_features=n_features,
nmf_components=nmf_components,
latent_dim=latent_dim,
hidden_dims=hidden_dims,
activation=activation,
dropout=dropout,
use_vae=use_vae,
use_feature_attention=use_feature_attention,
feature_attention_weight=feature_attention_weight,
feature_attention_temperature=feature_attention_temperature,
normalize_nmf_components=normalize_nmf_components,
device=device,
random_state=random_state,
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
# Load embeddings if available
embeddings_path = save_path_obj.with_suffix('.npy')
if embeddings_path.exists():
z = np.load(embeddings_path)
if verbose:
print(f"Loaded embeddings from {embeddings_path}")
return z, model
except Exception as e:
if verbose:
print(f"Error loading saved model: {e}. Training new model...")
# Create model if not provided
if model is None:
model = SparseNMF_Autoencoder(
n_samples=n_samples,
n_features=n_features,
nmf_components=nmf_components,
latent_dim=latent_dim,
hidden_dims=hidden_dims,
activation=activation,
dropout=dropout,
use_vae=use_vae,
use_feature_attention=use_feature_attention,
feature_attention_weight=feature_attention_weight,
feature_attention_temperature=feature_attention_temperature,
normalize_nmf_components=normalize_nmf_components,
device=device,
random_state=random_state,
)
model.to(device)
model.train()
# Convert sparse matrix to PyTorch sparse tensor
if verbose:
print(f"Converting sparse matrix to PyTorch sparse tensor...")
coo = X_sparse.tocoo()
indices = torch.from_numpy(np.vstack([coo.row, coo.col])).long()
values = torch.from_numpy(coo.data).float()
X_sparse_torch = torch.sparse_coo_tensor(
indices, values, X_sparse.shape, device=device
)
# Optimizer with weight decay for regularization
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# Determine batch size for autoencoder training
# Use smaller batches to match two-step approach and prevent OOM
if batch_size is None:
ae_batch_size = 256 # Match autoencoder default
else:
ae_batch_size = batch_size
# Training loop with batching (like two-step approach)
if verbose:
print(f"Training joint model for {n_epochs} epochs...")
print(f" NMF components: {nmf_components}")
print(f" Latent dim: {latent_dim}")
print(f" Autoencoder batch size: {ae_batch_size}")
print(f" Device: {device}")
# Create data loader for batching (shuffle for better training)
from torch.utils.data import TensorDataset, DataLoader
# Create dummy dataset (we'll use indices to slice W)
dataset = TensorDataset(torch.arange(n_samples, device=device))
dataloader = DataLoader(dataset, batch_size=ae_batch_size, shuffle=True)
iterator = tqdm(range(n_epochs), disable=not verbose, desc="Training")
for epoch in iterator:
epoch_loss = 0.0
epoch_losses = {}
n_batches = 0
# Process in batches (like two-step autoencoder)
for batch_idx, (batch_indices_tensor,) in enumerate(dataloader):
optimizer.zero_grad()
batch_indices = batch_indices_tensor.long()
batch_size_actual = len(batch_indices)
# Get batch of W (this is the key - we only process a batch at a time)
# Don't clamp here - let gradients flow, clamp only after optimizer.step()
W_batch = model.W[batch_indices] # (batch_size, nmf_components)
H = model.H
# Get corresponding sparse rows
# Extract batch rows from sparse matrix
X_batch_sparse = X_sparse[batch_indices.cpu().numpy()]
if X_batch_sparse.nnz == 0:
continue # Skip empty batches
# Convert batch sparse to torch
coo_batch = X_batch_sparse.tocoo()
indices_batch = torch.from_numpy(np.vstack([coo_batch.row, coo_batch.col])).long().to(device)
values_batch = torch.from_numpy(coo_batch.data).float().to(device)
X_batch_sparse_torch = torch.sparse_coo_tensor(
indices_batch, values_batch, X_batch_sparse.shape, device=device
)
# Normalize W_batch before autoencoder (matches normalize_input=True)
# Optionally normalize W_batch before autoencoder (matches normalize_input=True in two-step)
if model.normalize_nmf_components:
W_for_encoder = torch.nn.functional.normalize(W_batch, p=2, dim=1)
else:
W_for_encoder = W_batch
# Apply feature attention if enabled (matches two-step autoencoder)
if model.use_feature_attention:
attention_logits = model.feature_attention_net(W_for_encoder)
attention_weights = torch.sigmoid(attention_logits / model.feature_attention_temperature)
attended_W = W_for_encoder * attention_weights
W_for_encoder = (1.0 - model.feature_attention_weight) * W_for_encoder + model.feature_attention_weight * attended_W
# Forward pass through autoencoder only (NMF reconstruction computed separately)
encoded_base = model.encoder_base(W_for_encoder)
if model.use_vae:
mu = model.encoder_mu(encoded_base)
logvar = model.encoder_logvar(encoded_base)
z_batch = model.reparameterize(mu, logvar)
W_recon_batch = model.decoder(z_batch)
mu_batch, logvar_batch = mu, logvar
else:
z_batch = model.encoder_latent(encoded_base)
W_recon_batch = model.decoder(z_batch)
mu_batch, logvar_batch = None, None
# Compute NMF reconstruction for this batch (sparse-aware)
# W_batch @ H for non-zero positions only
coo = X_batch_sparse_torch.coalesce()
X_values = coo.values()
row_idx_batch = coo.indices()[0] # Local row indices (0 to batch_size-1)
col_idx = coo.indices()[1]
# Map local row indices back to global for W
W_rows = W_batch[row_idx_batch] # (nnz_batch, nmf_components)
X_recon_values = _compute_recon_values_chunked(W_rows, H, col_idx) # (nnz_batch,)
# Compute losses for this batch
# 1. NMF reconstruction loss
nmf_loss_batch = F.mse_loss(X_recon_values, X_values)
# 2. Autoencoder reconstruction loss
if use_cosine_loss:
W_batch_norm = torch.nn.functional.normalize(W_batch, p=2, dim=1)
W_recon_batch_norm = torch.nn.functional.normalize(W_recon_batch, p=2, dim=1)
cosine_sim = (W_batch_norm * W_recon_batch_norm).sum(dim=1)
ae_loss_batch = (1.0 - cosine_sim).mean()
else:
ae_loss_batch = F.mse_loss(W_recon_batch, W_batch)
# 3. KL divergence (VAE only)
kl_loss_batch = None
if model.use_vae and mu_batch is not None and logvar_batch is not None:
kl_loss_batch = -0.5 * torch.sum(1 + logvar_batch - mu_batch.pow(2) - logvar_batch.exp(), dim=1).mean()
# 4. Contrastive loss (on batch - matches autoencoder implementation)
contrastive_loss_batch = None
if use_contrastive and batch_size_actual >= 2:
# Normalize for cosine similarity (matches autoencoder)
W_batch_norm = torch.nn.functional.normalize(W_batch, p=2, dim=1)
z_batch_norm = torch.nn.functional.normalize(z_batch, p=2, dim=1)
# Compute pairwise similarities in W space (original)
W_sim = torch.mm(W_batch_norm, W_batch_norm.t()) # (batch, batch)
# Compute pairwise similarities in z space (latent)
z_sim = torch.mm(z_batch_norm, z_batch_norm.t()) / contrastive_temperature # (batch, batch)
# Convert W similarities to soft targets (probability distribution)
W_targets = torch.nn.functional.softmax(W_sim / contrastive_temperature, dim=1)
# Remove diagonal (self-similarity)
mask = ~torch.eye(batch_size_actual, dtype=torch.bool, device=device)
# Compute cross-entropy: z similarities should match W similarities
z_log_softmax = torch.nn.functional.log_softmax(z_sim, dim=1)
# KL divergence: sum over pairs, mean over batch
# This matches the autoencoder implementation exactly
# Note: The division by batch_size_actual is already in the sum, so this gives mean
contrastive_loss_batch = -torch.sum(W_targets * z_log_softmax * mask.float()) / (batch_size_actual * (batch_size_actual - 1))
# 5. Dimension regularization (on full z, but compute on batch for efficiency)
dim_reg_loss_batch = None
if dimension_reg_weight > 0:
z_var = z_batch.var(dim=0)
if z_var.min() > 0:
dim_ratio = z_var.max() / (z_var.min() + 1e-8)
dim_reg_loss_batch = torch.clamp(dim_ratio - 5.0, min=0.0)
else:
dim_reg_loss_batch = torch.tensor(1.0, device=device)
# Total batch loss
batch_loss = nmf_weight * nmf_loss_batch + ae_weight * ae_loss_batch
if kl_loss_batch is not None:
batch_loss += kl_weight * kl_loss_batch
if contrastive_loss_batch is not None:
batch_loss += contrastive_weight * contrastive_loss_batch
if dim_reg_loss_batch is not None:
batch_loss += dimension_reg_weight * dim_reg_loss_batch
# Backward pass
batch_loss.backward()
# Gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Accumulate losses
epoch_loss += batch_loss.item()
if 'nmf' not in epoch_losses:
epoch_losses = {'nmf': 0.0, 'ae': 0.0}
epoch_losses['nmf'] += nmf_loss_batch.item()
epoch_losses['ae'] += ae_loss_batch.item()
if kl_loss_batch is not None:
if 'kl' not in epoch_losses:
epoch_losses['kl'] = 0.0
epoch_losses['kl'] += kl_loss_batch.item()
if contrastive_loss_batch is not None:
if 'contrastive' not in epoch_losses:
epoch_losses['contrastive'] = 0.0
epoch_losses['contrastive'] += contrastive_loss_batch.item()
if dim_reg_loss_batch is not None:
if 'dim_reg' not in epoch_losses:
epoch_losses['dim_reg'] = 0.0
epoch_losses['dim_reg'] += dim_reg_loss_batch.item()
# Update parameters after EACH batch (critical fix - was outside loop!)
optimizer.step()
# Clamp parameters to maintain non-negativity after each batch update
with torch.no_grad():
model.W.clamp_(min=1e-8)
model.H.clamp_(min=1e-8)
n_batches += 1
# Update progress bar every epoch
if verbose:
avg_loss = epoch_loss / n_batches if n_batches > 0 else 0.0
loss_str = f"Loss={avg_loss:.6f}"
for key, val in epoch_losses.items():
avg_val = val / n_batches if n_batches > 0 else 0.0
loss_str += f", {key}={avg_val:.6f}"
iterator.set_postfix_str(loss_str)
# Check gradient norms for debugging (every 10 epochs)
if (epoch + 1) % 10 == 0:
total_norm = 0.0
param_count = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
param_count += 1
total_norm = total_norm ** (1. / 2)
if verbose >= 2:
print(f"\n Gradient norm: {total_norm:.6f} (from {param_count} parameters)")
# Get final embeddings
model.eval()
with torch.no_grad():
z = model.encode()
z_np = z.cpu().numpy()
# Save model and embeddings
if save_path is not None:
save_path_obj = Path(save_path)
save_path_obj.parent.mkdir(parents=True, exist_ok=True)
if verbose:
print(f"Saving model to {save_path}...")
torch.save({
'model_state_dict': model.state_dict(),
'nmf_components': nmf_components,
'latent_dim': latent_dim,
'hidden_dims': hidden_dims,
'activation': activation,
'dropout': dropout,
'use_vae': use_vae,
}, save_path)
# Save embeddings
embeddings_path = save_path_obj.with_suffix('.npy')
np.save(embeddings_path, z_np)
if verbose:
print(f"Saved embeddings to {embeddings_path}")
return z_np, model
[docs]
def train_sparse_nmf(
X_sparse: Optional[spmatrix] = None,
n_components: int = 256,
max_iter: int = 500,
device: str = 'cuda',
batch_size: Optional[int] = None,
verbose: bool = True,
random_state: Optional[int] = None,
mse_weight: float = 1.0,
r2_weight: float = 0.0,
learning_rate: float = 0.01,
nonzero_mse_weight: float = 0.0,
nonzero_r2_weight: float = 0.0,
normalize_inputs: bool = False, # Whether to L2-normalize input X before training
normalize_outputs: bool = False, # Whether to L2-normalize output W matrix
patience: Optional[int] = None,
embeddings_save_path: Optional[str] = None,
model_save_path: Optional[str] = None,
force: bool = False,
) -> Tuple[np.ndarray, SparseNMF]:
"""
Train sparse NMF model with automatic saving of embeddings and model.
This is a convenience wrapper that handles model creation, training, and saving.
If save paths are provided and files exist, loads from disk unless force=True.
When loading from disk, X_sparse is not required.
Parameters
----------
X_sparse : scipy.sparse matrix, optional
Sparse input matrix of shape (n_samples, n_features).
Required when training a new model (force=True or files don't exist).
Optional when loading from disk (files exist and force=False).
n_components : int, default 256
Number of components to extract.
max_iter : int, default 500
Maximum number of iterations.
device : str, default 'cuda'
Device to use ('cuda', 'cpu', etc.).
batch_size : int, optional
Batch size for processing. If None, auto-determines.
verbose : bool, default True
Whether to print progress.
random_state : int, optional
Random seed for reproducibility.
mse_weight : float, default 1.0
Weight for MSE loss component.
r2_weight : float, default 0.0
Weight for R² loss component. When > 0, uses gradient descent.
learning_rate : float, default 0.01
Learning rate for gradient descent (used when r2_weight > 0 or nonzero_mse_weight > 0).
nonzero_mse_weight : float, default 0.0
Controls whether MSE loss includes zeros or only non-zero values.
When > 0: MSE computed only on non-zero positions (ignores zeros).
When 0: MSE computed on all positions including zeros (learns sparsity patterns).
nonzero_r2_weight : float, default 0.0
Controls whether R² loss includes zeros or only non-zero values.
When > 0: R² computed only on non-zero positions (ignores zeros).
When 0: R² computed on all positions including zeros (learns sparsity patterns).
Only affects training when r2_weight > 0.
normalize_inputs : bool, default False
Whether to L2-normalize the input X matrix before training.
When True, each row of X will be normalized to unit length before NMF.
This helps balance dense vs sparse rows and prevents dense datasets from
dominating the NMF factors. Useful when datasets have varying sparsity levels.
normalize_outputs : bool, default False
Whether to L2-normalize the output W matrix (embeddings) before returning.
When True, each row of the output will have unit length. This is useful if you
plan to pass the embeddings to an autoencoder with `normalize_input=True` or
`use_cosine_loss=True`. When False, preserves the original magnitude information.
patience : int, optional
Number of iterations to wait without improvement before early stopping.
If None, only uses tolerance-based convergence. If specified, stops training
if the error doesn't improve for `patience` consecutive iterations.
embeddings_save_path : str, optional
Path to save the transformed embeddings (W matrix) as a .npy file.
If provided and file exists and force=False, loads embeddings from disk.
model_save_path : str, optional
Path to save the model (W and H matrices) as a .pkl file.
If provided and file exists and force=False, loads model from disk.
force : bool, default False
If True, force retraining even if save files exist.
If False and both save files exist, loads from disk instead of retraining.
Returns
-------
X_reduced : np.ndarray
Transformed matrix of shape (n_samples, n_components).
model : SparseNMF
Trained SparseNMF model.
Examples
--------
>>> from AoU.phenome.sparseNMF import train_sparse_nmf
>>>
>>> # Train and save model
>>> X_reduced, model = train_sparse_nmf(
... X_sparse,
... n_components=256,
... device='cuda:0',
... embeddings_save_path='embeddings.npy',
... model_save_path='model.pkl'
... )
>>>
>>> # Load from disk (if files exist and force=False)
>>> # X_sparse is not required when loading from disk
>>> X_reduced, model = train_sparse_nmf(
... embeddings_save_path='embeddings.npy',
... model_save_path='model.pkl'
... )
"""
from pathlib import Path
import pickle
import os
# Check if we can load from disk
can_load = False
if not force and embeddings_save_path is not None and model_save_path is not None:
if os.path.exists(embeddings_save_path) and os.path.exists(model_save_path):
can_load = True
if verbose:
print(f"Loading embeddings and model from disk...")
print(f" Embeddings: {embeddings_save_path}")
print(f" Model: {model_save_path}")
if can_load:
# Load embeddings
X_reduced = np.load(embeddings_save_path)
# Check if embeddings were saved with normalization by checking if they're already normalized
# (embeddings saved with normalize_outputs=True are already normalized)
# Only normalize if they're not already normalized (check if mean norm is close to 1.0)
if normalize_outputs:
# Check if embeddings are already normalized (mean L2 norm per row should be ~1.0)
row_norms = np.linalg.norm(X_reduced, axis=1)
mean_norm = np.mean(row_norms)
# If mean norm is close to 1.0 (within 0.01), assume already normalized
if abs(mean_norm - 1.0) > 0.01:
from AoU.utils import l2_normalize
X_reduced = l2_normalize(X_reduced)
if verbose:
print(f" Normalized embeddings to unit length")
elif verbose:
print(f" Embeddings already normalized (mean norm: {mean_norm:.4f})")
# Load model
with open(model_save_path, 'rb') as f:
model_dict = pickle.load(f)
# Reconstruct model object
model = SparseNMF(
n_components=model_dict['n_components'],
max_iter=max_iter, # Use provided max_iter (might differ from saved)
device=device,
batch_size=batch_size,
verbose=verbose,
random_state=random_state,
mse_weight=mse_weight,
r2_weight=r2_weight,
learning_rate=learning_rate,
nonzero_mse_weight=model_dict.get('nonzero_mse_weight', nonzero_mse_weight),
nonzero_r2_weight=model_dict.get('nonzero_r2_weight', nonzero_r2_weight),
)
# Restore model state
model.W = torch.from_numpy(model_dict['W']).to(device)
model.H = torch.from_numpy(model_dict['H']).to(device)
model.reconstruction_error_ = model_dict.get('reconstruction_error', 0.0)
model.r2_score_ = model_dict.get('r2_score', 0.0)
model.r2_score_nonzero_ = model_dict.get('r2_score_nonzero', 0.0)
model.n_iter_ = model_dict.get('n_iter', 0)
if verbose:
print(f" Loaded embeddings shape: {X_reduced.shape}")
print(f" Model components: {model.n_components}")
print(f" Reconstruction error: {model.reconstruction_error_:.6f}")
print(f" R² score (all values): {model.r2_score_:.6f}")
print(f" R² score (non-zero only): {model.r2_score_nonzero_:.6f}")
return X_reduced, model
# Train new model
if X_sparse is None:
raise ValueError(
"X_sparse is required when training a new model. "
"Either provide X_sparse or ensure embeddings_save_path and model_save_path "
"point to existing files to load from disk."
)
if verbose:
print("Training new NMF model...")
# Normalize inputs if requested (before training)
if normalize_inputs:
from sklearn.preprocessing import normalize
from scipy import sparse as sp
if sp.issparse(X_sparse):
X_sparse = X_sparse.tocsr()
X_sparse = sp.csr_matrix(normalize(X_sparse, norm='l2', axis=1))
else:
X_sparse = normalize(X_sparse, norm='l2', axis=1)
if verbose:
print("L2 normalized input matrix (each row has unit length)")
model = SparseNMF(
n_components=n_components,
max_iter=max_iter,
device=device,
batch_size=batch_size,
verbose=verbose,
random_state=random_state,
mse_weight=mse_weight,
r2_weight=r2_weight,
learning_rate=learning_rate,
nonzero_mse_weight=nonzero_mse_weight,
nonzero_r2_weight=nonzero_r2_weight,
patience=patience,
)
X_reduced = model.fit_transform(X_sparse)
# Normalize outputs if requested
if normalize_outputs:
from AoU.utils import l2_normalize
X_reduced = l2_normalize(X_reduced)
if verbose:
print(f"Normalized output embeddings to unit length")
# Save embeddings if path provided
if embeddings_save_path is not None:
Path(embeddings_save_path).parent.mkdir(parents=True, exist_ok=True)
np.save(embeddings_save_path, X_reduced)
if verbose:
print(f"Saved embeddings to {embeddings_save_path}")
# Save model if path provided
if model_save_path is not None:
Path(model_save_path).parent.mkdir(parents=True, exist_ok=True)
model_dict = {
'W': model.W.cpu().numpy(),
'H': model.H.cpu().numpy(),
'n_components': model.n_components,
'reconstruction_error': model.reconstruction_error_,
'r2_score': model.r2_score_,
'r2_score_nonzero': model.r2_score_nonzero_,
'n_iter': model.n_iter_,
'nonzero_mse_weight': model.nonzero_mse_weight,
'nonzero_r2_weight': model.nonzero_r2_weight,
}
with open(model_save_path, 'wb') as f:
pickle.dump(model_dict, f)
if verbose:
print(f"Saved model to {model_save_path}")
return X_reduced, model
[docs]
def sparse_nmf(
X_sparse: spmatrix,
n_components: int = 256,
max_iter: int = 500,
device: str = 'cuda',
batch_size: Optional[int] = None,
verbose: bool = True,
random_state: Optional[int] = None,
mse_weight: float = 1.0,
r2_weight: float = 0.0,
learning_rate: float = 0.01,
nonzero_mse_weight: float = 0.0,
nonzero_r2_weight: float = 0.0,
) -> np.ndarray:
"""
Convenience function for sparse NMF (without saving).
Parameters
----------
X_sparse : scipy.sparse matrix
Sparse input matrix of shape (n_samples, n_features).
n_components : int, default 256
Number of components to extract.
max_iter : int, default 500
Maximum number of iterations.
device : str, default 'cuda'
Device to use ('cuda', 'cpu', etc.).
batch_size : int, optional
Batch size for processing. If None, auto-determines.
verbose : bool, default True
Whether to print progress.
random_state : int, optional
Random seed for reproducibility.
mse_weight : float, default 1.0
Weight for MSE loss component.
r2_weight : float, default 0.0
Weight for R² loss component. When > 0, uses gradient descent.
learning_rate : float, default 0.01
Learning rate for gradient descent (used when r2_weight > 0 or nonzero_mse_weight > 0).
nonzero_mse_weight : float, default 0.0
Controls whether MSE loss includes zeros or only non-zero values.
When > 0: MSE computed only on non-zero positions (ignores zeros).
When 0: MSE computed on all positions including zeros (learns sparsity patterns).
nonzero_r2_weight : float, default 0.0
Controls whether R² loss includes zeros or only non-zero values.
When > 0: R² computed only on non-zero positions (ignores zeros).
When 0: R² computed on all positions including zeros (learns sparsity patterns).
Only affects training when r2_weight > 0.
Returns
-------
np.ndarray
Transformed matrix of shape (n_samples, n_components).
Examples
--------
>>> from AoU.phenome.sparseNMF import sparse_nmf
>>> # MSE only (default, fast multiplicative updates)
>>> X_reduced = sparse_nmf(X, n_components=256, device='cuda:0')
>>>
>>> # Weighted MSE + R² loss
>>> X_reduced = sparse_nmf(X, n_components=256, mse_weight=0.5, r2_weight=0.5)
"""
nmf = SparseNMF(
n_components=n_components,
max_iter=max_iter,
device=device,
batch_size=batch_size,
verbose=verbose,
random_state=random_state,
mse_weight=mse_weight,
r2_weight=r2_weight,
learning_rate=learning_rate,
nonzero_mse_weight=nonzero_mse_weight,
nonzero_r2_weight=nonzero_r2_weight,
)
return nmf.fit_transform(X_sparse)
[docs]
def trace_attention_to_genes(attention_weights_nmf, nmf_H, normalize=True):
"""
Trace attention weights from NMF components back to original gene features.
This uses the principled approach of matrix multiplication through the
linear transformation (NMF H matrix). Since NMF decomposition is:
X ≈ W @ H
where W is (n_samples, n_components) and H is (n_components, n_genes),
the attention on NMF components can be propagated to genes via:
gene_attention = attention_weights_nmf @ H
This is mathematically sound because:
1. Attention weights represent importance of each NMF component
2. H matrix maps each NMF component to original gene features
3. Matrix multiplication properly aggregates attention across components
This approach is equivalent to computing gradients through linear layers
and is the standard method for propagating importance/attention through
linear transformations in neural networks.
Parameters
----------
attention_weights_nmf : np.ndarray
Attention weights on NMF components, shape (n_samples, n_nmf_components).
Typically obtained from extract_attention_weights().
nmf_H : np.ndarray or torch.Tensor
NMF H matrix mapping components to genes, shape (n_nmf_components, n_genes).
This is the coefficient matrix from NMF decomposition (X ≈ W @ H).
Can be obtained from SparseNMF.H or SparseNMF_Autoencoder.H.
normalize : bool, default True
If True, normalize attention weights per sample so they sum to 1.
This makes the weights interpretable as a probability distribution.
Returns
-------
gene_attention_weights : np.ndarray
Attention weights for each sample and gene, shape (n_samples, n_genes).
Higher values indicate genes that are more important for that sample.
Examples
--------
>>> from AoU.phenome.sparseNMF import extract_attention_weights, trace_attention_to_genes
>>>
>>> # Extract attention on NMF components
>>> attention_nmf = extract_attention_weights(model, X_nmf)
>>>
>>> # Trace back to genes
>>> gene_attention = trace_attention_to_genes(attention_nmf, nmf.H)
>>>
>>> # Get top genes for a sample
>>> sample_idx = 0
>>> top_genes = np.argsort(gene_attention[sample_idx])[-10:][::-1]
"""
# Convert to numpy if needed (handle both CPU and GPU tensors)
if isinstance(nmf_H, torch.Tensor):
nmf_H = nmf_H.detach().cpu().numpy()
# Matrix multiplication: (n_samples, n_components) @ (n_components, n_genes)
# = (n_samples, n_genes)
gene_attention = attention_weights_nmf @ nmf_H
# Normalize per sample if requested (so each sample's attention sums to 1)
if normalize:
# Use efficient in-place normalization
row_sums = gene_attention.sum(axis=1, keepdims=True)
# Clip to avoid division by zero (set minimum to 1.0)
np.clip(row_sums, a_min=1.0, a_max=None, out=row_sums)
# In-place division (much faster than creating new array)
np.divide(gene_attention, row_sums, out=gene_attention)
return gene_attention
[docs]
def extract_and_aggregate_attention(
model,
X_nmf,
nmf_H,
batch_size=256,
device=None,
normalize=True,
gene_feature_names=None,
nmf_feature_names=None,
sample_names=None,
metadata=None,
verbose=True,
nonzero_threshold=None,
save_dir=None,
force=False,
return_attention_matrices=False,
):
"""
Extract attention weights for all samples, trace to genes, and aggregate statistics.
This is a high-level wrapper that:
1. Extracts attention weights for all samples on NMF components
2. Traces attention weights back to original gene features
3. Aggregates statistics across samples for both gene and NMF features
Supports both feature attention and transformer attention modes.
Optimized for very large datasets (1M+ samples): uses GPU operations, larger batches,
memory-efficient processing, and vectorized aggregations.
Parameters
----------
model : Autoencoder
Trained autoencoder model with either feature attention or transformer attention enabled.
X_nmf : np.ndarray or torch.Tensor
NMF-transformed embeddings, shape (n_samples, n_nmf_components).
nmf_H : np.ndarray or torch.Tensor
NMF H matrix mapping components to genes, shape (n_nmf_components, n_genes).
batch_size : int, default 256
Batch size for extracting attention weights. Auto-increased for large datasets.
device : str, optional
Device to use. If None, uses model's device.
normalize : bool, default True
If True, normalize attention weights per sample before aggregation.
gene_feature_names : array-like, optional
Names for gene features. If None, uses integer indices.
If metadata is provided, this will be inferred from metadata['var'].index.
nmf_feature_names : array-like, optional
Names for NMF features. If None, uses integer indices.
sample_names : array-like, optional
Names/IDs for samples. If None, uses integer indices.
Should match the order of samples in X_nmf.
Can be obtained from metadata['obs'].index or metadata['obs']['sourceId'].
If metadata is provided, this will be inferred from metadata['obs'].
metadata : dict or AnnData-like object, optional
Metadata object with 'var' and 'obs' keys (e.g., AnnData object).
If provided, gene_feature_names and sample_names will be inferred from:
- metadata['var'].index for gene names
- metadata['obs']['obs_id'] if available, otherwise metadata['obs'].index for sample names
This parameter takes precedence over gene_feature_names and sample_names if provided.
verbose : bool, default True
If True, print progress and summary information.
nonzero_threshold : float, optional
Threshold for counting "nonzero" attention. If None, uses percentile-based threshold:
the 1st percentile of all attention values (identifies bottom 1% as noise).
If normalize=True, all values are > 0 after normalization, so a threshold is needed
to identify meaningful attention vs. uniform noise distribution.
save_dir : str, optional
Directory path to save the aggregated dataframes as parquet files.
If provided, saves:
- gene_aggregated_df as 'gene_attention_aggregated.parquet'
- nmf_aggregated_df as 'nmf_attention_aggregated.parquet'
If None, dataframes are not saved.
force : bool, default False
If True, overwrite existing parquet files. If False and files exist, raises an error.
return_attention_matrices : bool, default False
If True, also return the pre-aggregated attention matrices:
- gene_attention_matrix: shape (n_samples, n_genes) - attention scores for each sample-gene pair
- nmf_attention_matrix: shape (n_samples, n_nmf_components) - attention scores for each sample-NMF component pair
Note: When True, this requires storing the full matrices in memory, which may be memory-intensive
for very large datasets. If loading from existing files, matrices cannot be returned unless
they were previously saved.
Returns
-------
gene_aggregated_df : pd.DataFrame
Aggregated statistics per gene feature with columns:
- feature_index: Gene feature index
- feature_name: Gene feature name (if provided)
- mean_attention: Mean attention across all samples
- min_attention: Minimum attention across all samples
- max_attention: Maximum attention across all samples
- n_samples_nonzero: Number of samples with nonzero attention for this gene
- pct_samples_nonzero: Percentage of samples with nonzero attention
- max_attention_sample: Sample name/ID with highest attention for this gene (if sample_names provided)
nmf_aggregated_df : pd.DataFrame
Aggregated statistics per NMF feature with columns:
- feature_index: NMF feature index
- feature_name: NMF feature name (if provided)
- mean_attention: Mean attention across all samples
- min_attention: Minimum attention across all samples
- max_attention: Maximum attention across all samples
- n_samples_nonzero: Number of samples with nonzero attention for this NMF feature
- pct_samples_nonzero: Percentage of samples with nonzero attention
- max_attention_sample: Sample name/ID with highest attention for this NMF feature (if sample_names provided)
gene_attention_matrix : np.ndarray, optional
Pre-aggregated attention matrix, shape (n_samples, n_genes).
Only returned if return_attention_matrices=True.
Contains continuous attention scores for each sample-gene pair.
nmf_attention_matrix : np.ndarray, optional
Pre-aggregated attention matrix, shape (n_samples, n_nmf_components).
Only returned if return_attention_matrices=True.
Contains continuous attention scores for each sample-NMF component pair.
Examples
--------
>>> from AoU.phenome.sparseNMF import extract_and_aggregate_attention
>>>
>>> gene_df, nmf_df = extract_and_aggregate_attention(
... model, X_nmf, nmf.H,
... gene_feature_names=gene_names,
... sample_names=sample_names
... )
>>>
>>> # Get top genes by mean attention
>>> top_genes = gene_df.nlargest(10, 'mean_attention')
"""
import pandas as pd
import time
from pathlib import Path
# Extract gene_feature_names and sample_names from metadata if provided
if metadata is not None:
# Extract gene names from metadata['var'].index
if gene_feature_names is None:
try:
gene_feature_names = metadata['var'].index.values
except (KeyError, AttributeError, TypeError):
if verbose:
print(f" Warning: Could not extract gene_feature_names from metadata['var'].index, using integer indices")
gene_feature_names = None
# Extract sample names from metadata['obs']
if sample_names is None:
try:
obs = metadata['obs']
if 'obs_id' in obs.columns:
sample_names = obs['obs_id'].values
else:
sample_names = obs.index.values
except (KeyError, AttributeError, TypeError):
if verbose:
print(f" Warning: Could not extract sample_names from metadata['obs'], using integer indices")
sample_names = None
# Check if files exist BEFORE doing any computation
if save_dir is not None:
save_path = Path(save_dir)
gene_file = save_path / 'gene_attention_aggregated.parquet'
nmf_file = save_path / 'nmf_attention_aggregated.parquet'
# If both files exist and force=False, just load and return them
if not force and gene_file.exists() and nmf_file.exists():
if verbose:
print(f"Loading existing aggregated attention data from {save_dir}...")
gene_aggregated_df = pd.read_parquet(gene_file)
nmf_aggregated_df = pd.read_parquet(nmf_file)
if verbose:
print(f" ✓ Loaded: {gene_file}")
print(f" ✓ Loaded: {nmf_file}")
print(f" Gene features: {len(gene_aggregated_df):,} genes")
print(f" NMF features: {len(nmf_aggregated_df):,} components")
# If return_attention_matrices is True, try to load the matrices
if return_attention_matrices:
gene_matrix_file = save_path / 'gene_attention_matrix.npy'
nmf_matrix_file = save_path / 'nmf_attention_matrix.npy'
if gene_matrix_file.exists() and nmf_matrix_file.exists():
if verbose:
print(f" Loading attention matrices...")
gene_attention_matrix = np.load(gene_matrix_file)
nmf_attention_matrix = np.load(nmf_matrix_file)
if verbose:
print(f" ✓ Loaded: {gene_matrix_file} (shape: {gene_attention_matrix.shape})")
print(f" ✓ Loaded: {nmf_matrix_file} (shape: {nmf_attention_matrix.shape})")
return gene_aggregated_df, nmf_aggregated_df, gene_attention_matrix, nmf_attention_matrix
else:
if verbose:
print(f" Warning: Attention matrices not found. Set force=True to recompute with matrices.")
# Return None for matrices if they don't exist
return gene_aggregated_df, nmf_aggregated_df, None, None
else:
return gene_aggregated_df, nmf_aggregated_df
start_time = time.time()
# Initialize matrices to None (will be set if return_attention_matrices=True)
gene_attention_matrix = None
nmf_attention_matrix = None
if verbose:
n_samples = len(X_nmf)
print(f"Extracting attention weights for all samples and aggregating statistics...")
print(f" Processing {n_samples:,} samples...")
if gene_feature_names is not None:
print(f" Gene features: {len(gene_feature_names):,} genes")
# Step 1: Extract attention weights for all samples (optimized, batched)
# Auto-increase batch size for very large datasets
n_samples = len(X_nmf)
if n_samples > 500000: # 500k+ samples
effective_batch_size = max(8192, batch_size * 16)
elif n_samples > 100000: # 100k+ samples
effective_batch_size = max(4096, batch_size * 8)
elif n_samples > 10000:
effective_batch_size = max(2048, batch_size * 4)
else:
effective_batch_size = batch_size
if verbose:
print(f" Using batch size: {effective_batch_size:,} for attention extraction")
attention_weights_nmf = extract_attention_weights(
model, X_nmf, batch_size=effective_batch_size, device=device, verbose=verbose
)
if verbose:
print(f" ✓ Extracted attention weights: {attention_weights_nmf.shape}")
# Step 2: Trace to genes (GPU-accelerated if possible, chunked for memory efficiency)
# Determine if we should use GPU for large matrices
# torch is already imported at module level
use_gpu_for_aggregation = False
# Get the model's device (model and input must be on same device)
try:
model_device = next(model.parameters()).device
except:
model_device = None
# Try to detect device from model or user input
if device is None:
if hasattr(model, 'device'):
device = model.device
elif model_device is not None:
device = model_device
else: # pragma: no cover (only hit by parameterless models w/o .device — defensive fallback)
device = None
else:
# User specified a device, but ensure it matches model's device
device = torch.device(device) if isinstance(device, str) else device
if model_device is not None and device != model_device:
if verbose:
print(f" Warning: Model is on {model_device}, but device={device} specified. Using model's device {model_device}.")
device = model_device
# Check if GPU is available and use it
if device is not None:
device_str = str(device)
if 'cuda' in device_str:
try:
use_gpu_for_aggregation = torch.cuda.is_available()
if use_gpu_for_aggregation:
device = torch.device(device_str) # Ensure it's a torch device
except:
pass
else: # pragma: no cover (CPU-only CI: device is None + cuda probe both unreachable)
# Try to use default CUDA device if available
try:
if torch.cuda.is_available():
device = torch.device('cuda:0')
use_gpu_for_aggregation = True
except:
pass
# Final check: if we still don't have a device but CUDA is available, use it
if not use_gpu_for_aggregation: # pragma: no cover (CPU-only CI: cuda init fallback never fires)
try:
if torch.cuda.is_available():
if model_device is not None and 'cuda' in str(model_device):
device = model_device
use_gpu_for_aggregation = True
else:
device = torch.device('cuda:0')
use_gpu_for_aggregation = True
except:
pass
# Convert to torch tensors for GPU processing if beneficial.
# NB: ``use_gpu_for_aggregation`` is True only when CUDA is
# available — the entire block below is unreachable on the
# CPU-only CI runners that build the coverage badge. Marking with
# ``# pragma: no cover`` so the badge reflects the code we
# actually exercise; the CPU path that handles the same logic
# starts at the matching ``else`` on line ~3195.
if use_gpu_for_aggregation: # pragma: no cover
if verbose:
print(f" ✓ Using GPU for aggregation (device: {device})")
# Convert to torch tensors on GPU
if isinstance(attention_weights_nmf, np.ndarray):
attention_weights_nmf_torch = torch.from_numpy(attention_weights_nmf).float().to(device)
else:
attention_weights_nmf_torch = attention_weights_nmf.float().to(device) if isinstance(attention_weights_nmf, torch.Tensor) else attention_weights_nmf
if isinstance(nmf_H, torch.Tensor):
nmf_H_torch = nmf_H.float().to(device)
else:
nmf_H_torch = torch.from_numpy(nmf_H).float().to(device)
# Matrix multiplication on GPU (chunked to avoid OOM for very large matrices)
n_samples, n_nmf_features = attention_weights_nmf_torch.shape
n_genes = nmf_H_torch.shape[1]
if verbose:
print(f" ✓ Tracing to genes (GPU, chunked): {attention_weights_nmf_torch.shape} @ {nmf_H_torch.shape}")
print(f" Output will be ({n_samples}, {n_genes}) - chunking to avoid OOM")
# Estimate memory needed for full matrix: n_samples * n_genes * 4 bytes (float32)
# If > 50GB, we must chunk
estimated_memory_gb = (n_samples * n_genes * 4) / (1024**3)
must_chunk = estimated_memory_gb > 50 # Chunk if > 50GB
if must_chunk:
# Chunked matrix multiplication - never create full matrix
if verbose:
print(f" Chunking matrix multiplication (estimated {estimated_memory_gb:.1f} GB)")
# Calculate safe chunk size based on available GPU memory
# Each chunk creates (chunk_size, n_genes) matrix = chunk_size * n_genes * 4 bytes
# We want to keep chunks under ~5GB to leave room for other operations and intermediate tensors
max_chunk_memory_gb = 5.0 # Maximum memory per chunk in GB (conservative)
max_chunk_size = int((max_chunk_memory_gb * (1024**3)) / (n_genes * 4))
# But don't make chunks too small (inefficient) or too large
# For 90k genes: 5GB / (90k * 4 bytes) ≈ 13,800 samples
chunk_size = max(1000, min(max_chunk_size, 20000)) # Between 1k and 20k samples
if verbose:
print(f" Using chunk size: {chunk_size:,} samples (max {max_chunk_size:,} based on {max_chunk_memory_gb}GB limit)")
# Initialize aggregators (we'll accumulate as we process chunks)
gene_max_sample_idx = torch.zeros(n_genes, dtype=torch.long, device=device)
gene_mean = torch.zeros(n_genes, dtype=torch.float32, device=device)
gene_min = torch.full((n_genes,), float('inf'), dtype=torch.float32, device=device)
gene_max = torch.full((n_genes,), float('-inf'), dtype=torch.float32, device=device)
gene_nonzero = torch.zeros(n_genes, dtype=torch.int32, device=device)
# If return_attention_matrices=True, allocate full matrix (memory-intensive!)
if return_attention_matrices:
if verbose:
print(f" Warning: return_attention_matrices=True requires storing full matrix ({estimated_memory_gb:.1f} GB)")
gene_attention_weights_torch = torch.zeros((n_samples, n_genes), dtype=torch.float32, device=device)
else:
gene_attention_weights_torch = None
# Process in chunks - compute gene attention, normalize, and aggregate all in one pass
all_values_for_percentile = []
chunk_iter = tqdm(range(0, n_samples, chunk_size), desc=" Processing chunks", disable=not verbose)
for chunk_start in chunk_iter:
chunk_end = min(chunk_start + chunk_size, n_samples)
chunk_attention = attention_weights_nmf_torch[chunk_start:chunk_end]
# Compute gene attention for this chunk only
chunk_gene_attention = torch.mm(chunk_attention, nmf_H_torch) # (chunk_size, n_genes)
# Normalize if requested (in-place on chunk)
if normalize:
row_sums = chunk_gene_attention.sum(dim=1, keepdim=True)
row_sums = torch.clamp(row_sums, min=1.0)
chunk_gene_attention = chunk_gene_attention / row_sums
# Store chunk in full matrix if requested
if return_attention_matrices:
gene_attention_weights_torch[chunk_start:chunk_end] = chunk_gene_attention
# Update running statistics
chunk_max, chunk_max_idx = torch.max(chunk_gene_attention, dim=0)
chunk_max_idx = chunk_max_idx + chunk_start
update_mask = chunk_max > gene_max
gene_max[update_mask] = chunk_max[update_mask]
gene_max_sample_idx[update_mask] = chunk_max_idx[update_mask]
gene_mean += chunk_gene_attention.sum(dim=0)
gene_min = torch.minimum(gene_min, chunk_gene_attention.min(dim=0)[0])
# Sample values for percentile
if nonzero_threshold is None:
sample_size = max(1000, chunk_gene_attention.numel() // 100)
flat_chunk = chunk_gene_attention.flatten()
sample_indices = torch.randint(0, flat_chunk.numel(), (min(sample_size, flat_chunk.numel()),), device=device)
all_values_for_percentile.append(flat_chunk[sample_indices])
# Free chunk immediately (unless we're storing it)
if not return_attention_matrices:
del chunk_gene_attention
if device.type == 'cuda':
torch.cuda.empty_cache()
gene_mean = gene_mean / n_samples
# Compute threshold from sampled values (use CPU numpy for memory efficiency)
if nonzero_threshold is None:
if all_values_for_percentile:
# Convert to numpy and compute percentile (more memory-efficient)
# Don't concatenate all at once - process in batches if needed
if len(all_values_for_percentile) > 0:
# Convert first chunk to numpy
all_samples_np = all_values_for_percentile[0].cpu().numpy()
# Add remaining chunks
for chunk in all_values_for_percentile[1:]:
all_samples_np = np.concatenate([all_samples_np, chunk.cpu().numpy()])
threshold = np.percentile(all_samples_np, 1.0)
del all_values_for_percentile, all_samples_np
else:
threshold = gene_mean.min().item() * 0.1
else:
threshold = gene_mean.min().item() * 0.1
else:
threshold = nonzero_threshold
# Second pass: count nonzero (recompute chunks)
if verbose:
print(f" ✓ Computing nonzero counts (second pass)")
chunk_iter = tqdm(range(0, n_samples, chunk_size), desc=" Counting nonzero", disable=not verbose)
for chunk_start in chunk_iter:
chunk_end = min(chunk_start + chunk_size, n_samples)
chunk_attention = attention_weights_nmf_torch[chunk_start:chunk_end]
chunk_gene_attention = torch.mm(chunk_attention, nmf_H_torch)
# Normalize if needed
if normalize:
row_sums = chunk_gene_attention.sum(dim=1, keepdim=True)
row_sums = torch.clamp(row_sums, min=1.0)
chunk_gene_attention = chunk_gene_attention / row_sums
gene_nonzero += (chunk_gene_attention > threshold).sum(dim=0).int()
del chunk_gene_attention
if device.type == 'cuda':
torch.cuda.empty_cache()
gene_pct_nonzero = (gene_nonzero.float() / n_samples) * 100.0
# Convert to CPU numpy for DataFrame construction (free GPU memory immediately)
gene_max_sample_idx_np = gene_max_sample_idx.cpu().numpy()
gene_mean_np = gene_mean.cpu().numpy()
gene_min_np = gene_min.cpu().numpy()
gene_max_np = gene_max.cpu().numpy()
gene_nonzero_np = gene_nonzero.cpu().numpy()
gene_pct_nonzero_np = gene_pct_nonzero.cpu().numpy()
# Free GPU tensors immediately
del gene_max_sample_idx, gene_mean, gene_min, gene_max, gene_nonzero, gene_pct_nonzero
if device.type == 'cuda':
torch.cuda.empty_cache()
# Move attention_weights_nmf to CPU for NMF aggregation to free GPU memory
if verbose:
print(f" ✓ Moving attention weights to CPU for NMF aggregation (freeing GPU memory)")
attention_weights_nmf_cpu = attention_weights_nmf_torch.cpu().numpy()
# Preserve attention_weights_nmf if return_attention_matrices=True
if return_attention_matrices:
nmf_attention_matrix = attention_weights_nmf_cpu.copy()
del attention_weights_nmf_torch
if device.type == 'cuda':
torch.cuda.empty_cache()
# Convert gene_attention_weights to CPU if we stored it
if return_attention_matrices:
if verbose:
print(f" ✓ Moving gene attention matrix to CPU")
gene_attention_matrix = gene_attention_weights_torch.cpu().numpy()
del gene_attention_weights_torch
if device.type == 'cuda':
torch.cuda.empty_cache()
# Note: gene_attention_matrix stays None if return_attention_matrices=False
# (it was initialized to None at function start)
# Reassign to original variable names for consistency with rest of code
gene_max_sample_idx = gene_max_sample_idx_np
gene_mean = gene_mean_np
gene_min = gene_min_np
gene_max = gene_max_np
gene_nonzero = gene_nonzero_np
gene_pct_nonzero = gene_pct_nonzero_np
# Skip to NMF aggregation (gene aggregation is done)
skip_gene_aggregation = True
use_cpu_for_nmf = True # Flag to use CPU for NMF aggregation
else:
# Full matrix multiplication (fits in memory)
use_fp16 = torch.cuda.is_available() and torch.cuda.get_device_capability(device)[0] >= 7 # Volta+
if use_fp16:
with torch.cuda.amp.autocast():
gene_attention_weights_torch = torch.mm(
attention_weights_nmf_torch.half(),
nmf_H_torch.half()
).float()
else:
gene_attention_weights_torch = torch.mm(attention_weights_nmf_torch, nmf_H_torch)
# Normalize if requested (GPU)
if normalize:
if verbose:
print(f" ✓ Normalizing gene attention weights (GPU)")
row_sums = gene_attention_weights_torch.sum(dim=1, keepdim=True)
row_sums = torch.clamp(row_sums, min=1.0)
gene_attention_weights_torch = gene_attention_weights_torch / row_sums
n_samples, n_genes = gene_attention_weights_torch.shape
if verbose:
print(f" ✓ Traced to genes: {gene_attention_weights_torch.shape}")
# Step 3: Aggregate gene features on GPU (optimized for large matrices)
if verbose:
print(f" ✓ Aggregating gene features (GPU, optimized)")
# Use vectorized operations - much faster than chunked loops for most GPUs
# For very large matrices (>1M samples), still chunk to avoid OOM
skip_gene_aggregation = False
if n_samples > 1000000:
chunk_size = 100000 # Process 100k samples at a time
gene_max_sample_idx = torch.zeros(n_genes, dtype=torch.long, device=device)
gene_mean = torch.zeros(n_genes, dtype=torch.float32, device=device)
gene_min = torch.full((n_genes,), float('inf'), dtype=torch.float32, device=device)
gene_max = torch.full((n_genes,), float('-inf'), dtype=torch.float32, device=device)
# Process in chunks with progress
chunk_iter = tqdm(range(0, n_samples, chunk_size), desc=" Aggregating", disable=not verbose)
for chunk_start in chunk_iter:
chunk_end = min(chunk_start + chunk_size, n_samples)
chunk = gene_attention_weights_torch[chunk_start:chunk_end]
# Update running statistics
chunk_max, chunk_max_idx = torch.max(chunk, dim=0)
chunk_max_idx = chunk_max_idx + chunk_start
update_mask = chunk_max > gene_max
gene_max[update_mask] = chunk_max[update_mask]
gene_max_sample_idx[update_mask] = chunk_max_idx[update_mask]
gene_mean += chunk.sum(dim=0)
gene_min = torch.minimum(gene_min, chunk.min(dim=0)[0])
gene_mean = gene_mean / n_samples
else:
# For smaller datasets, use full vectorized operations (much faster)
gene_max, gene_max_sample_idx = torch.max(gene_attention_weights_torch, dim=0)
gene_mean = gene_attention_weights_torch.mean(dim=0)
gene_min = gene_attention_weights_torch.min(dim=0)[0]
# Compute threshold using chunked percentile (memory-efficient)
if nonzero_threshold is None:
# Use approximate percentile: sample and compute
# For very large arrays, compute percentile on a sample
sample_size = min(100000, n_samples * n_genes)
if sample_size < n_samples * n_genes:
# Sample random indices
flat_indices = torch.randint(0, n_samples * n_genes, (sample_size,), device=device)
sample_values = gene_attention_weights_torch.flatten()[flat_indices]
threshold = torch.quantile(sample_values, 0.01).item()
else:
threshold = torch.quantile(gene_attention_weights_torch, 0.01).item()
else:
threshold = nonzero_threshold
# Count nonzero (vectorized if possible, chunked for very large)
if n_samples > 1000000:
gene_nonzero = torch.zeros(n_genes, dtype=torch.int32, device=device)
chunk_size = 100000
chunk_iter = tqdm(range(0, n_samples, chunk_size), desc=" Counting nonzero", disable=not verbose)
for chunk_start in chunk_iter:
chunk_end = min(chunk_start + chunk_size, n_samples)
chunk = gene_attention_weights_torch[chunk_start:chunk_end]
gene_nonzero += (chunk > threshold).sum(dim=0).int()
else:
gene_nonzero = (gene_attention_weights_torch > threshold).sum(dim=0).int()
gene_pct_nonzero = (gene_nonzero.float() / n_samples) * 100.0
# Preserve matrices if return_attention_matrices=True (before converting to CPU)
if return_attention_matrices:
if verbose:
print(f" ✓ Preserving attention matrices for return")
gene_attention_matrix = gene_attention_weights_torch.cpu().numpy()
nmf_attention_matrix = attention_weights_nmf_torch.cpu().numpy()
else:
gene_attention_matrix = None
nmf_attention_matrix = None
# Convert to CPU numpy for DataFrame construction
gene_max_sample_idx = gene_max_sample_idx.cpu().numpy()
gene_mean = gene_mean.cpu().numpy()
gene_min = gene_min.cpu().numpy()
gene_max = gene_max.cpu().numpy()
gene_nonzero = gene_nonzero.cpu().numpy()
gene_pct_nonzero = gene_pct_nonzero.cpu().numpy()
# Clean up GPU memory
del gene_attention_weights_torch
if device.type == 'cuda':
torch.cuda.empty_cache()
# Step 4: Aggregate NMF features
# Check if we need to use CPU (memory was freed in chunked path)
if 'use_cpu_for_nmf' in locals() and use_cpu_for_nmf:
# Use CPU numpy array (already moved above)
if verbose:
print(f" ✓ Aggregating NMF features (CPU)")
nmf_max_sample_idx = np.argmax(attention_weights_nmf_cpu, axis=0)
nmf_mean = np.mean(attention_weights_nmf_cpu, axis=0)
nmf_min = np.min(attention_weights_nmf_cpu, axis=0)
nmf_max = np.max(attention_weights_nmf_cpu, axis=0)
# Compute NMF threshold (sampled for memory efficiency)
if nonzero_threshold is None:
sample_size = min(100000, attention_weights_nmf_cpu.size)
flat_nmf = attention_weights_nmf_cpu.flatten()
sample_indices = np.random.choice(flat_nmf.size, size=min(sample_size, flat_nmf.size), replace=False)
nmf_threshold = np.percentile(flat_nmf[sample_indices], 1.0)
del flat_nmf
else:
nmf_threshold = nonzero_threshold
nmf_nonzero = np.sum(attention_weights_nmf_cpu > nmf_threshold, axis=0)
nmf_pct_nonzero = (nmf_nonzero / n_samples) * 100.0
# Clean up
del attention_weights_nmf_cpu
else:
# Use GPU (original path for smaller datasets)
nmf_max_sample_idx = torch.argmax(attention_weights_nmf_torch, dim=0).cpu().numpy()
nmf_mean = attention_weights_nmf_torch.mean(dim=0).cpu().numpy()
nmf_min = attention_weights_nmf_torch.min(dim=0)[0].cpu().numpy()
nmf_max = attention_weights_nmf_torch.max(dim=0)[0].cpu().numpy()
# Compute NMF threshold
if nonzero_threshold is None:
sample_size = min(100000, n_samples * n_nmf_features)
if sample_size < n_samples * n_nmf_features:
flat_indices = torch.randint(0, n_samples * n_nmf_features, (sample_size,), device=device)
sample_values = attention_weights_nmf_torch.flatten()[flat_indices]
nmf_threshold = torch.quantile(sample_values, 0.01).item()
else:
nmf_threshold = torch.quantile(attention_weights_nmf_torch, 0.01).item()
else:
nmf_threshold = nonzero_threshold
nmf_nonzero = (attention_weights_nmf_torch > nmf_threshold).sum(dim=0).int().cpu().numpy()
nmf_pct_nonzero = (nmf_nonzero / n_samples) * 100.0
# Preserve NMF matrix if return_attention_matrices=True
# Note: gene_attention_matrix was already preserved earlier (at line 3078 or 2965)
if return_attention_matrices:
# Get NMF matrix from torch tensor (it should exist here)
# gene_attention_matrix should already be set from earlier in the function
nmf_attention_matrix = attention_weights_nmf_torch.cpu().numpy()
else:
# Ensure it's None (already initialized to None at function start)
nmf_attention_matrix = None
# Clean up GPU memory
if 'gene_attention_weights_torch' in locals():
del gene_attention_weights_torch
del attention_weights_nmf_torch, nmf_H_torch
if 'cuda' in str(device):
torch.cuda.empty_cache()
else:
# CPU path: use GPU for matrix multiplication if available, CPU for aggregation
# This hybrid approach is much faster than pure CPU
use_gpu_for_matmul = torch.cuda.is_available() if 'torch' in globals() else False
if isinstance(nmf_H, torch.Tensor):
nmf_H_numpy = nmf_H.detach().cpu().numpy()
nmf_H_torch = nmf_H.to(device) if use_gpu_for_matmul and device is not None else None
else:
nmf_H_numpy = nmf_H
nmf_H_torch = torch.from_numpy(nmf_H).float().to(device) if use_gpu_for_matmul and device is not None else None
# For very large matrices, do everything in chunks
n_samples, n_nmf_features = attention_weights_nmf.shape
n_genes = nmf_H_numpy.shape[1]
if verbose:
if use_gpu_for_matmul:
print(f" ✓ Tracing to genes (GPU matmul, CPU aggregation): {attention_weights_nmf.shape} @ {nmf_H_numpy.shape}")
else:
print(f" ✓ Tracing to genes (CPU, chunked, no full matrix): {attention_weights_nmf.shape} @ {nmf_H_numpy.shape}")
# Initialize aggregators (we'll accumulate as we process chunks)
gene_max_sample_idx = np.zeros(n_genes, dtype=np.int64)
gene_mean = np.zeros(n_genes, dtype=np.float32)
gene_min = np.full(n_genes, np.inf, dtype=np.float32)
gene_max = np.full(n_genes, -np.inf, dtype=np.float32)
gene_nonzero = np.zeros(n_genes, dtype=np.int32)
# If return_attention_matrices=True, allocate full matrix (memory-intensive!)
if return_attention_matrices:
estimated_memory_gb = (n_samples * n_genes * 4) / (1024**3)
if verbose:
print(f" Warning: return_attention_matrices=True requires storing full matrix ({estimated_memory_gb:.1f} GB)")
gene_attention_weights = np.zeros((n_samples, n_genes), dtype=np.float32)
else:
gene_attention_weights = None
# Process in chunks - compute gene attention, normalize, and aggregate all in one pass
# Use larger chunks for better performance
chunk_size = min(100000, n_samples) # Increased from 50k to 100k
all_values_for_percentile = [] # Collect samples for percentile
chunk_iter = tqdm(range(0, n_samples, chunk_size), desc=" Processing chunks", disable=not verbose)
for chunk_start in chunk_iter:
chunk_end = min(chunk_start + chunk_size, n_samples)
chunk_attention = attention_weights_nmf[chunk_start:chunk_end]
# Compute gene attention for this chunk only
# Use GPU for matrix multiplication if available (much faster)
if use_gpu_for_matmul and nmf_H_torch is not None:
chunk_attention_torch = torch.from_numpy(chunk_attention).float().to(device)
chunk_gene_attention_torch = torch.mm(chunk_attention_torch, nmf_H_torch)
chunk_gene_attention = chunk_gene_attention_torch.cpu().numpy()
del chunk_attention_torch, chunk_gene_attention_torch
if device.type == 'cuda':
torch.cuda.empty_cache()
else:
chunk_gene_attention = chunk_attention @ nmf_H_numpy # (chunk_size, n_genes)
# Normalize if requested (in-place on chunk)
if normalize:
row_sums = chunk_gene_attention.sum(axis=1, keepdims=True)
np.clip(row_sums, a_min=1.0, a_max=None, out=row_sums)
np.divide(chunk_gene_attention, row_sums, out=chunk_gene_attention)
# Store chunk in full matrix if requested
if return_attention_matrices:
gene_attention_weights[chunk_start:chunk_end] = chunk_gene_attention
# Update running statistics
chunk_max_idx = np.argmax(chunk_gene_attention, axis=0) + chunk_start
chunk_max = np.max(chunk_gene_attention, axis=0)
update_mask = chunk_max > gene_max
gene_max[update_mask] = chunk_max[update_mask]
gene_max_sample_idx[update_mask] = chunk_max_idx[update_mask]
gene_mean += chunk_gene_attention.sum(axis=0)
gene_min = np.minimum(gene_min, np.min(chunk_gene_attention, axis=0))
# Sample values for percentile (don't store all)
if nonzero_threshold is None:
# Sample 1% of values from this chunk for percentile estimation
sample_size = max(1000, chunk_gene_attention.size // 100)
flat_chunk = chunk_gene_attention.flatten()
sample_indices = np.random.choice(flat_chunk.size, size=min(sample_size, flat_chunk.size), replace=False)
all_values_for_percentile.append(flat_chunk[sample_indices])
# Free chunk immediately (unless we're storing it)
if not return_attention_matrices:
del chunk_gene_attention
gene_mean = gene_mean / n_samples
# Compute threshold from sampled values (much more memory-efficient)
if nonzero_threshold is None:
if all_values_for_percentile:
all_samples = np.concatenate(all_values_for_percentile)
threshold = np.percentile(all_samples, 1.0)
del all_values_for_percentile, all_samples
else:
# Fallback: use a simple heuristic
threshold = gene_mean.min() * 0.1
else:
threshold = nonzero_threshold
# Second pass: count nonzero (recompute chunks, but don't store)
if verbose:
print(f" ✓ Computing nonzero counts (second pass)")
chunk_iter = tqdm(range(0, n_samples, chunk_size), desc=" Counting nonzero", disable=not verbose)
for chunk_start in chunk_iter:
chunk_end = min(chunk_start + chunk_size, n_samples)
chunk_attention = attention_weights_nmf[chunk_start:chunk_end]
# Use GPU for matrix multiplication if available
if use_gpu_for_matmul and nmf_H_torch is not None:
chunk_attention_torch = torch.from_numpy(chunk_attention).float().to(device)
chunk_gene_attention_torch = torch.mm(chunk_attention_torch, nmf_H_torch)
chunk_gene_attention = chunk_gene_attention_torch.cpu().numpy()
del chunk_attention_torch, chunk_gene_attention_torch
if device.type == 'cuda':
torch.cuda.empty_cache()
else:
chunk_gene_attention = chunk_attention @ nmf_H_numpy
# Normalize if needed (same as first pass)
if normalize:
row_sums = chunk_gene_attention.sum(axis=1, keepdims=True)
np.clip(row_sums, a_min=1.0, a_max=None, out=row_sums)
np.divide(chunk_gene_attention, row_sums, out=chunk_gene_attention)
gene_nonzero += np.sum(chunk_gene_attention > threshold, axis=0).astype(np.int32)
del chunk_gene_attention # Free immediately
gene_pct_nonzero = (gene_nonzero / n_samples) * 100.0
if verbose:
print(f" ✓ Traced to genes: ({n_samples}, {n_genes})")
# Step 4: Aggregate NMF features (chunked)
nmf_max_sample_idx = np.argmax(attention_weights_nmf, axis=0)
nmf_mean = np.mean(attention_weights_nmf, axis=0)
nmf_min = np.min(attention_weights_nmf, axis=0)
nmf_max = np.max(attention_weights_nmf, axis=0)
# Compute NMF threshold (sampled)
if nonzero_threshold is None:
sample_size = min(100000, attention_weights_nmf.size)
flat_nmf = attention_weights_nmf.flatten()
sample_indices = np.random.choice(flat_nmf.size, size=min(sample_size, flat_nmf.size), replace=False)
nmf_threshold = np.percentile(flat_nmf[sample_indices], 1.0)
del flat_nmf
else:
nmf_threshold = nonzero_threshold
nmf_nonzero = np.sum(attention_weights_nmf > nmf_threshold, axis=0)
nmf_pct_nonzero = (nmf_nonzero / n_samples) * 100.0
# Preserve matrices if return_attention_matrices=True
if return_attention_matrices:
gene_attention_matrix = gene_attention_weights
nmf_attention_matrix = attention_weights_nmf.copy()
else:
gene_attention_matrix = None
nmf_attention_matrix = None
# Free the large matrix as soon as we're done with it
del gene_attention_weights
if verbose:
print(f" ✓ Aggregated statistics")
# Step 5: Build DataFrames efficiently (pre-allocate arrays)
# Convert sample names to array once if needed
sample_names_array = None
if sample_names is not None:
sample_names_array = np.asarray(sample_names)
# Build gene DataFrame (use list of arrays for faster construction)
gene_cols = [
np.arange(n_genes, dtype=np.int32), # feature_index
gene_mean.astype(np.float32), # mean_attention
gene_min.astype(np.float32), # min_attention
gene_max.astype(np.float32), # max_attention
gene_nonzero.astype(np.int32), # n_samples_nonzero
gene_pct_nonzero.astype(np.float32), # pct_samples_nonzero
]
gene_col_names = ['feature_index', 'mean_attention', 'min_attention', 'max_attention',
'n_samples_nonzero', 'pct_samples_nonzero']
if gene_feature_names is not None:
gene_cols.append(np.asarray(gene_feature_names))
gene_col_names.append('feature_name')
if sample_names_array is not None:
gene_cols.append(sample_names_array[gene_max_sample_idx])
gene_col_names.append('max_attention_sample')
# Create DataFrame from dict (faster than column-by-column)
gene_aggregated_df = pd.DataFrame(dict(zip(gene_col_names, gene_cols)))
# Sort by mean attention (descending) - use numpy argsort for speed, then reindex
sort_idx = np.argsort(-gene_mean) # Negative for descending
gene_aggregated_df = gene_aggregated_df.iloc[sort_idx].reset_index(drop=True)
# Build NMF DataFrame (same approach)
nmf_cols = [
np.arange(n_nmf_features, dtype=np.int32), # feature_index
nmf_mean.astype(np.float32), # mean_attention
nmf_min.astype(np.float32), # min_attention
nmf_max.astype(np.float32), # max_attention
nmf_nonzero.astype(np.int32), # n_samples_nonzero
nmf_pct_nonzero.astype(np.float32), # pct_samples_nonzero
]
nmf_col_names = ['feature_index', 'mean_attention', 'min_attention', 'max_attention',
'n_samples_nonzero', 'pct_samples_nonzero']
# Always add feature_name column for NMF features
if nmf_feature_names is not None:
nmf_cols.append(np.asarray(nmf_feature_names))
else:
# Generate default names: factor1, factor2, factor3, ...
nmf_cols.append(np.array([f'factor{i+1}' for i in range(n_nmf_features)]))
nmf_col_names.append('feature_name')
if sample_names_array is not None:
nmf_cols.append(sample_names_array[nmf_max_sample_idx])
nmf_col_names.append('max_attention_sample')
nmf_aggregated_df = pd.DataFrame(dict(zip(nmf_col_names, nmf_cols)))
# Sort by mean attention (descending)
sort_idx = np.argsort(-nmf_mean) # Negative for descending
nmf_aggregated_df = nmf_aggregated_df.iloc[sort_idx].reset_index(drop=True)
# Step 6: Add label column from metadata if available
if metadata is not None and 'obs' in metadata:
try:
obs = metadata['obs']
if 'label' in obs.columns and 'max_attention_sample' in gene_aggregated_df.columns:
# Create a mapping from sample ID to label
# Use obs_id if available, otherwise use index
if 'obs_id' in obs.columns:
sample_to_label = dict(zip(obs['obs_id'], obs['label']))
else:
sample_to_label = dict(zip(obs.index, obs['label']))
# Add label column to gene_aggregated_df
gene_aggregated_df['label'] = gene_aggregated_df['max_attention_sample'].map(sample_to_label)
# Add label column to nmf_aggregated_df if it has max_attention_sample
if 'max_attention_sample' in nmf_aggregated_df.columns:
nmf_aggregated_df['label'] = nmf_aggregated_df['max_attention_sample'].map(sample_to_label)
if verbose:
n_labeled_genes = gene_aggregated_df['label'].notna().sum()
n_labeled_nmf = nmf_aggregated_df['label'].notna().sum() if 'label' in nmf_aggregated_df.columns else 0
print(f" ✓ Added label column: {n_labeled_genes:,} genes, {n_labeled_nmf:,} NMF components")
except Exception as e:
if verbose:
print(f" Warning: Could not add label column from metadata: {e}")
elapsed = time.time() - start_time
if verbose:
print(f"\n✓ Completed! Processed {n_samples:,} samples in {elapsed:.2f}s")
print(f" Gene features aggregated: {len(gene_aggregated_df):,} genes")
print(f" NMF features aggregated: {len(nmf_aggregated_df):,} components")
print(f"\nSummary Statistics:")
print(f" Gene attention - Mean: {gene_mean.mean():.6f}, Max: {gene_mean.max():.6f}")
print(f" NMF attention - Mean: {nmf_mean.mean():.6f}, Max: {nmf_mean.max():.6f}")
print(f" Genes with >0 attention in >1% samples: {(gene_pct_nonzero > 1).sum():,}")
print(f" NMF components with >0 attention in >1% samples: {(nmf_pct_nonzero > 1).sum():,}")
# Save dataframes as parquet files if save_dir is provided
if save_dir is not None:
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
gene_file = save_path / 'gene_attention_aggregated.parquet'
nmf_file = save_path / 'nmf_attention_aggregated.parquet'
# Check if files exist (edge case: if only one file exists, or if force was changed)
if not force:
if gene_file.exists():
raise FileExistsError(
f"File {gene_file} already exists. Set force=True to overwrite."
)
if nmf_file.exists():
raise FileExistsError(
f"File {nmf_file} already exists. Set force=True to overwrite."
)
# Save as parquet
if verbose:
print(f"\n Saving dataframes to {save_dir}...")
gene_aggregated_df.to_parquet(gene_file, index=False, engine='pyarrow')
nmf_aggregated_df.to_parquet(nmf_file, index=False, engine='pyarrow')
if verbose:
print(f" ✓ Saved: {gene_file}")
print(f" ✓ Saved: {nmf_file}")
# Save attention matrices if return_attention_matrices=True and they exist
if return_attention_matrices and gene_attention_matrix is not None and nmf_attention_matrix is not None:
gene_matrix_file = save_path / 'gene_attention_matrix.npy'
nmf_matrix_file = save_path / 'nmf_attention_matrix.npy'
if verbose:
print(f" Saving attention matrices...")
np.save(gene_matrix_file, gene_attention_matrix)
np.save(nmf_matrix_file, nmf_attention_matrix)
if verbose:
print(f" ✓ Saved: {gene_matrix_file} (shape: {gene_attention_matrix.shape})")
print(f" ✓ Saved: {nmf_matrix_file} (shape: {nmf_attention_matrix.shape})")
# Return with or without matrices based on return_attention_matrices
if return_attention_matrices:
return gene_aggregated_df, nmf_aggregated_df, gene_attention_matrix, nmf_attention_matrix
else:
return gene_aggregated_df, nmf_aggregated_df
[docs]
def plot_nmf_factor_distributions(
W: Union[np.ndarray, torch.Tensor],
n_factors_to_plot: Optional[int] = None,
figsize: Optional[Tuple[int, int]] = None,
bins: int = 50,
kde: bool = True,
title: Optional[str] = None,
factor_names: Optional[list] = None,
max_cols: int = 4,
max_samples: int = 50000,
sharex: bool = True,
sharey: bool = False,
log_x: bool = False,
log_y: bool = False,
filter_zeros: bool = False,
zero_threshold: float = 1e-10,
return_fig: bool = False,
) -> Optional["plt.Figure"]: # type: ignore
"""
Visualize the distribution of each NMF factor across all samples using a faceted plot.
Creates a grid of subplots where each subplot shows the distribution (histogram + KDE)
of one NMF factor across all samples. This helps identify factors with different
distributions, sparsity patterns, or outliers.
Optimized for large datasets by sampling when n_samples > max_samples.
Parameters
----------
W : np.ndarray or torch.Tensor
NMF factor matrix of shape (n_samples, n_components).
Each column represents one NMF factor, each row represents one sample.
n_factors_to_plot : int, optional
Number of factors to plot. If None, plots all factors.
If specified, plots the first n_factors_to_plot factors.
figsize : tuple, optional
Figure size in inches (width, height). If None, auto-calculated based on grid size.
bins : int, default 50
Number of bins for histogram.
kde : bool, default True
If True, overlay a kernel density estimate (KDE) curve on the histogram.
Automatically disabled for very large datasets (>100k samples) for performance.
title : str, optional
Overall plot title. If None, auto-generated.
factor_names : list, optional
Custom names for factors. If None, uses "Factor 1", "Factor 2", etc.
max_cols : int, default 4
Maximum number of columns in the subplot grid.
max_samples : int, default 50000
Maximum number of samples to use for plotting. If n_samples > max_samples,
randomly samples max_samples for faster computation. Statistics (mean, std, etc.)
are still computed on full dataset.
sharex : bool, default True
If True, share x-axis across all subplots. Makes it easier to compare value ranges.
sharey : bool, default False
If True, share y-axis across all subplots. Makes it easier to compare densities.
log_x : bool, default False
If True, use logarithmic scale for x-axis. Useful for visualizing sparse distributions
with long tails. Automatically adds small epsilon to handle zeros.
log_y : bool, default False
If True, use logarithmic scale for y-axis (density). Useful when density values
span multiple orders of magnitude.
filter_zeros : bool, default False
If True, exclude zero values from the distribution plot. Useful for focusing on
the non-zero tail of sparse distributions. Statistics are still computed on full dataset.
zero_threshold : float, default 1e-10
Values below this threshold are considered "zero" when filter_zeros=True.
Useful for filtering out numerical noise.
return_fig : bool, default False
If True, return the matplotlib figure for further customization.
If False, display the plot and return None.
Returns
-------
fig : matplotlib.figure.Figure or None
The matplotlib figure if return_fig=True, otherwise None.
Examples
--------
>>> from AoU.phenome.sparseNMF import train_sparse_nmf, plot_nmf_factor_distributions
>>>
>>> # Train NMF
>>> X_nmf, nmf_model = train_sparse_nmf(X_sparse, n_components=256)
>>>
>>> # Visualize first 12 factors (fast, even for large datasets)
>>> plot_nmf_factor_distributions(X_nmf, n_factors_to_plot=12)
>>>
>>> # Customize with factor names
>>> factor_names = [f"Component {i+1}" for i in range(12)]
>>> fig = plot_nmf_factor_distributions(
... X_nmf,
... n_factors_to_plot=12,
... factor_names=factor_names,
... kde=True,
... return_fig=True
... )
>>> fig.savefig("nmf_distributions.png")
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError(
"plot_nmf_factor_distributions requires matplotlib. "
"Install with: pip install matplotlib"
)
# Convert to numpy if needed
if isinstance(W, torch.Tensor):
W_np = W.cpu().numpy()
else:
W_np = np.asarray(W)
if W_np.ndim != 2:
raise ValueError(f"W must be 2D array (n_samples, n_components), got shape {W_np.shape}")
# W_np shape: (n_samples, n_factors)
# Each row is a sample, each column is a factor
# We plot the distribution of each factor (column) across all samples (rows)
n_samples, n_components = W_np.shape
# Sample data if too large (for plotting only, stats use full data)
use_sampling = n_samples > max_samples
if use_sampling:
# Random sample for plotting
sample_indices = np.random.choice(n_samples, size=max_samples, replace=False)
W_plot = W_np[sample_indices]
n_plot_samples = max_samples
else:
W_plot = W_np
n_plot_samples = n_samples
# Determine how many factors to plot
if n_factors_to_plot is None:
n_factors_to_plot = n_components
else:
n_factors_to_plot = min(n_factors_to_plot, n_components)
# Calculate grid dimensions
n_cols = min(max_cols, n_factors_to_plot)
n_rows = int(np.ceil(n_factors_to_plot / n_cols))
# Auto-calculate figsize if not provided
if figsize is None:
width = n_cols * 3.5
height = n_rows * 2.5
figsize = (width, height)
# Pre-compute statistics for all factors (vectorized, uses full dataset)
# axis=0 means across samples (rows), giving one stat per factor (column)
factor_stats = {
'mean': W_np.mean(axis=0), # Shape: (n_components,)
'std': W_np.std(axis=0), # Shape: (n_components,)
'median': np.median(W_np, axis=0), # Shape: (n_components,)
}
# Disable KDE for very large datasets (too slow)
if kde and n_samples > 100000:
kde = False
# Create figure and axes
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False,
sharex=sharex, sharey=sharey)
axes = axes.flatten()
# Generate factor names if not provided
if factor_names is None:
factor_names = [f"Factor {i+1}" for i in range(n_components)]
elif len(factor_names) < n_components:
# Extend with default names if not enough provided
factor_names = list(factor_names) + [f"Factor {i+1}" for i in range(len(factor_names), n_components)]
# Pre-compute ranges - use filtered data if filter_zeros is True
if filter_zeros:
# Compute ranges on non-zero values only
factor_ranges = []
for i in range(n_factors_to_plot):
factor_values = W_plot[:, i]
factor_values_nonzero = factor_values[factor_values > zero_threshold]
if len(factor_values_nonzero) > 0:
factor_ranges.append((factor_values_nonzero.min(), factor_values_nonzero.max()))
else:
# All zeros, use full range as fallback
factor_ranges.append((W_np[:, i].min(), W_np[:, i].max()))
else:
# Use full data range
factor_ranges = [(W_np[:, i].min(), W_np[:, i].max()) for i in range(n_factors_to_plot)]
# Plot each factor (vectorized where possible)
for i in range(n_factors_to_plot):
ax = axes[i]
# W_plot[:, i] gets all samples (rows) for factor i (column i)
# This gives us the distribution of factor i across all samples
factor_values_plot = W_plot[:, i] # Sampled data for plotting
factor_values_full = W_np[:, i] # Full data for statistics
# Filter zeros if requested (for plotting only, stats use full data)
if filter_zeros:
factor_values_plot = factor_values_plot[factor_values_plot > zero_threshold]
if len(factor_values_plot) == 0:
# All zeros, skip plotting
ax.text(0.5, 0.5, 'All zeros', transform=ax.transAxes,
ha='center', va='center', fontsize=12)
ax.set_title(factor_names[i], fontsize=10, fontweight='bold')
continue
# Handle log scale for x-axis
if log_x:
# For log scale, use the filtered range if filter_zeros is True
if filter_zeros and len(factor_values_plot) > 0:
x_min = max(factor_ranges[i][0], zero_threshold)
x_max = factor_ranges[i][1]
else:
# Add small epsilon to handle zeros
epsilon = zero_threshold
x_min = max(factor_ranges[i][0], epsilon)
x_max = factor_ranges[i][1]
factor_values_plot_log = np.log10(factor_values_plot)
# Compute histogram on log scale
counts, bin_edges = np.histogram(factor_values_plot_log, bins=bins, density=True)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
# Convert back to original scale for display
bin_centers_display = 10**bin_centers
bin_widths = 10**bin_edges[1:] - 10**bin_edges[:-1]
else:
# Compute histogram (fast, uses sampled data)
# If filtering zeros, use the filtered range for better visualization
if filter_zeros and len(factor_values_plot) > 0:
# Use filtered range for better visualization
counts, bin_edges = np.histogram(
factor_values_plot,
bins=bins,
density=True,
range=(factor_ranges[i][0], factor_ranges[i][1])
)
else:
counts, bin_edges = np.histogram(factor_values_plot, bins=bins, density=True)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
bin_centers_display = bin_centers
bin_widths = bin_edges[1] - bin_edges[0]
# Plot histogram using bar plot (faster than ax.hist)
ax.bar(bin_centers_display, counts, width=bin_widths,
alpha=0.6, color='steelblue', edgecolor='black', linewidth=0.5)
# Set log scale if requested
if log_x:
ax.set_xscale('log')
if log_y:
ax.set_yscale('log')
# Add KDE if requested (works with filtered data too)
if kde and len(factor_values_plot) > 1: # Need at least 2 points for KDE
try:
from scipy import stats
# Use smaller sample for KDE if dataset is very large
kde_sample_size = min(10000, len(factor_values_plot))
if len(factor_values_plot) > kde_sample_size:
kde_indices = np.random.choice(len(factor_values_plot), size=kde_sample_size, replace=False)
kde_values = factor_values_plot[kde_indices]
else:
kde_values = factor_values_plot
# Compute KDE
kde_obj = stats.gaussian_kde(kde_values)
if log_x:
# For log scale, create range in log space then convert
x_min = max(factor_ranges[i][0], zero_threshold if filter_zeros else 1e-10)
x_max = factor_ranges[i][1]
x_range_log = np.linspace(np.log10(x_min), np.log10(x_max), 200)
x_range = 10**x_range_log
else:
x_range = np.linspace(factor_ranges[i][0], factor_ranges[i][1], 200)
kde_values_plot = kde_obj(x_range)
ax.plot(x_range, kde_values_plot, 'r-', linewidth=2, label='KDE', alpha=0.8)
except (ImportError, np.linalg.LinAlgError):
# scipy not available or singular matrix, skip KDE
pass
# Get pre-computed statistics (from full dataset)
mean_val = factor_stats['mean'][i]
std_val = factor_stats['std'][i]
median_val = factor_stats['median'][i]
# Add vertical lines for mean and median
ax.axvline(mean_val, color='green', linestyle='--', linewidth=1.5, alpha=0.7, label=f'Mean: {mean_val:.3f}')
ax.axvline(median_val, color='orange', linestyle='--', linewidth=1.5, alpha=0.7, label=f'Median: {median_val:.3f}')
# Set title and labels
ax.set_title(factor_names[i], fontsize=10, fontweight='bold')
ax.set_xlabel('Value', fontsize=9)
ax.set_ylabel('Density', fontsize=9)
ax.grid(True, alpha=0.3, linestyle=':', linewidth=0.5)
# Simplified legend (faster rendering)
if kde:
ax.legend(fontsize=7, loc='best', framealpha=0.8)
else:
ax.legend(fontsize=7, loc='best', framealpha=0.8)
# Add text box with statistics (computed from full dataset)
stats_text = f'n={n_samples:,}'
if use_sampling:
stats_text += f' (plotted: {n_plot_samples:,})'
if filter_zeros:
n_nonzero = (factor_values_full > zero_threshold).sum()
pct_nonzero = 100 * n_nonzero / n_samples
stats_text += f'\nnon-zero: {n_nonzero:,} ({pct_nonzero:.1f}%)'
stats_text += f'\nμ={mean_val:.3f}\nσ={std_val:.3f}\nmed={median_val:.3f}'
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
fontsize=8, verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# Hide unused subplots
for i in range(n_factors_to_plot, len(axes)):
axes[i].axis('off')
# Set overall title
if title is None:
title = f"NMF Factor Distributions (showing {n_factors_to_plot} of {n_components} factors)"
if use_sampling:
title += f" [sampled {n_plot_samples:,} of {n_samples:,} samples]"
fig.suptitle(title, fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.98]) # Leave room for suptitle
if return_fig:
return fig
else:
plt.show()
return None
[docs]
def compute_attention_correlation(
gene_attention_matrix: np.ndarray,
X: Union[np.ndarray, spmatrix],
obs_mask: Optional[np.ndarray] = None,
stratify_by_unique_values: bool = True,
verbose: bool = True,
) -> pd.DataFrame:
"""
Compute correlations between gene attention matrix and original data matrix.
Computes Pearson and Spearman correlations between gene attention weights and
the original gene expression/association matrix. Can stratify samples by the
number of unique values per sample (useful for binary vs. continuous data).
Parameters
----------
gene_attention_matrix : np.ndarray
Gene attention matrix, shape (n_samples, n_genes).
Contains continuous attention scores for each sample-gene pair.
X : np.ndarray or scipy.sparse matrix
Original data matrix, shape (n_samples, n_genes).
Can be sparse (scipy.sparse) or dense (numpy array).
obs_mask : np.ndarray, optional
Boolean mask to select subset of samples. If provided, X[obs_mask] is used.
Should match the samples in gene_attention_matrix.
stratify_by_unique_values : bool, default True
If True, stratify samples by number of unique values per sample:
- "2_unique": Binary data (2 unique values)
- "3_unique": Ternary data (3 unique values)
- "4+_unique": Continuous data (4+ unique values)
If False, only computes global correlations.
verbose : bool, default True
If True, print summary information.
Returns
-------
correlation_results_df : pd.DataFrame
DataFrame with columns:
- stratum: "2_unique", "3_unique", "4+_unique", or "all"
- subset: "all" (all values) or "nonzero" (only nonzero values in X)
- n_samples: Number of samples (for strata) or values (for global)
- pearson: Pearson correlation coefficient
- spearman: Spearman rank correlation coefficient
- spearman_p: Spearman correlation p-value
Examples
--------
>>> from AoU.phenome.sparseNMF import compute_attention_correlation
>>> import numpy as np
>>>
>>> # Compute correlations
>>> results_df = compute_attention_correlation(
... gene_attention_matrix=gene_attention_matrix,
... X=X,
... obs_mask=obs_mask,
... stratify_by_unique_values=True
... )
>>>
>>> # Display results
>>> print(results_df)
"""
import pandas as pd
from scipy.stats import spearmanr
# Handle sparse matrices
if obs_mask is not None:
if hasattr(X, "toarray"):
X_dense = X[obs_mask].toarray()
else:
X_dense = X[obs_mask]
else:
if hasattr(X, "toarray"):
X_dense = X.toarray()
else:
X_dense = X
# Verify shapes match
if gene_attention_matrix.shape != X_dense.shape:
raise ValueError(
f"Shape mismatch: gene_attention_matrix {gene_attention_matrix.shape} "
f"vs X {X_dense.shape}. Ensure obs_mask matches if provided."
)
n_samples, n_genes = X_dense.shape
flat_gene_attention = gene_attention_matrix.flatten()
flat_x = X_dense.flatten()
correlation_results = []
if stratify_by_unique_values:
# Compute unique counts per sample (row)
num_unique_per_sample = np.apply_along_axis(lambda r: len(np.unique(r)), 1, X_dense)
strata = {
"2_unique": np.where(num_unique_per_sample == 2)[0],
"3_unique": np.where(num_unique_per_sample == 3)[0],
"4+_unique": np.where(num_unique_per_sample >= 4)[0],
}
if verbose:
print(f"Stratifying samples by unique values:")
for label, idx in strata.items():
print(f" {label}: {len(idx):,} samples")
# Compute correlations for each stratum
for label, idx in strata.items():
if len(idx) == 0:
# No samples in this group
correlation_results.append({
"stratum": label,
"subset": "all",
"n_samples": 0,
"pearson": np.nan,
"spearman": np.nan,
"spearman_p": np.nan,
})
correlation_results.append({
"stratum": label,
"subset": "nonzero",
"n_samples": 0,
"pearson": np.nan,
"spearman": np.nan,
"spearman_p": np.nan,
})
continue
# Get rows (samples) in this stratum
orig_sub = X_dense[idx, :]
attn_sub = gene_attention_matrix[idx, :]
# -- Correlation for all values --
flat_orig = orig_sub.flatten()
flat_attn = attn_sub.flatten()
try:
pearson_corr_all = np.corrcoef(flat_attn, flat_orig)[0, 1]
except Exception:
pearson_corr_all = np.nan
try:
spearman_corr_all, spearman_p_all = spearmanr(flat_attn, flat_orig)
except Exception:
spearman_corr_all, spearman_p_all = np.nan, np.nan
correlation_results.append({
"stratum": label,
"subset": "all",
"n_samples": len(idx),
"pearson": pearson_corr_all,
"spearman": spearman_corr_all,
"spearman_p": spearman_p_all,
})
# -- Correlation for ONLY values where X is nonzero --
nonzero_mask = flat_orig != 0
if np.any(nonzero_mask):
flat_orig_nz = flat_orig[nonzero_mask]
flat_attn_nz = flat_attn[nonzero_mask]
try:
pearson_corr_nz = np.corrcoef(flat_attn_nz, flat_orig_nz)[0, 1]
except Exception:
pearson_corr_nz = np.nan
try:
spearman_corr_nz, spearman_p_nz = spearmanr(flat_attn_nz, flat_orig_nz)
except Exception:
spearman_corr_nz, spearman_p_nz = np.nan, np.nan
else:
pearson_corr_nz, spearman_corr_nz, spearman_p_nz = np.nan, np.nan, np.nan
correlation_results.append({
"stratum": label,
"subset": "nonzero",
"n_samples": int(np.sum(nonzero_mask)),
"pearson": pearson_corr_nz,
"spearman": spearman_corr_nz,
"spearman_p": spearman_p_nz,
})
# Global correlation over all samples
# -- ALL values --
try:
pearson_corr = np.corrcoef(flat_gene_attention, flat_x)[0, 1]
except Exception:
pearson_corr = np.nan
try:
spearman_corr, spearman_p = spearmanr(flat_gene_attention, flat_x)
except Exception:
spearman_corr, spearman_p = np.nan, np.nan
correlation_results.append({
"stratum": "all",
"subset": "all",
"n_samples": n_samples,
"pearson": pearson_corr,
"spearman": spearman_corr,
"spearman_p": spearman_p,
})
# -- Only nonzero values in X --
nonzero_mask_all = flat_x != 0
if np.any(nonzero_mask_all):
flat_attn_nz_all = flat_gene_attention[nonzero_mask_all]
flat_x_nz_all = flat_x[nonzero_mask_all]
try:
pearson_corr_nz_all = np.corrcoef(flat_attn_nz_all, flat_x_nz_all)[0, 1]
except Exception:
pearson_corr_nz_all = np.nan
try:
spearman_corr_nz_all, spearman_p_nz_all = spearmanr(flat_attn_nz_all, flat_x_nz_all)
except Exception:
spearman_corr_nz_all, spearman_p_nz_all = np.nan, np.nan
else:
pearson_corr_nz_all, spearman_corr_nz_all, spearman_p_nz_all = np.nan, np.nan, np.nan
correlation_results.append({
"stratum": "all",
"subset": "nonzero",
"n_samples": int(np.sum(nonzero_mask_all)),
"pearson": pearson_corr_nz_all,
"spearman": spearman_corr_nz_all,
"spearman_p": spearman_p_nz_all,
})
# Convert to DataFrame
correlation_results_df = pd.DataFrame(correlation_results)
if verbose:
print(f"\nCorrelation Analysis Summary:")
print(f" Total samples: {n_samples:,}")
print(f" Total genes: {n_genes:,}")
print(f" Global correlation (all values):")
print(f" Pearson: {correlation_results_df.loc[(correlation_results_df['stratum'] == 'all') & (correlation_results_df['subset'] == 'all'), 'pearson'].values[0]:.4f}")
print(f" Spearman: {correlation_results_df.loc[(correlation_results_df['stratum'] == 'all') & (correlation_results_df['subset'] == 'all'), 'spearman'].values[0]:.4f}")
return correlation_results_df