别再纠结选CNN还是Transformer了!手把手教你用PyTorch复现CoAtNet(附代码)
2026/6/9 3:34:59 网站建设 项目流程

从零实现CoAtNet:融合CNN与Transformer的实战指南

当计算机视觉领域在CNN与Transformer之间摇摆不定时,CoAtNet给出了一个优雅的解决方案。这个将卷积神经网络(CNN)的局部感知优势与Transformer的全局建模能力相结合的架构,正在改变我们处理视觉任务的方式。本文将带你从PyTorch代码层面深入理解这一混合架构的精妙之处,特别适合那些已经了解基础理论但渴望动手实践的开发者。我们将从环境搭建开始,逐步构建完整的模型,并分享实际训练中的关键技巧。

1. 环境配置与数据准备

在开始构建CoAtNet之前,确保你的开发环境满足以下要求:

  • PyTorch 1.8+(建议使用1.10版本)
  • CUDA 11.1+(如果使用GPU加速)
  • torchvision 0.9+
  • 至少16GB内存(处理ImageNet等大型数据集时)

核心依赖安装命令:

pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install numpy tqdm matplotlib

对于数据准备,我们使用标准的ImageNet数据集结构:

imagenet/ ├── train/ │ ├── n01440764/ │ │ ├── n01440764_10026.JPEG │ │ └── ... │ └── ... └── val/ ├── n01440764/ │ ├── ILSVRC2012_val_00000293.JPEG │ └── ... └── ...

提示:如果完整ImageNet数据集过大,可以从torchvision.datasets.ImageNet开始,或使用CIFAR-10/100作为调试数据集

数据增强策略对模型性能至关重要,我们采用以下组合:

from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])

2. MBConv模块的PyTorch实现

MBConv(Mobile Inverted Bottleneck Convolution)是CoAtNet的基础构建块,它结合了深度可分离卷积和残差连接的优势。让我们分解其实现细节:

MBConv的核心结构:

  1. 扩展层(1x1卷积,扩展通道数)
  2. 深度卷积(3x3或5x5)
  3. SE(Squeeze-and-Excitation)注意力模块
  4. 投影层(1x1卷积,压缩通道数)
  5. 残差连接(当输入输出维度匹配时)
import torch import torch.nn as nn import torch.nn.functional as F class MBConv(nn.Module): def __init__(self, in_channels, out_channels, expansion=4, kernel_size=3, stride=1): super().__init__() hidden_dim = in_channels * expansion self.use_residual = in_channels == out_channels and stride == 1 layers = [] # 扩展层 if expansion != 1: layers.append(nn.Conv2d(in_channels, hidden_dim, 1, bias=False)) layers.append(nn.BatchNorm2d(hidden_dim)) layers.append(nn.SiLU()) # Swish激活 # 深度卷积 layers.append(nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, kernel_size//2, groups=hidden_dim, bias=False)) layers.append(nn.BatchNorm2d(hidden_dim)) layers.append(nn.SiLU()) # SE模块 layers.append(SELayer(hidden_dim)) # 投影层 layers.append(nn.Conv2d(hidden_dim, out_channels, 1, bias=False)) layers.append(nn.BatchNorm2d(out_channels)) self.block = nn.Sequential(*layers) self.stride = stride def forward(self, x): if self.use_residual: return x + self.block(x) return self.block(x) class SELayer(nn.Module): def __init__(self, channel, reduction=4): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction), nn.SiLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x)

注意:MBConv中的SiLU激活函数(也称为Swish)比ReLU在深层网络中表现更好,这是EfficientNet系列的重要发现

MBConv与标准卷积的对比:

特性MBConv标准卷积
参数量少 (depthwise分离)
计算量低 (扩展-压缩策略)
感受野局部固定局部固定
残差连接可选
通道交互深度分离+SE完全交互

3. 相对自注意力机制的实现

相对自注意力是CoAtNet的另一核心组件,它通过引入位置偏置来增强标准自注意力的位置感知能力。与ViT的绝对位置编码不同,相对自注意力考虑查询和键之间的相对位置关系。

相对自注意力的关键改进:

  1. 用相对位置偏置替代绝对位置编码
  2. 保持平移等变性(与CNN兼容)
  3. 计算效率优化(避免全局注意力)
