别再纠结CNN还是Transformer了!手把手教你用PyTorch复现CoAtNet(附代码避坑点)
当计算机视觉领域还在争论CNN与Transformer孰优孰劣时,CoAtNet早已用实验证明:鱼与熊掌可以兼得。这个在ImageNet竞赛中刷新SOTA的混合架构,通过MBConv块与相对自注意力的精妙组合,同时获得了卷积的平移不变性和注意力的全局建模能力。本文将带你从零实现一个完整的CoAtNet模型,重点解决论文未提及的工程细节问题。
1. 环境配置与数据准备
1.1 硬件与软件环境
推荐使用至少16GB显存的GPU(如RTX 3090或A100),因为自注意力机制对显存需求较高。关键软件版本要求:
torch==1.12.0+cu113 # 必须支持混合精度训练 torchvision==0.13.0 timm==0.6.7 # 提供预训练权重加载注意:若使用Colab环境,需在笔记本开头添加
!pip install -U torch torchvision timm,并启用GPU加速。
1.2 数据集处理
以ImageNet-1K为例,采用torchvision.datasets.ImageFolder加载数据时,需特别处理以下情况:
from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])常见坑点:
- 图像尺寸必须严格对齐模型输入(默认224×224)
- 归一化参数与预训练权重必须匹配
- 多GPU训练时需使用
DistributedSampler
2. 模型架构实现详解
2.1 MBConv块代码剖析
CoAtNet使用的MBConv与MobileNetV2的主要区别在于:
- 扩展率固定为4
- 使用Swish激活替代ReLU
- 添加了SE(Squeeze-Excitation)模块
import torch.nn as nn class MBConv(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() hidden_dim = in_channels * 4 self.conv = nn.Sequential( # 升维 nn.Conv2d(in_channels, hidden_dim, 1, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # Swish激活 # 深度卷积 nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # SE模块 SqueezeExcitation(hidden_dim), # 降维 nn.Conv2d(hidden_dim, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels) ) self.shortcut = nn.Identity() if (in_channels == out_channels and stride == 1) else None def forward(self, x): res = x x = self.conv(x) if self.shortcut is not None: res = self.shortcut(res) return x + res2.2 相对自注意力实现关键
传统Transformer的绝对位置编码会破坏平移不变性,CoAtNet采用相对位置偏置:
class RelativeAttention(nn.Module): def __init__(self, dim, heads=8): super().__init__() self.heads = heads self.scale = (dim // heads) ** -0.5 # 相对位置偏置表 self.rel_pos_bias = nn.Parameter(torch.randn((2*7-1)**2, heads)) pos = torch.arange(7) grid = torch.stack(torch.meshgrid(pos, pos), dim=-1) rel_pos = grid[:, :, None, None] - grid[None, None, :, :] rel_pos += 7 - 1 # 转换为非负索引 self.register_buffer('rel_pos_index', rel_pos) def forward(self, q, k, v): B, N, C = q.shape q = q.view(B, N, self.heads, -1).transpose(1, 2) k = k.view(B, N, self.heads, -1).transpose(1, 2) v = v.view(B, N, self.heads, -1).transpose(1, 2) # 注意力得分 attn = (q @ k.transpose(-2, -1)) * self.scale # 添加相对位置偏置 bias = self.rel_pos_bias[self.rel_pos_index.view(-1)].view( 7, 7, 7, 7, -1).permute(4, 0, 1, 2, 3) attn = attn + bias attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, C) return x提示:实际实现时需要处理任意输入尺寸,上述代码简化了位置索引生成逻辑。
3. 完整模型搭建与初始化
3.1 多阶段网络设计
CoAtNet采用5阶段架构(S0-S4),其中:
- S0:标准卷积层
- S1-S2:MBConv块
- S3-S4:相对自注意力块
class CoAtNet(nn.Module): def __init__(self, num_classes=1000): super().__init__() # 阶段0:下采样到1/2 self.s0 = nn.Sequential( nn.Conv2d(3, 64, 3, 2, 1), nn.BatchNorm2d(64), nn.SiLU() ) # 阶段1-2:MBConv self.s1 = MBConv(64, 96, stride=2) self.s2 = MBConv(96, 192, stride=2) # 阶段3-4:Transformer self.s3 = TransformerStage(192, 384, depth=6, heads=8) self.s4 = TransformerStage(384, 768, depth=14, heads=16) self.head = nn.Linear(768, num_classes) def forward(self, x): x = self.s0(x) # 112x112 x = self.s1(x) # 56x56 x = self.s2(x) # 28x28 x = self.s3(x) # 14x14 x = self.s4(x) # 7x7 x = x.mean(dim=[2,3]) # GAP return self.head(x)3.2 权重初始化技巧
混合架构需要针对不同层类型采用特定初始化:
def init_weights(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') elif isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) model.apply(init_weights)4. 训练优化与调试技巧
4.1 混合精度训练配置
使用PyTorch的AMP(自动混合精度)可显著减少显存占用:
from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() for inputs, targets in train_loader: inputs, targets = inputs.cuda(), targets.cuda() with autocast(): outputs = model(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad()4.2 学习率调度策略
推荐使用余弦退火配合线性warmup:
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR warmup_epochs = 5 total_epochs = 300 scheduler1 = LinearLR(optimizer, start_factor=0.01, total_iters=warmup_epochs) scheduler2 = CosineAnnealingLR(optimizer, T_max=total_epochs-warmup_epochs) scheduler = SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_epochs])4.3 常见报错解决方案
| 错误类型 | 可能原因 | 解决方法 |
|---|---|---|
| CUDA OOM | 注意力矩阵过大 | 减小batch size或使用梯度检查点 |
| NaN损失 | 未做梯度裁剪 | 添加nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| 低准确率 | 权重初始化不当 | 检查各层初始化方式是否匹配论文 |
在S3/S4阶段出现显存不足时,可采用以下优化手段:
- 使用
torch.utils.checkpoint分段计算注意力 - 降低注意力头数(heads)
- 采用更小的patch尺寸
5. 模型微调与部署建议
5.1 迁移学习技巧
当在小数据集上微调时:
- 冻结前三个阶段(S0-S2)的参数
- 仅训练Transformer阶段和分类头
- 使用比预训练更小的学习率(通常1e-5到1e-4)
for name, param in model.named_parameters(): if 's0' in name or 's1' in name or 's2' in name: param.requires_grad = False5.2 模型量化部署
使用TorchScript导出量化模型:
model.eval() quantized_model = torch.quantization.quantize_dynamic( model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8 ) traced_script = torch.jit.trace(quantized_model, torch.rand(1,3,224,224)) traced_script.save('coatnet_quantized.pt')量化后模型大小可减少4倍,推理速度提升2-3倍,适合移动端部署。