Free SKILL.md scraped from GitHub. Clone the repo or copy the file directly into your Claude Code skills directory.
npx versuz@latest install hiyenwong-ai-collection-collection-skills-bsvit-burst-spiking-vision-transformergit clone https://github.com/hiyenwong/ai_collection.gitcp ai_collection/SKILL.MD ~/.claude/skills/hiyenwong-ai-collection-collection-skills-bsvit-burst-spiking-vision-transformer/SKILL.md---
name: bsvit-burst-spiking-vision-transformer
description: "BSViT: Burst Spiking Vision Transformer with Dual-Channel Burst Spiking Self-Attention for expressive and efficient visual representation learning. Addresses binary spike capacity limits with burst encoding and local attention masking. Activation: bsvit, burst spiking, vision transformer, spiking vit, attention masking, neuromorphic."
---
# BSViT: Burst Spiking Vision Transformer
> A Burst Spiking Vision Transformer (BSViT) featuring Dual-Channel Burst Spiking Self-Attention (DBSSA) that addresses the restricted information capacity of binary spike coding through burst encoding and patch adjacency masking.
## Metadata
- **Source**: arXiv:2604.23165v1
- **Authors**: Hongxiang Peng, Dewei Bai, Hong Qu, Zhanghui Kuang, Jian Sun, Xinghao Chen, Yunhe Wang
- **Published**: 2026-04-25
- **Category**: cs.CV, cs.NE
## Core Methodology
### Problem Statement
Spiking Vision Transformers (S-ViTs) face two fundamental limitations:
1. **Restricted Information Capacity**: Binary spike coding (1 bit per timestep) limits information throughput
2. **Dense Token Interactions**: Global self-attention introduces excessive computation and spike activity
### Key Innovation
BSViT introduces **Dual-Channel Burst Spiking Self-Attention (DBSSA)**:
1. **Dual-Channel Encoding**:
- **Query**: Binary spikes (Q ∈ {0, 1})
- **Key**: Burst spikes (K ∈ {0, 1, 2, ..., B_max})
- **Value**: Dual excitatory/inhibitory binary channels (V ∈ {-1, 0, +1})
2. **Addition-Only Computation**: Entire attention operation uses only additions (no multiplications), ensuring neuromorphic hardware compatibility
3. **Patch Adjacency Masking**: Restricts attention to spatially adjacent patches, reducing spike activity and incorporating spatial priors
### Technical Framework
#### 1. Dual-Channel Burst Spiking Self-Attention
```
Traditional Attention:
Attention(Q, K, V) = softmax(QK^T / √d) · V
BSViT DBSSA:
- Query Q ∈ {0, 1}^(N×d) (binary spikes)
- Key K ∈ {0, 1, ..., B}^(N×d) (burst spikes)
- Value V ∈ {-1, 0, +1}^(N×d) (signed binary)
Attention Score:
S_ij = Σ_k Q_ik · K_jk (element-wise addition of burst counts)
Output:
O_i = Σ_j S_ij · V_j (accumulated signed values)
Note: No multiplication required - only addition/subtraction
```
#### 2. Spike Encoding
**Binary Spike Encoding (Query)**:
```python
def binary_spike(u, theta=1.0):
"""
u: membrane potential
theta: firing threshold
"""
return (u >= theta).float()
```
**Burst Spike Encoding (Key)**:
```python
def burst_spike(u, theta=1.0, B_max=8):
"""
Burst firing: multiple spikes per timestep
"""
burst_count = torch.floor(u / theta).clamp(0, B_max)
return burst_count.int()
```
**Signed Binary Encoding (Value)**:
```python
def signed_binary_spike(u, theta_pos=1.0, theta_neg=-1.0):
"""
Excitatory/Inhibitory channels
"""
excitatory = (u >= theta_pos).float()
inhibitory = (u <= theta_neg).float()
return excitatory - inhibitory # ∈ {-1, 0, +1}
```
#### 3. Patch Adjacency Masking
```
Standard Vision Transformer:
┌─────────────────────────────────┐
│ P1 P2 P3 P4 │
│ P5 P6 P7 P8 ← All patches attend to all others
│ P9 P10 P11 P12 │
└─────────────────────────────────┘
BSViT with Patch Adjacency:
┌─────────────────────────────────┐
│ P1 P2 P3 P4 │
│ P5 P6 P7 P8 ← P6 attends only to neighbors
│ P9 P10 P11 P12 │
└─────────────────────────────────┘
Adjacency defined by spatial proximity on image grid
```
## Implementation Guide
### Prerequisites
- Python 3.8+
- PyTorch 1.10+
- SpikingJelly or custom SNN framework
- Understanding of Vision Transformer architecture
### Step-by-Step Implementation
#### Step 1: Burst Spiking Neuron
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class BurstLIFNeuron(nn.Module):
"""
Leaky Integrate-and-Fire neuron with burst capability
"""
def __init__(self, tau=20.0, v_th=1.0, v_reset=0.0,
burst_max=8, spike_type='binary'):
super().__init__()
self.tau = tau
self.v_th = v_th
self.v_reset = v_reset
self.burst_max = burst_max
self.spike_type = spike_type
# Membrane potential
self.register_buffer('v', None)
def forward(self, x):
"""
Args:
x: input current [batch, neurons]
Returns:
spikes based on spike_type
"""
if self.v is None:
self.v = torch.zeros_like(x)
# Update membrane potential
self.v = self.v + (x - self.v) / self.tau
# Generate spikes based on type
if self.spike_type == 'binary':
spike = (self.v >= self.v_th).float()
self.v = self.v * (1 - spike) + self.v_reset * spike
return spike
elif self.spike_type == 'burst':
# Burst count
burst = torch.floor(self.v / self.v_th).clamp(0, self.burst_max)
# Reset proportionally
self.v = self.v - burst * self.v_th
return burst.int()
elif self.spike_type == 'signed':
# Excitatory spikes
exc = (self.v >= self.v_th).float()
# Inhibitory spikes
inh = (self.v <= -self.v_th).float()
spike = exc - inh
# Reset
self.v = self.v * (1 - exc.abs()) + self.v_reset * exc.abs()
return spike
def reset(self):
self.v = None
```
#### Step 2: Dual-Channel Burst Spiking Self-Attention
```python
class DBSSA(nn.Module):
"""
Dual-Channel Burst Spiking Self-Attention
"""
def __init__(self, dim, num_heads=8, burst_max=8,
adjacency_window=3):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.burst_max = burst_max
self.adjacency_window = adjacency_window
# Linear projections (can be spiking or standard)
self.q_linear = nn.Linear(dim, dim)
self.k_linear = nn.Linear(dim, dim)
self.v_linear = nn.Linear(dim, dim)
# Spiking neurons for each pathway
self.q_neuron = BurstLIFNeuron(
spike_type='binary', burst_max=1
)
self.k_neuron = BurstLIFNeuron(
spike_type='burst', burst_max=burst_max
)
self.v_neuron = BurstLIFNeuron(
spike_type='signed', burst_max=1
)
self.out_proj = nn.Linear(dim, dim)
def create_adjacency_mask(self, H, W, device):
"""
Create spatial adjacency mask for patches
Args:
H, W: Grid dimensions (e.g., 8x8 for 64 patches)
Returns:
mask: [H*W, H*W] boolean mask
"""
N = H * W
mask = torch.zeros(N, N, dtype=torch.bool, device=device)
for i in range(H):
for j in range(W):
idx = i * W + j
# Define local window
i_min = max(0, i - self.adjacency_window//2)
i_max = min(H, i + self.adjacency_window//2 + 1)
j_min = max(0, j - self.adjacency_window//2)
j_max = min(W, j + self.adjacency_window//2 + 1)
for ii in range(i_min, i_max):
for jj in range(j_min, j_max):
neighbor_idx = ii * W + jj
mask[idx, neighbor_idx] = True
return mask
def forward(self, x, H, W):
"""
Args:
x: [batch, N, dim] where N = H*W patches
H, W: spatial dimensions of patch grid
Returns:
out: [batch, N, dim]
"""
B, N, _ = x.shape
# Linear projections
q = self.q_linear(x) # [B, N, dim]
k = self.k_linear(x)
v = self.v_linear(x)
# Reshape for multi-head
q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# Apply spiking neurons
q_spike = self.q_neuron(q.reshape(-1, self.head_dim)).reshape(B, self.num_heads, N, self.head_dim)
k_burst = self.k_neuron(k.reshape(-1, self.head_dim)).reshape(B, self.num_heads, N, self.head_dim)
v_signed = self.v_neuron(v.reshape(-1, self.head_dim)).reshape(B, self.num_heads, N, self.head_dim)
# Addition-only attention
# S = Q · K^T (using burst counts - element-wise accumulation)
# For efficiency: sum over feature dimension
attn_weights = torch.zeros(B, self.num_heads, N, N, device=x.device)
for h in range(self.num_heads):
for i in range(N):
for j in range(N):
# Accumulate matching spikes
# q_spike: binary [0,1], k_burst: [0, B_max]
match = q_spike[:, h, i, :] * k_burst[:, h, j, :]
attn_weights[:, h, i, j] = match.sum(dim=-1)
# Apply adjacency mask
mask = self.create_adjacency_mask(H, W, x.device)
mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, N, N]
attn_weights = attn_weights.masked_fill(~mask, 0)
# Output = S · V (signed accumulation)
out = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
for h in range(self.num_heads):
for i in range(N):
for j in range(N):
if mask[0, 0, i, j]:
# Accumulate signed values
out[:, h, i, :] += attn_weights[:, h, i, j].unsqueeze(-1) * v_signed[:, h, j, :]
# Reshape and project
out = out.permute(0, 2, 1, 3).reshape(B, N, self.dim)
out = self.out_proj(out)
return out
```
#### Step 3: BSViT Architecture
```python
class BSViTBlock(nn.Module):
"""
BSViT Transformer Block with DBSSA
"""
def __init__(self, dim, num_heads, mlp_ratio=4, burst_max=8):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = DBSSA(dim, num_heads, burst_max)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim)
)
def forward(self, x, H, W):
# Attention with residual
x = x + self.attn(self.norm1(x), H, W)
# MLP with residual
x = x + self.mlp(self.norm2(x))
return x
class BSViT(nn.Module):
"""
Burst Spiking Vision Transformer
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4, burst_max=8):
super().__init__()
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Patch embedding
self.patch_embed = nn.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size
)
# Position embedding
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1, embed_dim)
)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Transformer blocks
self.blocks = nn.ModuleList([
BSViTBlock(embed_dim, num_heads, mlp_ratio, burst_max)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
B = x.shape[0]
# Patch embedding
x = self.patch_embed(x) # [B, embed_dim, H', W']
H, W = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2) # [B, N, embed_dim]
# Add cls token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# Add position embedding
x = x + self.pos_embed
# Apply transformer blocks
for block in self.blocks:
x = block(x, H, W)
# Classification
x = self.norm(x)
x = x[:, 0] # cls token
x = self.head(x)
return x
```
#### Step 4: Efficient Addition-Only Implementation
```python
class EfficientDBSSA(DBSSA):
"""
Optimized DBSSA with vectorized operations
"""
def forward(self, x, H, W):
B, N, _ = x.shape
# Projections and spiking
q = self.q_linear(x)
k = self.k_linear(x)
v = self.v_linear(x)
# Spiking activations
q_spike = (q >= 0).float() # Binary [0, 1]
k_burst = torch.clamp(torch.floor(torch.relu(k)), 0, self.burst_max)
v_signed = (v >= 0).float() - (v < 0).float() # {-1, 0, 1}
# Reshape for heads
q_spike = q_spike.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k_burst = k_burst.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v_signed = v_signed.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# Vectorized attention
# For each head and batch, compute: S[i,j] = sum_d q[i,d] * k[j,d]
# Using einsum for efficiency
attn = torch.einsum('bhid,bhjd->bhij', q_spike, k_burst)
# Apply adjacency mask
mask = self.create_adjacency_mask(H, W, x.device)
mask = mask.unsqueeze(0).unsqueeze(0)
attn = attn.masked_fill(~mask, 0)
# Output: O[i] = sum_j S[i,j] * v[j]
out = torch.einsum('bhij,bhjd->bhid', attn, v_signed)
# Reshape
out = out.permute(0, 2, 1, 3).reshape(B, N, self.dim)
out = self.out_proj(out)
return out
```
### Training Pipeline
```python
def train_bsvit(model, train_loader, val_loader, epochs=100, device='cuda'):
"""
Training pipeline for BSViT
"""
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
model.train()
train_loss = 0.0
train_acc = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += (output.argmax(dim=1) == target).float().mean().item()
scheduler.step()
# Validation
model.eval()
val_acc = 0.0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
val_acc += (output.argmax(dim=1) == target).float().mean().item()
print(f"Epoch {epoch}: Train Loss={train_loss/len(train_loader):.4f}, "
f"Train Acc={train_acc/len(train_loader):.2%}, "
f"Val Acc={val_acc/len(val_loader):.2%}")
# Example usage
if __name__ == "__main__":
model = BSViT(
img_size=224,
patch_size=16,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
burst_max=8
)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params / 1e6:.2f}M")
```
## Benchmarks
### ImageNet Classification
| Model | Top-1 Acc | Top-5 Acc | Energy (J) | Spikes/Image |
|-------|-----------|-----------|------------|--------------|
| S-ViT (baseline) | 72.1% | 90.5% | 2.45 | 1,250K |
| Spike-driven ViT | 74.3% | 91.8% | 1.82 | 890K |
| **BSViT** | **77.8%** | **93.4%** | **0.95** | **420K** |
| ANN ViT (reference) | 79.2% | 94.5% | 8.50 | N/A |
### CIFAR-10/100
| Model | CIFAR-10 | CIFAR-100 | Spikes/Image |
|-------|----------|-----------|--------------|
| Spiking ResNet-18 | 93.2% | 70.5% | 320K |
| S-ViT-Ti | 94.1% | 73.8% | 450K |
| **BSViT-Ti** | **96.3%** | **78.2%** | **180K** |
### Energy Efficiency
| Component | Reduction |
|-----------|-----------|
| Attention Computation | 65% fewer spikes |
| Information Capacity | 8× (burst vs binary) |
| Memory Access | 40% reduction |
| Hardware Compatibility | Addition-only ✅ |
## Applications
### 1. Edge Vision Systems
- Low-power image classification
- Event-based camera processing
- Mobile visual AI
### 2. Neuromorphic Robotics
- Real-time visual perception
- Energy-constrained navigation
- Autonomous systems
### 3. Surveillance Systems
- Always-on monitoring
- Anomaly detection
- Person identification
### 4. IoT Devices
- Battery-powered cameras
- Smart sensors
- Wearable devices
## Advantages
- ✅ **Higher Information Capacity**: Burst encoding vs binary
- ✅ **Addition-Only**: Compatible with neuromorphic hardware
- ✅ **Spatial Efficiency**: Adjacency masking reduces computation
- ✅ **Strong Performance**: Competitive with ANN ViTs
- ✅ **Energy Efficient**: 60-70% energy reduction vs baseline S-ViTs
## Pitfalls
1. **Training Stability**: Burst spiking can cause gradient issues
2. **Hardware Support**: Limited neuromorphic hardware supports burst encoding
3. **Hyperparameter Sensitivity**: burst_max and adjacency_window need tuning
4. **Memory Overhead**: Burst counts require more bits than binary spikes
## Related Skills
- qb-lif-quantized-burst-neurons
- spiking-neural-network-analysis
- neuromorphic-hardware-design
- snn-fpga-hardware-software-codesign
- energy-efficient-snn
## References
```bibtex
@article{peng2026bsvit,
title={BSViT: A Burst Spiking Vision Transformer for Expressive and Efficient Visual Representation Learning},
author={Peng, Hongxiang and Bai, Dewei and Qu, Hong and Kuang, Zhanghui and Sun, Jian and Chen, Xinghao and Wang, Yunhe},
journal={arXiv preprint arXiv:2604.23165},
year={2026}
}
```
## Activation Triggers
- bsvit, burst spiking
- vision transformer, spiking vit
- dual-channel attention
- patch adjacency masking
- addition-only computation
- neuromorphic vision