【DriveGen 文件详解】04——evaluate.py
2026/6/4 15:33:25 网站建设 项目流程
DriveGen/ ├── configs/ │ └── default.yaml # 训练配置文件 ├── DriveGen/ │ ├── __init__.py # 包初始化 │ ├── models/ │ │ ├── __init__.py │ │ ├── embedding.py # Patch嵌入、时间步编码、位置编码 │ │ ├── attention.py # 空间注意力、时间注意力 │ │ ├── dit_block.py # AdaLN-Zero DiT Block │ │ └── stdit.py # STDiT 完整模型 │ ├── data/ │ │ ├── __init__.py │ │ └── dataset.py # 合成数据集 + nuScenes 适配器 │ ├── schedules/ │ │ ├── __init__.py │ │ └── noise_schedule.py # 线性/余弦噪声调度 │ └── utils/ │ ├── __init__.py │ ├── visualization.py # 视频保存、对比图、损失曲线 │ └── logger.py # 日志工具 ├── train.py # 训练脚本 ├── inference.py # 推理脚本(DDPM 采样 + CFG) ├── evaluate.py # 评估脚本(FID 计算) ├── requirements.txt # 依赖清单 ├── setup.py # 安装配置 └── README.md # 本文件

LQY-hh/DriveGen-Transformer-: 自动驾驶技术的发展离不开海量数据的支撑,但稀有场景(如极端天气、突发事故)的数据采集成本极高。**DriveGen** 旨在通过扩散模型生成高质量的驾驶场景视频,为自动驾驶算法提供无限的虚拟训练数据。 ### 核心价值https://github.com/LQY-hh/DriveGen-Transformer-

DriveGen 评估脚本说明文档

概述

evaluate.py是 DriveGen 项目的视频生成质量评估脚本,主要用于计算生成视频与真实视频之间的FID(Frechet Inception Distance)分数,以量化评估生成模型的性能。


核心功能

功能模块

说明

FID 计算

衡量生成分布与真实分布的距离,值越小表示生成质量越好

特征提取

使用简化的 CNN 网络从视频帧中提取特征向量

指标评估

自动生成评估报告,保存结果到文件


评估原理

FID(Frechet Inception Distance)

FID 是衡量生成模型质量的标准指标,其核心思想是:

  1. 特征提取:从真实图像和生成图像中提取高维特征

  2. 分布建模:假设特征服从多元高斯分布,计算均值和协方差

  3. 距离计算:计算两个高斯分布之间的 Frechet 距离

数学公式

d² = ||μ₁ - μ₂||² + Tr(σ₁ + σ₂ - 2√(σ₁σ₂))

其中:

  • μ₁, μ₂:两个分布的均值向量

  • σ₁, σ₂:两个分布的协方差矩阵

  • Tr:矩阵的迹


代码结构

1. 特征提取器

SimpleFeatureExtractor(evaluate.py#L126-L210):

一个轻量级 CNN 网络,用于从图像帧中提取 256 维特征向量。

# 网络结构 Conv2d(3, 32) → ReLU → MaxPool → Conv2d(32, 64) → ReLU → MaxPool → Conv2d(64, 128) → AdaptiveAvgPool → Linear(256)

设计说明

  • 输入:(B, 3, H, W),值域 [0, 1]

  • 输出:(B, 256)

  • 使用自适应平均池化处理不同尺寸输入

2. FID 计算函数

compute_frechet_distance(evaluate.py#L247-L320):

计算两个多元高斯分布之间的 Frechet 距离。

关键处理

  • 使用scipy.linalg.sqrtm计算矩阵平方根

  • 处理数值不稳定情况(复数结果)

  • 确保结果非负

compute_fid(evaluate.py#L346-L378):

封装完整的 FID 计算流程:

  1. 计算真实特征的均值和协方差

  2. 计算生成特征的均值和协方差

  3. 调用compute_frechet_distance计算距离

3. 特征提取流程

extract_features_from_dataset(evaluate.py#L381-L428):

从数据集中提取真实视频帧的特征。

extract_features_from_generated(evaluate.py#L431-L503):

使用模型生成视频并提取特征。

4. 主函数

main(evaluate.py#L506-L648):

执行完整的评估流程:

加载配置 → 创建组件 → 提取真实特征 → 生成并提取特征 → 计算 FID → 保存结果

使用方法

命令行参数

参数

简写

类型

默认值

说明

--checkpoint

-c

str

必须

模型检查点路径

--config

-

str

configs/default.yaml

配置文件路径

--num_samples

-n

int

配置文件值

评估样本数

--output_dir

-o

str

eval_results/

输出目录

--device

-

str

自动检测

计算设备

--seed

-

int

42

随机种子

--batch_size

-

int

8

特征提取批量大小

使用示例

# 基本用法 python evaluate.py --checkpoint checkpoints/best.pth # 指定样本数 python evaluate.py --checkpoint checkpoints/best.pth --num_samples 200 # 指定输出目录和设备 python evaluate.py --checkpoint checkpoints/best.pth --output_dir eval_results/ --device cuda

输出结果

评估完成后,会在输出目录生成evaluation_results.txt文件:

DriveGen 评估结果 ======================================== FID 分数: 12.3456 评估样本数: 100 真实帧数: 400 生成帧数: 400 特征维度: 256 检查点: checkpoints/best.pth 说明: FID (Frechet Inception Distance) 衡量生成分布与真实分布的距离。 值越小越好,0 表示完美匹配。 注意: 此评估使用简化的特征提取器,结果仅供参考。 实际应用中建议使用 InceptionV3 获取更准确的 FID。

关键设计特点

1. 简化特征提取

使用自定义 CNN 而非预训练的 InceptionV3,便于学习和部署,同时保持 FID 计算流程的完整性。

2. 数值稳定性

compute_frechet_distance中:

  • 处理协方差矩阵平方根可能出现的复数问题

  • 使用 SVD 分解作为备选方法

  • 确保最终结果非负

3. 批处理优化

特征提取采用批处理方式,支持大样本评估,提高计算效率。

4. 结果可追溯

自动保存评估参数和结果,便于实验复现和对比分析。


扩展建议

使用 InceptionV3(推荐)

为获得更准确的 FID 分数,建议使用预训练的 InceptionV3:

from torchvision.models import inception_v3 class InceptionFeatureExtractor(nn.Module): def __init__(self): super().__init__() self.model = inception_v3(pretrained=True, transform_input=False) self.model.fc = nn.Identity() # 移除分类层 def forward(self, x): return self.model(x)

增加更多评估指标

可以扩展支持以下指标:

  • IS(Inception Score):衡量生成样本的多样性和质量

  • LPIPS:感知相似度指标

  • SSIM/PSNR:像素级相似度指标


注意事项

  1. 特征提取器差异:本脚本使用简化的 CNN,与标准 InceptionV3 的 FID 结果不可直接比较

  2. 样本数量:FID 计算需要足够的样本数才能稳定,建议至少 100 个样本

  3. 计算资源:大规模评估可能需要较长时间和较多显存

  4. 结果解读:FID 只是评估指标之一,还需结合主观视觉评估

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询