Swin Transformer花卉分类实战:从零实现95%+准确率的PyTorch完整指南
当计算机视觉遇上Transformer架构,图像分类领域正在经历一场革命性的变革。Swin Transformer作为微软亚洲研究院提出的新一代视觉Transformer模型,通过分层特征提取和移位窗口注意力机制,在保持计算效率的同时实现了媲美CNN的性能。本文将带您从零开始,使用PyTorch框架在花卉分类数据集上实现95%以上的验证准确率。
1. 环境准备与数据预处理
在开始之前,我们需要搭建适合Swin Transformer运行的开发环境。推荐使用Anaconda创建独立的Python环境以避免依赖冲突:
conda create -n swin python=3.8 conda activate swin pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html pip install timm matplotlib opencv-python花卉数据集采用经典的flower_photos,包含5个类别(雏菊、蒲公英、玫瑰、向日葵、郁金香)共3670张图片。我们可以使用以下代码快速完成数据集的下载和预处理:
import os from torchvision import transforms, datasets # 数据增强和归一化 train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 加载数据集 train_dataset = datasets.ImageFolder( root='./data/flower_photos/train', transform=train_transform ) val_dataset = datasets.ImageFolder( root='./data/flower_photos/val', transform=val_transform )数据加载器的配置对模型训练效率至关重要。建议根据GPU显存大小调整batch_size,并启用pin_memory加速数据加载:
from torch.utils.data import DataLoader batch_size = 32 # 根据GPU调整 num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True )2. Swin Transformer模型架构解析
Swin Transformer的核心创新在于其分层设计和移位窗口机制。与传统的Vision Transformer不同,Swin Transformer通过四个阶段逐步下采样特征图,每个阶段使用不同数量的Transformer块:
| 阶段 | 特征图尺寸 | 窗口大小 | 注意力头数 | 通道数 |
|---|---|---|---|---|
| 1 | 56×56 | 7×7 | 3 | 96 |
| 2 | 28×28 | 7×7 | 6 | 192 |
| 3 | 14×14 | 7×7 | 12 | 384 |
| 4 | 7×7 | 7×7 | 24 | 768 |
在PyTorch中实现Swin Transformer的关键组件包括:
窗口注意力机制:
class WindowAttention(nn.Module): def __init__(self, dim, window_size, num_heads): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads self.scale = (dim // num_heads) ** -0.5 # 相对位置偏置表 self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 计算相对位置索引 coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += window_size[0] - 1 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=False) self.proj = nn.Linear(dim, dim) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) return x移位窗口分区:
def window_partition(x, window_size): B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x3. 模型训练与超参数优化
使用预训练的Swin-Tiny模型可以大幅提升训练效率和最终准确率。我们从官方仓库加载预训练权重,并针对花卉分类任务进行微调:
import timm model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=5) # 冻结除分类头外的所有层 for name, param in model.named_parameters(): if "head" not in name: param.requires_grad = False训练过程中需要特别关注以下关键超参数:
- 学习率:初始学习率设为1e-4,使用余弦退火调度
- 优化器:AdamW配合权重衰减5e-2
- Batch Size:根据GPU显存选择32或64
- Epochs:通常10-20个epoch足够收敛
完整的训练循环实现:
from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR import torch.nn as nn device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) criterion = nn.CrossEntropyLoss() optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=5e-2) scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6) best_acc = 0.0 for epoch in range(20): model.train() running_loss = 0.0 correct = 0 total = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() train_loss = running_loss / len(train_loader) train_acc = 100. * correct / total # 验证阶段 model.eval() val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() val_loss = val_loss / len(val_loader) val_acc = 100. * correct / total print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}%') if val_acc > best_acc: best_acc = val_acc torch.save(model.state_dict(), 'best_model.pth') scheduler.step()通过合理的超参数设置,模型通常能在5-10个epoch内达到90%以上的验证准确率,最终稳定在95%左右。
4. 模型评估与部署应用
训练完成后,我们需要全面评估模型性能。除了准确率,还应该关注混淆矩阵和各类别的精确率、召回率:
from sklearn.metrics import confusion_matrix, classification_report import seaborn as sns import matplotlib.pyplot as plt model.load_state_dict(torch.load('best_model.pth')) model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for images, labels in val_loader: images = images.to(device) outputs = model(images) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.numpy()) # 生成混淆矩阵 cm = confusion_matrix(all_labels, all_preds) plt.figure(figsize=(10,8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.xlabel('Predicted') plt.ylabel('True') plt.title('Confusion Matrix') plt.show() # 分类报告 print(classification_report(all_labels, all_preds, target_names=class_names))实际部署时,我们可以创建一个简单的预测脚本,用于单张花卉图片的分类:
from PIL import Image import torch.nn.functional as F def predict_image(image_path, model, transform, class_names, device='cuda'): img = Image.open(image_path).convert('RGB') img_t = transform(img).unsqueeze(0).to(device) model.eval() with torch.no_grad(): outputs = model(img_t) probs = F.softmax(outputs, dim=1) conf, preds = torch.max(probs, 1) return class_names[preds.item()], conf.item() # 使用示例 class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] transform = val_transform # 使用验证集的变换 image_path = 'test_flower.jpg' pred, confidence = predict_image(image_path, model, transform, class_names) print(f'Predicted: {pred} with confidence {confidence:.2f}')对于生产环境部署,建议将模型转换为TorchScript或ONNX格式以提高推理效率:
# 转换为TorchScript example_input = torch.rand(1, 3, 224, 224).to(device) traced_script_module = torch.jit.trace(model, example_input) traced_script_module.save("swin_flower.pt") # 转换为ONNX torch.onnx.export( model, example_input, "swin_flower.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, opset_version=11 )5. 性能优化技巧与问题排查
在实际项目中,我们可能会遇到各种挑战。以下是一些常见问题的解决方案:
问题1:训练初期准确率波动大
解决方案:
- 使用更小的学习率(如5e-5)进行初始微调
- 增加warmup阶段,逐步提高学习率
- 尝试Layer-wise Learning Rate Decay(LLRD)策略
问题2:模型过拟合
解决方案:
# 在优化器中增加权重衰减 optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.05) # 添加更多的数据增强 from timm.data.auto_augment import rand_augment_transform train_transform.transforms.insert(0, rand_augment_transform(config_str='rand-m9-mstd0.5', hparams={}))问题3:GPU内存不足
解决方案:
- 启用梯度检查点技术
model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=5, features_only=False, checkpoint_path='./checkpoints')- 使用混合精度训练
from torch.cuda.amp import GradScaler, autocast scaler = GradScaler() for inputs, targets in train_loader: inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() with autocast(): outputs = model(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()性能优化对比表:
| 优化方法 | 训练速度提升 | 内存占用减少 | 准确率影响 |
|---|---|---|---|
| 混合精度训练 | 1.5-2x | 20-30% | 基本无影响 |
| 梯度检查点 | 10-20%降低 | 30-50% | 轻微下降 |
| 更小的batch size | 无 | 显著减少 | 可能下降 |
| 模型量化 | 2-4x | 50-75% | 轻微下降 |
通过以上优化技巧,我们可以在保持模型精度的同时,显著提升训练和推理效率,使Swin Transformer在实际应用中更加可行。