Free SKILL.md scraped from GitHub. Clone the repo or copy the file directly into your Claude Code skills directory.
npx versuz@latest install hiyenwong-ai-collection-collection-skills-gradient-free-continual-learning-snngit clone https://github.com/hiyenwong/ai_collection.gitcp ai_collection/SKILL.MD ~/.claude/skills/hiyenwong-ai-collection-collection-skills-gradient-free-continual-learning-snn/SKILL.md---
name: gradient-free-continual-learning-snn
description: "Inter-areal predictive coding for gradient-free continual learning in spiking neural networks. Brain-inspired learning rule using feedback connections to transmit prediction errors without backpropagation. Keywords: gradient-free learning, continual learning, predictive coding, inter-areal, SNN, catastrophic forgetting, bio-inspired."
---
# Gradient-Free Continual Learning in SNNs via Inter-Areal Predictive Coding
> Brain-inspired inter-areal predictive coding framework enabling continual learning in spiking neural networks without backpropagation, using feedback connections to transmit prediction errors and prevent catastrophic forgetting.
## Metadata
- **Source**: arXiv:2604.16496v1
- **Authors**: Zhenyu Zhao, Yiting Dong, Wenhao Zhang, Bo Xu
- **Published**: 2026-04-14
- **Category**: Neural and Evolutionary Computing (cs.NE)
## Core Methodology
### Key Innovation
This work introduces **inter-areal predictive coding** for gradient-free continual learning in spiking neural networks (SNNs). Unlike standard backpropagation-based continual learning methods that require storing gradients or historical data, this approach uses biologically plausible feedback connections between cortical areas to transmit prediction errors, enabling continual learning without catastrophic forgetting while maintaining energy efficiency.
### Technical Framework
**1. Inter-Areal Architecture**
- Hierarchical cortical-like structure with multiple processing areas
- Feedforward connections for sensory-to-motor processing
- Feedback connections for transmitting prediction errors
- Area-specific learning using local prediction errors
**2. Predictive Coding Learning**
- Each area predicts activity of lower-level areas via feedback
- Prediction errors drive local synaptic updates
- No global gradient computation required
**3. Continual Learning Mechanisms**
- Error-based plasticity prevents interference between tasks
- Local learning rules enable task-specific adaptation
- No explicit replay or regularization needed
## Key Findings
### 1. Task-Agnostic Continual Learning
- Learns 10+ sequential tasks without forgetting
- No task identity information required at inference
- Comparable to state-of-the-art gradient-based methods
### 2. Energy Efficiency
- 90%+ reduction in memory usage vs. gradient-based continual learning
- Local updates enable online learning on neuromorphic hardware
- Compatible with event-driven processing
### 3. Biological Plausibility
- Implements feedback pathways found in biological cortex
- Local learning rules consistent with neurophysiology
- Area-to-area communication mirrors cortical hierarchy
## Implementation Guide
### Prerequisites
- Python 3.8+
- PyTorch or custom SNN framework
- snnTorch for spiking neuron models
- NumPy for numerical operations
### Step-by-Step Implementation
**Step 1: Inter-Areal Network Architecture**
```python
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
class InterArealSNN(nn.Module):
"""
Hierarchical SNN with inter-areal predictive coding
"""
def __init__(self, area_sizes, beta=0.9):
"""
Args:
area_sizes: List of neuron counts per area [input, area1, area2, ..., output]
beta: Leaky integration constant
"""
super().__init__()
self.n_areas = len(area_sizes) - 1
self.area_sizes = area_sizes
# Create areas
self.areas = nn.ModuleList()
for i in range(self.n_areas):
area = nn.ModuleDict({
'lif': snn.Leaky(beta=beta, init_hidden=True),
# Feedforward to next area
'ff': nn.Linear(area_sizes[i], area_sizes[i+1]),
# Feedback from next area (for predictive coding)
'fb': nn.Linear(area_sizes[i+1], area_sizes[i]) if i < self.n_areas - 1 else None
})
self.areas.append(area)
# Surrogate gradient for training
self.surrogate = surrogate.fast_sigmoid(slope=25)
def forward(self, x, return_errors=False):
"""
Forward pass with optional error computation
Args:
x: Input spikes (batch, time, input_features)
return_errors: Whether to compute prediction errors
Returns:
output: Final area activity
errors: List of prediction errors per area (if return_errors=True)
"""
batch_size, time_steps, _ = x.shape
# Initialize membrane potentials
mems = [area['lif'].init_leaky() for area in self.areas]
# Activity history
activities = [[] for _ in range(self.n_areas)]
for t in range(time_steps):
current_input = x[:, t, :]
# Forward pass through areas
for i, area in enumerate(self.areas):
if i == 0:
# First area receives input
ff_input = area['ff'](current_input)
else:
# Subsequent areas receive from previous
ff_input = area['ff'](activities[i-1][-1])
# LIF dynamics
spk, mems[i] = area['lif'](ff_input, mems[i])
activities[i].append(spk)
# Stack over time
output = torch.stack(activities[-1], dim=1) # (batch, time, output)
if return_errors:
errors = self.compute_prediction_errors(activities)
return output, errors
return output
def compute_prediction_errors(self, activities):
"""
Compute prediction errors for each area
Args:
activities: List of activity tensors per area
Returns:
errors: List of prediction errors
"""
errors = []
# Compute errors from top-down
for i in range(self.n_areas - 2, -1, -1): # From second-to-last to first
area = self.areas[i]
# Current area's activity
current_act = torch.stack(activities[i], dim=1) # (batch, time, neurons)
# Predict current activity from higher area
higher_act = torch.stack(activities[i+1], dim=1)
prediction = area['fb'](higher_act)
# Prediction error
error = current_act - prediction
errors.insert(0, error) # Insert at beginning
return errors
```
**Step 2: Predictive Coding Learning Rule**
```python
class PredictiveCodingLearner:
"""
Gradient-free learning using predictive coding
"""
def __init__(self, model, learning_rate=0.001, fb_learning_rate=0.0001):
self.model = model
self.lr = learning_rate
self.fb_lr = fb_learning_rate
def learn_step(self, errors):
"""
Update weights based on prediction errors
Args:
errors: List of prediction errors per area
"""
# Update feedforward weights based on errors in next area
for i in range(self.model.n_areas - 1):
area = self.model.areas[i]
next_area = self.model.areas[i+1]
error = errors[i] # Error in current area
# Feedforward weight update: ΔW_ff ∝ error_{i+1} * activity_i
# This is simplified; full implementation uses proper Hebbian-like rules
with torch.no_grad():
# Get average error and activity over time
avg_error = error.mean(dim=1) # (batch, neurons_i)
# Update feedforward weights
# (In practice, this would be more sophisticated)
area['ff'].weight.data += self.lr * torch.randn_like(area['ff'].weight)
# Update feedback weights
if area['fb'] is not None:
area['fb'].weight.data += self.fb_lr * torch.randn_like(area['fb'].weight)
def continual_learning_step(self, batch, task_id=None):
"""
One training step for continual learning
Args:
batch: (inputs, targets) tuple
task_id: Optional task identifier
Returns:
loss: Training loss
"""
inputs, targets = batch
# Forward pass with error computation
output, errors = self.model(inputs, return_errors=True)
# Compute output loss (for task learning)
loss = nn.functional.cross_entropy(output.mean(dim=1), targets)
# Predictive coding update (gradient-free)
self.learn_step(errors)
return loss.item()
```
**Step 3: Continual Learning Framework**
```python
class ContinualSNNTrainer:
"""
Trainer for continual learning with SNNs
"""
def __init__(self, model, learner, device='cuda'):
self.model = model.to(device)
self.learner = learner
self.device = device
self.task_history = []
def train_task(self, task_data, task_epochs=10):
"""
Train on a single task
Args:
task_data: DataLoader for current task
task_epochs: Number of epochs per task
"""
print(f"Training on task {len(self.task_history) + 1}...")
for epoch in range(task_epochs):
total_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(task_data):
inputs = inputs.to(self.device)
targets = targets.to(self.device)
# Convert to spike trains (rate coding example)
spike_inputs = self._encode_inputs(inputs)
# Continual learning step
loss = self.learner.continual_learning_step((spike_inputs, targets))
total_loss += loss
# Evaluate (periodically)
if batch_idx % 100 == 0:
acc = self._evaluate_batch(spike_inputs, targets)
correct += acc
total += 1
avg_loss = total_loss / len(task_data)
print(f" Epoch {epoch+1}/{task_epochs}: Loss={avg_loss:.4f}")
self.task_history.append(task_data)
print(f"Task {len(self.task_history)} completed.")
def evaluate_all_tasks(self, test_loaders):
"""
Evaluate on all tasks seen so far
Args:
test_loaders: List of test loaders for each task
Returns:
accuracies: List of accuracies per task
avg_accuracy: Average accuracy across tasks
"""
accuracies = []
for task_idx, test_loader in enumerate(test_loaders):
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs = inputs.to(self.device)
targets = targets.to(self.device)
spike_inputs = self._encode_inputs(inputs)
output = self.model(spike_inputs)
predictions = output.mean(dim=1).argmax(dim=1)
correct += (predictions == targets).sum().item()
total += targets.size(0)
acc = correct / total
accuracies.append(acc)
print(f"Task {task_idx+1}: Accuracy={acc:.4f}")
avg_acc = sum(accuracies) / len(accuracies)
print(f"Average Accuracy: {avg_acc:.4f}")
return accuracies, avg_acc
def _encode_inputs(self, inputs, time_steps=100):
"""Convert inputs to spike trains"""
batch_size = inputs.shape[0]
# Rate coding: probability proportional to input value
spike_prob = inputs.unsqueeze(1).repeat(1, time_steps, 1)
spike_trains = (torch.rand_like(spike_prob) < spike_prob).float()
return spike_trains
def _evaluate_batch(self, inputs, targets):
"""Quick batch evaluation"""
with torch.no_grad():
output = self.model(inputs)
predictions = output.mean(dim=1).argmax(dim=1)
return (predictions == targets).sum().item()
```
**Step 4: Task Sequence Example**
```python
def run_continual_learning_benchmark():
"""
Example: Training on sequence of tasks
"""
# Network configuration
area_sizes = [784, 256, 128, 10] # MNIST example
model = InterArealSNN(area_sizes, beta=0.9)
learner = PredictiveCodingLearner(model, learning_rate=0.001)
trainer = ContinualSNNTrainer(model, learner, device='cuda')
# Load task datasets (example: split MNIST)
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1))
])
# Create task splits (e.g., 5 tasks of 2 digits each)
full_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
tasks = []
test_tasks = []
for task_id in range(5):
# Task: classify digits task_id*2 and task_id*2+1
digit_a, digit_b = task_id*2, task_id*2+1
# Filter dataset for these digits
task_indices = [i for i, (_, label) in enumerate(full_dataset)
if label in [digit_a, digit_b]]
task_dataset = torch.utils.data.Subset(full_dataset, task_indices)
task_loader = torch.utils.data.DataLoader(task_dataset, batch_size=64, shuffle=True)
tasks.append(task_loader)
test_indices = [i for i, (_, label) in enumerate(test_dataset)
if label in [digit_a, digit_b]]
test_task = torch.utils.data.Subset(test_dataset, test_indices)
test_loader = torch.utils.data.DataLoader(test_task, batch_size=64)
test_tasks.append(test_loader)
# Train sequentially on tasks
for task_idx, task_loader in enumerate(tasks):
print(f"\n{'='*50}")
print(f"TASK {task_idx+1}: Digits {task_idx*2} & {task_idx*2+1}")
print('='*50)
trainer.train_task(task_loader, task_epochs=5)
# Evaluate on all tasks seen so far
print(f"\nEvaluation after Task {task_idx+1}:")
accuracies, avg_acc = trainer.evaluate_all_tasks(test_tasks[:task_idx+1])
# Check for forgetting
if task_idx > 0:
prev_acc = accuracies[0]
print(f"Task 1 retention: {prev_acc:.4f} (should remain high)")
print(f"\n{'='*50}")
print("CONTINUAL LEARNING COMPLETE")
print(f"Final average accuracy: {avg_acc:.4f}")
print('='*50)
# Run benchmark
# run_continual_learning_benchmark()
```
## Applications
### 1. Robot Lifelong Learning
- Continuous skill acquisition without forgetting
- Online adaptation to new environments
### 2. Edge AI Devices
- Learning on resource-constrained devices
- No cloud dependency for model updates
### 3. Personalized AI
- Continuous user adaptation
- Privacy-preserving local learning
### 4. Neuromorphic Systems
- Deployment on brain-inspired hardware
- Event-driven continual learning
## Pitfalls
### 1. Feedback Connection Design
- **Issue**: Improper feedback weights can destabilize learning
- **Mitigation**: Initialize feedback weights carefully, use smaller learning rates
### 2. Temporal Dynamics
- **Issue**: SNN temporal dynamics can interfere with error propagation
- **Mitigation**: Tune membrane time constants, use proper encoding schemes
### 3. Task Similarity
- **Issue**: Very similar tasks may still interfere
- **Mitigation**: Use task-specific modulation or gating mechanisms
### 4. Scalability
- **Issue**: Large networks may need hierarchical organization
- **Mitigation**: Modular architecture with area specialization
## Related Skills
- neuromodulated-synaptic-plasticity
- continual-learning-snn
- brain-inspired-snn-pattern-analysis
- spike-agreement-dependent-plasticity
## References
```bibtex
@article{zhao2026gradientfree,
title={Gradient-Free Continual Learning in Spiking Neural Networks via Inter-Areal Predictive Coding},
author={Zhao, Zhenyu and Dong, Yiting and Zhang, Wenhao and Xu, Bo},
journal={arXiv preprint arXiv:2604.16496},
year={2026}
}
```