class RelativeSelfAttention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads self.heads = heads self.scale = dim_head ** -0.5 self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Linear(inner_dim, dim) # 相对位置偏置表 self.rel_pos_bias = nn.Parameter(torch.randn((2 * 7 - 1) * (2 * 7 - 1), heads)) # 生成相对位置索引 coords = torch.arange(7) relative_coords = coords[:, None] - coords[None, :] # 7x7 relative_coords += 7 - 1 # 转换为正数 self.register_buffer('relative_index', relative_coords) self.dropout = nn.Dropout(dropout) def forward(self, x): B, N, C = x.shape qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: t.view(B, N, self.heads, -1).transpose(1, 2), qkv) # 计算注意力分数 dots = (q @ k.transpose(-2, -1)) * self.scale # 添加相对位置偏置 rel_pos_bias = self.rel_pos_bias[ self.relative_index[:, :, None] * (2 * 7 - 1) + self.relative_index[:, None, :] ].permute(2, 0, 1) # h x 7 x 7 dots += rel_pos_bias.unsqueeze(0) attn = dots.softmax(dim=-1) attn = self.dropout(attn) out = (attn @ v).transpose(1, 2).reshape(B, N, -1) return self.to_out(out)

相对自注意力与传统自注意力的对比实验:

指标相对自注意力传统自注意力
ImageNet Top-181.2%80.5%
训练稳定性更高 (梯度更稳定)较低
位置敏感度相对位置感知依赖显式编码
计算复杂度O(N^2d)O(N^2d)
内存占用略高 (位置表)较低

4. CoAtNet的完整架构实现

现在我们将MBConv和相对自注意力模块组合成完整的CoAtNet架构。根据原论文,我们采用S0-CTT配置(第一阶段卷积,后接两个Transformer阶段)。

class CoAtNet(nn.Module): def __init__(self, num_classes=1000, channels=[64, 96, 192, 384, 768], num_blocks=[2, 2, 6, 14, 2], block_types=['C', 'C', 'T', 'T']): super().__init__() # 初始卷积下采样 self.stem = nn.Sequential( nn.Conv2d(3, channels[0], 3, stride=2, padding=1), nn.BatchNorm2d(channels[0]), nn.SiLU(), nn.Conv2d(channels[0], channels[0], 3, stride=1, padding=1), nn.BatchNorm2d(channels[0]), nn.SiLU() ) # 构建各阶段 self.stages = nn.ModuleList() for i in range(4): # S0-S3 stage = [] # 下采样 if i > 0: stage.append(nn.Conv2d(channels[i], channels[i+1], 3, stride=2, padding=1)) stage.append(nn.BatchNorm2d(channels[i+1])) stage.append(nn.SiLU()) # 添加块 for _ in range(num_blocks[i]): if block_types[i] == 'C': stage.append(MBConv(channels[i+1], channels[i+1])) else: # 将特征图转换为序列 stage.append(nn.Sequential( Rearrange('b c h w -> b (h w) c'), RelativeSelfAttention(channels[i+1]), Rearrange('b (h w) c -> b c h w', h=7 if i==2 else 14) )) self.stages.append(nn.Sequential(*stage)) # 分类头 self.head = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(channels[-1], num_classes) ) def forward(self, x): x = self.stem(x) for stage in self.stages: x = stage(x) return self.head(x)

CoAtNet各阶段配置详解:

阶段块类型输出尺寸通道数块数量关键操作
S0Conv112x112642初始下采样
S1Conv56x56962MBConv堆叠
S2Transformer28x281926相对自注意力
S3Transformer14x1438414相对自注意力
S4-7x7768-全局池化

5. 训练技巧与实战经验

实现模型只是第一步,正确的训练策略同样重要。以下是我们在复现CoAtNet时积累的关键经验:

1. 学习率调度与优化器选择

使用AdamW优化器配合余弦退火学习率调度:

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

2. 混合精度训练

大幅减少显存占用并加速训练:

scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): outputs = model(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

3. 梯度裁剪

防止Transformer层的梯度爆炸:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

4. 常见问题与解决方案

  • 显存不足:减小batch size或使用梯度累积

    if (i+1) % accum_steps == 0: optimizer.step() optimizer.zero_grad()
  • 训练不稳定:增加warmup阶段

    def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): def f(x): if x >= warmup_iters: return 1 alpha = float(x) / warmup_iters return warmup_factor * (1 - alpha) + alpha return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
  • 过拟合:加强数据增强或添加Dropout

5. 性能基准测试

在ImageNet-1k验证集上的预期表现:

模型变体参数量Top-1准确率训练epochs
CoAtNet-025M81.5%300
CoAtNet-142M83.3%300
CoAtNet-275M84.7%300
CoAtNet-3168M85.8%300

在实际项目中,我们发现CoAtNet特别适合中等规模数据集(10万-100万图像),它平衡了CNN的样本效率和Transformer的表达能力。当遇到显存限制时,可以尝试以下调整:

# 降低注意力头的维度 RelativeSelfAttention(dim=256, heads=4, dim_head=64) # 减少Transformer块数量 num_blocks=[2, 2, 4, 8, 2] # 原为[2,2,6,14,2]

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询