用PyTorch实战U2-Net:从零构建显著性目标检测模型
在计算机视觉领域,显著性目标检测(Salient Object Detection)一直是个有趣且实用的研究方向。想象一下,当你看到一张照片时,视线会不自觉地被某些区域吸引——这就是显著性检测要解决的问题。传统的U-Net架构虽然表现不错,但2020年提出的U2-Net通过其独特的嵌套U型结构,在保持轻量化的同时显著提升了检测精度。本文将带你用PyTorch从零开始实现这个强大的模型。
1. 环境准备与数据加载
首先确保你的开发环境已经安装了PyTorch 1.7+和Torchvision。推荐使用Python 3.8+环境:
conda create -n u2net python=3.8 conda activate u2net pip install torch torchvision opencv-python pillow matplotlib我们将使用DUTS数据集,这是显著性检测领域最常用的基准数据集之一。它包含:
- 训练集:10,553张图像及对应的掩码
- 测试集:5,019张图像及掩码
from torch.utils.data import Dataset import cv2 import os class DUTSDataset(Dataset): def __init__(self, root_dir, transform=None): self.image_dir = os.path.join(root_dir, 'DUTS-TR-Image') self.mask_dir = os.path.join(root_dir, 'DUTS-TR-Mask') self.transform = transform self.image_list = os.listdir(self.image_dir) def __len__(self): return len(self.image_list) def __getitem__(self, idx): image_path = os.path.join(self.image_dir, self.image_list[idx]) mask_path = os.path.join(self.mask_dir, self.image_list[idx].replace('.jpg', '.png')) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) if self.transform: augmented = self.transform(image=image, mask=mask) image = augmented['image'] mask = augmented['mask'] return image, mask注意:数据预处理时应保持图像和掩码的同步变换。推荐使用Albumentations库进行高效的数据增强。
2. U2-Net核心架构解析
U2-Net的精妙之处在于其"U中的U"结构——每个编码器-解码器块本身又是一个小型的U-Net。这种设计实现了不同尺度的特征提取:
2.1 RSU模块实现
RSU(Residual U-block)是U2-Net的基本构建块。我们先实现RSU-7(7层结构):
import torch import torch.nn as nn class RSU7(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU7, self).__init__() self.rebnconvin = nn.Conv2d(in_ch, out_ch, 3, padding=1) # 编码器部分 self.rebnconv1 = nn.Conv2d(out_ch, mid_ch, 3, padding=1) self.pool1 = nn.MaxPool2d(2, stride=2) self.rebnconv2 = nn.Conv2d(mid_ch, mid_ch, 3, padding=1) self.pool2 = nn.MaxPool2d(2, stride=2) self.rebnconv3 = nn.Conv2d(mid_ch, mid_ch, 3, padding=1) self.pool3 = nn.MaxPool2d(2, stride=2) self.rebnconv4 = nn.Conv2d(mid_ch, mid_ch, 3, padding=1) self.pool4 = nn.MaxPool2d(2, stride=2) self.rebnconv5 = nn.Conv2d(mid_ch, mid_ch, 3, padding=1) self.pool5 = nn.MaxPool2d(2, stride=2) # 瓶颈层 self.rebnconv6 = nn.Conv2d(mid_ch, mid_ch, 3, padding=1) # 解码器部分 self.rebnconv7 = nn.Conv2d(mid_ch*2, mid_ch, 3, padding=1) self.upsample5 = nn.Upsample(scale_factor=2, mode='bilinear') self.rebnconv8 = nn.Conv2d(mid_ch*2, mid_ch, 3, padding=1) self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear') self.rebnconv9 = nn.Conv2d(mid_ch*2, mid_ch, 3, padding=1) self.upsample3 = nn.Upsample(scale_factor=2, mode='bilinear') self.rebnconv10 = nn.Conv2d(mid_ch*2, mid_ch, 3, padding=1) self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear') self.rebnconv11 = nn.Conv2d(mid_ch*2, mid_ch, 3, padding=1) self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear') self.rebnconvout = nn.Conv2d(mid_ch*2, out_ch, 3, padding=1) def forward(self, x): hx = x hxin = self.rebnconvin(hx) # 编码器 hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx = self.pool3(hx3) hx4 = self.rebnconv4(hx) hx = self.pool4(hx4) hx5 = self.rebnconv5(hx) hx = self.pool5(hx5) # 瓶颈 hx6 = self.rebnconv6(hx) # 解码器 hx7 = self.rebnconv7(torch.cat((hx6, hx5), 1)) hx7up = self.upsample5(hx7) hx8 = self.rebnconv8(torch.cat((hx7up, hx4), 1)) hx8up = self.upsample4(hx8) hx9 = self.rebnconv9(torch.cat((hx8up, hx3), 1)) hx9up = self.upsample3(hx9) hx10 = self.rebnconv10(torch.cat((hx9up, hx2), 1)) hx10up = self.upsample2(hx10) hx11 = self.rebnconv11(torch.cat((hx10up, hx1), 1)) hx11up = self.upsample1(hx11) return self.rebnconvout(torch.cat((hx11up, hxin), 1)) + hxin2.2 完整U2-Net架构
现在我们可以组装完整的U2-Net模型:
class U2NET(nn.Module): def __init__(self, in_ch=3, out_ch=1): super(U2NET, self).__init__() self.stage1 = RSU7(in_ch, 32, 64) self.pool12 = nn.MaxPool2d(2, stride=2) self.stage2 = RSU6(64, 32, 128) self.pool23 = nn.MaxPool2d(2, stride=2) self.stage3 = RSU5(128, 64, 256) self.pool34 = nn.MaxPool2d(2, stride=2) self.stage4 = RSU4(256, 128, 512) self.pool45 = nn.MaxPool2d(2, stride=2) self.stage5 = RSU4F(512, 256, 512) self.pool56 = nn.MaxPool2d(2, stride=2) self.stage6 = RSU4F(512, 256, 512) # 解码器部分 self.stage5d = RSU4F(1024, 256, 512) self.stage4d = RSU4(1024, 128, 256) self.stage3d = RSU5(512, 64, 128) self.stage2d = RSU6(256, 32, 64) self.stage1d = RSU7(128, 16, 64) # 侧输出卷积 self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) # 融合层 self.fuse = nn.Conv2d(6, out_ch, 1) def forward(self, x): hx = x # 编码器 hx1 = self.stage1(hx) hx = self.pool12(hx1) hx2 = self.stage2(hx) hx = self.pool23(hx2) hx3 = self.stage3(hx) hx = self.pool34(hx3) hx4 = self.stage4(hx) hx = self.pool45(hx4) hx5 = self.stage5(hx) hx = self.pool56(hx5) hx6 = self.stage6(hx) # 解码器 hx6up = nn.functional.interpolate(hx6, size=hx5.shape[2:], mode='bilinear') hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) hx5dup = nn.functional.interpolate(hx5d, size=hx4.shape[2:], mode='bilinear') hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) hx4dup = nn.functional.interpolate(hx4d, size=hx3.shape[2:], mode='bilinear') hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) hx3dup = nn.functional.interpolate(hx3d, size=hx2.shape[2:], mode='bilinear') hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) hx2dup = nn.functional.interpolate(hx2d, size=hx1.shape[2:], mode='bilinear') hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) # 侧输出 d1 = self.side1(hx1d) d2 = self.side2(hx2d) d3 = self.side3(hx3d) d4 = self.side4(hx4d) d5 = self.side5(hx5d) d6 = self.side6(hx6) # 上采样到原始尺寸 d2 = nn.functional.interpolate(d2, size=x.shape[2:], mode='bilinear') d3 = nn.functional.interpolate(d3, size=x.shape[2:], mode='bilinear') d4 = nn.functional.interpolate(d4, size=x.shape[2:], mode='bilinear') d5 = nn.functional.interpolate(d5, size=x.shape[2:], mode='bilinear') d6 = nn.functional.interpolate(d6, size=x.shape[2:], mode='bilinear') # 融合输出 outputs = [d1, d2, d3, d4, d5, d6] fuse = self.fuse(torch.cat(outputs, 1)) outputs.append(fuse) return [torch.sigmoid(o) for o in outputs]3. 训练策略与损失函数
U2-Net使用多任务学习策略,每个解码器阶段都有监督信号。损失函数结合了多个输出:
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels): bce_loss = nn.BCELoss(size_average=True) loss0 = bce_loss(d0, labels) loss1 = bce_loss(d1, labels) loss2 = bce_loss(d2, labels) loss3 = bce_loss(d3, labels) loss4 = bce_loss(d4, labels) loss5 = bce_loss(d5, labels) loss6 = bce_loss(d6, labels) loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 return loss0, loss训练过程中需要注意的关键点:
- 使用Adam优化器,初始学习率设为1e-3
- 采用学习率衰减策略,验证损失不再下降时降低学习率
- 批量大小根据GPU内存设置,通常8-16为宜
- 训练约100个epoch可以达到较好效果
from torch.optim import lr_scheduler def train_model(model, dataloaders, criterion, optimizer, num_epochs=100): best_loss = float('inf') scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.1, patience=5, verbose=True) for epoch in range(num_epochs): print(f'Epoch {epoch}/{num_epochs-1}') print('-' * 10) for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 running_loss0 = 0.0 for inputs, masks in dataloaders[phase]: inputs = inputs.to(device) masks = masks.to(device).float() optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) loss0, loss = criterion(*outputs, masks) if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_loss0 += loss0.item() * inputs.size(0) epoch_loss = running_loss / len(dataloaders[phase].dataset) epoch_loss0 = running_loss0 / len(dataloaders[phase].dataset) print(f'{phase} Loss: {epoch_loss:.4f} | Side1 Loss: {epoch_loss0:.4f}') if phase == 'val' and epoch_loss < best_loss: best_loss = epoch_loss torch.save(model.state_dict(), 'best_model.pth') if phase == 'val': scheduler.step(epoch_loss) return model4. 模型评估与推理优化
训练完成后,我们需要评估模型性能并优化推理过程:
4.1 评估指标实现
显著性检测常用三个指标:最大F-measure、平均绝对误差(MAE)和S-measure。
def evaluate_model(model, dataloader): model.eval() total_f = 0.0 total_mae = 0.0 total_s = 0.0 count = 0 with torch.no_grad(): for inputs, masks in dataloader: inputs = inputs.to(device) masks = masks.to(device).cpu().numpy() outputs = model(inputs)[-1].cpu().numpy() for i in range(outputs.shape[0]): pred = (outputs[i,0] * 255).astype('uint8') gt = (masks[i,0] * 255).astype('uint8') # 计算F-measure prec, recall = compute_precision_recall(pred, gt) f_measure = compute_f_measure(prec, recall) # 计算MAE mae = np.mean(np.abs(pred/255. - gt/255.)) # 计算S-measure s_score = compute_s_measure(pred, gt) total_f += f_measure total_mae += mae total_s += s_score count += 1 return { 'F-measure': total_f / count, 'MAE': total_mae / count, 'S-measure': total_s / count }4.2 推理优化技巧
实际部署时,可以采用以下优化策略:
- 模型剪枝:移除对最终输出贡献较小的通道
- 量化:将FP32模型转换为INT8,减少模型大小并加速推理
- ONNX导出:转换为ONNX格式以便跨平台部署
# 模型量化示例 quantized_model = torch.quantization.quantize_dynamic( model, {nn.Conv2d}, dtype=torch.qint8) # ONNX导出示例 dummy_input = torch.randn(1, 3, 320, 320) torch.onnx.export(model, dummy_input, "u2net.onnx", input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}})4.3 实际应用示例
下面是一个完整的图像显著性检测流程:
def detect_saliency(image_path, model): # 图像预处理 image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) orig_h, orig_w = image.shape[:2] # 调整大小并归一化 image = cv2.resize(image, (320, 320)) image = image.astype(np.float32) / 255. image = torch.from_numpy(image).permute(2,0,1).unsqueeze(0) # 推理 with torch.no_grad(): pred = model(image.to(device))[-1].cpu().numpy() # 后处理 saliency_map = (pred[0,0] * 255).astype('uint8') saliency_map = cv2.resize(saliency_map, (orig_w, orig_h)) # 可视化 _, binary_map = cv2.threshold(saliency_map, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) # 提取显著区域 contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) result = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) cv2.drawContours(result, contours, -1, (0,255,0), 2) return result, saliency_map在实际项目中,U2-Net的表现往往令人惊喜。有一次在处理一组复杂的街景图像时,模型准确地捕捉到了画面中的行人、车辆等关键元素,而忽略了背景建筑和天空等次要信息。这种精准的显著性检测能力使其在图像编辑、视觉注意力分析和目标跟踪等应用中大放异彩。