Free SKILL.md scraped from GitHub. Clone the repo or copy the file directly into your Claude Code skills directory.
npx versuz@latest install hiyenwong-ai-collection-collection-skills-energy-based-autoregressive-neural-dynamicsgit clone https://github.com/hiyenwong/ai_collection.gitcp ai_collection/SKILL.MD ~/.claude/skills/hiyenwong-ai-collection-collection-skills-energy-based-autoregressive-neural-dynamics/SKILL.md---
name: energy-based-autoregressive-neural-dynamics
description: "Energy-based Autoregressive Generation (EAG) framework for neural population dynamics using energy-based transformer in latent space with strictly proper scoring rules. Activation triggers: energy-based model, neural population dynamics, autoregressive generation, brain modeling, transformer dynamics."
---
# Energy-based Autoregressive Generation for Neural Population Dynamics
> Novel EAG framework employing energy-based transformer learning temporal dynamics in latent space through strictly proper scoring rules for efficient neural population generation.
## Metadata
- **Source**: arXiv:2511.17606 [cs.LG]
- **Authors**: Ningling Ge, Sicheng Dai, Yu Zhu, Shan Yu
- **Published**: 2025-11-18
- **Code**: Available at https URL (see paper)
## Core Methodology
### Key Innovation
Neural population dynamics modeling faces a fundamental trade-off between computational efficiency and high-fidelity modeling. EAG addresses this by combining:
- **Energy-based modeling** for capturing complex distributions
- **Autoregressive generation** for temporal coherence
- **Strictly proper scoring rules** for efficient training without adversarial objectives
- **Transformer architecture** in latent space for long-range dependencies
### Technical Framework
#### 1. Energy-Based Model Foundation
- **Energy Function**: $E_\theta(x)$ assigns lower energy to realistic data
- **Boltzmann Distribution**: $p_\theta(x) \propto \exp(-E_\theta(x))$
- **Advantage**: Can model complex, multi-modal distributions without explicit normalization
#### 2. Autoregressive Temporal Dynamics
- **Factorization**: $p(x_{1:T}) = \prod_{t=1}^T p(x_t | x_{<t})$
- **Causal Masking**: Ensures temporal causality in predictions
- **Recurrent State**: Maintains history information efficiently
#### 3. Strictly Proper Scoring Rules (SPSR)
- **Training Objective**: Minimize expected scoring rule loss
- **Proper Scoring**: True distribution minimizes expected score
- **Strictly Proper**: Unique minimum at true distribution
- **Examples**: Energy score, Kernel score, Continuous Ranked Probability Score
#### 4. Latent Space Transformer
- **Encoder**: Maps observations to latent representations
- **Transformer**: Models dynamics in latent space
- **Decoder**: Maps predictions back to observation space
- **Advantage**: Lower dimensionality, smoother dynamics
## Implementation Guide
### Prerequisites
- Python 3.8+
- PyTorch 2.0+
- NumPy, SciPy
- Optional: einops for tensor manipulation
### Step-by-Step Implementation
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
import math
class EnergyBasedAutoregressiveModel(nn.Module):
"""
EAG: Energy-based Autoregressive Generation for Neural Population Dynamics
"""
def __init__(
self,
obs_dim: int,
latent_dim: int,
hidden_dim: int = 256,
num_layers: int = 4,
num_heads: int = 8,
dropout: float = 0.1,
scoring_rule: str = 'energy'
):
super().__init__()
self.obs_dim = obs_dim
self.latent_dim = latent_dim
self.scoring_rule = scoring_rule
# Encoder: observation -> latent
self.encoder = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, latent_dim)
)
# Latent dynamics transformer
self.latent_transformer = LatentTransformer(
latent_dim, hidden_dim, num_layers, num_heads, dropout
)
# Energy network for conditional distribution
self.energy_net = EnergyNetwork(latent_dim, hidden_dim)
# Decoder: latent -> observation
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, obs_dim)
)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode observation to latent space"""
return self.encoder(x)
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Decode latent to observation"""
return self.decoder(z)
def compute_energy(self, z_curr: torch.Tensor, z_hist: torch.Tensor) -> torch.Tensor:
"""
Compute energy for current latent given history
Args:
z_curr: [batch, latent_dim] current latent state
z_hist: [batch, seq_len, latent_dim] history
Returns:
energy: [batch] energy values
"""
return self.energy_net(z_curr, z_hist)
def forward(
self,
observations: torch.Tensor,
num_samples: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass with autoregressive generation
Args:
observations: [batch, seq_len, obs_dim] observed sequence
num_samples: number of samples for prediction
Returns:
predictions: [batch, seq_len, obs_dim]
energy: [batch, seq_len] energy values
"""
batch_size, seq_len, _ = observations.shape
# Encode to latent
z = self.encode(observations) # [batch, seq_len, latent_dim]
energies = []
predictions = []
# Autoregressive generation
for t in range(seq_len):
z_hist = z[:, :t] if t > 0 else None
z_curr = z[:, t]
# Compute energy
energy_t = self.compute_energy(z_curr, z_hist)
energies.append(energy_t)
# Sample and decode (simplified - in practice use Langevin or similar)
pred = self.decode(z_curr)
predictions.append(pred)
predictions = torch.stack(predictions, dim=1)
energies = torch.stack(energies, dim=1)
return predictions, energies
class LatentTransformer(nn.Module):
"""
Transformer for modeling latent dynamics
"""
def __init__(
self,
latent_dim: int,
hidden_dim: int,
num_layers: int,
num_heads: int,
dropout: float
):
super().__init__()
self.embedding = nn.Linear(latent_dim, hidden_dim)
# Causal transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True
)
# Custom causal mask
self.transformer = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers,
norm=nn.LayerNorm(hidden_dim)
)
self.output_proj = nn.Linear(hidden_dim, latent_dim)
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
Args:
z: [batch, seq_len, latent_dim]
Returns:
[batch, seq_len, latent_dim]
"""
# Create causal mask
seq_len = z.size(1)
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=z.device) * float('-inf'),
diagonal=1
)
# Embed and transform
h = self.embedding(z)
h = self.transformer(h, mask=causal_mask)
return self.output_proj(h)
class EnergyNetwork(nn.Module):
"""
Energy network for conditional distribution
"""
def __init__(self, latent_dim: int, hidden_dim: int):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(latent_dim * 2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 1)
)
def forward(
self,
z_curr: torch.Tensor,
z_hist: Optional[torch.Tensor]
) -> torch.Tensor:
"""
Compute energy for current latent given history
Args:
z_curr: [batch, latent_dim]
z_hist: [batch, hist_len, latent_dim] or None
Returns:
energy: [batch]
"""
if z_hist is None or z_hist.size(1) == 0:
# No history - use zero history
hist_repr = torch.zeros_like(z_curr)
else:
# Aggregate history (mean pooling)
hist_repr = z_hist.mean(dim=1)
# Concatenate and compute energy
combined = torch.cat([z_curr, hist_repr], dim=-1)
return self.mlp(combined).squeeze(-1)
# Strictly Proper Scoring Rules
def energy_score(pred: torch.Tensor, target: torch.Tensor, num_samples: int = 100) -> torch.Tensor:
"""
Energy scoring rule for energy-based models
E_S(p, y) = 2 * E_{X~p}[||X - y||] - E_{X,X'~p}[||X - X'||]
"""
# Simplified implementation - in practice sample from energy model
diff_pred_target = torch.norm(pred - target, dim=-1)
return 2 * diff_pred_target.mean()
def kernel_score(pred: torch.Tensor, target: torch.Tensor, kernel='rbf', sigma=1.0) -> torch.Tensor:
"""
Kernel scoring rule using RBF kernel
"""
def rbf_kernel(x, y, sigma):
dist = torch.sum((x.unsqueeze(1) - y.unsqueeze(0)) ** 2, dim=-1)
return torch.exp(-dist / (2 * sigma ** 2))
if kernel == 'rbf':
k_pred_target = rbf_kernel(pred, target.unsqueeze(0), sigma)
return -k_pred_target.mean()
return energy_score(pred, target)
```
### Training with Strictly Proper Scoring Rules
```python
class EAGTrainer:
"""
Trainer for Energy-based Autoregressive Generation
"""
def __init__(self, model: EnergyBasedAutoregressiveModel, lr: float = 1e-4):
self.model = model
self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
def train_step(self, batch: torch.Tensor) -> dict:
"""
Single training step
Args:
batch: [batch, seq_len, obs_dim] neural population activity
Returns:
Dictionary of losses
"""
batch_size, seq_len, obs_dim = batch.shape
# Forward pass
predictions, energies = self.model(batch)
# Compute scoring rule loss
if self.model.scoring_rule == 'energy':
loss = energy_score(predictions, batch)
elif self.model.scoring_rule == 'kernel':
loss = kernel_score(predictions, batch)
else:
# MSE as fallback
loss = F.mse_loss(predictions, batch)
# Add energy regularization (encourage lower energy for data)
energy_reg = energies.mean()
total_loss = loss + 0.01 * energy_reg
# Backward
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
return {
'loss': loss.item(),
'energy_reg': energy_reg.item(),
'total_loss': total_loss.item()
}
def generate(
self,
initial_obs: torch.Tensor,
num_steps: int,
temperature: float = 1.0
) -> torch.Tensor:
"""
Generate future neural activity
Args:
initial_obs: [batch, init_len, obs_dim] initial observations
num_steps: number of steps to generate
temperature: sampling temperature
Returns:
generated: [batch, num_steps, obs_dim]
"""
self.model.eval()
generated = []
with torch.no_grad():
# Encode initial
z_hist = self.model.encode(initial_obs) # [batch, init_len, latent]
for _ in range(num_steps):
# Sample next latent using Langevin dynamics (simplified)
z_next = self.sample_langevin(z_hist, temperature)
# Decode
obs_next = self.model.decode(z_next)
generated.append(obs_next)
# Update history
z_hist = torch.cat([z_hist, z_next.unsqueeze(1)], dim=1)
return torch.stack(generated, dim=1)
def sample_langevin(
self,
z_hist: torch.Tensor,
temperature: float = 1.0,
num_steps: int = 20,
step_size: float = 0.01
) -> torch.Tensor:
"""
Langevin dynamics sampling from energy-based model
Args:
z_hist: [batch, hist_len, latent_dim] history
temperature: sampling temperature
num_steps: Langevin steps
step_size: step size for Langevin
Returns:
z_sample: [batch, latent_dim] sampled latent
"""
batch_size = z_hist.size(0)
latent_dim = z_hist.size(-1)
# Initialize from Gaussian
z = torch.randn(batch_size, latent_dim, device=z_hist.device) * temperature
for _ in range(num_steps):
z.requires_grad_(True)
energy = self.model.compute_energy(z, z_hist)
# Compute gradient
grad = torch.autograd.grad(energy.sum(), z)[0]
# Langevin update
z = z.detach() - step_size * grad + torch.randn_like(z) * math.sqrt(2 * step_size)
return z
```
## Applications
### 1. Neural Population Forecasting
- **Spontaneous Activity**: Predict future firing patterns
- **Multi-session Data**: Generalize across recording sessions
- **Cross-subject**: Transfer learned dynamics between animals
### 2. Brain-Computer Interfaces
- **Motor BCIs**: Predict intended movements from neural activity
- **Closed-loop Control**: Real-time neural state estimation
- **Error Correction**: Detect and correct decoding errors
### 3. Scientific Discovery
- **Circuit Mechanisms**: Understand what drives neural dynamics
- **Intervention Effects**: Predict impact of perturbations
- **Model Comparison**: Evaluate different mechanistic hypotheses
### 4. Synthetic Data Generation
- **Data Augmentation**: Generate realistic training data
- **Privacy Preservation**: Share synthetic instead of real data
- **Rare Event Sampling**: Oversample underrepresented conditions
## Pitfalls
1. **Training Instability**: Energy-based models can be unstable
- *Mitigation*: Use proper scoring rules, add regularization, gradient clipping
2. **Computational Cost**: Langevin sampling is expensive at inference
- *Mitigation*: Use fewer steps, amortized inference, or distill to autoregressive model
3. **Mode Collapse**: Can fail to capture all data modes
- *Mitigation*: Annealed Langevin, multiple chains, diversity penalties
4. **Long-term Prediction**: Error accumulation in autoregressive generation
- *Mitigation*: Scheduled sampling, teacher forcing curriculum, latent consistency losses
5. **Scaling Challenges**: Difficult for very large populations
- *Mitigation*: Factorized latent spaces, hierarchical models, sparse interactions
## Related Skills
- neural-population-decoding: Decoding methods for neural populations
- autoregressive-flow-matching-neural-dynamics: Flow matching for neural dynamics
- brain-dit-fmri-foundation-model: fMRI foundation models
- neuromorphic-continual-nuclear-ics: Continual learning for neural interfaces
## References
```bibtex
@article{ge2025energy,
title={Energy-based Autoregressive Generation for Neural Population Dynamics},
author={Ge, Ningling and Dai, Sicheng and Zhu, Yu and Yu, Shan},
journal={arXiv preprint arXiv:2511.17606},
year={2025}
}
```
## Further Reading
- Energy-Based Models: LeCun et al., "A Tutorial on Energy-Based Learning"
- Autoregressive Models: Vaswani et al., "Attention is All You Need"
- Proper Scoring Rules: Gneiting & Raftery, "Strictly Proper Scoring Rules"
- Neural Population Dynamics: Saxena et al., "Towards Community-Driven Neural Latents Benchmarks"