# Predictor **Repository Path**: juventus_bupt/predictor ## Basic Information - **Project Name**: Predictor - **Description**: No description available - **Primary Language**: Unknown - **License**: AGPL-3.0 - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2025-12-10 - **Last Updated**: 2025-12-10 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 预测器训练指南(SigLIP 原生特征) 本仓库实现了一个极简的“未来帧特征预测”模型,用于基于预测的多模态 token 剪枝的前置部分。当前只训练一个小型 MLP 预测器,直接以 SigLIP 视觉塔的原生特征(不经过 Llava projector)为目标。 - 预测器模型文件:`predictive_model.py` - 数据加载与特征提取:`data_loader.py` - 训练脚本(Transformers Trainer + 可选 WandB):`train_predictor.py` ## 环境依赖 - Python 3.9/3.10 - PyTorch(建议 GPU) - Transformers >= 4.44(包含 SigLIP 模型) - 其他:Pillow、numpy、wandb(可选) 示例安装: - `pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121`(根据你的 CUDA 版本调整) - `pip install transformers pillow numpy wandb` ## 数据准备 训练脚本按“目录列表”读取视频帧:每一行是一个帧目录,目录中包含该视频的所有帧图像(建议按文件名排序,后缀 jpg/png/webp/bmp 均可)。 示例目录结构: - `frames/video_0001/000000.jpg ...` - `frames/video_0002/000000.jpg ...` 准备两个文本文件: - `data/train_dirs.txt`(每行一个目录路径) - `data/val_dirs.txt`(可选) 例如: ``` /abs/path/to/frames/video_0001 /abs/path/to/frames/video_0002 ``` ## 训练启动 最小示例(使用 SigLIP so400m 384,历史帧 x=4,MSE 损失): ``` python train_predictor.py \ --train_dirs_file data/train_dirs.txt \ --val_dirs_file data/val_dirs.txt \ --vision_model_id google/siglip-so400m-patch14-384 \ --x_frames 4 --stride 1 \ --batch_size 32 --num_train_epochs 3 \ --lr 1e-3 \ --output_dir outputs/predictor ``` 可选参数(常用): - `--image_size 384`:SigLIP 输入分辨率 - `--drop_cls`:丢弃 CLS token 后再做均值池化(推荐开启) - `--layer_index N`:从指定层的 hidden_states 取特征(默认使用 last_hidden_state) - `--hidden_dim`、`--num_layers`、`--dropout`、`--temporal_pool {mean,last}`:MLP 预测器结构 - `--fp16`:半精度训练(如显卡支持) ## WandB 可视化 传入 `--wandb_project your_project` 即可开启日志上报: ``` python train_predictor.py ... --wandb_project my_proj --wandb_run_name siglip_pred_mlp ``` - 首次使用需要设置 `WANDB_API_KEY` 环境变量或 `wandb login` - 上报的指标包含:训练 loss、验证 MSE、验证 (1 − cosine) ## 脚本说明 - `train_predictor.py` - 加载 SigLIP 视觉塔(`SiglipVisionModel.from_pretrained`),仅用于提特征,不参与训练 - 数据集 `VideoDirWindowDataset` 基于“滑动窗口 (x 帧 -> 下一帧)”自动从目录生成样本 - 预测器目标是 SigLIP 原生特征维度(常见 1152),损失为 MSE - 使用 Transformers `Trainer` 进行训练与评估,并保存到 `--output_dir` - `data_loader.py` - `extract_siglip_frame_embeddings(...)`:直接从 SigLIP 取 (T, N, D) 的序列特征,可选去 CLS;对 token 维均值池化为 (T, D) - `VideoDirWindowDataset`:返回 `(hist, next)`,形状分别为 `(x, D)` 与 `(D)` - 自测:`python data_loader.py`(会创建临时随机数据并打印 batch 形状) - `predictive_model.py` - `PredictorModel`:最小 MLP 预测器,输入 `(B, T, d)` 输出 `(B, T-1, d)` 与 loss - 自测:`python predictive_model.py`(打印前向形状与 loss) ## 训练产物与加载 训练完成后,目录 `--output_dir` 下会保存: - `pytorch_model.bin`:模型权重(`PredictorForNextFrame` 的 `state_dict`) - `predictor_config.json`:预测器配置 加载示例(推理): ``` import json, torch from train_predictor import PredictorForNextFrame from predictive_model import PredictorConfig with open('outputs/predictor/predictor_config.json','r',encoding='utf-8') as f: cfg = PredictorConfig(**json.load(f)) model = PredictorForNextFrame(cfg).eval() # 懒加载:确保 d_model 已知(可先构造一次假输入,或手动设置 cfg.d_model) cfg.d_model = cfg.d_model or 1152 # 例如 SigLIP so400m 的隐藏维 _ = model(torch.zeros(1, cfg.context, cfg.d_model)) # 构建内部 MLP state = torch.load('outputs/predictor/pytorch_model.bin', map_location='cpu') model.load_state_dict(state, strict=True) # hist: (B, x, d) hist = torch.randn(2, cfg.context, cfg.d_model) out = model(hist) print(out['logits'].shape) # (B, d) ``` ## 注意事项 - 性能建议:GPU 运行;由于数据集中会调用视觉塔做前向,`dataloader_num_workers` 设为 0 更稳。 - 预处理:当前实现使用简化的 Resize + 中心裁剪 + [-1,1] 归一化,与官方 SigLIP 处理一致性较高;如需绝对对齐,可改为使用 `AutoImageProcessor`。 - 提前缓存:如需更快训练,可在数据阶段将帧级特征 (T, D) 缓存到磁盘,训练时直接读取;本仓库暂未内置缓存逻辑。 如需我把缓存或分布式/混合精度等增强功能补上,告诉我你的偏好与环境即可。 或直接给出 LLaVA-OneVision 权重(脚本会从中提取 SigLIP 视觉塔并冻结,仅训练预测器): ``` python train_predictor.py \ --train_dirs_file data/train_dirs.txt \ --val_dirs_file data/val_dirs.txt \ --llava_model_id /path/to/llava-onevision-qwen2-7b-ov-hf \ --x_frames 4 --stride 1 \ --batch_size 32 --num_train_epochs 3 \ --lr 1e-3 \ --output_dir outputs/predictor ```