# 基于CNN原型网络的小样本学习方法 **Repository Path**: levijia/VPCN-FSL ## Basic Information - **Project Name**: 基于CNN原型网络的小样本学习方法 - **Description**: 本项目采用**原型网络(Prototypical Networks)**作为核心算法,结合**卷积神经网络(CNN)**作为特征提取器,实现蔬菜病虫害的小样本分类任务。 - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: main - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2026-02-28 - **Last Updated**: 2026-03-05 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 蔬菜病虫害基于CNN原型网络的小样本学习项目 ## 项目简介 本项目实现了基于CNN原型网络(Prototypical Networks)的蔬菜病虫害小样本分类系统。通过元学习框架,能够在仅有少量样本的情况下快速学习新类别,适用于实际农业场景中病虫害种类繁多、标注数据稀缺的问题。 ## 核心特性 - **原型网络算法**:基于欧氏距离的类别原型计算 - **4层CNN特征提取器**:轻量级网络,高效提取图像特征 - **小样本学习**:支持 N-way K-shot 任务配置 - **完整训练流程**:包含训练、验证、测试全流程 - **一键数据集准备**:自动下载和划分PlantVillage数据集 ## 环境要求 ### 硬件要求 - **CPU**: Intel i5及以上(推荐GPU加速) - **GPU**: NVIDIA GTX 1060及以上(可选,强烈推荐) - **内存**: 8GB以上 - **存储**: 至少20GB可用空间 ### 软件环境 - Python >= 3.8 - PyTorch >= 1.12.0 - CUDA >= 11.0(如果使用GPU) ## 安装步骤 ### 1. 克隆项目 ```bash git clone cd cnn_demo ``` ### 2. 创建虚拟环境(推荐) ```bash # 创建虚拟环境 python -m venv venv # 激活虚拟环境 # Windows: venv\Scripts\activate # Linux/macOS: source venv/bin/activate ``` ### 3. 安装依赖 ```bash pip install -r requirements.txt ``` 依赖列表: ``` torch>=1.12.0 torchvision>=0.13.0 numpy>=1.21.0 pillow>=9.0.0 matplotlib>=3.5.0 tqdm>=4.64.0 scikit-learn>=1.1.0 opencv-python>=4.5.0 ``` ## 数据集准备 ### 方式1: 使用一键下载脚本(推荐) #### 使用Kaggle下载 ```bash # 1. 配置Kaggle API # 访问 https://www.kaggle.com/settings # 点击 "Create New API Token" 下载 kaggle.json # 将 kaggle.json 放到 ~/.kaggle/ 目录 # 2. 下载并自动划分数据集 python download_dataset.py --method kaggle --split ``` #### 使用GitHub下载 ```bash python download_dataset.py --method github --split ``` #### 使用已下载的数据集 ```bash python download_dataset.py --method existing --existing_dir /path/to/your/dataset --split ``` ### 方式2: 手动准备 1. 下载PlantVillage数据集 2. 按以下结构组织数据: ``` data/ ├── base_classes/ # 基础类别(用于训练) │ ├── Tomato_Bacterial_spot/ │ │ ├── img_001.jpg │ │ ├── img_002.jpg │ │ └── ... │ ├── Tomato_Early_blight/ │ └── ... ├── novel_classes/ # 新类别(用于测试) │ ├── Tomato_Late_blight/ │ └── ... └── meta_test/ # 元测试集 ├── Tomato_healthy/ └── ... ``` ### 数据集划分比例 - **基础类别 (base_classes)**: 50% 的类别,用于训练 - **新类别 (novel_classes)**: 25% 的类别,用于测试小样本学习能力 - **元测试集 (meta_test)**: 25% 的类别,用于验证 ## 配置说明 编辑 `config.py` 文件可以调整以下参数: ```python # 设备配置 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据路径 BASE_CLASSES_DIR = './data/split/base_classes' NOVEL_CLASSES_DIR = './data/split/novel_classes' META_TEST_DIR = './data/split/meta_test' # 模型配置 IMAGE_SIZE = 224 FEATURE_DIM = 256 # Few-shot学习配置 N_WAY = 5 # 每个episode的类别数 K_SHOT = 5 # 每类的支持样本数 Q_QUERY = 15 # 每类的查询样本数 # 训练配置 NUM_TRAIN_EPISODES = 10000 NUM_VAL_EPISODES = 1000 NUM_EPOCHS = 100 BATCH_SIZE = 4 LEARNING_RATE = 0.001 ``` ## 训练模型 ### 基础训练 ```bash python train.py ``` ### 训练输出 训练过程中会显示: - 每个epoch的训练损失和准确率 - 验证集准确率 - 最佳模型保存信息 示例输出: ``` Using device: cuda 加载数据集... 训练集类别数: 19 训练集样本数: 45456 验证集类别数: 10 验证集样本数: 8856 创建模型... 模型总参数量: 3,456,256 可训练参数量: 3,456,256 开始训练... 配置: 5-way 5-shot 训练Episodes: 10000 验证Episodes: 1000 训练轮数: 100 -------------------------------------------------- Epoch 1/100: loss: 0.4523, acc: 0.8234 Epoch 1/100: Train Loss: 0.4892, Train Acc: 0.8156 Val Acc: 0.8342 Best model saved with acc: 0.8342 -------------------------------------------------- ... Training completed! Best validation accuracy: 0.8765 ``` ### 模型保存 最佳模型会自动保存到 `./checkpoints/best_model.pth` ## 测试模型 ```bash python test.py ``` ### 测试输出 ``` Using device: cuda 加载测试数据集... 测试集类别数: 10 测试集样本数: 8856 开始测试... -------------------------------------------------- Loaded model from epoch 87 Validation accuracy: 0.8765 Test Accuracy: 0.8543 ``` ## 项目结构 ``` cnn_demo/ ├── README.md # 本文档 ├── requirements.txt # 依赖列表 ├── config.py # 配置文件 ├── train.py # 训练脚本 ├── test.py # 测试脚本 ├── download_dataset.py # 数据集下载脚本 ├── models/ # 模型定义 │ ├── __init__.py │ ├── feature_extractor.py # CNN特征提取器 │ └── prototypical_network.py # 原型网络 ├── data/ # 数据处理 │ ├── __init__.py │ └── episode.py # Episode生成器 ├── utils/ # 工具函数 │ ├── __init__.py │ ├── train.py # 训练函数 │ └── evaluate.py # 评估函数 ├── checkpoints/ # 模型检查点 │ └── best_model.pth └── data/ # 数据目录 └── split/ ├── base_classes/ ├── novel_classes/ └── meta_test/ ``` ## 性能指标 ### 预期准确率 | 配置 | 预期准确率 | |------|-----------| | 5-way 1-shot | 65% - 75% | | 5-way 5-shot | 80% - 88% | | 5-way 10-shot | 85% - 92% | ### 训练时间 - **GPU (RTX 3090)**: 约 2-4 小时 - **GPU (GTX 1060)**: 约 4-6 小时 - **CPU (Intel i7)**: 约 20-40 小时(不推荐) ## 常见问题 ### 1. 显存不足 **解决方案**: - 减小 `BATCH_SIZE` (config.py) - 减小 `IMAGE_SIZE` (config.py) - 使用梯度累积 ### 2. 数据集下载失败 **解决方案**: - 检查网络连接 - 尝试使用不同的下载方式(kaggle/github/existing) - 手动下载后使用 `--method existing` 指定路径 ### 3. 训练loss不下降 **解决方案**: - 检查学习率设置 - 确认数据增强是否合理 - 检查模型结构是否正确 ### 4. 模型加载失败 **解决方案**: - 确认模型文件路径正确 - 检查模型文件是否完整 - 确认PyTorch版本兼容性 ## 高级用法 ### 1. 调整Few-shot配置 修改 `config.py` 中的参数: ```python # 5-way 1-shot N_WAY = 5 K_SHOT = 1 Q_QUERY = 15 # 10-way 5-shot N_WAY = 10 K_SHOT = 5 Q_QUERY = 10 ``` ### 2. 使用不同的特征提取器 修改 `models/feature_extractor.py`,可以替换为更强大的backbone: ```python # 使用ResNet from torchvision.models import resnet18 feature_extractor = resnet18(pretrained=True) ``` ### 3. 调整数据增强 修改 `train.py` 中的数据增强配置: ```python train_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), transforms.RandomRotation(degrees=15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) ``` ### 4. 使用预训练模型 在 `train.py` 中加载预训练权重: ```python # 加载ImageNet预训练权重 feature_extractor = CNNFeatureExtractor() # 假设有预训练权重 feature_extractor.load_state_dict(torch.load('pretrained_feature_extractor.pth')) ``` ## 远程部署 ### 使用智星云等云GPU平台 1. **创建实例** - 选择GPU: RTX 3090 / 4090 / A100 - 系统: Ubuntu 20.04 或 22.04 - 内存: 32GB+ - 存储: 100GB+ 2. **连接实例** ```bash ssh -i your_key.pem root@your_instance_ip ``` 3. **上传项目** ```bash # 使用SCP上传 scp -i your_key.pem -r ./cnn_demo root@your_instance_ip:/root/ ``` 4. **安装依赖并训练** ```bash cd /root/cnn_demo pip install -r requirements.txt python download_dataset.py --method kaggle --split nohup python train.py > train.log 2>&1 & ``` 5. **监控训练** ```bash tail -f train.log nvidia-smi ``` 6. **下载结果** ```bash scp -i your_key.pem root@your_instance_ip:/root/cnn_demo/checkpoints/best_model.pth ./ ``` ## 扩展方向 ### 1. 算法改进 - 使用更强的特征提取器(ResNet, EfficientNet) - 引入注意力机制 - 改进距离度量方式 ### 2. 数据增强 - 使用Mixup、CutMix等高级增强方法 - 使用生成模型(GAN)生成更多样本 ### 3. 多模态学习 - 结合图像和文本信息 - 使用专家知识辅助分类 ### 4. 增量学习 - 支持新类别的持续学习 - 避免灾难性遗忘 ## 参考文献 1. **Prototypical Networks for Few-shot Learning** - Snell, J., Swersky, K., & Zemel, R. - NIPS 2017 - [arXiv:1703.05175](https://arxiv.org/abs/1703.05175) 2. **PlantVillage Dataset** - [GitHub](https://github.com/spMohanty/PlantVillage-Dataset) ## 许可证 本项目遵循 MIT 许可证。 ## 联系方式 如有问题或建议,请通过以下方式联系: - GitHub Issues: [项目地址] - Email: your.email@example.com --- **版本**: 1.0 **更新日期**: 2026-02-28