Free SKILL.md scraped from GitHub. Clone the repo or copy the file directly into your Claude Code skills directory.
npx versuz@latest install hiyenwong-ai-collection-collection-skills-gemst-multidimensional-grouping-snngit clone https://github.com/hiyenwong/ai_collection.gitcp ai_collection/SKILL.MD ~/.claude/skills/hiyenwong-ai-collection-collection-skills-gemst-multidimensional-grouping-snn/SKILL.md---
name: gemst-multidimensional-grouping-snn
description: "Ge²mS-T: Multi-Dimensional Grouping for Ultra-High Energy Efficiency in Spiking Transformers. Temporal, spatial, and channel grouping for efficient S-ViT training and inference. Triggers: spiking transformer, S-ViT, energy efficiency, multi-dimensional grouping, SNN."
---
# Ge²mS-T: Multi-Dimensional Grouping for Ultra-High Energy Efficiency in Spiking Transformers
> A multi-dimensional grouping framework for Spiking Vision Transformers (S-ViTs) that achieves ultra-high energy efficiency through temporal, spatial, and channel-wise grouping strategies.
## Metadata
- **Source**: arXiv:2604.08894v1
- **Authors**: Qinyu Chen, Zhenxi Song, Ziyang Zhang, et al.
- **Published**: 2026-04-10
- **Institution**: Tsinghua University, Chinese Academy of Sciences
## Core Methodology
### Key Innovation
Spiking Vision Transformers (S-ViTs) combine the energy efficiency of Spiking Neural Networks (SNNs) with the powerful visual representation of Vision Transformers. However, they face challenges in both training (surrogate gradient mismatch) and inference (high computational cost from dense self-attention). Ge²mS-T (Grouped Spiking Transformer) introduces multi-dimensional grouping across time, space, and channels to achieve significant efficiency gains while maintaining accuracy.
### Multi-Dimensional Grouping Strategy
#### 1. Temporal Grouping
Standard SNNs process static images by presenting them for multiple timesteps, leading to high latency. Temporal grouping divides timesteps into groups processed in parallel:
```
Standard: [T1] → [T2] → [T3] → [T4] → ... (sequential)
Grouped: [T1,T2] ∥ [T3,T4] ∥ ... (parallel groups)
```
**Implementation**:
```python
class TemporalGrouping(nn.Module):
def __init__(self, group_size=2):
super().__init__()
self.group_size = group_size
def forward(self, x, time_steps):
"""
Args:
x: input spikes (batch, time, channels, height, width)
Returns:
grouped: (batch, time//group_size, channels*group_size, h, w)
"""
batch, T, C, H, W = x.shape
assert T % self.group_size == 0
# Reshape to group timesteps as channels
x = x.view(batch, T // self.group_size, self.group_size, C, H, W)
x = x.permute(0, 1, 3, 2, 4, 5) # (B, T//G, C, G, H, W)
x = x.reshape(batch, T // self.group_size, C * self.group_size, H, W)
return x
```
#### 2. Spatial Grouping
Self-attention in ViTs has O(n²) complexity with patch count. Spatial grouping limits attention to local windows:
```
Global attention: Each patch attends to all N patches (N² cost)
Window attention: Each patch attends to W patches in window (N×W cost)
```
**Window Partitioning**:
```python
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size: int
Returns:
windows: (B*n_windows, window_size, window_size, C)
"""
B, H, W, C = x.shape
assert H % window_size == 0 and W % window_size == 0
# Reshape into windows
x = x.view(B, H // window_size, window_size,
W // window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = x.view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""Reverse window partition"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size,
window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
```
#### 3. Channel Grouping
Channel grouping divides feature channels into groups with separate transformations, reducing computation:
```python
class GroupedChannelAttention(nn.Module):
def __init__(self, dim, num_heads=8, num_groups=4):
super().__init__()
self.num_groups = num_groups
self.group_dim = dim // num_groups
self.num_heads = num_heads // num_groups
# Separate attention per group
self.group_attns = nn.ModuleList([
SpikingSelfAttention(self.group_dim, self.num_heads)
for _ in range(num_groups)
])
def forward(self, x):
"""
Args:
x: (B, T, C, H, W) spike tensor
"""
B, T, C, H, W = x.shape
# Split into groups
x_groups = torch.chunk(x, self.num_groups, dim=2)
# Process each group independently
outputs = []
for i, (x_g, attn) in enumerate(zip(x_groups, self.group_attns)):
out_g = attn(x_g)
outputs.append(out_g)
# Concatenate groups
output = torch.cat(outputs, dim=2)
return output
```
### Spiking Self-Attention
#### Leaky Integrate-and-Fire Attention
```python
class SpikingSelfAttention(nn.Module):
def __init__(self, dim, num_heads, tau=2.0, v_threshold=1.0):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
# LIF neuron parameters
self.tau = tau
self.v_threshold = v_threshold
# Q, K, V projections
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)
def lif_forward(self, x, v_mem):
"""Leaky Integrate-and-Fire neuron"""
v_mem = v_mem + (x - v_mem) / self.tau
spike = (v_mem >= self.v_threshold).float()
v_mem = v_mem * (1 - spike) # Reset
return spike, v_mem
def forward(self, x):
"""
Args:
x: (B, T, N, C) where N = H*W (number of patches)
"""
B, T, N, C = x.shape
# Initialize membrane potentials
v_q = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
v_k = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
v_v = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
v_out = torch.zeros(B, N, C, device=x.device)
outputs = []
for t in range(T):
# Project to Q, K, V
q = self.q_proj(x[:, t]) # (B, N, C)
k = self.k_proj(x[:, t])
v = self.v_proj(x[:, t])
# Reshape for multi-head
q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
# LIF dynamics
spike_q, v_q = self.lif_forward(q, v_q)
spike_k, v_k = self.lif_forward(k, v_k)
spike_v, v_v = self.lif_forward(v, v_v)
# Attention with spikes (event-driven)
if spike_q.sum() > 0 or spike_k.sum() > 0:
attn = (spike_q @ spike_k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = attn @ spike_v
else:
out = torch.zeros_like(spike_v)
# Reshape and project
out = out.transpose(1, 2).reshape(B, N, C)
out = self.out_proj(out)
# Output LIF
spike_out, v_out = self.lif_forward(out, v_out)
outputs.append(spike_out)
return torch.stack(outputs, dim=1) # (B, T, N, C)
```
## Implementation Guide
### Prerequisites
- Python 3.8+
- PyTorch 1.12+ with CUDA support
- SpikingJelly or custom SNN framework
- timm for Vision Transformer utilities
### Step-by-Step: Building Ge²mS-T
1. **Complete Model Architecture**
```python
import torch
import torch.nn as nn
from functools import partial
class Ge2mSTBlock(nn.Module):
"""Ge²mS-T Transformer Block with multi-dimensional grouping"""
def __init__(self, dim, num_heads, window_size=7,
temporal_group=2, channel_groups=4):
super().__init__()
self.dim = dim
self.window_size = window_size
self.temporal_group = temporal_group
# Temporal grouping
self.temporal_grouping = TemporalGrouping(temporal_group)
# Spatial (window) attention
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowedSpikingAttention(
dim, num_heads, window_size, channel_groups
)
# MLP with channel grouping
self.norm2 = nn.LayerNorm(dim)
mlp_hidden = dim * 4
self.mlp = GroupedSpikingMLP(dim, mlp_hidden, channel_groups)
def forward(self, x):
"""
Args:
x: (B, T, H, W, C) spike input
"""
B, T, H, W, C = x.shape
# Temporal grouping
x = x.permute(0, 1, 4, 2, 3) # (B, T, C, H, W)
x = self.temporal_grouping(x, T) # (B, T//G, C*G, H, W)
TG = x.shape[1]
x = x.permute(0, 1, 3, 4, 2) # (B, T//G, H, W, C*G)
# Window attention with residual
shortcut = x
x = self.norm1(x)
x = self.attn(x) # Windowed spiking attention
x = shortcut + x
# MLP with residual
shortcut = x
x = self.norm2(x)
x = self.mlp(x)
x = shortcut + x
# Reverse temporal grouping
x = x.permute(0, 1, 4, 2, 3) # (B, T//G, C*G, H, W)
x = x.view(B, T, C, H, W)
x = x.permute(0, 1, 3, 4, 2) # (B, T, H, W, C)
return x
class Ge2mSViT(nn.Module):
"""Complete Ge²mS-T Vision Transformer"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
num_classes=1000, embed_dim=768, depth=12,
num_heads=12, window_size=7, temporal_group=2,
channel_groups=4, time_steps=4):
super().__init__()
self.time_steps = time_steps
self.patch_embed = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size,
stride=patch_size)
# Ge²mS-T blocks
self.blocks = nn.ModuleList([
Ge2mSTBlock(embed_dim, num_heads, window_size,
temporal_group, channel_groups)
for _ in range(depth)
])
# Classification head
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
"""
Args:
x: (B, C, H, W) static image
Returns:
logits: (B, num_classes)
"""
B = x.shape[0]
# Convert to spikes (rate coding)
x = x.unsqueeze(1).repeat(1, self.time_steps, 1, 1, 1)
# Add noise for stochastic spike generation
x = (x + torch.randn_like(x) * 0.1).clamp(0, 1)
# Patch embedding per timestep
spikes = []
for t in range(self.time_steps):
patch = self.patch_embed(x[:, t]) # (B, embed_dim, H//P, W//P)
patch = patch.flatten(2).transpose(1, 2) # (B, N, embed_dim)
spikes.append(patch)
x = torch.stack(spikes, dim=1) # (B, T, N, embed_dim)
H, W = int(x.shape[2] ** 0.5), int(x.shape[2] ** 0.5)
x = x.view(B, self.time_steps, H, W, -1)
# Apply Ge²mS-T blocks
for block in self.blocks:
x = block(x)
# Global average pooling
x = x.mean(dim=(1, 2, 3)) # (B, embed_dim)
x = self.norm(x)
# Classification
logits = self.head(x)
return logits
```
2. **Training with Surrogate Gradients**
```python
from spikingjelly.clock_driven import surrogate
def train_ge2mst(model, train_loader, epochs=300):
"""Train Ge²mS-T with surrogate gradients"""
# Use surrogate gradient for backprop through spikes
surrogate_fn = surrogate.ATan(alpha=2.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
model.train()
total_loss = 0
correct = 0
for images, labels in train_loader:
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
# Forward pass
outputs = model(images)
# Loss
loss = criterion(outputs, labels)
# Backward with surrogate gradients
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (outputs.argmax(dim=1) == labels).sum().item()
scheduler.step()
acc = 100 * correct / len(train_loader.dataset)
print(f"Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, Acc={acc:.2f}%")
return model
```
## Performance Benchmarks
| Model | Dataset | Accuracy | Energy (mJ) | Speedup |
|-------|---------|----------|-------------|---------|
| ViT-B/16 | ImageNet | 81.8% | 1250 | 1× |
| S-ViT (baseline) | ImageNet | 74.2% | 42 | 29.8× |
| Ge²mS-T | ImageNet | 78.5% | 18 | 69.4× |
| Ge²mS-T (small) | CIFAR-100 | 86.3% | 2.1 | 595× |
## Applications
### 1. Edge Vision Systems
- **Smart cameras**: Real-time object detection with milliwatt power
- **Drone navigation**: Visual SLAM on battery-constrained UAVs
- **Mobile AR**: Efficient scene understanding on smartphones
### 2. Neuromorphic Sensors
- **Event cameras**: Direct processing of DVS output
- **Always-on vision**: Continuous monitoring with ultra-low power
- **Industrial inspection**: High-speed defect detection
### 3. Sustainable AI
- **Data center reduction**: 70× energy reduction for inference
- **Carbon footprint**: Lower emissions for large-scale vision tasks
- **Renewable-powered AI**: Viable on solar/battery systems
## Pitfalls
### Accuracy-Efficiency Tradeoff
- **Problem**: Grouping reduces model capacity
- **Solution**: Progressive grouping (less in early layers); knowledge distillation
### Temporal Grouping Artifacts
- **Problem**: Grouping timesteps can lose temporal precision
- **Solution**: Use small group sizes (2-4); attention across groups
### Window Boundary Effects
- **Problem**: Objects spanning window boundaries handled poorly
- **Solution**: Shifted window attention (Swin-style); cross-window connections
### Training Instability
- **Problem**: Spiking transformers can be hard to train
- **Solution**: Layer normalization before attention; warm-up schedule; gradient clipping
## Related Skills
- bsvit-burst-spiking-vision-transformer: Burst spiking ViT
- winner-take-all-spiking: WTA spiking transformer
- quantized-snn-hardware-optimization: Quantization for SNN hardware
## References
```bibtex
@article{chen2026gemst,
title={Ge$^\\text{2}$mS-T: Multi-Dimensional Grouping for Ultra-High Energy Efficiency in Spiking Transformer},
author={Chen, Qinyu and Song, Zhenxi and Zhang, Ziyang and others},
journal={arXiv preprint arXiv:2604.08894},
year={2026}
}
```