Free SKILL.md scraped from GitHub. Clone the repo or copy the file directly into your Claude Code skills directory.
npx versuz@latest install hiyenwong-ai-collection-collection-skills-h2learn-snn-acceleratorgit clone https://github.com/hiyenwong/ai_collection.gitcp ai_collection/SKILL.MD ~/.claude/skills/hiyenwong-ai-collection-collection-skills-h2learn-snn-accelerator/SKILL.md---
name: h2learn-snn-accelerator
description: 'High-efficiency hardware accelerator for BPTT-based Spiking Neural Network training. Design LUT-based processing elements, dual-sparsity-aware backward engine, and pipeline optimization. Achieve 7.38x area saving, 10.20x speedup vs GPU.'
---
# H2Learn: SNN Training Accelerator
## Description
A novel hardware architecture achieving high efficiency for BPTT-based Spiking Neural Network (SNN) learning while ensuring high accuracy. Exploits binary spike computation and gradient sparsity to design specialized processing engines, achieving 7.38x area saving, 5.74-10.20x speedup, and 5.25-7.12x energy saving compared to NVIDIA V100 GPU.
**Source:** arXiv:2107.11746v1
**Utility:** 0.91
## Activation Keywords
- SNN accelerator
- BPTT spiking neural network
- neuromorphic hardware training
- LUT-based SNN processing
- dual-sparsity backward engine
- spike-based gradient computation
- SNN hardware optimization
- efficient SNN training
## Core Concepts
### 1. BPTT for SNNs Challenge
**Problem:**
- Local synaptic plasticity rules → low accuracy
- BPTT-based SNN learning → high accuracy but computationally expensive
- General-purpose processors → low efficiency (ANN-tailored)
- Neuromorphic chips → cannot support BPTT (local rules only)
**Solution:**
H2Learn bridges the gap with specialized hardware for BPTT-based SNN training.
### 2. SNN BPTT Behavior Analysis
**Key Observations:**
| Phase | Computation | Characteristics |
|-------|-------------|-----------------|
| Forward pass | Binary spike-based | Sparse, binary operations |
| Backward pass | Gradient computation | Rich sparsity in gradients |
| Weight update | Spike-based accumulation | Binary operations |
**Exploited Properties:**
1. **Binary spike computation** → LUT-based processing
2. **Gradient sparsity** → Dual-sparsity-aware design
3. **Weight update** → Implicit accumulation
### 3. Three-Engine Architecture
**Engine 1: Forward Engine**
```
Input spikes (binary) → LUT-based PE → Membrane potential → Output spikes
Key features:
- LUT-based processing elements (implicit accumulation)
- Fused computation for multiple input points
- Binary spike operations
```
**Engine 2: Backward Engine**
```
Gradient input → Dual-sparsity-aware processing → Gradient output
Key features:
- Exploit both input and output sparsity
- Skip zero-gradient computations
- Efficient sparse matrix operations
```
**Engine 3: Weight Update Engine**
```
Gradient + Spike trace → LUT-based PE → Weight delta
Key features:
- LUT-based processing (similar to Forward Engine)
- Implicit accumulation
- Binary spike-based update
```
### 4. Pipeline Optimization
**End-to-End Pipeline:**
```
Forward Engine → Backward Engine → Weight Update Engine
↓ ↓ ↓
Pipeline stage 1 Pipeline stage 2 Pipeline stage 3
Optimization:
- Minimize pipeline bubbles
- Overlap computation across layers
- Efficient memory access pattern
```
### 5. Performance Results
| Metric | vs NVIDIA V100 GPU |
|--------|-------------------|
| Area saving | 7.38x |
| Speedup | 5.74-10.20x |
| Energy saving | 5.25-7.12x |
**Benchmark datasets:**
- MNIST
- CIFAR-10
- DVS-Gesture
- N-MNIST
## Step-by-Step Instructions
### 1. LUT-Based Processing Element Design
```python
import numpy as np
from typing import Tuple, Optional
class LUTProcessingElement:
"""
LUT-based processing element for spike computation.
Benefits:
- Implicit accumulation via LUT lookup
- Fused computation for multiple input points
- Binary spike operations
Args:
lut_size: Size of lookup table
threshold: Spike threshold
"""
def __init__(self, lut_size: int = 256, threshold: float = 1.0):
self.lut_size = lut_size
self.threshold = threshold
# Pre-compute LUT for membrane potential
self.lut = self._build_lut()
def _build_lut(self) -> np.ndarray:
"""
Build lookup table for membrane potential accumulation.
Returns:
lut: Lookup table array
"""
# LUT stores accumulated membrane potential
lut = np.linspace(-self.threshold, self.threshold, self.lut_size)
return lut
def forward_pass(
self,
input_spikes: np.ndarray,
weights: np.ndarray,
membrane_potential: float
) -> Tuple[float, np.ndarray]:
"""
LUT-based forward pass computation.
Args:
input_spikes: Binary input spikes (0 or 1)
weights: Synaptic weights
membrane_potential: Current membrane potential
Returns:
new_potential: Updated membrane potential
output_spikes: Binary output spikes
"""
# Implicit accumulation via LUT
weight_sum = np.sum(input_spikes * weights)
# LUT lookup for membrane potential update
lut_index = int((membrane_potential + weight_sum) / self.threshold * (self.lut_size // 2))
lut_index = np.clip(lut_index, 0, self.lut_size - 1)
new_potential = self.lut[lut_index]
# Spike generation
output_spikes = (new_potential >= self.threshold).astype(float)
return new_potential, output_spikes
def fused_computation(
self,
input_batch: np.ndarray,
weights: np.ndarray
) -> np.ndarray:
"""
Fused computation for multiple input points.
Args:
input_batch: Batch of binary input spikes
weights: Synaptic weights
Returns:
membrane_potentials: Batch of membrane potentials
"""
# Batch LUT lookup
weight_sums = np.dot(input_batch, weights)
lut_indices = (weight_sums / self.threshold * (self.lut_size // 2)).astype(int)
lut_indices = np.clip(lut_indices, 0, self.lut_size - 1)
membrane_potentials = self.lut[lut_indices]
return membrane_potentials
```
### 2. Dual-Sparsity-Aware Backward Engine
```python
class DualSparsityBackwardEngine:
"""
Dual-sparsity-aware backward engine for gradient computation.
Exploits:
- Input sparsity: Many gradients are zero
- Output sparsity: Many neurons don't fire
Args:
input_sparsity_threshold: Threshold for input sparsity
output_sparsity_threshold: Threshold for output sparsity
"""
def __init__(
self,
input_sparsity_threshold: float = 0.01,
output_sparsity_threshold: float = 0.01
):
self.input_threshold = input_sparsity_threshold
self.output_threshold = output_sparsity_threshold
def backward_pass(
self,
gradient_input: np.ndarray,
weights: np.ndarray,
spike_traces: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""
Dual-sparsity-aware backward pass.
Args:
gradient_input: Input gradients
weights: Synaptic weights
spike_traces: Spike traces from forward pass
Returns:
gradient_output: Output gradients
weight_gradients: Weight gradients
"""
# Input sparsity: skip zero gradients
input_mask = np.abs(gradient_input) > self.input_threshold
# Output sparsity: skip neurons that didn't fire
output_mask = spike_traces > self.output_threshold
# Sparse gradient computation
sparse_gradient_input = gradient_input * input_mask
# Only compute where both input and output are active
dual_mask = input_mask & output_mask
gradient_output = np.dot(sparse_gradient_input, weights.T) * output_mask
weight_gradients = np.outer(sparse_gradient_input, spike_traces) * dual_mask
return gradient_output, weight_gradients
def compute_sparsity_ratio(self, gradient_input: np.ndarray) -> float:
"""
Compute sparsity ratio of gradients.
Args:
gradient_input: Input gradients
Returns:
sparsity_ratio: Ratio of zero gradients
"""
zero_count = np.sum(np.abs(gradient_input) <= self.input_threshold)
total_count = gradient_input.size
sparsity_ratio = zero_count / total_count
return sparsity_ratio
```
### 3. Weight Update Engine
```python
class WeightUpdateEngine:
"""
LUT-based weight update engine.
Similar to Forward Engine:
- LUT-based processing elements
- Implicit accumulation
- Binary spike-based update
Args:
learning_rate: Learning rate for weight update
lut_size: Size of lookup table
"""
def __init__(self, learning_rate: float = 0.01, lut_size: int = 256):
self.learning_rate = learning_rate
self.lut_size = lut_size
# LUT for weight delta computation
self.delta_lut = self._build_delta_lut()
def _build_delta_lut(self) -> np.ndarray:
"""
Build LUT for weight delta computation.
Returns:
delta_lut: Lookup table for weight deltas
"""
delta_lut = np.linspace(-1.0, 1.0, self.lut_size)
return delta_lut
def update_weights(
self,
weights: np.ndarray,
weight_gradients: np.ndarray,
spike_traces: np.ndarray
) -> np.ndarray:
"""
LUT-based weight update.
Args:
weights: Current weights
weight_gradients: Weight gradients
spike_traces: Spike traces
Returns:
new_weights: Updated weights
"""
# LUT lookup for weight delta
delta_indices = (weight_gradients * self.learning_rate).astype(int)
delta_indices = np.clip(delta_indices, 0, self.lut_size - 1)
weight_deltas = self.delta_lut[delta_indices]
# Implicit accumulation (via LUT)
new_weights = weights + weight_deltas
return new_weights
```
### 4. Pipeline Controller
```python
class H2LearnPipeline:
"""
End-to-end pipeline controller for H2Learn.
Coordinates:
- Forward Engine
- Backward Engine
- Weight Update Engine
Args:
n_layers: Number of layers
batch_size: Batch size
"""
def __init__(self, n_layers: int = 3, batch_size: int = 32):
self.n_layers = n_layers
self.batch_size = batch_size
# Initialize engines for each layer
self.forward_engines = [
LUTProcessingElement() for _ in range(n_layers)
]
self.backward_engines = [
DualSparsityBackwardEngine() for _ in range(n_layers)
]
self.weight_update_engines = [
WeightUpdateEngine() for _ in range(n_layers)
]
# Pipeline state
self.pipeline_state = {
'forward_complete': False,
'backward_complete': False,
'update_complete': False
}
def forward_pipeline(
self,
input_batch: np.ndarray,
weights: list
) -> Tuple[list, list]:
"""
Execute forward pass pipeline.
Args:
input_batch: Input spike batch
weights: List of weight matrices
Returns:
membrane_potentials: List of membrane potentials per layer
spike_traces: List of spike traces per layer
"""
membrane_potentials = []
spike_traces = []
current_input = input_batch
for i, engine in enumerate(self.forward_engines):
# Forward pass for layer i
potentials = engine.fused_computation(current_input, weights[i])
spikes = (potentials >= engine.threshold).astype(float)
membrane_potentials.append(potentials)
spike_traces.append(spikes)
current_input = spikes
self.pipeline_state['forward_complete'] = True
return membrane_potentials, spike_traces
def backward_pipeline(
self,
gradient_input: np.ndarray,
weights: list,
spike_traces: list
) -> Tuple[list, list]:
"""
Execute backward pass pipeline.
Args:
gradient_input: Gradient from loss
weights: List of weight matrices
spike_traces: Spike traces from forward pass
Returns:
layer_gradients: List of layer gradients
weight_gradients: List of weight gradients
"""
layer_gradients = []
weight_gradients = []
current_gradient = gradient_input
for i in reversed(range(self.n_layers)):
# Backward pass for layer i
grad_out, w_grad = self.backward_engines[i].backward_pass(
current_gradient, weights[i], spike_traces[i]
)
layer_gradients.append(grad_out)
weight_gradients.append(w_grad)
current_gradient = grad_out
self.pipeline_state['backward_complete'] = True
return layer_gradients, weight_gradients
def update_pipeline(
self,
weights: list,
weight_gradients: list,
spike_traces: list
) -> list:
"""
Execute weight update pipeline.
Args:
weights: Current weights
weight_gradients: Weight gradients
spike_traces: Spike traces
Returns:
new_weights: Updated weights
"""
new_weights = []
for i, engine in enumerate(self.weight_update_engines):
updated = engine.update_weights(
weights[i], weight_gradients[i], spike_traces[i]
)
new_weights.append(updated)
self.pipeline_state['update_complete'] = True
return new_weights
def full_pipeline(
self,
input_batch: np.ndarray,
weights: list,
gradient_input: np.ndarray
) -> list:
"""
Execute full training pipeline.
Args:
input_batch: Input spike batch
weights: Current weights
gradient_input: Gradient from loss
Returns:
new_weights: Updated weights
"""
# Forward pass
potentials, traces = self.forward_pipeline(input_batch, weights)
# Backward pass
layer_grads, weight_grads = self.backward_pipeline(
gradient_input, weights, traces
)
# Weight update
new_weights = self.update_pipeline(weights, weight_grads, traces)
return new_weights
```
### 5. Complete Training Loop
```python
def h2learn_training_loop(
train_data: np.ndarray,
train_labels: np.ndarray,
n_layers: int = 3,
epochs: int = 10,
learning_rate: float = 0.01
) -> dict:
"""
Complete H2Learn training loop.
Args:
train_data: Training data (binary spikes)
train_labels: Training labels
n_layers: Number of layers
epochs: Training epochs
learning_rate: Learning rate
Returns:
results: Training results
"""
batch_size = 32
pipeline = H2LearnPipeline(n_layers, batch_size)
# Initialize weights
weights = [
np.random.randn(100, 100) * 0.1 for _ in range(n_layers)
]
losses = []
for epoch in range(epochs):
epoch_loss = 0.0
for batch_idx in range(0, len(train_data), batch_size):
batch = train_data[batch_idx:batch_idx + batch_size]
labels = train_labels[batch_idx:batch_idx + batch_size]
# Forward pass
potentials, traces = pipeline.forward_pipeline(batch, weights)
# Compute loss (simple MSE)
output = traces[-1]
loss = np.mean((output - labels)**2)
epoch_loss += loss
# Gradient
gradient = 2 * (output - labels) / labels.size
# Backward + update
weights = pipeline.full_pipeline(batch, weights, gradient)
losses.append(epoch_loss / (len(train_data) // batch_size))
print(f"Epoch {epoch}: Loss = {losses[-1]}")
results = {
'final_loss': losses[-1],
'loss_history': losses,
'final_weights': weights
}
return results
```
## Tools Used
- `numpy` - Numerical computations
- `typing` - Type annotations
- `exec` - Run simulation scripts
## Example Use Cases
### 1. Basic Forward Pass
```python
# Create LUT-based PE
pe = LUTProcessingElement(lut_size=256, threshold=1.0)
# Binary input spikes
input_spikes = np.array([1, 0, 1, 1, 0])
weights = np.random.randn(5) * 0.1
# Forward pass
membrane, output = pe.forward_pass(input_spikes, weights, 0.0)
print(f"Output spike: {output}")
```
### 2. Sparsity Analysis
```python
# Analyze gradient sparsity
backward_engine = DualSparsityBackwardEngine()
gradient = np.random.randn(100) * 0.1
gradient[gradient < 0.05] = 0 # Make sparse
sparsity = backward_engine.compute_sparsity_ratio(gradient)
print(f"Sparsity ratio: {sparsity:.2%}")
```
### 3. Full Pipeline Training
```python
# Generate synthetic data
train_data = np.random.randint(0, 2, (1000, 100)).astype(float)
train_labels = np.random.randint(0, 2, (1000, 10)).astype(float)
# Train with H2Learn pipeline
results = h2learn_training_loop(train_data, train_labels, epochs=10)
print(f"Final loss: {results['final_loss']}")
```
## Hardware Implementation Notes
**Area Optimization:**
- LUT-based PE → 7.38x area saving vs GPU
- Implicit accumulation → Reduced hardware complexity
- Fused computation → Fewer memory accesses
**Speedup Optimization:**
- Pipeline overlap → 5.74-10.20x speedup
- Dual-sparsity skip → Reduced computation
- LUT lookup → Faster than arithmetic
**Energy Optimization:**
- Sparse computation → 5.25-7.12x energy saving
- Binary operations → Reduced power consumption
- Pipeline efficiency → Minimized idle cycles
## Instructions for Agents
Follow these steps when applying this skill:
### Step 1: LUT-Based Processing Element Design
## Examples
### Example 1: Basic Application
**User:** I need to apply H2Learn: SNN Training Accelerator to my analysis.
**Agent:** I'll help you apply h2learn-snn-accelerator. First, let me understand your specific use case...
**Context:** Apply the methodology
### Example 2: Advanced Scenario
**User:** Complex analysis scenario
**Agent:** Based on the methodology, I'll guide you through the advanced application...
### Example 2: Advanced Application
**User:** What are the key considerations for h2learn-snn-accelerator?
**Agent:** Let me search for the latest research and best practices...
## Related Skills
- `spikingjelly-framework` - Spiking Jelly framework
- `multi-plasticity-snn-training` - Multi-plasticity SNN training
- `decolle-snn-learning` - DECOLLE SNN learning
## References
- Liang, L. et al. (2021). "H2Learn: High-Efficiency Learning Accelerator for High-Accuracy Spiking Neural Networks" arXiv:2107.11746v1 [cs.NE]
---
**Created:** 2026-03-30 00:05
**Author:** Aerial (from arXiv:2107.11746v1)