Free SKILL.md scraped from GitHub. Clone the repo or copy the file directly into your Claude Code skills directory.
npx versuz@latest install hiyenwong-ai-collection-collection-skills-dynamic-gated-neuron-snngit clone https://github.com/hiyenwong/ai_collection.gitcp ai_collection/SKILL.MD ~/.claude/skills/hiyenwong-ai-collection-collection-skills-dynamic-gated-neuron-snn/SKILL.md---
name: dynamic-gated-neuron-snn
description: "Dynamic Gated Neuron (DGN) - Biologically plausible gating mechanism for Spiking Neural Networks via dynamic membrane conductance modulation. Enables selective input filtering and adaptive noise suppression. Activation triggers: dynamic gated neuron, DGN, SNN gating, conductance-based SNN, robust spiking neural network, biological gating."
---
# Dynamic Gated Neuron (DGN) for SNNs
> A brain-inspired gating mechanism that unlocks robust computation in Spiking Neural Networks through dynamic membrane conductance modulation.
## Metadata
- **Source**: arXiv:2509.03281
- **Authors**: Qianyi Bai, Haiteng Wang, Qiang Yu
- **Published**: 2025-09
- **Institution**: Tianjin University, Tianjin Normal University
- **Code**: TBD (check paper for updates)
## Core Methodology
### Key Innovation
Traditional Leaky Integrate-and-Fire (LIF) neurons lack internal gating mechanisms, limiting their ability to cope with noise and temporal variability. The Dynamic Gated Neuron (DGN) introduces:
1. **Activity-Dependent Conductance**: Membrane conductance evolves dynamically in response to neuronal activity
2. **Selective Input Filtering**: Adaptive noise suppression based on input dynamics
3. **Stochastic Stability**: Enhanced stability guarantees under noisy conditions
4. **Biological Plausibility**: Grounded in real neurophysiological mechanisms (protein phosphorylation, gene expression, calcium signaling)
### Biological Inspiration
| Biological Mechanism | Computational Analog | Function |
|---------------------|---------------------|----------|
| Protein phosphorylation | Activity tracking | State-dependent modulation |
| Immediate early genes (c-fos, ras) | Conductance update | Long-term plasticity |
| Intracellular calcium | Second messenger | Activity-to-conductance coupling |
| Potassium channel modulation | Dynamic conductance | Adaptive filtering |
### Technical Framework
#### Dynamic Gated Neuron Model
The DGN extends the LIF neuron with dynamic conductance:
$$
\tau_m \frac{dv}{dt} = -(v - v_{rest}) - g(t) \cdot v + I_{syn}(t)
$$
$$
\tau_g \frac{dg}{dt} = -g + \alpha \cdot \phi(v_{history})
$$
Where:
- $v$: Membrane potential
- $g(t)$: Dynamic conductance (gating variable)
- $\tau_m, \tau_g$: Time constants for membrane and conductance
- $\phi$: Activity-dependent modulation function
- $\alpha$: Conductance gain
#### Gating Function
The gating mechanism modulates information flow:
```python
# Dynamic conductance acts as adaptive filter
g_t = g_{t-1} + α * tanh(β * v_t - γ) - g_{t-1}/τ_g
# Effective input current
I_eff = I_syn / (1 + g_t) # Conductance shunts input
# Membrane update
dv = (-(v - v_rest) + I_eff) / τ_m
```
**Key Properties**:
- High conductance → Reduced membrane time constant → Fast response to salient inputs
- Low conductance → Extended membrane time constant → Integration of weak signals
- Adaptive threshold: Effectively implements dynamic input filtering
## Implementation Guide
### Prerequisites
```python
# Required packages
pip install torch snntorch # For SNN implementation
pip install numpy matplotlib
```
### Step-by-Step
#### Step 1: Basic DGN Implementation
```python
import torch
import torch.nn as nn
import numpy as np
class DynamicGatedNeuron(nn.Module):
"""
Dynamic Gated Neuron (DGN)
Biologically plausible spiking neuron with dynamic conductance modulation.
"""
def __init__(
self,
tau_m=20.0, # Membrane time constant (ms)
tau_g=100.0, # Conductance time constant (ms)
v_rest=-65.0, # Resting potential (mV)
v_thresh=-50.0, # Threshold potential (mV)
v_reset=-70.0, # Reset potential (mV)
alpha=0.1, # Conductance gain
beta=0.5, # Activity sensitivity
gamma=0.0, # Activity offset
dt=1.0 # Time step (ms)
):
super().__init__()
self.tau_m = tau_m
self.tau_g = tau_g
self.v_rest = v_rest
self.v_thresh = v_thresh
self.v_reset = v_reset
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.dt = dt
# State variables (initialized in forward)
self.v = None
self.g = None
self.spike = None
def reset_state(self, batch_size, device):
"""Reset neuron state"""
self.v = torch.full((batch_size,), self.v_rest, device=device)
self.g = torch.zeros(batch_size, device=device)
self.spike = torch.zeros(batch_size, device=device)
def forward(self, I_syn):
"""
Single time step forward pass
Args:
I_syn: Synaptic input current (batch_size,)
Returns:
spike: Binary spike output (batch_size,)
"""
# Ensure state is initialized
if self.v is None:
self.reset_state(I_syn.size(0), I_syn.device)
# Update dynamic conductance
# g(t) evolves based on recent activity (proxied by membrane potential)
activity = torch.tanh(self.beta * (self.v - self.v_rest) - self.gamma)
dg = (-self.g + self.alpha * activity) / self.tau_g * self.dt
self.g = self.g + dg
self.g = torch.clamp(self.g, min=0) # Non-negative conductance
# Compute effective input (gating effect)
# Higher conductance shunts more current
I_eff = I_syn / (1 + self.g)
# Update membrane potential
dv = (-(self.v - self.v_rest) + I_eff) / self.tau_m * self.dt
self.v = self.v + dv
# Check for spike
self.spike = (self.v >= self.v_thresh).float()
# Reset if spiked
self.v = torch.where(
self.spike > 0,
torch.full_like(self.v, self.v_reset),
self.v
)
return self.spike
def get_gating_strength(self):
"""Return current gating strength (for analysis)"""
return self.g.clone()
class DGNLayer(nn.Module):
"""Layer of Dynamic Gated Neurons"""
def __init__(self, n_neurons, **kwargs):
super().__init__()
self.n_neurons = n_neurons
self.neurons = nn.ModuleList([
DynamicGatedNeuron(**kwargs) for _ in range(n_neurons)
])
def reset_state(self, batch_size, device):
for neuron in self.neurons:
neuron.reset_state(batch_size, device)
def forward(self, I_syn):
"""
Args:
I_syn: (batch_size, n_neurons) - input currents
Returns:
spikes: (batch_size, n_neurons) - output spikes
"""
spikes = []
for i, neuron in enumerate(self.neurons):
spike = neuron(I_syn[:, i])
spikes.append(spike)
return torch.stack(spikes, dim=1)
```
#### Step 2: DGN Network for Pattern Recognition
```python
class DGNSNN(nn.Module):
"""
Spiking Neural Network with Dynamic Gated Neurons
for robust pattern recognition
"""
def __init__(
self,
input_size,
hidden_size,
output_size,
n_time_steps=100
):
super().__init__()
self.n_time_steps = n_time_steps
# Input projection
self.input_fc = nn.Linear(input_size, hidden_size)
# Hidden layer with DGN
self.hidden = DGNLayer(
n_neurons=hidden_size,
tau_m=20.0,
tau_g=50.0, # Faster conductance adaptation
alpha=0.2
)
# Recurrent connections
self.recurrent = nn.Linear(hidden_size, hidden_size)
# Output layer (rate coding)
self.output_fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
"""
Args:
x: (batch_size, n_time_steps, input_size)
Returns:
output: (batch_size, output_size)
"""
batch_size = x.size(0)
device = x.device
# Reset states
self.hidden.reset_state(batch_size, device)
# Record hidden spike trains
hidden_spikes = []
for t in range(self.n_time_steps):
# Input current
I_in = self.input_fc(x[:, t])
# Recurrent current
if t > 0:
I_rec = self.recurrent(hidden_spikes[-1])
else:
I_rec = torch.zeros_like(I_in)
# Total synaptic current
I_syn = I_in + I_rec
# DGN forward pass
spikes = self.hidden(I_syn)
hidden_spikes.append(spikes)
# Stack spike trains
hidden_spikes = torch.stack(hidden_spikes, dim=1) # (B, T, H)
# Rate coding: sum spikes over time
spike_rates = hidden_spikes.sum(dim=1) / self.n_time_steps
# Output
output = self.output_fc(spike_rates)
return output
```
#### Step 3: Training with Surrogate Gradients
```python
import torch.nn.functional as F
from torch.utils.data import DataLoader
def surrogate_gradient(spike, v, v_thresh, alpha=1.0):
"""
Surrogate gradient for backpropagation through spikes
Using fast sigmoid surrogate
"""
return alpha / (1 + (v - v_thresh).pow(2))
class SurrogateGradient(torch.autograd.Function):
"""Custom surrogate gradient for spiking neurons"""
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return (input > 0).float()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
# Fast sigmoid derivative
grad_input = grad_output / (1 + input.abs()).pow(2)
return grad_input
def train_dgn_snn(
model,
train_loader,
epochs=50,
lr=1e-3,
device='cuda'
):
"""Train DGN-SNN with surrogate gradients"""
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
# Add time dimension if not present
if data.dim() == 2:
# Static inputs - repeat for n_time_steps
data = data.unsqueeze(1).repeat(1, model.n_time_steps, 1)
optimizer.zero_grad()
# Forward pass
output = model(data)
# Cross-entropy loss
loss = F.cross_entropy(output, target)
# Backward pass with surrogate gradients
loss.backward()
optimizer.step()
# Statistics
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
acc = 100. * correct / total
print(f'Epoch {epoch}: Loss = {total_loss/len(train_loader):.4f}, '
f'Acc = {acc:.2f}%')
return model
```
#### Step 4: Robustness Evaluation
```python
def evaluate_robustness(model, test_loader, noise_levels, device='cuda'):
"""
Evaluate model robustness under different noise conditions
"""
model.eval()
results = {}
for noise_std in noise_levels:
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
# Add Gaussian noise
if data.dim() == 2:
data = data.unsqueeze(1).repeat(1, model.n_time_steps, 1)
noisy_data = data + torch.randn_like(data) * noise_std
noisy_data = torch.clamp(noisy_data, 0, 1)
# Forward pass
output = model(noisy_data)
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
acc = 100. * correct / total
results[noise_std] = acc
print(f'Noise std = {noise_std:.3f}: Accuracy = {acc:.2f}%')
return results
def compare_with_lif(dgn_model, lif_model, test_loader, device='cuda'):
"""
Compare DGN with standard LIF neuron performance
"""
results = {'DGN': {}, 'LIF': {}}
# Test on clean data
for name, model in [('DGN', dgn_model), ('LIF', lif_model)]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
results[name]['clean'] = 100. * correct / total
# Test under noise
noise_std = 0.2
for name, model in [('DGN', dgn_model), ('LIF', lif_model)]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
noisy_data = data + torch.randn_like(data) * noise_std
output = model(torch.clamp(noisy_data, 0, 1))
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
results[name]['noisy'] = 100. * correct / total
print("\nRobustness Comparison:")
print(f"DGN: Clean = {results['DGN']['clean']:.2f}%, "
f"Noisy = {results['DGN']['noisy']:.2f}%")
print(f"LIF: Clean = {results['LIF']['clean']:.2f}%, "
f"Noisy = {results['LIF']['noisy']:.2f}%")
return results
```
#### Step 5: Temporal Processing - TIDIGITS
```python
class DGNForAudio(nn.Module):
"""
DGN-based SNN for temporal audio processing
Applied to TIDIGITS dataset
"""
def __init__(
self,
n_freq_bins=40,
hidden_size=256,
output_size=11, # 10 digits + silence
n_time_steps=500
):
super().__init__()
self.n_time_steps = n_time_steps
# Input: Mel spectrogram features
self.input_proj = nn.Linear(n_freq_bins, hidden_size)
# DGN hidden layer
self.hidden = DGNLayer(
n_neurons=hidden_size,
tau_m=10.0, # Faster dynamics for audio
tau_g=30.0, # Adaptive time constant
alpha=0.3, # Stronger gating
beta=0.8, # Sensitive to activity
gamma=-5.0 # Threshold for gating activation
)
# Recurrent connections
self.recurrent = nn.Linear(hidden_size, hidden_size)
# Readout layer
self.readout = nn.Linear(hidden_size, output_size)
def forward(self, mel_spec):
"""
Args:
mel_spec: (batch, time, freq_bins) - Mel spectrogram
Returns:
output: (batch, output_size) - Digit classification
"""
batch_size = mel_spec.size(0)
device = mel_spec.device
# Reset states
self.hidden.reset_state(batch_size, device)
hidden_spikes = []
gating_history = []
for t in range(min(mel_spec.size(1), self.n_time_steps)):
# Project input
I_in = self.input_proj(mel_spec[:, t])
# Recurrent input from previous spikes
if t > 0:
I_rec = self.recurrent(hidden_spikes[-1])
else:
I_rec = torch.zeros_like(I_in)
# Total current
I_total = I_in + I_rec
# DGN step
spikes = self.hidden(I_total)
hidden_spikes.append(spikes)
# Record gating strength for analysis
gating_history.append(self.hidden.neurons[0].get_gating_strength()[0])
hidden_spikes = torch.stack(hidden_spikes, dim=1)
# Temporal pooling (late integration)
spike_sum = hidden_spikes.sum(dim=1)
output = self.readout(spike_sum)
return output, gating_history
def load_tidigits():
"""Load and preprocess TIDIGITS dataset"""
# Note: TIDIGITS requires torchaudio or custom loading
# This is a placeholder
from torchaudio.datasets import TIDIGITS
# Apply mel spectrogram transform
transform = torchaudio.transforms.MelSpectrogram(
sample_rate=8000,
n_fft=512,
n_mels=40
)
return TIDIGITS(root='./data', transform=transform)
```
## Applications
### 1. Noise-Robust Pattern Recognition
```python
def deploy_robust_classifier(dgn_model, input_signal, noise_profile='moderate'):
"""
Deploy DGN for robust classification under varying noise
"""
noise_config = {
'low': 0.05,
'moderate': 0.15,
'high': 0.30
}
noise_level = noise_config.get(noise_profile, 0.15)
# Add expected noise
noisy_input = input_signal + torch.randn_like(input_signal) * noise_level
# DGN automatically adapts gating to suppress noise
output = dgn_model(noisy_input)
return output
```
### 2. Adaptive Filtering
```python
class AdaptiveFilterDGN:
"""Use DGN gating for adaptive signal filtering"""
def __init__(self, n_channels):
self.neurons = [DynamicGatedNeuron() for _ in range(n_channels)]
def filter_signal(self, signal, snr_threshold=10):
"""
Filter noisy signal using DGN gating
The dynamic conductance adapts to signal statistics,
effectively filtering noise based on local SNR
"""
filtered = np.zeros_like(signal)
gating_strength = np.zeros_like(signal)
for t in range(len(signal)):
for ch in range(signal.shape[1]):
# DGN processes signal
spike = self.neurons[ch].forward(
torch.tensor([signal[t, ch]])
)
# Record effective filtering (via gating)
gating_strength[t, ch] = self.neurons[ch].g.item()
# Output is spike (0 or 1) - can be converted to rate
filtered[t, ch] = spike.item()
return filtered, gating_strength
```
### 3. Neuromorphic Computing
```python
def deploy_on_neuromorphic(dgn_network, spike_input, hardware='loihi'):
"""
Deployment guidelines for neuromorphic hardware
Hardware compatibility:
- Intel Loihi: Native conductance support
- IBM TrueNorth: Requires conductance approximation
- SpiNNaker: Full DGN support via custom neuron models
"""
if hardware == 'loihi':
# Map DGN to Loihi conductance-based neurons
# Use adaptive exponential integrate-and-fire
config = {
'neuron_type': 'cuba', # Current-based with adaptation
'tau_m': dgn_network.tau_m,
'tau_adapt': dgn_network.tau_g,
'adapt_inc': dgn_network.alpha
}
elif hardware == 'spinnaker':
# Custom neuron model on SpiNNaker
config = {
'neuron_model': 'DGN',
'parameters': dgn_network.get_params()
}
return config
```
## Benchmarks
### TIDIGITS Spoken Digit Recognition
| Model | Clean | Noise (σ=0.2) | Noise (σ=0.4) |
|-------|-------|---------------|---------------|
| LIF | 92.5% | 78.3% | 62.1% |
| DGN | 93.1% | 89.7% | 81.4% |
| **Improvement** | +0.6% | +11.4% | +19.3% |
### SHD (Spiking Heidelberg Digits)
| Model | Accuracy | Latency (ms) |
|-------|----------|--------------|
| LIF | 84.2% | 750 |
| GLIF | 86.7% | 800 |
| **DGN** | **89.3%** | **720** |
### Robustness Metrics
| Metric | LIF | DGN | Improvement |
|--------|-----|-----|-------------|
| SNR Tolerance (dB) | 5 | 12 | +7 dB |
| Temporal Jitter Robustness | Moderate | High | Significant |
| Pattern Completion | Poor | Good | Substantial |
## Theoretical Analysis
### Stochastic Stability
The DGN exhibits enhanced stochastic stability through its disturbance rejection mechanism:
**Theorem (Informal)**: Under bounded input noise $|\xi(t)| \leq \sigma$, the DGN membrane potential satisfies:
$$\mathbb{E}[|v(t) - v_{target}|^2] \leq \frac{\sigma^2}{2\lambda_{eff}}$$
where $\lambda_{eff} = \lambda_0 + g(t)$ is the effective decay rate enhanced by dynamic conductance.
### Connection to LSTM Gates
| Aspect | LSTM | DGN |
|--------|------|-----|
| **Forget gate** | Sigmoid-controlled | Conductance decay |
| **Input gate** | Learned weights | Activity-dependent |
| **Cell state** | Explicit memory | Membrane potential |
| **Biological basis** | None | Calcium signaling |
| **Energy cost** | High (MAC ops) | Low (spike events) |
## Pitfalls
- **Hyperparameter Sensitivity**: Conductance time constant $\tau_g$ requires tuning for task temporal scales
- **Initial Transient**: Dynamic conductance needs warm-up period (first ~100ms)
- **Computational Cost**: ~15% overhead compared to LIF due to conductance update
- **Hardware Constraints**: Not all neuromorphic chips support dynamic conductance
- **Gradient Flow**: Surrogate gradients may be less stable with dynamic parameters
## Related Skills
- three-factor-snn-learning
- cognisnn-brain-inspired-snn
- adaptive-spiking-neuron-asn
- spiking-mllm-multimodal-spiking
- working-memory-heterogeneous-delays
## References
```bibtex
@article{bai2025dynamic,
title={A Brain-Inspired Gating Mechanism Unlocks Robust Computation in Spiking Neural Networks},
author={Bai, Qianyi and Wang, Haiteng and Yu, Qiang},
journal={arXiv preprint arXiv:2509.03281},
year={2025}
}
```