# post_continual_learning **Repository Path**: enzeyu/post_continual_learning ## Basic Information - **Project Name**: post_continual_learning - **Description**: sssssssssssssssss - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2025-03-20 - **Last Updated**: 2026-05-18 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 自演进算法 本仓库实现了一个面向类增量学习场景的端-边协同训练流程。 ## 主要思想 实验将 CIFAR-100 按类别顺序划分为 10 个连续任务,每个任务包含 10 个类别。每到一个新任务,端侧设备先从当前任务数据中筛选少量有代表性的样本并上传给服务器;服务器基于这些代表样本评估模型通道重要性,对当前任务模型进行结构化裁剪;随后将裁剪后的模型下发给端侧,在完整的当前任务训练集上进行本地训练。 核心目标是在边缘服务器每个增量任务上生成一个更轻量的任务模型给终端设备训练,同时尽量保留当前任务所需的关键通道。`target_mask_density=0.7` 表示裁剪时每个可裁剪卷积层约保留 70% 的通道,约剪掉 30% 重要性较低的通道。 ![思想](./figures/idea.png) ## 实验设置 | 项目 | 设置 | | --- | --- | | 数据集 | CIFAR-100 | | 任务数 | 10 | | 每个任务类别数 | 10 | | 类别划分 | 第 1 个任务为类别 0-9,第 2 个任务为类别 10-19,依次到类别 90-99 | | 默认模型 | ResNet-18 | | 端侧数量 | 1 | | 全局轮数 | 10 | | 每个任务持续轮数 | 1 | | 本地训练轮数 | 100 | | batch size | 64 | | 本地学习率 | 0.001 | | 服务器学习率 | 0.2 | | 随机候选样本数 | 每类 30 个 | | 原型筛选样本数 | 每类 10 个 | | 目标保留密度 | 0.7 | | 随机种子 | 2021 | 除命令中显式给出的参数外,其余参数来自 `option.py` 的默认值。 ## 数据集 本实验使用 `datasets/cifar100.py` 中的 `iCIFAR100`。代码默认从以下路径读取 CIFAR-100: ```text ./dataset/CIFAR100 ``` `main.py` 中设置 `download=False`,因此运行前需要保证 CIFAR-100 已经存在于该目录下。当前仓库中的默认数据路径为: ```text /mnt/data/enzeyu/post_continual_learning/dataset/CIFAR100 ``` 数据预处理针对 CIFAR-100 + ResNet-18: - 训练集:随机裁剪、随机水平翻转、颜色扰动、转 Tensor、按 CIFAR-100 均值方差归一化。 - 测试集:resize 到 `img_size=32`、转 Tensor、按 CIFAR-100 均值方差归一化。 ## 模型 默认模型为 ResNet-18,对 CIFAR-100 输入做了适配: - 首层卷积使用 `3x3`、stride 为 1、padding 为 1,适合 32x32 图像。 - ResNet-18 作为特征提取器,输出 512 维特征。 - `models/myNetwork.py` 在特征提取器后接一个全连接分类头。 - 每个任务只训练当前 10 个类别,因此分类头输出维度为 `numclass=10`。 端侧样本筛选阶段额外使用 `torchvision.models.mobilenet_v3_small` 作为特征提取模型,并从 `models/mobilenet_v3_small-047dcff4.pth` 加载权重。 ## 算法流程 1. 初始化参数、随机种子、日志和保存目录。 2. 加载 CIFAR-100,并按照 `tasks_global=10` 和 `numclass=10` 构造 10 个任务的数据加载器。 3. 初始化服务器模型 `ServerModel`,其中包含 ResNet-18 特征提取器和 10 类分类头。 4. 对每个全局 epoch 计算当前任务编号。由于 `tasks_epoch=1`,每个 epoch 对应一个新任务。 5. 端侧阶段一:对当前任务的每个类别先随机采样 30 个候选样本。 6. 使用 MobileNetV3-small 提取候选样本特征,计算每个类别的特征原型。 7. 对每个类别选取距离原型最近的 10 个样本,组成代表样本 dataloader,并上传给服务器。 8. 服务器阶段:使用代表样本对 ResNet-18 各卷积层注册前向和反向钩子,计算通道重要性: ```text importance = mean(abs(activation)) * mean(abs(gradient)) ``` 9. 根据通道重要性从低到高选择待剪枝通道,并通过 `torch_pruning.DependencyGraph` 执行结构化通道裁剪。`target_mask_density=0.7` 表示每层约保留 70% 通道。 10. 保存当前任务裁剪后的模型到 `model_saved_check/cifar100_1_10_test/Task__model.pkl`。 11. 端侧阶段二:端侧接收裁剪后的模型,在当前任务完整训练数据上使用 SGD 训练 100 个本地 epoch。 12. 在当前任务测试集上计算并记录测试损失和准确率。 ## 运行方法 在仓库根目录执行: ```bash python main.py --dataset cifar100 --numclass 10 --tasks_global 10 --target_mask_density 0.7 ``` 如果需要指定 GPU,可以额外设置 `--device`,例如: ```bash python main.py --dataset cifar100 --numclass 10 --tasks_global 10 --target_mask_density 0.7 --device 0 ``` ## 输出结果 运行后主要产生以下输出: - `train.log`:记录参数配置、每个任务的裁剪信息、模型参数量、模型大小、训练损失、训练准确率和测试准确率。 - `model_saved_check/cifar100_1_10_test/`:保存每个任务对应的裁剪模型。 - 控制台输出:包括日志初始化、数据准备、模型参数统计等信息。 模型保存文件示例: ```text model_saved_check/cifar100_1_10_test/Task_1_model.pkl model_saved_check/cifar100_1_10_test/Task_2_model.pkl ... model_saved_check/cifar100_1_10_test/Task_10_model.pkl ``` ## 代码结构 ```text main.py # 实验入口,组织数据加载、端侧筛选、服务器裁剪和本地训练 option.py # 命令行参数和默认实验配置 datasets/cifar100.py # CIFAR-100 类增量数据集封装 ends/end.py # 端侧逻辑:任务数据加载、代表样本筛选、本地训练和测试 models/serverModel.py # 服务器模型封装 models/resnet.py # ResNet-18 特征提取器 models/myNetwork.py # 特征提取器 + 分类头 myutils/server_utils.py # 通道重要性计算和结构化裁剪 myutils/global_utils.py # 随机种子、参数统计、模型大小统计等工具函数 ``` ## 依赖说明 - `torch` - `torchvision` - `numpy` - `scikit-learn` - `torch-pruning` - `matplotlib` - `Pillow`