实验管理与模型版本控制:从"炼丹笔记"到可复现的工程体系
一、AI 实验的"混沌状态":哪个模型效果最好?没人记得清
AI 工程师的日常:调了一晚上超参数,模型准确率从 92% 涨到了 93.5%。第二天想复现这个结果,却发现忘了记录学习率是多少、用了哪个数据集版本、随机种子是多少。更常见的情况是:团队中有 5 个人在各自跑实验,每个人用自己的命名规则(model_v2_final_really_final.pt),没有人知道哪个模型是线上在用的。
实验管理的核心问题是"可复现性"——给定相同的代码、数据和配置,能否得到相同的结果?传统软件工程通过 Git 解决了代码的可复现性,但 AI 实验还涉及数据版本、模型权重、超参数和随机种子,这些都不是 Git 能单独管理的。
二、实验管理的架构与数据流
实验管理系统需要追踪四个维度的信息:代码版本(Git commit)、数据版本(DVC/数据哈希)、配置参数(超参数 YAML)和产出物(模型权重、评测指标)。每次实验是一个完整的快照,包含这四个维度的状态。
flowchart TD A[实验启动] --> B[记录代码版本<br/>Git commit SHA] A --> C[记录数据版本<br/>DVC / 数据哈希] A --> D[记录配置参数<br/>超参数 YAML] A --> E[分配实验 ID<br/>exp_20260609_001] B --> F[实验运行] C --> F D --> F F --> G[记录训练指标<br/>Loss / Accuracy 曲线] F --> H[保存模型权重<br/>checkpoint / best_model] F --> I[记录评测结果<br/>Benchmark 分数] G --> J[实验元数据存储<br/>MLflow / 自建数据库] H --> J I --> J J --> K[实验对比<br/>A/B 指标对比] J --> L[模型注册<br/>Staging → Production] J --> M[复现验证<br/>重新运行实验] subgraph "版本控制层次" N[代码版本 → Git] O[数据版本 → DVC] P[配置版本 → YAML + Git] Q[模型版本 → Registry] end关键设计原则:
- 自动记录:实验参数和指标自动记录,不依赖人工
- 不可篡改:实验记录一旦创建不可修改,只能追加
- 可查询:支持按参数、指标、时间范围查询实验
- 可复现:从实验记录可以还原完整的运行环境
三、实验管理系统的实现
# experiment_manager.py — AI 实验管理系统 # 设计意图:自动追踪实验的代码版本、数据版本、配置参数和产出物, # 提供实验对比、模型注册和复现验证功能 import json import hashlib import subprocess import os from datetime import datetime from pathlib import Path from dataclasses import dataclass, field, asdict from typing import List, Dict, Optional, Any @dataclass class ExperimentConfig: """实验配置""" experiment_name: str model_architecture: str hyperparameters: Dict[str, Any] dataset_name: str dataset_version: str = "" seed: int = 42 tags: List[str] = field(default_factory=list) @dataclass class ExperimentMetrics: """实验指标""" train_loss: List[float] = field(default_factory=list) val_loss: List[float] = field(default_factory=list) val_accuracy: List[float] = field(default_factory=list) final_train_loss: float = 0.0 final_val_loss: float = 0.0 final_val_accuracy: float = 0.0 custom_metrics: Dict[str, float] = field(default_factory=dict) @dataclass class Experiment: """实验记录""" experiment_id: str config: ExperimentConfig code_version: str = "" # Git commit SHA data_hash: str = "" # 数据集哈希 status: str = "running" # running / completed / failed created_at: str = "" completed_at: str = "" metrics: Optional[ExperimentMetrics] = None artifact_paths: Dict[str, str] = field(default_factory=dict) notes: str = "" class ExperimentTracker: """实验追踪器""" def __init__(self, storage_dir: str = "./experiments"): self.storage_dir = Path(storage_dir) self.storage_dir.mkdir(parents=True, exist_ok=True) self.current_experiment: Optional[Experiment] = None def create_experiment(self, config: ExperimentConfig) -> Experiment: """创建新实验""" exp_id = self._generate_experiment_id(config.experiment_name) experiment = Experiment( experiment_id=exp_id, config=config, code_version=self._get_git_commit(), data_hash=self._compute_data_hash(config.dataset_name), created_at=datetime.now().isoformat(), metrics=ExperimentMetrics(), ) # 保存实验配置 exp_dir = self.storage_dir / exp_id exp_dir.mkdir(parents=True, exist_ok=True) config_path = exp_dir / "config.json" with open(config_path, "w") as f: json.dump(asdict(experiment), f, indent=2, ensure_ascii=False) self.current_experiment = experiment return experiment def log_metrics( self, step: int, train_loss: float = 0.0, val_loss: float = 0.0, val_accuracy: float = 0.0, custom: Optional[Dict[str, float]] = None, ): """记录训练指标""" if not self.current_experiment: raise RuntimeError("No active experiment") metrics = self.current_experiment.metrics if train_loss: metrics.train_loss.append(train_loss) if val_loss: metrics.val_loss.append(val_loss) if val_accuracy: metrics.val_accuracy.append(val_accuracy) # 定期保存指标 if step % 100 == 0: self._save_metrics() def log_artifact(self, name: str, path: str): """记录产出物路径(模型权重、评测结果等)""" if not self.current_experiment: raise RuntimeError("No active experiment") self.current_experiment.artifact_paths[name] = path def complete_experiment(self, status: str = "completed"): """完成实验""" if not self.current_experiment: raise RuntimeError("No active experiment") self.current_experiment.status = status self.current_experiment.completed_at = datetime.now().isoformat() # 更新最终指标 metrics = self.current_experiment.metrics if metrics.train_loss: metrics.final_train_loss = metrics.train_loss[-1] if metrics.val_loss: metrics.final_val_loss = metrics.val_loss[-1] if metrics.val_accuracy: metrics.final_val_accuracy = metrics.val_accuracy[-1] self._save_experiment() def compare_experiments( self, exp_ids: List[str] ) -> Dict[str, Dict]: """对比多个实验""" comparisons = {} for exp_id in exp_ids: exp = self._load_experiment(exp_id) if exp: comparisons[exp_id] = { "name": exp.config.experiment_name, "hyperparameters": exp.config.hyperparameters, "final_val_accuracy": exp.metrics.final_val_accuracy if exp.metrics else 0, "final_val_loss": exp.metrics.final_val_loss if exp.metrics else 0, "code_version": exp.code_version, "status": exp.status, } return comparisons def find_best_experiment( self, metric: str = "final_val_accuracy", dataset: Optional[str] = None, ) -> Optional[Experiment]: """查找指标最优的实验""" best_exp = None best_score = float('-inf') for exp_dir in self.storage_dir.iterdir(): if not exp_dir.is_dir(): continue exp = self._load_experiment(exp_dir.name) if not exp or exp.status != "completed": continue if dataset and exp.config.dataset_name != dataset: continue score = getattr(exp.metrics, metric, 0) if exp.metrics else 0 if score > best_score: best_score = score best_exp = exp return best_exp def reproduce_experiment(self, exp_id: str) -> Dict: """生成复现实验所需的信息""" exp = self._load_experiment(exp_id) if not exp: return {"error": f"Experiment {exp_id} not found"} return { "experiment_id": exp_id, "reproduction_steps": [ f"1. Checkout code: git checkout {exp.code_version}", f"2. Verify data: expected hash {exp.data_hash}", f"3. Install dependencies: pip install -r requirements.txt", f"4. Run with config:", json.dumps(exp.config.hyperparameters, indent=2), f"5. Set seed: {exp.config.seed}", ], "config": asdict(exp.config), "expected_results": { "final_val_accuracy": exp.metrics.final_val_accuracy if exp.metrics else None, "final_val_loss": exp.metrics.final_val_loss if exp.metrics else None, }, } def _generate_experiment_id(self, name: str) -> str: """生成实验 ID""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") name_hash = hashlib.md5(name.encode()).hexdigest()[:6] return f"exp_{timestamp}_{name_hash}" def _get_git_commit(self) -> str: """获取当前 Git commit SHA""" try: result = subprocess.run( ["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True, ) return result.stdout.strip()[:12] except Exception: return "unknown" def _compute_data_hash(self, dataset_name: str) -> str: """计算数据集哈希""" # 简化实现:使用数据集名称的哈希 # 生产环境中应计算实际文件的 MD5 return hashlib.md5(dataset_name.encode()).hexdigest()[:12] def _save_metrics(self): """保存指标到文件""" if not self.current_experiment: return exp_dir = self.storage_dir / self.current_experiment.experiment_id metrics_path = exp_dir / "metrics.json" with open(metrics_path, "w") as f: json.dump(asdict(self.current_experiment.metrics), f, indent=2) def _save_experiment(self): """保存完整实验记录""" if not self.current_experiment: return exp_dir = self.storage_dir / self.current_experiment.experiment_id exp_path = exp_dir / "experiment.json" with open(exp_path, "w") as f: json.dump(asdict(self.current_experiment), f, indent=2, ensure_ascii=False) def _load_experiment(self, exp_id: str) -> Optional[Experiment]: """加载实验记录""" exp_path = self.storage_dir / exp_id / "experiment.json" if not exp_path.exists(): return None with open(exp_path) as f: data = json.load(f) config = ExperimentConfig(**data["config"]) metrics = ExperimentMetrics(**data["metrics"]) if data.get("metrics") else None return Experiment( experiment_id=data["experiment_id"], config=config, code_version=data.get("code_version", ""), data_hash=data.get("data_hash", ""), status=data.get("status", "unknown"), created_at=data.get("created_at", ""), completed_at=data.get("completed_at", ""), metrics=metrics, artifact_paths=data.get("artifact_paths", {}), notes=data.get("notes", ""), ) class ModelRegistry: """模型注册表:管理模型的版本和生命周期""" def __init__(self, registry_dir: str = "./model_registry"): self.registry_dir = Path(registry_dir) self.registry_dir.mkdir(parents=True, exist_ok=True) def register_model( self, model_name: str, experiment_id: str, stage: str = "staging", # staging / production / archived metrics: Optional[Dict] = None, ) -> str: """注册模型""" version = self._get_next_version(model_name) model_record = { "model_name": model_name, "version": version, "experiment_id": experiment_id, "stage": stage, "registered_at": datetime.now().isoformat(), "metrics": metrics or {}, } # 保存注册记录 record_dir = self.registry_dir / model_name / version record_dir.mkdir(parents=True, exist_ok=True) record_path = record_dir / "register.json" with open(record_path, "w") as f: json.dump(model_record, f, indent=2) # 更新最新版本指针 latest_path = self.registry_dir / model_name / "latest.json" with open(latest_path, "w") as f: json.dump(model_record, f, indent=2) return f"{model_name}/{version}" def promote_model(self, model_name: str, version: str, stage: str): """提升模型阶段(staging → production)""" record_dir = self.registry_dir / model_name / version record_path = record_dir / "register.json" if not record_path.exists(): raise ValueError(f"Model {model_name}/{version} not found") with open(record_path) as f: record = json.load(f) record["stage"] = stage record["promoted_at"] = datetime.now().isoformat() with open(record_path, "w") as f: json.dump(record, f, indent=2) def get_production_model(self, model_name: str) -> Optional[Dict]: """获取当前生产环境的模型""" model_dir = self.registry_dir / model_name if not model_dir.exists(): return None for version_dir in sorted(model_dir.iterdir(), reverse=True): if not version_dir.is_dir(): continue record_path = version_dir / "register.json" if record_path.exists(): with open(record_path) as f: record = json.load(f) if record.get("stage") == "production": return record return None def _get_next_version(self, model_name: str) -> str: """获取下一个版本号""" model_dir = self.registry_dir / model_name if not model_dir.exists(): return "v1" versions = [ d.name for d in model_dir.iterdir() if d.is_dir() and d.name.startswith("v") ] if not versions: return "v1" latest = max( int(v[1:]) for v in versions if v[1:].isdigit() ) return f"v{latest + 1}"四、实验管理的 Trade-offs
自动记录的侵入性:自动记录实验参数需要修改训练代码,增加代码侵入性。如果追踪框架与训练框架耦合过紧,迁移成本很高。建议使用装饰器或回调接口,将追踪逻辑与训练逻辑解耦。
存储成本:每次实验保存完整的配置、指标和模型权重,存储开销随实验数量线性增长。一个训练 10 个 epoch 的 BERT 模型,checkpoint 约 1.5GB,10 次实验就是 15GB。建议只保存最佳 checkpoint 和最后一个 checkpoint,中间 checkpoint 在实验完成后删除。
团队协作的标准化:不同开发者使用不同的命名规则和参数格式,导致实验记录难以对比。需要团队统一实验配置的格式(如使用 JSON Schema 校验)和命名规则(如{模型}_{数据集}_{日期})。
复现的随机性:即使记录了随机种子,由于 GPU 浮点运算的不确定性,不同硬件上的训练结果可能不完全一致。完全复现需要固定 CUDA 版本、cuDNN 版本和硬件型号,这在跨团队协作中几乎不可能。建议将"复现"定义为"指标在统计上等价"而非"数值完全一致"。
五、总结
实验管理系统通过追踪代码版本、数据版本、配置参数和产出物,将 AI 实验从"炼丹笔记"推向可复现的工程体系。自动记录确保信息完整,不可篡改保证数据可信,可查询支持实验对比,可复现验证结果可靠。但自动记录的侵入性、存储成本、团队标准化和复现随机性是需要权衡的因素。在实际落地中,建议使用 MLflow 等成熟框架而非自建系统,统一团队实验配置格式,定期清理过期 checkpoint,将复现标准定义为统计等价。实验管理的目标不是"记录一切",而是"让每一次实验都有价值、可追溯、可复现"。