从零实现CoAtNet:融合CNN与Transformer的实战指南
当计算机视觉领域在CNN与Transformer之间摇摆不定时,CoAtNet给出了一个优雅的解决方案。这个将卷积神经网络(CNN)的局部感知优势与Transformer的全局建模能力相结合的架构,正在改变我们处理视觉任务的方式。本文将带你从PyTorch代码层面深入理解这一混合架构的精妙之处,特别适合那些已经了解基础理论但渴望动手实践的开发者。我们将从环境搭建开始,逐步构建完整的模型,并分享实际训练中的关键技巧。
1. 环境配置与数据准备
在开始构建CoAtNet之前,确保你的开发环境满足以下要求:
- PyTorch 1.8+(建议使用1.10版本)
- CUDA 11.1+(如果使用GPU加速)
- torchvision 0.9+
- 至少16GB内存(处理ImageNet等大型数据集时)
核心依赖安装命令:
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install numpy tqdm matplotlib对于数据准备,我们使用标准的ImageNet数据集结构:
imagenet/ ├── train/ │ ├── n01440764/ │ │ ├── n01440764_10026.JPEG │ │ └── ... │ └── ... └── val/ ├── n01440764/ │ ├── ILSVRC2012_val_00000293.JPEG │ └── ... └── ...提示:如果完整ImageNet数据集过大,可以从torchvision.datasets.ImageNet开始,或使用CIFAR-10/100作为调试数据集
数据增强策略对模型性能至关重要,我们采用以下组合:
from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])2. MBConv模块的PyTorch实现
MBConv(Mobile Inverted Bottleneck Convolution)是CoAtNet的基础构建块,它结合了深度可分离卷积和残差连接的优势。让我们分解其实现细节:
MBConv的核心结构:
- 扩展层(1x1卷积,扩展通道数)
- 深度卷积(3x3或5x5)
- SE(Squeeze-and-Excitation)注意力模块
- 投影层(1x1卷积,压缩通道数)
- 残差连接(当输入输出维度匹配时)
import torch import torch.nn as nn import torch.nn.functional as F class MBConv(nn.Module): def __init__(self, in_channels, out_channels, expansion=4, kernel_size=3, stride=1): super().__init__() hidden_dim = in_channels * expansion self.use_residual = in_channels == out_channels and stride == 1 layers = [] # 扩展层 if expansion != 1: layers.append(nn.Conv2d(in_channels, hidden_dim, 1, bias=False)) layers.append(nn.BatchNorm2d(hidden_dim)) layers.append(nn.SiLU()) # Swish激活 # 深度卷积 layers.append(nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, kernel_size//2, groups=hidden_dim, bias=False)) layers.append(nn.BatchNorm2d(hidden_dim)) layers.append(nn.SiLU()) # SE模块 layers.append(SELayer(hidden_dim)) # 投影层 layers.append(nn.Conv2d(hidden_dim, out_channels, 1, bias=False)) layers.append(nn.BatchNorm2d(out_channels)) self.block = nn.Sequential(*layers) self.stride = stride def forward(self, x): if self.use_residual: return x + self.block(x) return self.block(x) class SELayer(nn.Module): def __init__(self, channel, reduction=4): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction), nn.SiLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x)注意:MBConv中的SiLU激活函数(也称为Swish)比ReLU在深层网络中表现更好,这是EfficientNet系列的重要发现
MBConv与标准卷积的对比:
| 特性 | MBConv | 标准卷积 |
|---|---|---|
| 参数量 | 少 (depthwise分离) | 多 |
| 计算量 | 低 (扩展-压缩策略) | 高 |
| 感受野 | 局部固定 | 局部固定 |
| 残差连接 | 有 | 可选 |
| 通道交互 | 深度分离+SE | 完全交互 |
3. 相对自注意力机制的实现
相对自注意力是CoAtNet的另一核心组件,它通过引入位置偏置来增强标准自注意力的位置感知能力。与ViT的绝对位置编码不同,相对自注意力考虑查询和键之间的相对位置关系。
相对自注意力的关键改进:
- 用相对位置偏置替代绝对位置编码
- 保持平移等变性(与CNN兼容)
- 计算效率优化(避免全局注意力)
class RelativeSelfAttention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads self.heads = heads self.scale = dim_head ** -0.5 self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Linear(inner_dim, dim) # 相对位置偏置表 self.rel_pos_bias = nn.Parameter(torch.randn((2 * 7 - 1) * (2 * 7 - 1), heads)) # 生成相对位置索引 coords = torch.arange(7) relative_coords = coords[:, None] - coords[None, :] # 7x7 relative_coords += 7 - 1 # 转换为正数 self.register_buffer('relative_index', relative_coords) self.dropout = nn.Dropout(dropout) def forward(self, x): B, N, C = x.shape qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: t.view(B, N, self.heads, -1).transpose(1, 2), qkv) # 计算注意力分数 dots = (q @ k.transpose(-2, -1)) * self.scale # 添加相对位置偏置 rel_pos_bias = self.rel_pos_bias[ self.relative_index[:, :, None] * (2 * 7 - 1) + self.relative_index[:, None, :] ].permute(2, 0, 1) # h x 7 x 7 dots += rel_pos_bias.unsqueeze(0) attn = dots.softmax(dim=-1) attn = self.dropout(attn) out = (attn @ v).transpose(1, 2).reshape(B, N, -1) return self.to_out(out)相对自注意力与传统自注意力的对比实验:
| 指标 | 相对自注意力 | 传统自注意力 |
|---|---|---|
| ImageNet Top-1 | 81.2% | 80.5% |
| 训练稳定性 | 更高 (梯度更稳定) | 较低 |
| 位置敏感度 | 相对位置感知 | 依赖显式编码 |
| 计算复杂度 | O(N^2d) | O(N^2d) |
| 内存占用 | 略高 (位置表) | 较低 |
4. CoAtNet的完整架构实现
现在我们将MBConv和相对自注意力模块组合成完整的CoAtNet架构。根据原论文,我们采用S0-CTT配置(第一阶段卷积,后接两个Transformer阶段)。
class CoAtNet(nn.Module): def __init__(self, num_classes=1000, channels=[64, 96, 192, 384, 768], num_blocks=[2, 2, 6, 14, 2], block_types=['C', 'C', 'T', 'T']): super().__init__() # 初始卷积下采样 self.stem = nn.Sequential( nn.Conv2d(3, channels[0], 3, stride=2, padding=1), nn.BatchNorm2d(channels[0]), nn.SiLU(), nn.Conv2d(channels[0], channels[0], 3, stride=1, padding=1), nn.BatchNorm2d(channels[0]), nn.SiLU() ) # 构建各阶段 self.stages = nn.ModuleList() for i in range(4): # S0-S3 stage = [] # 下采样 if i > 0: stage.append(nn.Conv2d(channels[i], channels[i+1], 3, stride=2, padding=1)) stage.append(nn.BatchNorm2d(channels[i+1])) stage.append(nn.SiLU()) # 添加块 for _ in range(num_blocks[i]): if block_types[i] == 'C': stage.append(MBConv(channels[i+1], channels[i+1])) else: # 将特征图转换为序列 stage.append(nn.Sequential( Rearrange('b c h w -> b (h w) c'), RelativeSelfAttention(channels[i+1]), Rearrange('b (h w) c -> b c h w', h=7 if i==2 else 14) )) self.stages.append(nn.Sequential(*stage)) # 分类头 self.head = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(channels[-1], num_classes) ) def forward(self, x): x = self.stem(x) for stage in self.stages: x = stage(x) return self.head(x)CoAtNet各阶段配置详解:
| 阶段 | 块类型 | 输出尺寸 | 通道数 | 块数量 | 关键操作 |
|---|---|---|---|---|---|
| S0 | Conv | 112x112 | 64 | 2 | 初始下采样 |
| S1 | Conv | 56x56 | 96 | 2 | MBConv堆叠 |
| S2 | Transformer | 28x28 | 192 | 6 | 相对自注意力 |
| S3 | Transformer | 14x14 | 384 | 14 | 相对自注意力 |
| S4 | - | 7x7 | 768 | - | 全局池化 |
5. 训练技巧与实战经验
实现模型只是第一步,正确的训练策略同样重要。以下是我们在复现CoAtNet时积累的关键经验:
1. 学习率调度与优化器选择
使用AdamW优化器配合余弦退火学习率调度:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)2. 混合精度训练
大幅减少显存占用并加速训练:
scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): outputs = model(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()3. 梯度裁剪
防止Transformer层的梯度爆炸:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)4. 常见问题与解决方案
显存不足:减小batch size或使用梯度累积
if (i+1) % accum_steps == 0: optimizer.step() optimizer.zero_grad()训练不稳定:增加warmup阶段
def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): def f(x): if x >= warmup_iters: return 1 alpha = float(x) / warmup_iters return warmup_factor * (1 - alpha) + alpha return torch.optim.lr_scheduler.LambdaLR(optimizer, f)过拟合:加强数据增强或添加Dropout
5. 性能基准测试
在ImageNet-1k验证集上的预期表现:
| 模型变体 | 参数量 | Top-1准确率 | 训练epochs |
|---|---|---|---|
| CoAtNet-0 | 25M | 81.5% | 300 |
| CoAtNet-1 | 42M | 83.3% | 300 |
| CoAtNet-2 | 75M | 84.7% | 300 |
| CoAtNet-3 | 168M | 85.8% | 300 |
在实际项目中,我们发现CoAtNet特别适合中等规模数据集(10万-100万图像),它平衡了CNN的样本效率和Transformer的表达能力。当遇到显存限制时,可以尝试以下调整:
# 降低注意力头的维度 RelativeSelfAttention(dim=256, heads=4, dim_head=64) # 减少Transformer块数量 num_blocks=[2, 2, 4, 8, 2] # 原为[2,2,6,14,2]