# jnn-learning **Repository Path**: socrates2017/jnn-learning ## Basic Information - **Project Name**: jnn-learning - **Description**: java神经网络学习项目,通俗易懂,让您深入理解神经网络的算法机制 - **Primary Language**: Java - **License**: MIT - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 1 - **Forks**: 0 - **Created**: 2026-05-25 - **Last Updated**: 2026-05-27 ## Categories & Tags **Categories**: Uncategorized **Tags**: 神经网络, 反向传播算法, 深度学习 ## README # jnn-learning [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE) [![Java](https://img.shields.io/badge/Java-8-blue.svg)](https://www.oracle.com/java/) [![Maven](https://img.shields.io/badge/Maven-3.x-red.svg)](https://maven.apache.org/) > **用纯 Java 手撕神经网络——零框架依赖,每一行代码都在告诉你:深度学习到底在干什么。** --- ## 为什么你需要这个项目? 你学过神经网络的原理,看过无数推导公式,但内心深处一直有个声音: > **"我知道反向传播怎么算,但如果让我从零写一个,我真的能写出来吗?"** PyTorch 一行 `torch.nn.Linear()` 就搞定了前向传播,TensorFlow 一个 `model.fit()` 就完成了训练。框架太强大,强大到让你**看不见**底层发生了什么。梯度是怎么流的?权重是何时更新的?批量训练和单样本训练的本质区别是什么? `jnn-learning` 的答案很简单:**我们把神经网络拆开,一个神经元一个神经元地拼给你看。** ## 构建网络示例 ```java // 构建一个 2-2-1 的网络,用 Sigmoid 激活函数 NeuronNet neuronNet = new NeuronNet(); // 隐藏层:2 输入 → 2 神经元 NeuronLayer inputNeuronLayer = new NeuronLayer(neuronNet, null, 0, 2, 2, SigmoidFunctionBatch.forwardFunction(), SigmoidFunctionBatch.backwardFunction()); // 输出层:2 神经元 → 1 输出节点 NeuronLayer outputNeuronLayer = new NeuronLayer(neuronNet, inputNeuronLayer, 0, 2, 1, SigmoidFunctionBatch.forwardFunction(), SigmoidFunctionBatch.backwardFunction()); inputNeuronLayer.nextNeuronLayer = outputNeuronLayer; neuronNet.addNeuronLayer(inputNeuronLayer); neuronNet.addNeuronLayer(outputNeuronLayer); log.info("{}", neuronNet.toString()); //输出网络架构 NeuronNet { learningRate=0.001, batchSize=0, debug=false topology=2-1 totalParams=9, layers=2 Layer[0] (input) neurons=2, inputs=2, activation=com.zrzhen.com.nn.forwardBackwardFuncion.SigmoidFunctionBatch$1, mode=batch, params=6 Layer[1] (input) neurons=1, inputs=2, activation=com.zrzhen.com.nn.forwardBackwardFuncion.SigmoidFunctionBatch$1, mode=batch, params=3 Structure: Input(2) --> Output(1)} ``` **没有 `import torch`,没有 `tf.keras`,只有 Java 对象、float 数组和纯粹的链式传播。** ## 核心亮点 | 亮点 | 它意味着什么 | | --- | --- | | 🧱 **零框架依赖** | 不依赖 PyTorch / TensorFlow / DL4J,纯 JDK + 几个轻量库(Jackson、Logback) | | 🔍 **完全透明** | 每个 Neuron 都有独立的 `weight[]`、`bias`、`output`、`grad`,你可以打印任何一个神经元在任意时刻的状态 | | 🎯 **策略模式激活函数** | 通过 `DyFunction` 泛型接口注入前向/反向计算逻辑,切换 Sigmoid ↔ Tanh 只需改一行参数 | | ⚡ **双模式训练** | 单样本 SGD 逐个调参 + Mini-batch 矩阵批量计算,两种路径完整实现,对比学习效果拔群 | | 🔗 **层间双向链表** | 每层持有 `preNeuronLayer` 和 `nextNeuronLayer`,前向沿 next 走,反向沿 pre 回,结构即算法 | | 💾 **JSON 模型持久化** | 训练好的权重保存为可读 JSON,每个神经元的 bias 和 weight[] 一目了然 | | 📐 **延迟更新机制** | 梯度缓存到 `weightNew`/`biasNew`,等所有层算完再统一提交,避免参数不一致 | ## 适合谁? | 你是... | 这个项目能帮你... | | --- | --- | | **AI 入门者** | 从零理解神经网络的工作原理,而不是只会调 API | | **Java 开发者** | 用你最熟悉的语言理解深度学习,无需额外学 Python | | **计算机专业学生** | 课程设计 / 毕业设计的绝佳参考,代码量适中(核心 ~2000 行),注释清晰 | | **面试准备者** | 能够手写反向传播,面试官问你 "从零实现 MLP" 时从容应对 | | **教育工作者** | 可作为《机器学习》或《人工智能导论》课程的配套实验项目 | ## 快速上手 ### 环境要求 - JDK 1.8+ - Maven 3.x - IDE 推荐:IntelliJ IDEA 或 Eclipse ### 三步跑通 ```bash # 1. 克隆项目 git clone https://github.com/Socrates2017/jnn-learning.git cd jnn-learning # 2. 编译 mvn compile # 3. 运行 XOR 示例(最经典的神经网络入门问题) mvn -Dtest=XorTest test ``` 运行后你会在日志中看到完整的训练过程: ``` epoch=1000, loss=0.25000000 epoch=2000, loss=0.12500000 ... epoch=9900, loss=0.00000123 ← 收敛! predict(0,0)=0.002, predict(0,1)=0.998, predict(1,0)=0.997, predict(1,1)=0.003 ``` > **推荐在 IDE 中直接运行测试类**(`src/test/java/com/zrzhen/com/nn/` 下),Logback 日志输出更直观。 --- ## 技术栈 | 技术 | 版本 | 用途 | | --- | --- | --- | | Java | 1.8 | 编程语言 | | Maven | 3.x | 构建工具 | | JUnit Jupiter | 5.8.0 | 测试框架 | | Jackson | 2.9.5 | JSON 序列化(模型保存/加载) | | Logback | 1.1.7 | 日志框架 | | Apache Commons Lang3 | 3.8.1 | 工具类 | ## 目录结构 ```text . ├── data/ # 训练数据、测试数据、模型参数 │ ├── xor/ # XOR 异或问题 │ ├── classify/ # 简单二分类 │ ├── classify2/ # MLP 二分类 │ └── iris/ # Iris 鸢尾花分类 ├── src/ │ ├── main/java/com/zrzhen/com/nn/ │ │ ├── NeuronNet.java # 网络容器 │ │ ├── NeuronLayer.java # 神经层 │ │ ├── NeuronLayerBase.java # 神经层基础数据 │ │ ├── Neuron.java # 神经元 │ │ ├── NeuronBase.java # 神经元基础数据 │ │ ├── Tensor.java # 输入数据封装 │ │ ├── DyFunction.java # 前向/反向计算策略基类 │ │ ├── BackwardRsp.java # 反向传播返回值 │ │ ├── forwardBackwardFuncion/ # 激活函数实现 │ │ │ ├── Sigmoid.java # Sigmoid 数学工具 │ │ │ ├── SigmoidFunction.java # Sigmoid 单样本策略 │ │ │ ├── SigmoidFunctionBatch.java # Sigmoid 批量策略 │ │ │ ├── Tanh.java # Tanh 数学工具 │ │ │ └── TanhFunction.java # Tanh 单样本策略 │ │ ├── data/ │ │ │ └── TrainData.java # 训练数据 POJO │ │ └── util/ │ │ ├── DataUtil.java # 数据加载/保存 │ │ ├── FileUtil.java # 文件读写 │ │ └── JsonUtil.java # JSON 序列化 │ ├── main/resources/ │ │ └── logback.xml # 日志配置 │ └── test/java/com/zrzhen/com/nn/ # 示例训练与验证用例 ├── docs/ │ └── DEVELOPER_GUIDE.md # 开发者指南(700+ 行) └── pom.xml ``` ## 核心架构 ### 类层次 ``` NeuronBase 神经元基础数据: bias, weight[] └── Neuron 完整神经元: output, grad, loss, state, 前向/反向/更新 NeuronLayerBase 神经层基础数据: neuronList, weight[][], bias[] └── NeuronLayer 完整神经层: 前向/反向传播, 批量计算, 层连接 DyFunction 前向/反向计算策略基类 ├── SigmoidFunction Sigmoid 单样本策略 ├── SigmoidFunctionBatch Sigmoid 批量策略 └── TanhFunction Tanh 单样本策略 NeuronNet 网络容器: 层组织, 训练, 预测, 模型序列化 Tensor 输入数据封装: float[] 或 float[][] BackwardRsp 反向传播返回值: bias, weight[] TrainData 训练数据: input[] + label[] ``` ### 模块说明 | 模块 | 说明 | | --- | --- | | `NeuronNet` | 网络容器,负责组织层、初始化权重、训练(单样本/批量)、预测、模型加载与导出、网络结构描述。 | | `NeuronLayer` | 神经层,保存本层权重矩阵 `weight[neuron][input]`、偏置数组、输出、梯度和与前后层的双向连接。 | | `NeuronLayerBase` | 神经层基础数据类,保存神经元列表、二维权重矩阵和偏置数组。 | | `Neuron` | 单个神经元,保存自己的权重、偏置、输出、梯度、训练状态(状态机 1→2→3)和缓存的新参数。 | | `NeuronBase` | 神经元基础数据类,仅保存偏置和权重数组,用于模型序列化。 | | `Tensor` | 输入数据封装,支持单条数据 `float[]` 和批量数据 `float[][]`。 | | `DyFunction` | 泛型抽象策略基类,通过 `apply(req, req2)` 接口注入前向和反向计算逻辑。 | | `BackwardRsp` | 反向传播返回值,包含更新后的 `bias` 和 `weight[]`。 | | `SigmoidFunction` | Sigmoid 单样本前向传播与反向传播策略。 | | `SigmoidFunctionBatch` | Sigmoid 批量(Mini-batch)前向传播与反向传播策略。 | | `TanhFunction` | Tanh 单样本前向传播与反向传播策略。 | | `Sigmoid` / `Tanh` | 激活函数数学工具类,提供 `computeValue()` 和 `computeDerivativeByOutput()`。 | | `TrainData` | 训练数据 POJO,包含 `input[]` 和 `label[]`。 | | `DataUtil` | 数据加载工具,从 CSV 文本解析训练/测试数据。 | | `FileUtil` | UTF-8 文件读写工具。 | | `JsonUtil` | 基于 Jackson 的 JSON 序列化与反序列化工具,用于模型保存与加载。 | ## 设计思路 项目的核心理念是:**把神经网络训练过程拆成几个直观对象,而不是一开始就抽象成矩阵框架。** ### 1. 网络由层组成,层由神经元组成 `NeuronNet` 维护 `List`,每一层通过 `preNeuronLayer` 和 `nextNeuronLayer` 与相邻层关联。前向传播沿 `next` 方向逐层计算输出,反向传播沿 `pre` 方向逐层计算梯度。 ### 2. 每个神经元拥有自己的参数 单样本训练路径中,`Neuron` 内部保存 `weight`、`bias`、`output`、`loss`、`grad` 等状态。前向传播时计算输出,反向传播时根据标签或后一层梯度计算本神经元梯度,再生成新的权重和偏置并缓存到 `weightNew` / `biasNew`,避免影响其他神经元的梯度计算。 ### 3. 层级参数用于批量训练 批量训练路径更接近矩阵计算:`NeuronLayer` 直接保存二维权重矩阵 `weight[neuron][input]`、批量输出 `outputBatch`、批量梯度 `gradBatch`、批量新权重 `weightNewBatch`。这样可以避免逐个 `Neuron` 分散保存批量中间结果。 ### 4. 前向和反向计算以策略形式注入 `DyFunction` 被用作轻量策略接口。构造 `NeuronLayer` 时注入对应的 `forwardFunction` / `forwardFunctionLayer` 和 `backwardFunction` / `backwardFunctionLayer`,因此同一套网络结构可以换用 `Sigmoid`、`Tanh` 或批量训练版本。 ### 5. 模型参数与代码解耦 训练后的参数通过 `getWeight()` 或 `getWeightLayer()` 转成 `NeuronBase` 列表,再序列化为 JSON 保存到 `data/**/model*.txt`。预测时通过 `loadWeightByPath()` 恢复权重和偏置。加载时通过 `validateWeightList()` 校验层结构和权重维度的匹配性。 ### 6. 延迟更新机制 反向传播时,计算得到的新参数先写入 `weightNew` / `biasNew` 缓存,等全部层计算完成后,再通过 `updateWeight()` / `updateWeightBatch()` 统一提交到 `weight` / `bias`,避免梯度计算过程中参数被部分更新导致不一致。 ### 7. 神经元状态机 `Neuron` 使用 `state` 字段管理生命周期,防止重复计算: | state | 含义 | | --- | --- | | 1 | 未训练或已训练完毕(权重已更新) | | 2 | 已计算 output,未计算新权重 | | 3 | 已计算新权重和偏置(缓存中) | `NeuronLayer` 使用更细粒度的状态机: | state | 含义 | | --- | --- | | 1 | 未训练或已训练完毕 | | 2 | 已开始计算 output | | 3 | 已完成所有 output 计算 | | 4 | 已开始计算新权重 | | 5 | 已完成计算新权重 | | 6 | 已开始更新权重 | ## 训练流程 ### 单样本训练 由 `NeuronNet.trainOneData(Tensor input, float[] label)` 驱动: ``` 输入 Tensor + 标签 │ ▼ 逐层 forward,计算各神经元输出 │ ▼ 从输出层反向 backward,计算 loss 和 grad │ ▼ 逐层 updateWeight,将缓存的新权重和偏置提交 ``` ### 批量训练 由 `NeuronNet.train(float[][] inputDataBatch, float[][] destValueBatch)` 驱动: ``` 批量输入 inputDataBatch + 批量标签 destValueBatch │ ▼ 逐层 forwardBatch,计算 outputBatch │ ▼ 逐层 backwardBatch,累积 batch 梯度 │ ▼ updateWeightBatch,按平均梯度更新参数 ``` ## 数据格式 训练数据和测试数据是简单 CSV 文本,每行一条样本: ```text input_1,input_2,...,label_1,label_2,... ``` 例如 XOR 数据: ```text 0,0,0 0,1,1 1,0,1 1,1,0 ``` 例如二分类数据(x > 0 为正类): ```text -73.83198,-61.928238,0.00 53.444916,75.09889,1.00 ``` 读取时需要调用 `DataUtil.loadData(path, inputSize, labelSize)` 并显式指定输入维度和标签维度。 ## 模型格式 训练好的模型以 JSON 格式保存,每层一个数组,每个神经元包含 `bias` 和 `weight[]`: ```json [ [ {"bias":-2.8127565, "weight":[-4.776751, 4.6256166]}, {"bias":-2.0995603, "weight":[3.6940491, -3.9325764]} ], [ {"bias":-2.6467807, "weight":[5.432912, 5.3824787]} ] ] ``` ## 示例说明 | 示例 | 测试文件 | 网络结构 | 激活函数 | 训练方式 | 说明 | | --- | --- | --- | --- | --- | --- | | XOR | `XorTest` | 2-2-1 | Sigmoid | 单样本 SGD | 两层网络学习异或问题 | | 简单二分类 | `ClassifyTest` | 2-1 | Sigmoid | 单样本 SGD | 单层感知机学习二维点分类 | | MLP 二分类 | `MlpClassifyTest` | 2-2-1 | Sigmoid | 单样本 SGD | 两层 MLP 进行二维分类 | | 批量 MLP 二分类 | `MlpClassifyBatchTest` | 2-3-2-1 | Sigmoid | Mini-batch | 使用批量策略做 mini-batch 训练 | | Iris | `IrisTest` | 4-2-1 | Tanh | 单样本 SGD | 处理 Iris 鸢尾花数据 | | 网络构建 | `NeuronNetTest` | 2-3-2-1 | - | - | 仅验证网络结构构建 | ## 网络结构可视化 `NeuronNet.toString()` 可以打印当前网络的完整结构信息,输出示例(2-2-1 XOR 网络): ```text NeuronNet { learningRate=0.5, batchSize=0, debug=false topology=2-1 totalParams=3, layers=1 Layer[0] (output) neurons=1, inputs=2, activation=SigmoidFunction, mode=single, params=3 Structure: Output(1) } ``` 多层网络输出示例(2-3-1 网络): ```text NeuronNet { learningRate=0.01, batchSize=0, debug=false topology=2-3-1 totalParams=13, layers=2 Layer[0] (input) neurons=2, inputs=2, activation=SigmoidFunction, mode=single, params=6 Layer[1] (output) neurons=3, inputs=2, activation=SigmoidFunction, mode=single, params=9 Structure: Input(2) --> Output(3) } ``` 输出字段说明: | 字段 | 说明 | | --- | --- | | `learningRate` | 学习率 | | `batchSize` | 批量训练的批次大小 | | `topology` | 网络拓扑结构,如 `2-3-2-1` | | `totalParams` | 总参数量(权重 + 偏置) | | `Layer[i]` | 各层角色、神经元数、输入维度、激活函数、训练模式(single/batch)、参数量 | ## 开发者指南 想要深入理解源码、添加新激活函数(如 ReLU)或扩展功能?请阅读 [开发者指南](docs/DEVELOPER_GUIDE.md),涵盖 700+ 行的完整源码导读、训练流程图和扩展示例。 ## 贡献 欢迎提交 Issue 和 Pull Request! - **Issue**:报告 Bug、提出功能建议或问题讨论 - **PR**:修复 Bug 或实现新功能,请保持代码风格一致(4 空格缩进,UTF-8 编码) - 提交 PR 前请确保 `mvn compile` 通过,示例测试(`XorTest`、`MlpClassifyBatchTest`)能正常运行 ## License 本项目基于 [MIT License](LICENSE) 开源。 ---

jnn-learning — 因为理解原理,比会调 API 更酷。🚀