From 8c8aac01c6cf09431f2af56c1e2b318ee59399f2 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Thu, 5 Feb 2026 10:54:27 +0000 Subject: [PATCH 01/29] Add pynative pp. --- docs/run_pynative_pp_guide.md | 230 ++++++++++++++++++ ds_pynative_pp.yaml | 261 +++++++++++++++++++++ run_pynative_pp.py | 423 ++++++++++++++++++++++++++++++++++ 3 files changed, 914 insertions(+) create mode 100755 docs/run_pynative_pp_guide.md create mode 100755 ds_pynative_pp.yaml create mode 100755 run_pynative_pp.py diff --git a/docs/run_pynative_pp_guide.md b/docs/run_pynative_pp_guide.md new file mode 100755 index 000000000..2fee229b8 --- /dev/null +++ b/docs/run_pynative_pp_guide.md @@ -0,0 +1,230 @@ +# Pipeline Parallel (PP) 训练指南 + +本文档介绍如何使用 `run_pynative_pp.py` 脚本进行 Pipeline Parallel 训练。 + +## 目录 + +- [环境依赖](#环境依赖) +- [配置说明](#配置说明) +- [运行方式](#运行方式) +- [参数说明](#参数说明) +- [示例配置](#示例配置) + +## 环境依赖 + +### 必需依赖 + +```bash +# MindSpore +pip install mindspore + +# hyper_parallel (流水线并行调度库) +# 确保 hyper_parallel 库已安装,通常位于 ../hyper-parallel 目录 +export PYTHONPATH=$PYTHONPATH:/data/hyper-parallel + +# MindFormers +# 当前仓库 +``` + +### 硬件要求 + +- Ascend NPU 设备 +- 设备数量需要能被 `pipeline_parallel` 整除 +- 推荐配置:8卡或16卡 + +## 配置说明 + +### YAML 配置文件 + +在 YAML 配置文件(如 `ds_pynative.yaml`)中添加以下配置: + +```yaml +# 并行配置 +parallelism: + tensor_parallel: 1 # 张量并行度 + pipeline_parallel: 4 # 流水线并行 stage 数量 + context_parallel: 1 # 上下文并行度 + data_parallel: 2 # 数据并行度(自动计算:总卡数 / PP / TP / CP) + virtual_pipeline_parallel: 2 # VPP 虚拟 stage 数量(可选,仅 VPP 模式使用) + +# 并行配置(旧版兼容) +parallel_config: + pipeline_stage: 4 # 与 parallelism.pipeline_parallel 对应 + micro_batch_num: 4 # micro batch 数量,影响流水线效率 + +# 训练配置 +training: + max_steps: 2000 # 最大训练步数 + global_batch_size: 8 # 全局 batch size + local_batch_size: 1 # 每个设备的 batch size +``` + +### 关键参数说明 + +| 参数 | 说明 | 推荐值 | +|------|------|--------| +| `pipeline_parallel` | PP stage 数量 | 2, 4, 8 | +| `micro_batch_num` | micro batch 数量 | >= pipeline_parallel | +| `virtual_pipeline_parallel` | VPP 虚拟 stage 数 | 2 | + +### 约束条件 + +1. **设备数量约束**:`总设备数 = PP × DP × TP × CP` +2. **层数约束**:`num_hidden_layers` 需要能被 `pipeline_parallel` 整除 +3. **micro batch 约束**:`micro_batch_num >= pipeline_parallel` 以保证流水线效率 + +## 运行方式 + +### 单机多卡运行 + +```bash +# 标准 PP 模式(8卡,4个 stage) +msrun --worker_num=8 --local_worker_num=8 \ + python run_pynative_pp.py --config ds_pynative.yaml --mode pp + +# VPP 模式(8卡,4个物理 stage,每个 rank 2个虚拟 stage) +msrun --worker_num=8 --local_worker_num=8 \ + python run_pynative_pp.py --config ds_pynative.yaml --mode vpp +``` + +### 多机多卡运行 + +```bash +# 主节点 +msrun --worker_num=16 --local_worker_num=8 \ + --master_addr=192.168.1.1 --master_port=8118 \ + --node_rank=0 \ + python run_pynative_pp.py --config ds_pynative.yaml --mode pp + +# 从节点 +msrun --worker_num=16 --local_worker_num=8 \ + --master_addr=192.168.1.1 --master_port=8118 \ + --node_rank=1 \ + python run_pynative_pp.py --config ds_pynative.yaml --mode pp +``` + +### 使用 mpirun 运行 + +```bash +# 8卡运行 +mpirun -n 8 python run_pynative_pp.py --config ds_pynative.yaml --mode pp +``` + +## 参数说明 + +### 命令行参数 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `--config` | str | `ds_pynative.yaml` | 配置文件路径 | +| `--mode` | str | `pp` | 训练模式:`pp` 或 `vpp` | + +### 训练模式对比 + +| 模式 | 说明 | 适用场景 | +|------|------|----------| +| `pp` | 标准流水线并行 | 通用场景,简单易用 | +| `vpp` | 虚拟流水线并行 | 减少流水线气泡,提高效率 | + +## 示例配置 + +### 8卡 PP 配置示例 + +```yaml +# ds_pynative_pp8.yaml +parallelism: + tensor_parallel: 1 + pipeline_parallel: 4 + data_parallel: 2 + +parallel_config: + micro_batch_num: 8 + +model: + model_config: + num_hidden_layers: 8 # 需要能被 pipeline_parallel 整除 + hidden_size: 512 + seq_length: 4096 + +training: + max_steps: 1000 + global_batch_size: 8 + local_batch_size: 1 +``` + +### 16卡 VPP 配置示例 + +```yaml +# ds_pynative_vpp16.yaml +parallelism: + tensor_parallel: 1 + pipeline_parallel: 4 + data_parallel: 4 + virtual_pipeline_parallel: 2 + +parallel_config: + micro_batch_num: 16 + +model: + model_config: + num_hidden_layers: 16 # 需要能被 (PP × VPP) 整除 + hidden_size: 1024 + seq_length: 4096 + +training: + max_steps: 2000 + global_batch_size: 16 + local_batch_size: 1 +``` + +## 常见问题 + +### Q1: 报错 "Unable to import hyper_parallel" + +**解决方案**:确保 hyper_parallel 库在 PYTHONPATH 中 + +```bash +export PYTHONPATH=$PYTHONPATH:/data/hyper-parallel +``` + +### Q2: 层数无法整除 stage 数量 + +**解决方案**:调整 `num_hidden_layers` 或 `pipeline_parallel` 使其能整除 + +### Q3: 流水线气泡过大 + +**解决方案**: +1. 增加 `micro_batch_num` +2. 使用 VPP 模式 +3. 减少 `pipeline_parallel` 数量 + +### Q4: 显存不足 + +**解决方案**: +1. 减小 `local_batch_size` +2. 增加 `pipeline_parallel` 数量 +3. 启用梯度检查点(recompute) + +## 架构说明 + +### PP 模式架构 + +``` +Stage 0 (Rank 0-1) Stage 1 (Rank 2-3) Stage 2 (Rank 4-5) Stage 3 (Rank 6-7) +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ Embedding │ │ Layer 2-3 │ │ Layer 4-5 │ │ Layer 6-7 │ +│ Layer 0-1 │──▶│ │──▶│ │──▶│ Output Layer │ +│ │ │ │ │ │ │ Loss │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ └─────────────────┘ +``` + +### VPP 模式架构 + +``` +Rank 0-1 holds: Stage 0, Stage 4 +Rank 2-3 holds: Stage 1, Stage 5 +Rank 4-5 holds: Stage 2, Stage 6 +Rank 6-7 holds: Stage 3, Stage 7 + +Interleaved 1F1B Schedule reduces pipeline bubbles +``` diff --git a/ds_pynative_pp.yaml b/ds_pynative_pp.yaml new file mode 100755 index 000000000..3d19f2e96 --- /dev/null +++ b/ds_pynative_pp.yaml @@ -0,0 +1,261 @@ +# Pipeline Parallel Training Configuration +# 用于 run_pynative_pp.py 脚本的专用配置文件 + +# ============================================================================ +# 学习率调度器配置 +# ============================================================================ +lr_scheduler: + type: ConstantWarmUpLR + learning_rate: 1.e-5 + warmup_steps: 100 + total_steps: -1 + +# ============================================================================ +# 检查点配置 +# ============================================================================ +checkpoint: + save_path: "../ds3_pp_checkpoint" + save_max: 3 + save_interleaved_steps: 500 + no_save_optim: False + async_save: False + prefix: "deepseekv3_pp" + remove_redundancy: False + load_balanced: False + no_load_optim: True + load_worker_number: 1 + +# ============================================================================ +# 并行配置 (PP 专用) +# ============================================================================ +parallelism: + tensor_parallel: 1 + pipeline_parallel: 4 # PP stage 数量,需要能整除 num_hidden_layers + context_parallel: 1 + data_parallel: 2 # 每个 stage 内的数据并行度 + virtual_pipeline_parallel: 2 # VPP 虚拟 stage 数量(仅 VPP 模式使用) + # HSDP 配置 + hsdp: + shard_size: 2 + shard_dim: 0 + level: "level1" + enable_hierarchical_allgather: True + +# ============================================================================ +# 训练配置 +# ============================================================================ +training: + max_steps: 2000 + global_batch_size: 8 # 全局 batch size + local_batch_size: 1 # 每个设备的 batch size + save_steps: 500 + +# ============================================================================ +# 基础配置 +# ============================================================================ +seed: 1234 +output_dir: './output_pp' +load_checkpoint: '' +load_ckpt_format: 'safetensors' +auto_trans_ckpt: True +only_save_strategy: False +resume_training: False +use_parallel: True +run_mode: 'train' +print_separate_loss: True + +# ============================================================================ +# Trainer 配置 +# ============================================================================ +trainer: + type: CausalLanguageModelingTrainer + model_name: 'deepseekV3' + +# ============================================================================ +# Runner 配置 +# ============================================================================ +runner_config: + epochs: 1 + batch_size: 1 + sink_mode: True + sink_size: 1 + +# ============================================================================ +# 优化器配置 +# ============================================================================ +optimizer: + type: AdamW + betas: [0.9, 0.95] + eps: 1.e-8 + weight_decay: 0.0 + +# ============================================================================ +# 学习率调度配置 +# ============================================================================ +lr_schedule: + type: ConstantWarmUpLR + learning_rate: 1.e-5 + warmup_steps: 100 + total_steps: -1 + +# ============================================================================ +# 数据集配置 +# ============================================================================ +train_dataset: &train_dataset + data_loader: + type: BlendedMegatronDatasetDataLoader + datasets_type: "GPTDataset" + sizes: + - 100 + - 0 + - 0 + config: + seed: 42 + split: "1, 0, 0" + seq_length: 4096 + eod_mask_loss: False + reset_position_ids: False + create_attention_mask: False + reset_attention_mask: False + create_compressed_eod_mask: False + eod_pad_length: 128 + eod: 1 + pad: -1 + data_path: + - '1' + - "/path/to/your/dataset_text_document" # 请修改为实际数据集路径 + input_columns: ["input_ids", "labels", "loss_mask", "position_ids"] + construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids"] + num_parallel_workers: 8 + python_multiprocessing: False + drop_remainder: True + numa_enable: False + prefetch_size: 1 + seed: 42 + +train_dataset_task: + type: CausalLanguageModelDataset + dataset_config: *train_dataset + +# ============================================================================ +# MindSpore Context 配置 +# ============================================================================ +context: + mode: 1 # 0--Graph Mode; 1--Pynative Mode (PP需要Pynative模式) + device_target: "Ascend" + max_call_depth: 10000 + max_device_memory: "58GB" + save_graphs: False + save_graphs_path: "./graph" + jit_config: + jit_level: "O0" + +# ============================================================================ +# 并行上下文配置 (旧版兼容) +# ============================================================================ +parallel_config: + data_parallel: 2 + model_parallel: 1 + pipeline_stage: 4 # PP stage 数量 + expert_parallel: 1 + micro_batch_num: 8 # micro batch 数量,建议 >= pipeline_stage * 2 + use_seq_parallel: False + gradient_aggregation_group: 4 + +micro_batch_interleave_num: 1 + +# ============================================================================ +# 并行模式配置 +# ============================================================================ +parallel: + parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel + gradients_mean: False + enable_alltoall: True + search_mode: "sharding_propagation" + enable_parallel_optimizer: False + strategy_ckpt_config: + save_file: "./ckpt_strategy_pp.ckpt" + only_trainable_params: False + parallel_optimizer_config: + gradient_accumulation_shard: False + parallel_optimizer_threshold: 8 + +# ============================================================================ +# 重计算配置 +# ============================================================================ +recompute_config: + recompute: True + select_recompute: False + mp_comm_recompute: False + +# ============================================================================ +# 模型配置 +# ============================================================================ +use_legacy: False +model: + model_config: + topk_group: 0 + n_group: 0 + model_type: deepseek_v3 + architectures: DeepseekV3ForCausalLM + offset: 0 + vocab_size: 129280 + seq_length: 4096 + hidden_size: 512 + intermediate_size: 3072 + num_hidden_layers: 8 # 需要能被 pipeline_parallel 整除 + max_position_embeddings: 163840 + hidden_act: 'silu' + num_attention_heads: 4 + rms_norm_eps: 1.e-6 + add_bias_linear: False + use_flash_attention: True + # MLA 配置 + multi_latent_attention: True + mla_qkv_concat: True + kv_lora_rank: 512 + q_lora_rank: 1536 + qk_rope_head_dim: 64 + v_head_dim: 192 + qk_nope_head_dim: 128 + qk_layernorm: True + attention_dropout: 0.0 + hidden_dropout: 0.0 + # 数据类型配置 + params_dtype: "float32" + compute_dtype: "bfloat16" + layernorm_compute_dtype: "float32" + softmax_compute_dtype: "float32" + rotary_dtype: "float32" + initializer_range: 0.01 + # MTP 配置 + num_nextn_predict_layers: 0 + mtp_loss_scaling_factor: 0.3 + # 位置编码配置 + position_embedding_type: "yarn" + scaling_factor: 40 + beta_fast: 32 + beta_slow: 1 + mscale: 1 + mscale_all_dim: 1 + rope_theta: 10000 + # MoE 配置 + hash_routed_layer: 0 + hash_size: 3 + router_dense_type: "float32" + gated_linear_unit: True + moe_intermediate_size: 2048 + routed_scaling_factor: 1.5 + first_k_dense_replace: 1 + n_routed_experts: 16 + num_experts_per_tok: 8 + n_shared_experts: 1 + num_copy_experts: 0 + moe_shared_expert_intermediate_size: 2048 + moe_grouped_gemm: True + moe_router_load_balancing_type: 'seq_aux_loss' + moe_aux_loss_coeff: 0.001 + scoring_func: 'sigmoid' + norm_topk_prob: True + moe_token_drop_policy: probs + moe_router_enable_expert_bias: True diff --git a/run_pynative_pp.py b/run_pynative_pp.py new file mode 100755 index 000000000..645ae7136 --- /dev/null +++ b/run_pynative_pp.py @@ -0,0 +1,423 @@ +# Copyright 2026 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Pipeline Parallel Training Script for MindFormers PyNative Mode.""" + +import numpy as np +import mindspore as ms +from mindspore import nn, Tensor, mint +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.common import set_seed + +from hyper_parallel import PipelineStage, ScheduleInterleaved1F1B +from hyper_parallel import DTensor, init_device_mesh +from hyper_parallel.core.hsdp import hsdp +from hyper_parallel import shard_module +from hyper_parallel.core.placement_types import Shard, Replicate +from hyper_parallel.core.sharding_plan import ShardingPlan + +from mindformers.tools.register import MindFormerConfig +from mindformers.models.build_model import build_network +from mindformers.tools.logger import logger + + +class PPStageModel(nn.Cell): + """ + Pipeline Parallel Stage Model wrapper. + + This class wraps a portion of the full model for a specific pipeline stage. + Each stage contains a subset of transformer layers. + + Args: + full_model: The full model instance + stage_id: Current stage ID (0-indexed) + num_stages: Total number of pipeline stages + num_layers: Total number of transformer layers + is_first_stage: Whether this is the first stage (has embedding) + is_last_stage: Whether this is the last stage (has output layer) + """ + + def __init__( + self, + full_model, + stage_id: int, + num_stages: int, + num_layers: int, + is_first_stage: bool = False, + is_last_stage: bool = False + ): + super().__init__() + self.stage_id = stage_id + self.num_stages = num_stages + self.is_first_stage = is_first_stage + self.is_last_stage = is_last_stage + + # Calculate layer range for this stage + layers_per_stage = num_layers // num_stages + self.start_layer = stage_id * layers_per_stage + self.end_layer = (stage_id + 1) * layers_per_stage + if stage_id == num_stages - 1: + # Last stage gets remaining layers + self.end_layer = num_layers + + # Extract components from full model + self.gpt_model = full_model.model + + # Keep embedding only for first stage + if self.is_first_stage: + self.embedding = self.gpt_model.embedding + self.casual_mask = self.gpt_model.casual_mask + + # Keep rotary embedding (needed for all stages) + if hasattr(self.gpt_model, 'rotary_pos_emb'): + self.rotary_pos_emb = self.gpt_model.rotary_pos_emb + + # Extract layers for this stage + self.layers = nn.CellList() + for i in range(self.start_layer, self.end_layer): + self.layers.append(self.gpt_model.decoder.layers[i]) + + # Keep final layernorm and output layer only for last stage + if self.is_last_stage: + self.final_layernorm = self.gpt_model.decoder.final_layernorm + self.output_layer = self.gpt_model.output_layer + self.loss = self.gpt_model.loss + + logger.info( + f"Stage {stage_id}: layers [{self.start_layer}, {self.end_layer}), " + f"first={is_first_stage}, last={is_last_stage}" + ) + + def construct(self, hidden_states, attention_mask=None, rotary_pos_emb=None, + labels=None, loss_mask=None, input_ids=None): + """ + Forward pass for this pipeline stage. + + Args: + hidden_states: Input tensor (input_ids for first stage, hidden states otherwise) + attention_mask: Attention mask tensor + rotary_pos_emb: Rotary position embedding + labels: Labels for loss computation (only used in last stage) + loss_mask: Loss mask tensor (only used in last stage) + input_ids: Original input ids (for hash router) + + Returns: + hidden_states or loss depending on stage + """ + aux_loss = None + + # First stage: process embedding + if self.is_first_stage: + input_ids = hidden_states + bs, seq_len = input_ids.shape + + # Generate attention mask + if attention_mask is None: + attention_mask = self.casual_mask(input_ids) + + # Embedding + hidden_states = self.embedding(input_ids, position_ids=None) + + # Generate rotary position embedding + if hasattr(self, 'rotary_pos_emb'): + rotary_pos_emb = self.rotary_pos_emb(seq_len, position_ids=None) + + # Process transformer layers for this stage + for layer in self.layers: + hidden_states, _, layer_aux_loss = layer( + hidden_states, + attention_mask, + rotary_pos_emb=rotary_pos_emb, + ) + if layer_aux_loss is not None: + aux_loss = layer_aux_loss if aux_loss is None else aux_loss + layer_aux_loss + + # Last stage: process output and loss + if self.is_last_stage: + hidden_states = self.final_layernorm(hidden_states) + logits, _ = self.output_layer(hidden_states, weight=None) + + if logits.ndim > 2: + logits = mint.permute(logits, (1, 0, 2)) + logits = mint.reshape(logits, (-1, logits.shape[-1])) + logits = logits.astype(ms.float32) + + if labels is not None: + loss = self.loss(logits, labels, loss_mask) + if aux_loss is not None: + loss = loss + aux_loss + return loss + return logits + + return hidden_states + + +def split_model_for_pp(model, num_stages, rank_id, device_num): + """ + Split model into pipeline stages. + + Args: + model: Full model instance + num_stages: Number of pipeline stages + rank_id: Current rank ID + device_num: Total number of devices + + Returns: + PPStageModel for the current stage + """ + device_num_per_stage = device_num // num_stages + stage_id = rank_id // device_num_per_stage + + num_layers = model.model.config.num_layers + + is_first_stage = (stage_id == 0) + is_last_stage = (stage_id == num_stages - 1) + + stage_model = PPStageModel( + full_model=model, + stage_id=stage_id, + num_stages=num_stages, + num_layers=num_layers, + is_first_stage=is_first_stage, + is_last_stage=is_last_stage + ) + + return stage_model, stage_id + + +def run_pp_training(config_path: str = 'ds_pynative.yaml'): + """ + Run pipeline parallel training. + + Args: + config_path: Path to the configuration yaml file + """ + # Initialize + ms.set_seed(0) + init("hccl") + + rank_id = get_rank() + device_num = get_group_size() + + # Load config + mf_config = MindFormerConfig(config_path) + logger.info(f"Loaded config: {mf_config}") + + # PP config from yaml + parallelism = getattr(mf_config, 'parallelism', None) + if parallelism is None: + num_stages = 4 # default + micro_batch_num = 4 + else: + num_stages = getattr(parallelism, 'pipeline_parallel', 4) + micro_batch_num = getattr(mf_config.parallel_config, 'micro_batch_num', 4) + + device_num_per_stage = device_num // num_stages + stage_id = rank_id // device_num_per_stage + + logger.info( + f"PP Config: num_stages={num_stages}, micro_batch_num={micro_batch_num}, " + f"rank_id={rank_id}, stage_id={stage_id}" + ) + + # Build full model + model = build_network(mf_config.model) + model.set_train(True) + logger.info(f"Built model: {type(model)}") + + # Split model for pipeline parallel + stage_model, stage_id = split_model_for_pp(model, num_stages, rank_id, device_num) + + # Setup device mesh for data parallel within stage + dp = device_num_per_stage + mp = 1 + mesh = init_device_mesh(mesh_shape=(dp, mp), alias_name=("dp", "mp")) + + # Define placements + in_placements = (Shard(0), Replicate()) + w_placements = (Replicate(), Replicate()) + out_placements = (Shard(0), Replicate()) + + # Apply sharding plan + model_stra = ShardingPlan( + plan={"weight": w_placements}, + input_plan={"input": in_placements}, + output_plan={"output": out_placements}, + ) + shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) + + # Apply HSDP + stage_model = hsdp(stage_model, dp, 0, "level1", True) + + # Create pipeline stage + pipeline_stage = PipelineStage(stage_model, stage_id, num_stages) + + # Create schedule + schedule = ScheduleInterleaved1F1B([pipeline_stage], micro_batch_num) + + # Prepare input data + seq_length = mf_config.model.model_config.seq_length + local_batch_size = getattr(mf_config.training, 'local_batch_size', 1) + + x_placements = (Shard(0), Shard(1)) + x = DTensor.from_local( + Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32), + mesh, x_placements + ) + + # Training loop + max_steps = getattr(mf_config.training, 'max_steps', 100) + + for step in range(max_steps): + stage_model.zero_grads() + + if stage_id == 0: + loss = schedule.run(x) + else: + loss = schedule.run() + + if stage_id == num_stages - 1: + logger.info(f"Step {step}, Loss: {loss}") + + logger.info("Training completed!") + + +def run_vpp_training(config_path: str = 'ds_pynative.yaml'): + """ + Run Virtual Pipeline Parallel (VPP) training with interleaved stages. + + This implements VPP where each rank holds multiple non-consecutive stages, + similar to the reference vpp_schedule.py. + + Args: + config_path: Path to the configuration yaml file + """ + # Initialize + ms.set_seed(0) + init("hccl") + + rank_id = get_rank() + device_num = get_group_size() + + # Load config + mf_config = MindFormerConfig(config_path) + logger.info(f"Loaded config: {mf_config}") + + # VPP config + parallelism = getattr(mf_config, 'parallelism', None) + if parallelism is None: + num_stages = 4 + num_virtual_stages = 2 # Each rank holds 2 virtual stages + micro_batch_num = 4 + else: + num_stages = getattr(parallelism, 'pipeline_parallel', 4) + num_virtual_stages = getattr(parallelism, 'virtual_pipeline_parallel', 2) + micro_batch_num = getattr(mf_config.parallel_config, 'micro_batch_num', 4) + + total_stages = num_stages * num_virtual_stages + device_num_per_stage = device_num // num_stages + stage_id = rank_id // device_num_per_stage + + logger.info( + f"VPP Config: num_stages={num_stages}, virtual_stages={num_virtual_stages}, " + f"total_stages={total_stages}, micro_batch_num={micro_batch_num}" + ) + + # Build full model + model = build_network(mf_config.model) + model.set_train(True) + + num_layers = model.model.config.num_layers + + # Create multiple stage models for VPP + stage_models = [] + pipeline_stages = [] + + for v in range(num_virtual_stages): + virtual_stage_id = stage_id + v * num_stages + + is_first = (virtual_stage_id == 0) + is_last = (virtual_stage_id == total_stages - 1) + + stage_model = PPStageModel( + full_model=model, + stage_id=virtual_stage_id, + num_stages=total_stages, + num_layers=num_layers, + is_first_stage=is_first, + is_last_stage=is_last + ) + + # Setup device mesh + dp = device_num_per_stage + mp = 1 + mesh = init_device_mesh(mesh_shape=(dp, mp), alias_name=("dp", "mp")) + + # Apply HSDP + stage_model = hsdp(stage_model, dp, 0, "level1", True) + + stage_models.append(stage_model) + pipeline_stages.append(PipelineStage(stage_model, virtual_stage_id, total_stages)) + + # Create interleaved schedule + schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) + + # Prepare input + seq_length = mf_config.model.model_config.seq_length + local_batch_size = getattr(mf_config.training, 'local_batch_size', 1) + + dp = device_num_per_stage + mesh = init_device_mesh(mesh_shape=(dp, 1), alias_name=("dp", "mp")) + x_placements = (Shard(0), Shard(1)) + x = DTensor.from_local( + Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32), + mesh, x_placements + ) + + # Training loop + max_steps = getattr(mf_config.training, 'max_steps', 100) + + for step in range(max_steps): + for stage_model in stage_models: + stage_model.zero_grads() + + if stage_id == 0: + loss = schedule.run(x) + else: + loss = schedule.run() + + if stage_id == num_stages - 1: + logger.info(f"Step {step}, Loss: {loss}") + + logger.info("VPP Training completed!") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Pipeline Parallel Training") + parser.add_argument( + "--config", type=str, default="ds_pynative.yaml", + help="Path to config yaml file" + ) + parser.add_argument( + "--mode", type=str, default="pp", choices=["pp", "vpp"], + help="Training mode: pp (pipeline parallel) or vpp (virtual pipeline parallel)" + ) + args = parser.parse_args() + + if args.mode == "pp": + run_pp_training(args.config) + else: + run_vpp_training(args.config) -- Gitee From 095bf4f94594d4de41ad2217549f5bd41466b0f7 Mon Sep 17 00:00:00 2001 From: wangchengzhao Date: Thu, 5 Feb 2026 21:12:11 +0800 Subject: [PATCH 02/29] Update config. --- ds_pynative.yaml | 20 +++-- ds_pynative_pp.yaml | 210 +++++++++++++++++++++++++------------------- run_pynative_pp.py | 8 +- 3 files changed, 135 insertions(+), 103 deletions(-) diff --git a/ds_pynative.yaml b/ds_pynative.yaml index 15461272e..20b44a8f1 100644 --- a/ds_pynative.yaml +++ b/ds_pynative.yaml @@ -6,7 +6,7 @@ lr_scheduler: total_steps: -1 # -1 means it will load the total steps of the dataset checkpoint: - save_path: "../ds3_1dense1moe1mtp_concat_pynative" + save_path: "checkpoints/ds3_1dense1moe1mtp_concat_pynative" save_max: 1 save_interleaved_steps: 100000 # 500 no_save_optim: False @@ -26,7 +26,7 @@ parallelism: data_parallel: 1 training: - max_steps: 2000 + max_steps: 20000 global_batch_size: 1 local_batch_size: 1 save_steps: 2000 @@ -35,7 +35,7 @@ training: seed: 1234 -output_dir: './output' # path to save checkpoint and strategy +output_dir: './output_pynative' # path to save checkpoint and strategy load_checkpoint: '' load_ckpt_format: 'safetensors' # format of checkpoint files to load src_strategy_path_or_dir: '' @@ -91,7 +91,7 @@ train_dataset: &train_dataset type: BlendedMegatronDatasetDataLoader datasets_type: "GPTDataset" sizes: - - 100 # Number of training set data samples. + - 20000 # Number of training set data samples. - 0 # Number of test set data samples. Currently, configuration is not supported. - 0 # Number of eval set data samples. Currently, configuration is not supported. config: # GPTDataset Configs @@ -108,7 +108,7 @@ train_dataset: &train_dataset pad: -1 # The token id of `pad` in the dataset. data_path: # Megatron dataset sampling ratio and path. - '1' - - "/home/w00932055/dsv4/deepseek-datasets/mmap_deepseekv3_datasets_text_document" + - "/home/w00932055/deepseek-datasets/fineweb_edu_10BT_text_document" input_columns: ["input_ids", "labels", "loss_mask", "position_ids"] construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids"] num_parallel_workers: 8 @@ -117,13 +117,14 @@ train_dataset: &train_dataset numa_enable: False prefetch_size: 1 seed: 42 + train_dataset_task: type: CausalLanguageModelDataset dataset_config: *train_dataset # mindspore context init config context: - mode: 0 # 0--Graph Mode; 1--Pynative Mode + mode: 1 # 0--Graph Mode; 1--Pynative Mode device_target: "Ascend" max_call_depth: 10000 max_device_memory: "58GB" @@ -177,13 +178,14 @@ model: seq_length: 4096 hidden_size: 512 intermediate_size: 3072 - num_hidden_layers: 8 + num_hidden_layers: 12 max_position_embeddings: 163840 hidden_act: 'silu' # 'fusedswiglu' - num_attention_heads: 4 + num_attention_heads: 12 rms_norm_eps: 1.e-6 add_bias_linear: False use_flash_attention: True + # MLA 配置 multi_latent_attention: True mla_qkv_concat: True kv_lora_rank: 512 @@ -198,7 +200,7 @@ model: # # DSA (DeepSeek Sparse Attention) Configuration # experimental_attention_variant: 'dsa' # dsa_indexer_n_heads: 4 # Same as num_attention_heads - # dsa_indexer_head_dim: 192 # qk_rope_head_dim + qk_nope_head_dim + # dsa_indexer_head_dim: 80 # should greater than qk_rope_head_dim # dsa_indexer_topk: 256 # dsa_indexer_loss_coeff: 0.001 # dsa_indexer_use_sparse_loss: False diff --git a/ds_pynative_pp.yaml b/ds_pynative_pp.yaml index 3d19f2e96..8a72fe7b5 100755 --- a/ds_pynative_pp.yaml +++ b/ds_pynative_pp.yaml @@ -1,22 +1,18 @@ -# Pipeline Parallel Training Configuration -# 用于 run_pynative_pp.py 脚本的专用配置文件 +# add for pynative +# 可复用变量定义 +vars: + pp_stage: &pp_stage 2 -# ============================================================================ -# 学习率调度器配置 -# ============================================================================ lr_scheduler: type: ConstantWarmUpLR learning_rate: 1.e-5 - warmup_steps: 100 - total_steps: -1 + warmup_steps: 0 + total_steps: -1 # -1 means it will load the total steps of the dataset -# ============================================================================ -# 检查点配置 -# ============================================================================ checkpoint: - save_path: "../ds3_pp_checkpoint" - save_max: 3 - save_interleaved_steps: 500 + save_path: "checkpoints/ds3_pp_checkpoint" + save_max: 1 + save_interleaved_steps: 100000 # 500 no_save_optim: False async_save: False prefix: "deepseekv3_pp" @@ -25,14 +21,11 @@ checkpoint: no_load_optim: True load_worker_number: 1 -# ============================================================================ -# 并行配置 (PP 专用) -# ============================================================================ parallelism: tensor_parallel: 1 - pipeline_parallel: 4 # PP stage 数量,需要能整除 num_hidden_layers + pipeline_parallel: *pp_stage # PP stage 数量,需要能整除 num_hidden_layers context_parallel: 1 - data_parallel: 2 # 每个 stage 内的数据并行度 + data_parallel: 1 # 每个 stage 内的数据并行度 virtual_pipeline_parallel: 2 # VPP 虚拟 stage 数量(仅 VPP 模式使用) # HSDP 配置 hsdp: @@ -41,89 +34,90 @@ parallelism: level: "level1" enable_hierarchical_allgather: True -# ============================================================================ -# 训练配置 -# ============================================================================ training: - max_steps: 2000 + max_steps: 20000 global_batch_size: 8 # 全局 batch size local_batch_size: 1 # 每个设备的 batch size - save_steps: 500 + save_steps: 2000 + +# original + -# ============================================================================ -# 基础配置 -# ============================================================================ seed: 1234 -output_dir: './output_pp' +output_dir: './output_pynative_pp' # path to save checkpoint and strategy load_checkpoint: '' -load_ckpt_format: 'safetensors' -auto_trans_ckpt: True +load_ckpt_format: 'safetensors' # format of checkpoint files to load +src_strategy_path_or_dir: '' +auto_trans_ckpt: True # If True, auto transform `load_checkpoint` to load in distributed model only_save_strategy: False resume_training: False use_parallel: True run_mode: 'train' print_separate_loss: True -# ============================================================================ -# Trainer 配置 -# ============================================================================ +# trainer config trainer: type: CausalLanguageModelingTrainer model_name: 'deepseekV3' -# ============================================================================ -# Runner 配置 -# ============================================================================ +# runner config runner_config: epochs: 1 batch_size: 1 sink_mode: True sink_size: 1 -# ============================================================================ -# 优化器配置 -# ============================================================================ +# optimizer optimizer: type: AdamW betas: [0.9, 0.95] eps: 1.e-8 weight_decay: 0.0 -# ============================================================================ -# 学习率调度配置 -# ============================================================================ +# Muon optimizer configuration example (uncomment to use): +# optimizer: +# type: Muon +# learning_rate: 2.e-2 +# weight_decay: 0.1 +# matched_adamw_rms: 0.2 +# momentum: 0.95 +# nesterov: True +# ns_steps: 5 +# adamw_betas: [0.95, 0.95] +# adamw_eps: 1.e-8 +# qk_clip_threshold: 100 + +# lr schedule lr_schedule: type: ConstantWarmUpLR learning_rate: 1.e-5 - warmup_steps: 100 - total_steps: -1 + warmup_steps: 0 + total_steps: -1 # -1 means it will load the total steps of the dataset -# ============================================================================ -# 数据集配置 -# ============================================================================ +# Dataset configuration train_dataset: &train_dataset data_loader: type: BlendedMegatronDatasetDataLoader datasets_type: "GPTDataset" sizes: - - 100 - - 0 - - 0 - config: - seed: 42 - split: "1, 0, 0" - seq_length: 4096 - eod_mask_loss: False - reset_position_ids: False - create_attention_mask: False - reset_attention_mask: False - create_compressed_eod_mask: False - eod_pad_length: 128 - eod: 1 - pad: -1 - data_path: + - 20000 # Number of training set data samples. + - 0 # Number of test set data samples. Currently, configuration is not supported. + - 0 # Number of eval set data samples. Currently, configuration is not supported. + config: # GPTDataset Configs + seed: 42 # Data sampling random seed + split: "1, 0, 0" # The usage ratio of training, testing and evaluation sets. Currently, configuration is not supported. + seq_length: 4096 # The sequence length of the data set returned. + eod_mask_loss: False # Whether to calculate loss at eod. + reset_position_ids: False # Whether to reset position_ids at eod. + create_attention_mask: False # Whether to return `attention_mask`. + reset_attention_mask: False # Whether to reset the `attention_mask` at eod and return a stepped `attention_mask`. + create_compressed_eod_mask: False # Whether to return the compressed `attention_mask`. + eod_pad_length: 128 # Set the length of the compressed `attention_mask`. + eod: 1 # The token id of `eod` in the dataset. + pad: -1 # The token id of `pad` in the dataset. + data_path: # Megatron dataset sampling ratio and path. - '1' - - "/path/to/your/dataset_text_document" # 请修改为实际数据集路径 + - "/home/w00932055/deepseek-datasets/fineweb_edu_10BT_text_document" input_columns: ["input_ids", "labels", "loss_mask", "position_ids"] construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids"] num_parallel_workers: 8 @@ -137,11 +131,9 @@ train_dataset_task: type: CausalLanguageModelDataset dataset_config: *train_dataset -# ============================================================================ -# MindSpore Context 配置 -# ============================================================================ +# mindspore context init config context: - mode: 1 # 0--Graph Mode; 1--Pynative Mode (PP需要Pynative模式) + mode: 1 # 0--Graph Mode; 1--Pynative Mode device_target: "Ascend" max_call_depth: 10000 max_device_memory: "58GB" @@ -150,25 +142,21 @@ context: jit_config: jit_level: "O0" -# ============================================================================ -# 并行上下文配置 (旧版兼容) -# ============================================================================ +# parallel config parallel_config: - data_parallel: 2 + data_parallel: 1 model_parallel: 1 - pipeline_stage: 4 # PP stage 数量 + pipeline_stage: *pp_stage # PP stage 数量 expert_parallel: 1 micro_batch_num: 8 # micro batch 数量,建议 >= pipeline_stage * 2 use_seq_parallel: False gradient_aggregation_group: 4 - +# when model parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process. micro_batch_interleave_num: 1 -# ============================================================================ -# 并行模式配置 -# ============================================================================ +# parallel context config parallel: - parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel + parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel gradients_mean: False enable_alltoall: True search_mode: "sharding_propagation" @@ -180,17 +168,13 @@ parallel: gradient_accumulation_shard: False parallel_optimizer_threshold: 8 -# ============================================================================ -# 重计算配置 -# ============================================================================ +# recompute config recompute_config: recompute: True select_recompute: False mp_comm_recompute: False -# ============================================================================ -# 模型配置 -# ============================================================================ +# model config use_legacy: False model: model_config: @@ -203,10 +187,10 @@ model: seq_length: 4096 hidden_size: 512 intermediate_size: 3072 - num_hidden_layers: 8 # 需要能被 pipeline_parallel 整除 + num_hidden_layers: 12 # 需要能被 pipeline_parallel 整除 max_position_embeddings: 163840 - hidden_act: 'silu' - num_attention_heads: 4 + hidden_act: 'silu' # 'fusedswiglu' + num_attention_heads: 12 rms_norm_eps: 1.e-6 add_bias_linear: False use_flash_attention: True @@ -219,19 +203,27 @@ model: v_head_dim: 192 qk_nope_head_dim: 128 qk_layernorm: True + # QK-clip scaling for Muon optimizer (tracks max attention logit per head) + # Enable this when using Muon optimizer with qk_clip_threshold + # track_max_attention_logit: True + # # DSA (DeepSeek Sparse Attention) Configuration + # experimental_attention_variant: 'dsa' + # dsa_indexer_n_heads: 4 # Same as num_attention_heads + # dsa_indexer_head_dim: 80 # should greater than qk_rope_head_dim + # dsa_indexer_topk: 256 + # dsa_indexer_loss_coeff: 0.001 + # dsa_indexer_use_sparse_loss: False + # dsa_use_fused_ops: True # Use fused DSA operators (lightning_indexer + sparse_flash_attention) attention_dropout: 0.0 hidden_dropout: 0.0 - # 数据类型配置 params_dtype: "float32" compute_dtype: "bfloat16" layernorm_compute_dtype: "float32" softmax_compute_dtype: "float32" rotary_dtype: "float32" initializer_range: 0.01 - # MTP 配置 - num_nextn_predict_layers: 0 + num_nextn_predict_layers: 0 # 1 mtp_loss_scaling_factor: 0.3 - # 位置编码配置 position_embedding_type: "yarn" scaling_factor: 40 beta_fast: 32 @@ -239,9 +231,14 @@ model: mscale: 1 mscale_all_dim: 1 rope_theta: 10000 - # MoE 配置 + # mhc + mhc_residual: False + mhc_expansion_rate: 4 + mhc_sk_iter: 20 + mhc_gating_factor_init: 0.01 + # moe config hash_routed_layer: 0 - hash_size: 3 + hash_size: 3 # feasible when hash_router_layer > 0 router_dense_type: "float32" gated_linear_unit: True moe_intermediate_size: 2048 @@ -251,6 +248,9 @@ model: num_experts_per_tok: 8 n_shared_experts: 1 num_copy_experts: 0 + use_topk_router_with_load_balancing: False + moe_expected_ffn_experts: 2.0 # Default best value: top-k * FFN/(FFN + COPY) + moe_router_bias_update_rate: 0.001 moe_shared_expert_intermediate_size: 2048 moe_grouped_gemm: True moe_router_load_balancing_type: 'seq_aux_loss' @@ -259,3 +259,29 @@ model: norm_topk_prob: True moe_token_drop_policy: probs moe_router_enable_expert_bias: True + +# callbacks +callbacks: + - type: MFLossMonitor + # balance topk bias with callback + # - type: TopkBiasBalanceCallback + # - type: CheckpointMonitor + # prefix: "deepseekv3" + # save_checkpoint_steps: 100000 + # keep_checkpoint_max: 1 + # integrated_save: False + # async_save: False + # checkpoint_format: "safetensors" # format of checkpoint files to save + +# wrapper cell config +runner_wrapper: + type: MFTrainOneStepCell + scale_sense: 1.0 + use_clip_grad: True + +profile: False +profile_start_step: 5 +profile_stop_step: 7 +init_start_profile: False +profile_communication: False +profile_memory: True diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 645ae7136..f1dbe54dc 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -14,6 +14,10 @@ # ============================================================================ """Pipeline Parallel Training Script for MindFormers PyNative Mode.""" +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).parent.parent / 'hyper-parallel')) + import numpy as np import mindspore as ms from mindspore import nn, Tensor, mint @@ -25,7 +29,7 @@ from hyper_parallel import DTensor, init_device_mesh from hyper_parallel.core.hsdp import hsdp from hyper_parallel import shard_module from hyper_parallel.core.placement_types import Shard, Replicate -from hyper_parallel.core.sharding_plan import ShardingPlan +from hyper_parallel.core.shard.sharding_plan import ShardingPlan from mindformers.tools.register import MindFormerConfig from mindformers.models.build_model import build_network @@ -408,7 +412,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Pipeline Parallel Training") parser.add_argument( - "--config", type=str, default="ds_pynative.yaml", + "--config", type=str, default="ds_pynative_pp.yaml", help="Path to config yaml file" ) parser.add_argument( -- Gitee From d3073f7016733e9c1a5d211342bb746a276b8781 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Thu, 5 Feb 2026 13:39:00 +0000 Subject: [PATCH 03/29] Update pp stage. --- run_pynative_pp.py | 189 +++++++++++++++++++++++---------------------- 1 file changed, 98 insertions(+), 91 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index f1dbe54dc..cfeae1082 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -41,38 +41,38 @@ class PPStageModel(nn.Cell): Pipeline Parallel Stage Model wrapper. This class wraps a portion of the full model for a specific pipeline stage. - Each stage contains a subset of transformer layers. + Following the reference vpp_schedule.py pattern, each stage contains specific + transformer layers based on stage_id. Args: full_model: The full model instance - stage_id: Current stage ID (0-indexed) - num_stages: Total number of pipeline stages + stage_id: Current virtual stage ID (0-indexed) + total_num_stages: Total number of virtual pipeline stages num_layers: Total number of transformer layers - is_first_stage: Whether this is the first stage (has embedding) - is_last_stage: Whether this is the last stage (has output layer) """ def __init__( self, full_model, stage_id: int, - num_stages: int, + total_num_stages: int, num_layers: int, - is_first_stage: bool = False, - is_last_stage: bool = False ): super().__init__() self.stage_id = stage_id - self.num_stages = num_stages - self.is_first_stage = is_first_stage - self.is_last_stage = is_last_stage + self.total_num_stages = total_num_stages + self.num_layers = num_layers - # Calculate layer range for this stage - layers_per_stage = num_layers // num_stages + # Determine if this is first/last stage + self.is_first_stage = (stage_id == 0) + self.is_last_stage = (stage_id == total_num_stages - 1) + + # Calculate which layer this stage holds + # Each stage holds one layer, similar to reference code + layers_per_stage = num_layers // total_num_stages self.start_layer = stage_id * layers_per_stage self.end_layer = (stage_id + 1) * layers_per_stage - if stage_id == num_stages - 1: - # Last stage gets remaining layers + if stage_id == total_num_stages - 1: self.end_layer = num_layers # Extract components from full model @@ -99,8 +99,8 @@ class PPStageModel(nn.Cell): self.loss = self.gpt_model.loss logger.info( - f"Stage {stage_id}: layers [{self.start_layer}, {self.end_layer}), " - f"first={is_first_stage}, last={is_last_stage}" + f"Stage {stage_id}/{total_num_stages}: layers [{self.start_layer}, {self.end_layer}), " + f"first={self.is_first_stage}, last={self.is_last_stage}" ) def construct(self, hidden_states, attention_mask=None, rotary_pos_emb=None, @@ -124,7 +124,7 @@ class PPStageModel(nn.Cell): # First stage: process embedding if self.is_first_stage: input_ids = hidden_states - bs, seq_len = input_ids.shape + _, seq_len = input_ids.shape # Generate attention mask if attention_mask is None: @@ -167,42 +167,52 @@ class PPStageModel(nn.Cell): return hidden_states -def split_model_for_pp(model, num_stages, rank_id, device_num): +def create_pp_stage_models(model, num_pp_stages, num_virtual_stages, rank_id, device_num): """ - Split model into pipeline stages. + Create pipeline stage models for the current rank. + + Following the reference vpp_schedule.py pattern: + - Each rank holds `num_virtual_stages` stage models + - Total virtual stages = num_pp_stages * num_virtual_stages + - rank i holds stages: [stage_index, stage_index + num_pp_stages, ...] Args: model: Full model instance - num_stages: Number of pipeline stages + num_pp_stages: Number of physical pipeline stages (ranks for PP) + num_virtual_stages: Number of virtual stages per rank rank_id: Current rank ID device_num: Total number of devices Returns: - PPStageModel for the current stage + List of (stage_model, virtual_stage_id) tuples """ - device_num_per_stage = device_num // num_stages - stage_id = rank_id // device_num_per_stage + device_num_per_stage = device_num // num_pp_stages + stage_index = rank_id // device_num_per_stage num_layers = model.model.config.num_layers + total_num_stages = num_pp_stages * num_virtual_stages - is_first_stage = (stage_id == 0) - is_last_stage = (stage_id == num_stages - 1) + stage_models = [] + for v in range(num_virtual_stages): + virtual_stage_id = stage_index + v * num_pp_stages - stage_model = PPStageModel( - full_model=model, - stage_id=stage_id, - num_stages=num_stages, - num_layers=num_layers, - is_first_stage=is_first_stage, - is_last_stage=is_last_stage - ) + stage_model = PPStageModel( + full_model=model, + stage_id=virtual_stage_id, + total_num_stages=total_num_stages, + num_layers=num_layers, + ) + stage_models.append((stage_model, virtual_stage_id)) - return stage_model, stage_id + return stage_models, stage_index def run_pp_training(config_path: str = 'ds_pynative.yaml'): """ - Run pipeline parallel training. + Run pipeline parallel training with VPP-style interleaved stages. + + Following the reference vpp_schedule.py pattern, each rank holds multiple + non-consecutive stages for better pipeline efficiency. Args: config_path: Path to the configuration yaml file @@ -221,18 +231,22 @@ def run_pp_training(config_path: str = 'ds_pynative.yaml'): # PP config from yaml parallelism = getattr(mf_config, 'parallelism', None) if parallelism is None: - num_stages = 4 # default + num_pp_stages = 4 # Number of physical PP stages (ranks) + num_virtual_stages = 2 # Each rank holds 2 virtual stages micro_batch_num = 4 else: - num_stages = getattr(parallelism, 'pipeline_parallel', 4) + num_pp_stages = getattr(parallelism, 'pipeline_parallel', 4) + num_virtual_stages = getattr(parallelism, 'virtual_pipeline_parallel', 2) micro_batch_num = getattr(mf_config.parallel_config, 'micro_batch_num', 4) - device_num_per_stage = device_num // num_stages - stage_id = rank_id // device_num_per_stage + total_num_stages = num_pp_stages * num_virtual_stages + device_num_per_stage = device_num // num_pp_stages + stage_index = rank_id // device_num_per_stage logger.info( - f"PP Config: num_stages={num_stages}, micro_batch_num={micro_batch_num}, " - f"rank_id={rank_id}, stage_id={stage_id}" + f"PP Config: num_pp_stages={num_pp_stages}, num_virtual_stages={num_virtual_stages}, " + f"total_stages={total_num_stages}, micro_batch_num={micro_batch_num}, " + f"rank_id={rank_id}, stage_index={stage_index}" ) # Build full model @@ -240,8 +254,10 @@ def run_pp_training(config_path: str = 'ds_pynative.yaml'): model.set_train(True) logger.info(f"Built model: {type(model)}") - # Split model for pipeline parallel - stage_model, stage_id = split_model_for_pp(model, num_stages, rank_id, device_num) + # Create stage models using the helper function + stage_model_list, stage_index = create_pp_stage_models( + model, num_pp_stages, num_virtual_stages, rank_id, device_num + ) # Setup device mesh for data parallel within stage dp = device_num_per_stage @@ -253,22 +269,24 @@ def run_pp_training(config_path: str = 'ds_pynative.yaml'): w_placements = (Replicate(), Replicate()) out_placements = (Shard(0), Replicate()) - # Apply sharding plan model_stra = ShardingPlan( plan={"weight": w_placements}, input_plan={"input": in_placements}, output_plan={"output": out_placements}, ) - shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) - # Apply HSDP - stage_model = hsdp(stage_model, dp, 0, "level1", True) + # Apply sharding and HSDP to each stage model, create pipeline stages + stage_models = [] + pipeline_stages = [] - # Create pipeline stage - pipeline_stage = PipelineStage(stage_model, stage_id, num_stages) + for stage_model, virtual_stage_id in stage_model_list: + shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) + stage_model = hsdp(stage_model, dp, 0, "level1", True) + stage_models.append(stage_model) + pipeline_stages.append(PipelineStage(stage_model, virtual_stage_id, total_num_stages)) - # Create schedule - schedule = ScheduleInterleaved1F1B([pipeline_stage], micro_batch_num) + # Create interleaved 1F1B schedule + schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) # Prepare input data seq_length = mf_config.model.model_config.seq_length @@ -284,14 +302,18 @@ def run_pp_training(config_path: str = 'ds_pynative.yaml'): max_steps = getattr(mf_config.training, 'max_steps', 100) for step in range(max_steps): - stage_model.zero_grads() + # Zero gradients for all stage models + for stage_model in stage_models: + stage_model.zero_grads() - if stage_id == 0: + # Run schedule - only first stage provides input + if stage_index == 0: loss = schedule.run(x) else: loss = schedule.run() - if stage_id == num_stages - 1: + # Log loss from last stage + if stage_index == num_pp_stages - 1: logger.info(f"Step {step}, Loss: {loss}") logger.info("Training completed!") @@ -321,58 +343,45 @@ def run_vpp_training(config_path: str = 'ds_pynative.yaml'): # VPP config parallelism = getattr(mf_config, 'parallelism', None) if parallelism is None: - num_stages = 4 + num_pp_stages = 4 num_virtual_stages = 2 # Each rank holds 2 virtual stages micro_batch_num = 4 else: - num_stages = getattr(parallelism, 'pipeline_parallel', 4) + num_pp_stages = getattr(parallelism, 'pipeline_parallel', 4) num_virtual_stages = getattr(parallelism, 'virtual_pipeline_parallel', 2) micro_batch_num = getattr(mf_config.parallel_config, 'micro_batch_num', 4) - total_stages = num_stages * num_virtual_stages - device_num_per_stage = device_num // num_stages - stage_id = rank_id // device_num_per_stage + total_num_stages = num_pp_stages * num_virtual_stages + device_num_per_stage = device_num // num_pp_stages + stage_index = rank_id // device_num_per_stage logger.info( - f"VPP Config: num_stages={num_stages}, virtual_stages={num_virtual_stages}, " - f"total_stages={total_stages}, micro_batch_num={micro_batch_num}" + f"VPP Config: num_pp_stages={num_pp_stages}, virtual_stages={num_virtual_stages}, " + f"total_stages={total_num_stages}, micro_batch_num={micro_batch_num}" ) # Build full model model = build_network(mf_config.model) model.set_train(True) - num_layers = model.model.config.num_layers + # Create stage models using the helper function + stage_model_list, stage_index = create_pp_stage_models( + model, num_pp_stages, num_virtual_stages, rank_id, device_num + ) + + # Setup device mesh + dp = device_num_per_stage + mp = 1 + mesh = init_device_mesh(mesh_shape=(dp, mp), alias_name=("dp", "mp")) - # Create multiple stage models for VPP + # Apply HSDP and create pipeline stages stage_models = [] pipeline_stages = [] - for v in range(num_virtual_stages): - virtual_stage_id = stage_id + v * num_stages - - is_first = (virtual_stage_id == 0) - is_last = (virtual_stage_id == total_stages - 1) - - stage_model = PPStageModel( - full_model=model, - stage_id=virtual_stage_id, - num_stages=total_stages, - num_layers=num_layers, - is_first_stage=is_first, - is_last_stage=is_last - ) - - # Setup device mesh - dp = device_num_per_stage - mp = 1 - mesh = init_device_mesh(mesh_shape=(dp, mp), alias_name=("dp", "mp")) - - # Apply HSDP + for stage_model, virtual_stage_id in stage_model_list: stage_model = hsdp(stage_model, dp, 0, "level1", True) - stage_models.append(stage_model) - pipeline_stages.append(PipelineStage(stage_model, virtual_stage_id, total_stages)) + pipeline_stages.append(PipelineStage(stage_model, virtual_stage_id, total_num_stages)) # Create interleaved schedule schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) @@ -381,8 +390,6 @@ def run_vpp_training(config_path: str = 'ds_pynative.yaml'): seq_length = mf_config.model.model_config.seq_length local_batch_size = getattr(mf_config.training, 'local_batch_size', 1) - dp = device_num_per_stage - mesh = init_device_mesh(mesh_shape=(dp, 1), alias_name=("dp", "mp")) x_placements = (Shard(0), Shard(1)) x = DTensor.from_local( Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32), @@ -396,12 +403,12 @@ def run_vpp_training(config_path: str = 'ds_pynative.yaml'): for stage_model in stage_models: stage_model.zero_grads() - if stage_id == 0: + if stage_index == 0: loss = schedule.run(x) else: loss = schedule.run() - if stage_id == num_stages - 1: + if stage_index == num_pp_stages - 1: logger.info(f"Step {step}, Loss: {loss}") logger.info("VPP Training completed!") -- Gitee From 25410727500fea7b16dfd543bccbc24401bdef6a Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Thu, 5 Feb 2026 13:55:41 +0000 Subject: [PATCH 04/29] Update pp --- run_pynative_pp.py | 201 +++++++++------------------------------------ 1 file changed, 41 insertions(+), 160 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index cfeae1082..ad91bc8fb 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Pipeline Parallel Training Script for MindFormers PyNative Mode.""" +"""Pipeline Parallel Training Script for MindFormers PyNative Mode. + +This script implements Virtual Pipeline Parallel (VPP) training following +the pattern from hyper-parallel/tests/mindspore/st/pipeline_parallel/vpp_schedule.py. + +Key features: +- Each rank holds multiple non-consecutive virtual stages +- Uses ScheduleInterleaved1F1B for efficient pipeline scheduling +- Supports HSDP + SHARD + PP combination +""" from pathlib import Path import sys @@ -22,7 +31,6 @@ import numpy as np import mindspore as ms from mindspore import nn, Tensor, mint from mindspore.communication.management import init, get_rank, get_group_size -from mindspore.common import set_seed from hyper_parallel import PipelineStage, ScheduleInterleaved1F1B from hyper_parallel import DTensor, init_device_mesh @@ -67,8 +75,7 @@ class PPStageModel(nn.Cell): self.is_first_stage = (stage_id == 0) self.is_last_stage = (stage_id == total_num_stages - 1) - # Calculate which layer this stage holds - # Each stage holds one layer, similar to reference code + # Calculate which layers this stage holds layers_per_stage = num_layers // total_num_stages self.start_layer = stage_id * layers_per_stage self.end_layer = (stage_id + 1) * layers_per_stage @@ -104,7 +111,7 @@ class PPStageModel(nn.Cell): ) def construct(self, hidden_states, attention_mask=None, rotary_pos_emb=None, - labels=None, loss_mask=None, input_ids=None): + labels=None, loss_mask=None): """ Forward pass for this pipeline stage. @@ -114,7 +121,6 @@ class PPStageModel(nn.Cell): rotary_pos_emb: Rotary position embedding labels: Labels for loss computation (only used in last stage) loss_mask: Loss mask tensor (only used in last stage) - input_ids: Original input ids (for hash router) Returns: hidden_states or loss depending on stage @@ -167,52 +173,21 @@ class PPStageModel(nn.Cell): return hidden_states -def create_pp_stage_models(model, num_pp_stages, num_virtual_stages, rank_id, device_num): +def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): """ - Create pipeline stage models for the current rank. + Run Virtual Pipeline Parallel (VPP) training. Following the reference vpp_schedule.py pattern: - - Each rank holds `num_virtual_stages` stage models + - num_pp_stages physical ranks for pipeline parallel + - Each rank holds num_virtual_stages virtual stages - Total virtual stages = num_pp_stages * num_virtual_stages - rank i holds stages: [stage_index, stage_index + num_pp_stages, ...] - Args: - model: Full model instance - num_pp_stages: Number of physical pipeline stages (ranks for PP) - num_virtual_stages: Number of virtual stages per rank - rank_id: Current rank ID - device_num: Total number of devices - - Returns: - List of (stage_model, virtual_stage_id) tuples - """ - device_num_per_stage = device_num // num_pp_stages - stage_index = rank_id // device_num_per_stage - - num_layers = model.model.config.num_layers - total_num_stages = num_pp_stages * num_virtual_stages - - stage_models = [] - for v in range(num_virtual_stages): - virtual_stage_id = stage_index + v * num_pp_stages - - stage_model = PPStageModel( - full_model=model, - stage_id=virtual_stage_id, - total_num_stages=total_num_stages, - num_layers=num_layers, - ) - stage_models.append((stage_model, virtual_stage_id)) - - return stage_models, stage_index - - -def run_pp_training(config_path: str = 'ds_pynative.yaml'): - """ - Run pipeline parallel training with VPP-style interleaved stages. - - Following the reference vpp_schedule.py pattern, each rank holds multiple - non-consecutive stages for better pipeline efficiency. + Example with num_pp_stages=4, num_virtual_stages=2: + rank 0 holds: stage 0, stage 4 + rank 1 holds: stage 1, stage 5 + rank 2 holds: stage 2, stage 6 + rank 3 holds: stage 3, stage 7 Args: config_path: Path to the configuration yaml file @@ -226,7 +201,7 @@ def run_pp_training(config_path: str = 'ds_pynative.yaml'): # Load config mf_config = MindFormerConfig(config_path) - logger.info(f"Loaded config: {mf_config}") + logger.info(f"Loaded config from: {config_path}") # PP config from yaml parallelism = getattr(mf_config, 'parallelism', None) @@ -252,12 +227,8 @@ def run_pp_training(config_path: str = 'ds_pynative.yaml'): # Build full model model = build_network(mf_config.model) model.set_train(True) - logger.info(f"Built model: {type(model)}") - - # Create stage models using the helper function - stage_model_list, stage_index = create_pp_stage_models( - model, num_pp_stages, num_virtual_stages, rank_id, device_num - ) + num_layers = model.model.config.num_layers + logger.info(f"Built model with {num_layers} layers") # Setup device mesh for data parallel within stage dp = device_num_per_stage @@ -275,17 +246,29 @@ def run_pp_training(config_path: str = 'ds_pynative.yaml'): output_plan={"output": out_placements}, ) - # Apply sharding and HSDP to each stage model, create pipeline stages + # Create stage models for this rank + # Each rank holds num_virtual_stages models stage_models = [] pipeline_stages = [] - for stage_model, virtual_stage_id in stage_model_list: + for v in range(num_virtual_stages): + virtual_stage_id = stage_index + v * num_pp_stages + + stage_model = PPStageModel( + full_model=model, + stage_id=virtual_stage_id, + total_num_stages=total_num_stages, + num_layers=num_layers, + ) + + # Apply sharding and HSDP shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) stage_model = hsdp(stage_model, dp, 0, "level1", True) + stage_models.append(stage_model) pipeline_stages.append(PipelineStage(stage_model, virtual_stage_id, total_num_stages)) - # Create interleaved 1F1B schedule + # Create interleaved 1F1B schedule with all stages for this rank schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) # Prepare input data @@ -316,102 +299,7 @@ def run_pp_training(config_path: str = 'ds_pynative.yaml'): if stage_index == num_pp_stages - 1: logger.info(f"Step {step}, Loss: {loss}") - logger.info("Training completed!") - - -def run_vpp_training(config_path: str = 'ds_pynative.yaml'): - """ - Run Virtual Pipeline Parallel (VPP) training with interleaved stages. - - This implements VPP where each rank holds multiple non-consecutive stages, - similar to the reference vpp_schedule.py. - - Args: - config_path: Path to the configuration yaml file - """ - # Initialize - ms.set_seed(0) - init("hccl") - - rank_id = get_rank() - device_num = get_group_size() - - # Load config - mf_config = MindFormerConfig(config_path) - logger.info(f"Loaded config: {mf_config}") - - # VPP config - parallelism = getattr(mf_config, 'parallelism', None) - if parallelism is None: - num_pp_stages = 4 - num_virtual_stages = 2 # Each rank holds 2 virtual stages - micro_batch_num = 4 - else: - num_pp_stages = getattr(parallelism, 'pipeline_parallel', 4) - num_virtual_stages = getattr(parallelism, 'virtual_pipeline_parallel', 2) - micro_batch_num = getattr(mf_config.parallel_config, 'micro_batch_num', 4) - - total_num_stages = num_pp_stages * num_virtual_stages - device_num_per_stage = device_num // num_pp_stages - stage_index = rank_id // device_num_per_stage - - logger.info( - f"VPP Config: num_pp_stages={num_pp_stages}, virtual_stages={num_virtual_stages}, " - f"total_stages={total_num_stages}, micro_batch_num={micro_batch_num}" - ) - - # Build full model - model = build_network(mf_config.model) - model.set_train(True) - - # Create stage models using the helper function - stage_model_list, stage_index = create_pp_stage_models( - model, num_pp_stages, num_virtual_stages, rank_id, device_num - ) - - # Setup device mesh - dp = device_num_per_stage - mp = 1 - mesh = init_device_mesh(mesh_shape=(dp, mp), alias_name=("dp", "mp")) - - # Apply HSDP and create pipeline stages - stage_models = [] - pipeline_stages = [] - - for stage_model, virtual_stage_id in stage_model_list: - stage_model = hsdp(stage_model, dp, 0, "level1", True) - stage_models.append(stage_model) - pipeline_stages.append(PipelineStage(stage_model, virtual_stage_id, total_num_stages)) - - # Create interleaved schedule - schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) - - # Prepare input - seq_length = mf_config.model.model_config.seq_length - local_batch_size = getattr(mf_config.training, 'local_batch_size', 1) - - x_placements = (Shard(0), Shard(1)) - x = DTensor.from_local( - Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32), - mesh, x_placements - ) - - # Training loop - max_steps = getattr(mf_config.training, 'max_steps', 100) - - for step in range(max_steps): - for stage_model in stage_models: - stage_model.zero_grads() - - if stage_index == 0: - loss = schedule.run(x) - else: - loss = schedule.run() - - if stage_index == num_pp_stages - 1: - logger.info(f"Step {step}, Loss: {loss}") - - logger.info("VPP Training completed!") + logger.info("PP Training completed!") if __name__ == "__main__": @@ -422,13 +310,6 @@ if __name__ == "__main__": "--config", type=str, default="ds_pynative_pp.yaml", help="Path to config yaml file" ) - parser.add_argument( - "--mode", type=str, default="pp", choices=["pp", "vpp"], - help="Training mode: pp (pipeline parallel) or vpp (virtual pipeline parallel)" - ) args = parser.parse_args() - if args.mode == "pp": - run_pp_training(args.config) - else: - run_vpp_training(args.config) + run_pp_training(args.config) -- Gitee From 5e36a3ffb51822ecc036e8454ccce61cdc3bccc9 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Thu, 5 Feb 2026 14:05:30 +0000 Subject: [PATCH 05/29] Fix model share. --- run_pynative_pp.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index ad91bc8fb..f4db622f6 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -224,12 +224,6 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): f"rank_id={rank_id}, stage_index={stage_index}" ) - # Build full model - model = build_network(mf_config.model) - model.set_train(True) - num_layers = model.model.config.num_layers - logger.info(f"Built model with {num_layers} layers") - # Setup device mesh for data parallel within stage dp = device_num_per_stage mp = 1 @@ -248,12 +242,21 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): # Create stage models for this rank # Each rank holds num_virtual_stages models + # IMPORTANT: Each stage needs its own model instance to avoid shared parameter issues stage_models = [] pipeline_stages = [] for v in range(num_virtual_stages): virtual_stage_id = stage_index + v * num_pp_stages + # Build independent model for each virtual stage + model = build_network(mf_config.model) + model.set_train(True) + num_layers = model.model.config.num_layers + + if v == 0: + logger.info(f"Built model with {num_layers} layers") + stage_model = PPStageModel( full_model=model, stage_id=virtual_stage_id, -- Gitee From e64d0be2784b6a98aa8bcb0f89a8589888095282 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Thu, 5 Feb 2026 14:43:32 +0000 Subject: [PATCH 06/29] Update pp model input --- ds_pynative_pp.yaml | 2 ++ run_pynative_pp.py | 56 +++++++++++++++++++++++---------------------- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/ds_pynative_pp.yaml b/ds_pynative_pp.yaml index 8a72fe7b5..03ce393ce 100755 --- a/ds_pynative_pp.yaml +++ b/ds_pynative_pp.yaml @@ -194,6 +194,8 @@ model: rms_norm_eps: 1.e-6 add_bias_linear: False use_flash_attention: True + # PP 模式需要启用压缩 mask,使非第一个 stage 也能生成 attention mask + use_attn_mask_compression: True # MLA 配置 multi_latent_attention: True mla_qkv_concat: True diff --git a/run_pynative_pp.py b/run_pynative_pp.py index f4db622f6..9a63c08ab 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -85,12 +85,18 @@ class PPStageModel(nn.Cell): # Extract components from full model self.gpt_model = full_model.model + # Store seq_length for generating attention mask and rotary embedding + self.seq_length = self.gpt_model.seq_length + self.config = self.gpt_model.config + # Keep embedding only for first stage if self.is_first_stage: self.embedding = self.gpt_model.embedding - self.casual_mask = self.gpt_model.casual_mask - # Keep rotary embedding (needed for all stages) + # Keep casual_mask for all stages (needed for attention) + self.casual_mask = self.gpt_model.casual_mask + + # Keep rotary embedding for all stages if hasattr(self.gpt_model, 'rotary_pos_emb'): self.rotary_pos_emb = self.gpt_model.rotary_pos_emb @@ -110,38 +116,37 @@ class PPStageModel(nn.Cell): f"first={self.is_first_stage}, last={self.is_last_stage}" ) - def construct(self, hidden_states, attention_mask=None, rotary_pos_emb=None, - labels=None, loss_mask=None): + def construct(self, x): """ Forward pass for this pipeline stage. Args: - hidden_states: Input tensor (input_ids for first stage, hidden states otherwise) - attention_mask: Attention mask tensor - rotary_pos_emb: Rotary position embedding - labels: Labels for loss computation (only used in last stage) - loss_mask: Loss mask tensor (only used in last stage) + x: Input tensor + - For first stage: input_ids with shape (batch, seq_len) + - For other stages: hidden_states with shape (seq_len, batch, hidden_size) Returns: - hidden_states or loss depending on stage + - For last stage: loss scalar or logits + - For other stages: hidden_states tensor """ aux_loss = None - # First stage: process embedding - if self.is_first_stage: - input_ids = hidden_states - _, seq_len = input_ids.shape + # Generate attention mask (works for all stages with compression mode) + # For compression mode, casual_mask() returns pre-computed mask without needing input_ids + attention_mask = self.casual_mask() - # Generate attention mask - if attention_mask is None: - attention_mask = self.casual_mask(input_ids) + # Generate rotary position embedding for all stages + rotary_pos_emb = None + if hasattr(self, 'rotary_pos_emb'): + rotary_pos_emb = self.rotary_pos_emb(self.seq_length, position_ids=None) - # Embedding + # First stage: process embedding + if self.is_first_stage: + input_ids = x hidden_states = self.embedding(input_ids, position_ids=None) - - # Generate rotary position embedding - if hasattr(self, 'rotary_pos_emb'): - rotary_pos_emb = self.rotary_pos_emb(seq_len, position_ids=None) + else: + # For non-first stages, x is hidden_states from previous stage + hidden_states = x # Process transformer layers for this stage for layer in self.layers: @@ -163,11 +168,8 @@ class PPStageModel(nn.Cell): logits = mint.reshape(logits, (-1, logits.shape[-1])) logits = logits.astype(ms.float32) - if labels is not None: - loss = self.loss(logits, labels, loss_mask) - if aux_loss is not None: - loss = loss + aux_loss - return loss + if aux_loss is not None: + return logits, aux_loss return logits return hidden_states -- Gitee From 9323e3229cc73176d8fda25f438b96c476e6ea2f Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Thu, 5 Feb 2026 15:00:24 +0000 Subject: [PATCH 07/29] Add dp check. --- run_pynative_pp.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 9a63c08ab..a343c8609 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -280,6 +280,15 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): seq_length = mf_config.model.model_config.seq_length local_batch_size = getattr(mf_config.training, 'local_batch_size', 1) + # Ensure local_batch_size is large enough for sharding + # x_placements (Shard(0), Shard(1)) means batch dim is sharded across dp + # So local_batch_size must be >= dp to avoid zero-sized tensor + if local_batch_size < dp: + local_batch_size = dp * micro_batch_num + logger.warning( + f"local_batch_size too small for dp={dp}, adjusted to {local_batch_size}" + ) + x_placements = (Shard(0), Shard(1)) x = DTensor.from_local( Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32), -- Gitee From 8c6468f5b071b601f47574b1aa5e4064a5d10002 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Fri, 6 Feb 2026 01:52:39 +0000 Subject: [PATCH 08/29] Revert input format. --- ds_pynative_pp.yaml | 2 -- run_pynative_pp.py | 65 +++++++++++++++++++-------------------------- 2 files changed, 27 insertions(+), 40 deletions(-) diff --git a/ds_pynative_pp.yaml b/ds_pynative_pp.yaml index 03ce393ce..8a72fe7b5 100755 --- a/ds_pynative_pp.yaml +++ b/ds_pynative_pp.yaml @@ -194,8 +194,6 @@ model: rms_norm_eps: 1.e-6 add_bias_linear: False use_flash_attention: True - # PP 模式需要启用压缩 mask,使非第一个 stage 也能生成 attention mask - use_attn_mask_compression: True # MLA 配置 multi_latent_attention: True mla_qkv_concat: True diff --git a/run_pynative_pp.py b/run_pynative_pp.py index a343c8609..f4db622f6 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -85,18 +85,12 @@ class PPStageModel(nn.Cell): # Extract components from full model self.gpt_model = full_model.model - # Store seq_length for generating attention mask and rotary embedding - self.seq_length = self.gpt_model.seq_length - self.config = self.gpt_model.config - # Keep embedding only for first stage if self.is_first_stage: self.embedding = self.gpt_model.embedding + self.casual_mask = self.gpt_model.casual_mask - # Keep casual_mask for all stages (needed for attention) - self.casual_mask = self.gpt_model.casual_mask - - # Keep rotary embedding for all stages + # Keep rotary embedding (needed for all stages) if hasattr(self.gpt_model, 'rotary_pos_emb'): self.rotary_pos_emb = self.gpt_model.rotary_pos_emb @@ -116,37 +110,38 @@ class PPStageModel(nn.Cell): f"first={self.is_first_stage}, last={self.is_last_stage}" ) - def construct(self, x): + def construct(self, hidden_states, attention_mask=None, rotary_pos_emb=None, + labels=None, loss_mask=None): """ Forward pass for this pipeline stage. Args: - x: Input tensor - - For first stage: input_ids with shape (batch, seq_len) - - For other stages: hidden_states with shape (seq_len, batch, hidden_size) + hidden_states: Input tensor (input_ids for first stage, hidden states otherwise) + attention_mask: Attention mask tensor + rotary_pos_emb: Rotary position embedding + labels: Labels for loss computation (only used in last stage) + loss_mask: Loss mask tensor (only used in last stage) Returns: - - For last stage: loss scalar or logits - - For other stages: hidden_states tensor + hidden_states or loss depending on stage """ aux_loss = None - # Generate attention mask (works for all stages with compression mode) - # For compression mode, casual_mask() returns pre-computed mask without needing input_ids - attention_mask = self.casual_mask() - - # Generate rotary position embedding for all stages - rotary_pos_emb = None - if hasattr(self, 'rotary_pos_emb'): - rotary_pos_emb = self.rotary_pos_emb(self.seq_length, position_ids=None) - # First stage: process embedding if self.is_first_stage: - input_ids = x + input_ids = hidden_states + _, seq_len = input_ids.shape + + # Generate attention mask + if attention_mask is None: + attention_mask = self.casual_mask(input_ids) + + # Embedding hidden_states = self.embedding(input_ids, position_ids=None) - else: - # For non-first stages, x is hidden_states from previous stage - hidden_states = x + + # Generate rotary position embedding + if hasattr(self, 'rotary_pos_emb'): + rotary_pos_emb = self.rotary_pos_emb(seq_len, position_ids=None) # Process transformer layers for this stage for layer in self.layers: @@ -168,8 +163,11 @@ class PPStageModel(nn.Cell): logits = mint.reshape(logits, (-1, logits.shape[-1])) logits = logits.astype(ms.float32) - if aux_loss is not None: - return logits, aux_loss + if labels is not None: + loss = self.loss(logits, labels, loss_mask) + if aux_loss is not None: + loss = loss + aux_loss + return loss return logits return hidden_states @@ -280,15 +278,6 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): seq_length = mf_config.model.model_config.seq_length local_batch_size = getattr(mf_config.training, 'local_batch_size', 1) - # Ensure local_batch_size is large enough for sharding - # x_placements (Shard(0), Shard(1)) means batch dim is sharded across dp - # So local_batch_size must be >= dp to avoid zero-sized tensor - if local_batch_size < dp: - local_batch_size = dp * micro_batch_num - logger.warning( - f"local_batch_size too small for dp={dp}, adjusted to {local_batch_size}" - ) - x_placements = (Shard(0), Shard(1)) x = DTensor.from_local( Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32), -- Gitee From 77332df93a5b2b4a2859dc08b6b8de827da22583 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Fri, 6 Feb 2026 01:52:58 +0000 Subject: [PATCH 09/29] Update pp data loader. --- run_pynative_pp.py | 105 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 85 insertions(+), 20 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index f4db622f6..7d2732f8e 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -41,6 +41,7 @@ from hyper_parallel.core.shard.sharding_plan import ShardingPlan from mindformers.tools.register import MindFormerConfig from mindformers.models.build_model import build_network +from mindformers.dataset.build_dataset import build_dataset from mindformers.tools.logger import logger @@ -110,31 +111,34 @@ class PPStageModel(nn.Cell): f"first={self.is_first_stage}, last={self.is_last_stage}" ) - def construct(self, hidden_states, attention_mask=None, rotary_pos_emb=None, - labels=None, loss_mask=None): + def construct(self, hidden_states_or_input_ids, labels=None, loss_mask=None): """ Forward pass for this pipeline stage. + Pipeline data flow: + - First stage: receives (input_ids, labels, loss_mask), returns (hidden_states, labels, loss_mask) + - Middle stages: receives (hidden_states, labels, loss_mask), returns (hidden_states, labels, loss_mask) + - Last stage: receives (hidden_states, labels, loss_mask), returns loss + Args: - hidden_states: Input tensor (input_ids for first stage, hidden states otherwise) - attention_mask: Attention mask tensor - rotary_pos_emb: Rotary position embedding - labels: Labels for loss computation (only used in last stage) - loss_mask: Loss mask tensor (only used in last stage) + hidden_states_or_input_ids: input_ids for first stage, hidden_states for other stages + labels: Labels for loss computation (passed through all stages) + loss_mask: Loss mask tensor (passed through all stages) Returns: - hidden_states or loss depending on stage + Tuple of (hidden_states, labels, loss_mask) for non-last stages, or loss for last stage """ aux_loss = None + attention_mask = None + rotary_pos_emb = None - # First stage: process embedding + # First stage: process embedding from input_ids if self.is_first_stage: - input_ids = hidden_states + input_ids = hidden_states_or_input_ids _, seq_len = input_ids.shape # Generate attention mask - if attention_mask is None: - attention_mask = self.casual_mask(input_ids) + attention_mask = self.casual_mask(input_ids) # Embedding hidden_states = self.embedding(input_ids, position_ids=None) @@ -142,6 +146,20 @@ class PPStageModel(nn.Cell): # Generate rotary position embedding if hasattr(self, 'rotary_pos_emb'): rotary_pos_emb = self.rotary_pos_emb(seq_len, position_ids=None) + else: + # Non-first stages: input is hidden_states + hidden_states = hidden_states_or_input_ids + _, seq_len, _ = hidden_states.shape + + # Generate attention mask and rotary embedding for non-first stages + # Use seq_len to create proper masks + if hasattr(self, 'rotary_pos_emb'): + rotary_pos_emb = self.rotary_pos_emb(seq_len, position_ids=None) + + # Create causal attention mask for non-first stages + # Shape: [1, 1, seq_len, seq_len] + attention_mask = mint.ones((1, 1, seq_len, seq_len), dtype=hidden_states.dtype) + attention_mask = mint.tril(attention_mask) # Process transformer layers for this stage for layer in self.layers: @@ -153,7 +171,7 @@ class PPStageModel(nn.Cell): if layer_aux_loss is not None: aux_loss = layer_aux_loss if aux_loss is None else aux_loss + layer_aux_loss - # Last stage: process output and loss + # Last stage: process output and compute loss if self.is_last_stage: hidden_states = self.final_layernorm(hidden_states) logits, _ = self.output_layer(hidden_states, weight=None) @@ -170,7 +188,8 @@ class PPStageModel(nn.Cell): return loss return logits - return hidden_states + # Non-last stages: return tuple to pass to next stage + return hidden_states, labels, loss_mask def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): @@ -274,27 +293,73 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): # Create interleaved 1F1B schedule with all stages for this rank schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) - # Prepare input data + # Get training config seq_length = mf_config.model.model_config.seq_length local_batch_size = getattr(mf_config.training, 'local_batch_size', 1) + global_batch_size = getattr(mf_config.training, 'global_batch_size', 8) + + # Create dataset following Trainer pattern + train_dataset_config = getattr(mf_config, 'train_dataset', None) + if train_dataset_config is not None: + dataset_config = MindFormerConfig() + dataset_config.type = 'CausalLanguageModelDataset' + dataset_config.dataset_config = train_dataset_config + dataset_config.dataset_config.batch_size = global_batch_size // dp + dataset_config.dataset_config.seed = getattr(mf_config, 'seed', 42) + dataset = build_dataset(dataset_config) + dataset_iter = dataset.create_tuple_iterator() + logger.info(f"Created dataset with batch_size={global_batch_size // dp}") + else: + dataset = None + dataset_iter = None + logger.warning("No dataset config found, using dummy data") + # DTensor placements for input data x_placements = (Shard(0), Shard(1)) - x = DTensor.from_local( - Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32), - mesh, x_placements - ) # Training loop max_steps = getattr(mf_config.training, 'max_steps', 100) for step in range(max_steps): + # Get batch data from dataset + if dataset_iter is not None: + try: + batch = next(dataset_iter) + # Dataset returns: (input_ids, labels, loss_mask, position_ids) + input_ids = batch[0] + labels = batch[1] + loss_mask = batch[2] + # position_ids = batch[3] # Not used currently + except StopIteration: + # Reset iterator when exhausted + dataset_iter = dataset.create_tuple_iterator() + batch = next(dataset_iter) + input_ids = batch[0] + labels = batch[1] + loss_mask = batch[2] + + # Convert to DTensor for distributed training + input_ids = DTensor.from_local(input_ids, mesh, x_placements) + labels = DTensor.from_local(labels, mesh, x_placements) + loss_mask = DTensor.from_local( + loss_mask.astype(ms.float32), mesh, x_placements + ) + else: + # Fallback to dummy data if no dataset + input_ids = DTensor.from_local( + Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32), + mesh, x_placements + ) + labels = None + loss_mask = None + # Zero gradients for all stage models for stage_model in stage_models: stage_model.zero_grads() # Run schedule - only first stage provides input if stage_index == 0: - loss = schedule.run(x) + loss = schedule.run(input_ids, labels, loss_mask) else: loss = schedule.run() -- Gitee From 82cc219ef4011e4f852a090245b23d311d420d56 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Fri, 6 Feb 2026 07:28:34 +0000 Subject: [PATCH 10/29] Remove dp. --- run_pynative_pp.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 7d2732f8e..0c5faa1e2 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -244,7 +244,8 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): ) # Setup device mesh for data parallel within stage - dp = device_num_per_stage + # Set dp=1 to disable data parallel, only use PP + VPP + dp = 1 mp = 1 mesh = init_device_mesh(mesh_shape=(dp, mp), alias_name=("dp", "mp")) @@ -283,9 +284,10 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): num_layers=num_layers, ) - # Apply sharding and HSDP - shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) - stage_model = hsdp(stage_model, dp, 0, "level1", True) + # Apply sharding and HSDP only when dp > 1 + if dp > 1: + shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) + stage_model = hsdp(stage_model, dp, 0, "level1", True) stage_models.append(stage_model) pipeline_stages.append(PipelineStage(stage_model, virtual_stage_id, total_num_stages)) @@ -314,7 +316,7 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): dataset_iter = None logger.warning("No dataset config found, using dummy data") - # DTensor placements for input data + # DTensor placements for input data (only used when dp > 1) x_placements = (Shard(0), Shard(1)) # Training loop @@ -338,18 +340,20 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): labels = batch[1] loss_mask = batch[2] - # Convert to DTensor for distributed training - input_ids = DTensor.from_local(input_ids, mesh, x_placements) - labels = DTensor.from_local(labels, mesh, x_placements) - loss_mask = DTensor.from_local( - loss_mask.astype(ms.float32), mesh, x_placements - ) + # Convert to DTensor only when dp > 1 + if dp > 1: + input_ids = DTensor.from_local(input_ids, mesh, x_placements) + labels = DTensor.from_local(labels, mesh, x_placements) + loss_mask = DTensor.from_local( + loss_mask.astype(ms.float32), mesh, x_placements + ) + else: + loss_mask = loss_mask.astype(ms.float32) else: # Fallback to dummy data if no dataset - input_ids = DTensor.from_local( - Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32), - mesh, x_placements - ) + input_ids = Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32) + if dp > 1: + input_ids = DTensor.from_local(input_ids, mesh, x_placements) labels = None loss_mask = None -- Gitee From 5810a35b58387eca93b6fff8ce367965246a6505 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Fri, 6 Feb 2026 07:36:49 +0000 Subject: [PATCH 11/29] fix dp --- run_pynative_pp.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 0c5faa1e2..552a38829 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -357,9 +357,10 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): labels = None loss_mask = None - # Zero gradients for all stage models - for stage_model in stage_models: - stage_model.zero_grads() + # Zero gradients for all stage models (only needed when dp > 1 with HSDP) + if dp > 1: + for stage_model in stage_models: + stage_model.zero_grads() # Run schedule - only first stage provides input if stage_index == 0: -- Gitee From e8f22eb16c6672be8df2a1a1f62f6f1a0be7a677 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Fri, 6 Feb 2026 07:48:32 +0000 Subject: [PATCH 12/29] Fix Dtensor. --- run_pynative_pp.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 552a38829..d66b65028 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -316,8 +316,12 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): dataset_iter = None logger.warning("No dataset config found, using dummy data") - # DTensor placements for input data (only used when dp > 1) - x_placements = (Shard(0), Shard(1)) + # DTensor placements for input data + # When dp=1, use Replicate; when dp>1, use Shard for data parallel + if dp > 1: + x_placements = (Shard(0), Replicate()) # Shard on batch dim for dp + else: + x_placements = (Replicate(), Replicate()) # No sharding when dp=1 # Training loop max_steps = getattr(mf_config.training, 'max_steps', 100) @@ -340,20 +344,18 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): labels = batch[1] loss_mask = batch[2] - # Convert to DTensor only when dp > 1 - if dp > 1: - input_ids = DTensor.from_local(input_ids, mesh, x_placements) - labels = DTensor.from_local(labels, mesh, x_placements) - loss_mask = DTensor.from_local( - loss_mask.astype(ms.float32), mesh, x_placements - ) - else: - loss_mask = loss_mask.astype(ms.float32) + # Always convert to DTensor (required by PP framework) + input_ids = DTensor.from_local(input_ids, mesh, x_placements) + labels = DTensor.from_local(labels, mesh, x_placements) + loss_mask = DTensor.from_local( + loss_mask.astype(ms.float32), mesh, x_placements + ) else: # Fallback to dummy data if no dataset - input_ids = Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32) - if dp > 1: - input_ids = DTensor.from_local(input_ids, mesh, x_placements) + input_ids = DTensor.from_local( + Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32), + mesh, x_placements + ) labels = None loss_mask = None -- Gitee From 96f1bf83a015671a8f014e3f302b1eb25f89aa91 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Fri, 6 Feb 2026 08:18:30 +0000 Subject: [PATCH 13/29] Fix dp 1. --- run_pynative_pp.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index d66b65028..c80b82fd6 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -249,10 +249,15 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): mp = 1 mesh = init_device_mesh(mesh_shape=(dp, mp), alias_name=("dp", "mp")) - # Define placements - in_placements = (Shard(0), Replicate()) - w_placements = (Replicate(), Replicate()) - out_placements = (Shard(0), Replicate()) + # Define placements based on dp + # When dp=1, use Replicate; when dp>1, use Shard for data parallel + if dp > 1: + in_placements = (Shard(0), Replicate()) # Shard on batch dim + out_placements = (Shard(0), Replicate()) + else: + in_placements = (Replicate(), Replicate()) # No sharding + out_placements = (Replicate(), Replicate()) + w_placements = (Replicate(), Replicate()) # Weights always replicated model_stra = ShardingPlan( plan={"weight": w_placements}, @@ -284,10 +289,11 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): num_layers=num_layers, ) - # Apply sharding and HSDP only when dp > 1 - if dp > 1: - shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) - stage_model = hsdp(stage_model, dp, 0, "level1", True) + # Always apply sharding (required by PP framework) + shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) + + # Always apply HSDP (provides zero_grads and gradient sync) + stage_model = hsdp(stage_model, dp, 0, "level1", True) stage_models.append(stage_model) pipeline_stages.append(PipelineStage(stage_model, virtual_stage_id, total_num_stages)) @@ -359,10 +365,9 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): labels = None loss_mask = None - # Zero gradients for all stage models (only needed when dp > 1 with HSDP) - if dp > 1: - for stage_model in stage_models: - stage_model.zero_grads() + # Zero gradients for all stage models + for stage_model in stage_models: + stage_model.zero_grads() # Run schedule - only first stage provides input if stage_index == 0: -- Gitee From 4604a402c821fb6a2035b4d136df353c798fdf04 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Fri, 6 Feb 2026 09:08:02 +0000 Subject: [PATCH 14/29] Update input shard plan. --- run_pynative_pp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index c80b82fd6..98424f015 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -259,9 +259,10 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): out_placements = (Replicate(), Replicate()) w_placements = (Replicate(), Replicate()) # Weights always replicated + # Input plan for 3 inputs: (hidden_states_or_input_ids, labels, loss_mask) model_stra = ShardingPlan( plan={"weight": w_placements}, - input_plan={"input": in_placements}, + input_plan={"input": [in_placements, in_placements, in_placements]}, output_plan={"output": out_placements}, ) -- Gitee From 195070b305a20c49671f78cafa34cb19eddaa36b Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Sat, 7 Feb 2026 03:25:26 +0000 Subject: [PATCH 15/29] Update pp stage model. --- run_pynative_pp.py | 105 +++++++++++++++++++-------------------------- 1 file changed, 45 insertions(+), 60 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 98424f015..94bfe80cf 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -14,12 +14,12 @@ # ============================================================================ """Pipeline Parallel Training Script for MindFormers PyNative Mode. -This script implements Virtual Pipeline Parallel (VPP) training following +This script implements Pipeline Parallel (PP) training following the pattern from hyper-parallel/tests/mindspore/st/pipeline_parallel/vpp_schedule.py. Key features: -- Each rank holds multiple non-consecutive virtual stages -- Uses ScheduleInterleaved1F1B for efficient pipeline scheduling +- Each rank holds one pipeline stage +- Uses ScheduleInterleaved1F1B for pipeline scheduling - Supports HSDP + SHARD + PP combination """ @@ -55,8 +55,8 @@ class PPStageModel(nn.Cell): Args: full_model: The full model instance - stage_id: Current virtual stage ID (0-indexed) - total_num_stages: Total number of virtual pipeline stages + stage_id: Current stage ID (0-indexed) + num_pp_stages: Number of pipeline parallel stages num_layers: Total number of transformer layers """ @@ -64,23 +64,24 @@ class PPStageModel(nn.Cell): self, full_model, stage_id: int, - total_num_stages: int, + num_pp_stages: int, num_layers: int, ): super().__init__() self.stage_id = stage_id - self.total_num_stages = total_num_stages + self.num_pp_stages = num_pp_stages self.num_layers = num_layers # Determine if this is first/last stage self.is_first_stage = (stage_id == 0) - self.is_last_stage = (stage_id == total_num_stages - 1) + self.is_last_stage = (stage_id == num_pp_stages - 1) # Calculate which layers this stage holds - layers_per_stage = num_layers // total_num_stages + # For PP: stage 0 gets layers 0 to num_layers//2, stage 1 gets num_layers//2 to num_layers + layers_per_stage = num_layers // num_pp_stages self.start_layer = stage_id * layers_per_stage self.end_layer = (stage_id + 1) * layers_per_stage - if stage_id == total_num_stages - 1: + if stage_id == num_pp_stages - 1: self.end_layer = num_layers # Extract components from full model @@ -107,7 +108,7 @@ class PPStageModel(nn.Cell): self.loss = self.gpt_model.loss logger.info( - f"Stage {stage_id}/{total_num_stages}: layers [{self.start_layer}, {self.end_layer}), " + f"Stage {stage_id}/{num_pp_stages}: layers [{self.start_layer}, {self.end_layer}), " f"first={self.is_first_stage}, last={self.is_last_stage}" ) @@ -194,19 +195,16 @@ class PPStageModel(nn.Cell): def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): """ - Run Virtual Pipeline Parallel (VPP) training. + Run Pipeline Parallel (PP) training. Following the reference vpp_schedule.py pattern: - num_pp_stages physical ranks for pipeline parallel - - Each rank holds num_virtual_stages virtual stages - - Total virtual stages = num_pp_stages * num_virtual_stages - - rank i holds stages: [stage_index, stage_index + num_pp_stages, ...] + - Each rank holds one pipeline stage + - Layers are evenly distributed across stages - Example with num_pp_stages=4, num_virtual_stages=2: - rank 0 holds: stage 0, stage 4 - rank 1 holds: stage 1, stage 5 - rank 2 holds: stage 2, stage 6 - rank 3 holds: stage 3, stage 7 + Example with num_pp_stages=2, num_layers=8: + rank 0 (stage 0): layers 0-3 + rank 1 (stage 1): layers 4-7 Args: config_path: Path to the configuration yaml file @@ -225,26 +223,23 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): # PP config from yaml parallelism = getattr(mf_config, 'parallelism', None) if parallelism is None: - num_pp_stages = 4 # Number of physical PP stages (ranks) - num_virtual_stages = 2 # Each rank holds 2 virtual stages + num_pp_stages = 2 # Number of PP stages micro_batch_num = 4 else: - num_pp_stages = getattr(parallelism, 'pipeline_parallel', 4) - num_virtual_stages = getattr(parallelism, 'virtual_pipeline_parallel', 2) + num_pp_stages = getattr(parallelism, 'pipeline_parallel', 2) micro_batch_num = getattr(mf_config.parallel_config, 'micro_batch_num', 4) - total_num_stages = num_pp_stages * num_virtual_stages device_num_per_stage = device_num // num_pp_stages stage_index = rank_id // device_num_per_stage logger.info( - f"PP Config: num_pp_stages={num_pp_stages}, num_virtual_stages={num_virtual_stages}, " - f"total_stages={total_num_stages}, micro_batch_num={micro_batch_num}, " + f"PP Config: num_pp_stages={num_pp_stages}, " + f"micro_batch_num={micro_batch_num}, " f"rank_id={rank_id}, stage_index={stage_index}" ) # Setup device mesh for data parallel within stage - # Set dp=1 to disable data parallel, only use PP + VPP + # Set dp=1 to disable data parallel, only use PP dp = 1 mp = 1 mesh = init_device_mesh(mesh_shape=(dp, mp), alias_name=("dp", "mp")) @@ -266,41 +261,32 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): output_plan={"output": out_placements}, ) - # Create stage models for this rank - # Each rank holds num_virtual_stages models - # IMPORTANT: Each stage needs its own model instance to avoid shared parameter issues - stage_models = [] - pipeline_stages = [] + # Create stage model for this rank + # Each rank holds one pipeline stage + model = build_network(mf_config.model) + model.set_train(True) + num_layers = model.model.config.num_layers - for v in range(num_virtual_stages): - virtual_stage_id = stage_index + v * num_pp_stages + logger.info(f"Built model with {num_layers} layers") - # Build independent model for each virtual stage - model = build_network(mf_config.model) - model.set_train(True) - num_layers = model.model.config.num_layers - - if v == 0: - logger.info(f"Built model with {num_layers} layers") - - stage_model = PPStageModel( - full_model=model, - stage_id=virtual_stage_id, - total_num_stages=total_num_stages, - num_layers=num_layers, - ) + stage_model = PPStageModel( + full_model=model, + stage_id=stage_index, + num_pp_stages=num_pp_stages, + num_layers=num_layers, + ) - # Always apply sharding (required by PP framework) - shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) + # Apply sharding (required by PP framework) + shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) - # Always apply HSDP (provides zero_grads and gradient sync) - stage_model = hsdp(stage_model, dp, 0, "level1", True) + # Apply HSDP (provides zero_grads and gradient sync) + stage_model = hsdp(stage_model, dp, 0, "level1", True) - stage_models.append(stage_model) - pipeline_stages.append(PipelineStage(stage_model, virtual_stage_id, total_num_stages)) + # Create pipeline stage + pipeline_stage = PipelineStage(stage_model, stage_index, num_pp_stages) - # Create interleaved 1F1B schedule with all stages for this rank - schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) + # Create interleaved 1F1B schedule + schedule = ScheduleInterleaved1F1B([pipeline_stage], micro_batch_num) # Get training config seq_length = mf_config.model.model_config.seq_length @@ -366,9 +352,8 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): labels = None loss_mask = None - # Zero gradients for all stage models - for stage_model in stage_models: - stage_model.zero_grads() + # Zero gradients + stage_model.zero_grads() # Run schedule - only first stage provides input if stage_index == 0: -- Gitee From 512c13936c6d7a1719a733efd92432c8d6b68981 Mon Sep 17 00:00:00 2001 From: wangchengzhao Date: Sat, 7 Feb 2026 11:41:04 +0800 Subject: [PATCH 16/29] Revert "Update pp stage model." This reverts commit 195070b305a20c49671f78cafa34cb19eddaa36b. --- run_pynative_pp.py | 105 ++++++++++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 45 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 94bfe80cf..98424f015 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -14,12 +14,12 @@ # ============================================================================ """Pipeline Parallel Training Script for MindFormers PyNative Mode. -This script implements Pipeline Parallel (PP) training following +This script implements Virtual Pipeline Parallel (VPP) training following the pattern from hyper-parallel/tests/mindspore/st/pipeline_parallel/vpp_schedule.py. Key features: -- Each rank holds one pipeline stage -- Uses ScheduleInterleaved1F1B for pipeline scheduling +- Each rank holds multiple non-consecutive virtual stages +- Uses ScheduleInterleaved1F1B for efficient pipeline scheduling - Supports HSDP + SHARD + PP combination """ @@ -55,8 +55,8 @@ class PPStageModel(nn.Cell): Args: full_model: The full model instance - stage_id: Current stage ID (0-indexed) - num_pp_stages: Number of pipeline parallel stages + stage_id: Current virtual stage ID (0-indexed) + total_num_stages: Total number of virtual pipeline stages num_layers: Total number of transformer layers """ @@ -64,24 +64,23 @@ class PPStageModel(nn.Cell): self, full_model, stage_id: int, - num_pp_stages: int, + total_num_stages: int, num_layers: int, ): super().__init__() self.stage_id = stage_id - self.num_pp_stages = num_pp_stages + self.total_num_stages = total_num_stages self.num_layers = num_layers # Determine if this is first/last stage self.is_first_stage = (stage_id == 0) - self.is_last_stage = (stage_id == num_pp_stages - 1) + self.is_last_stage = (stage_id == total_num_stages - 1) # Calculate which layers this stage holds - # For PP: stage 0 gets layers 0 to num_layers//2, stage 1 gets num_layers//2 to num_layers - layers_per_stage = num_layers // num_pp_stages + layers_per_stage = num_layers // total_num_stages self.start_layer = stage_id * layers_per_stage self.end_layer = (stage_id + 1) * layers_per_stage - if stage_id == num_pp_stages - 1: + if stage_id == total_num_stages - 1: self.end_layer = num_layers # Extract components from full model @@ -108,7 +107,7 @@ class PPStageModel(nn.Cell): self.loss = self.gpt_model.loss logger.info( - f"Stage {stage_id}/{num_pp_stages}: layers [{self.start_layer}, {self.end_layer}), " + f"Stage {stage_id}/{total_num_stages}: layers [{self.start_layer}, {self.end_layer}), " f"first={self.is_first_stage}, last={self.is_last_stage}" ) @@ -195,16 +194,19 @@ class PPStageModel(nn.Cell): def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): """ - Run Pipeline Parallel (PP) training. + Run Virtual Pipeline Parallel (VPP) training. Following the reference vpp_schedule.py pattern: - num_pp_stages physical ranks for pipeline parallel - - Each rank holds one pipeline stage - - Layers are evenly distributed across stages + - Each rank holds num_virtual_stages virtual stages + - Total virtual stages = num_pp_stages * num_virtual_stages + - rank i holds stages: [stage_index, stage_index + num_pp_stages, ...] - Example with num_pp_stages=2, num_layers=8: - rank 0 (stage 0): layers 0-3 - rank 1 (stage 1): layers 4-7 + Example with num_pp_stages=4, num_virtual_stages=2: + rank 0 holds: stage 0, stage 4 + rank 1 holds: stage 1, stage 5 + rank 2 holds: stage 2, stage 6 + rank 3 holds: stage 3, stage 7 Args: config_path: Path to the configuration yaml file @@ -223,23 +225,26 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): # PP config from yaml parallelism = getattr(mf_config, 'parallelism', None) if parallelism is None: - num_pp_stages = 2 # Number of PP stages + num_pp_stages = 4 # Number of physical PP stages (ranks) + num_virtual_stages = 2 # Each rank holds 2 virtual stages micro_batch_num = 4 else: - num_pp_stages = getattr(parallelism, 'pipeline_parallel', 2) + num_pp_stages = getattr(parallelism, 'pipeline_parallel', 4) + num_virtual_stages = getattr(parallelism, 'virtual_pipeline_parallel', 2) micro_batch_num = getattr(mf_config.parallel_config, 'micro_batch_num', 4) + total_num_stages = num_pp_stages * num_virtual_stages device_num_per_stage = device_num // num_pp_stages stage_index = rank_id // device_num_per_stage logger.info( - f"PP Config: num_pp_stages={num_pp_stages}, " - f"micro_batch_num={micro_batch_num}, " + f"PP Config: num_pp_stages={num_pp_stages}, num_virtual_stages={num_virtual_stages}, " + f"total_stages={total_num_stages}, micro_batch_num={micro_batch_num}, " f"rank_id={rank_id}, stage_index={stage_index}" ) # Setup device mesh for data parallel within stage - # Set dp=1 to disable data parallel, only use PP + # Set dp=1 to disable data parallel, only use PP + VPP dp = 1 mp = 1 mesh = init_device_mesh(mesh_shape=(dp, mp), alias_name=("dp", "mp")) @@ -261,32 +266,41 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): output_plan={"output": out_placements}, ) - # Create stage model for this rank - # Each rank holds one pipeline stage - model = build_network(mf_config.model) - model.set_train(True) - num_layers = model.model.config.num_layers + # Create stage models for this rank + # Each rank holds num_virtual_stages models + # IMPORTANT: Each stage needs its own model instance to avoid shared parameter issues + stage_models = [] + pipeline_stages = [] - logger.info(f"Built model with {num_layers} layers") + for v in range(num_virtual_stages): + virtual_stage_id = stage_index + v * num_pp_stages - stage_model = PPStageModel( - full_model=model, - stage_id=stage_index, - num_pp_stages=num_pp_stages, - num_layers=num_layers, - ) + # Build independent model for each virtual stage + model = build_network(mf_config.model) + model.set_train(True) + num_layers = model.model.config.num_layers + + if v == 0: + logger.info(f"Built model with {num_layers} layers") + + stage_model = PPStageModel( + full_model=model, + stage_id=virtual_stage_id, + total_num_stages=total_num_stages, + num_layers=num_layers, + ) - # Apply sharding (required by PP framework) - shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) + # Always apply sharding (required by PP framework) + shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) - # Apply HSDP (provides zero_grads and gradient sync) - stage_model = hsdp(stage_model, dp, 0, "level1", True) + # Always apply HSDP (provides zero_grads and gradient sync) + stage_model = hsdp(stage_model, dp, 0, "level1", True) - # Create pipeline stage - pipeline_stage = PipelineStage(stage_model, stage_index, num_pp_stages) + stage_models.append(stage_model) + pipeline_stages.append(PipelineStage(stage_model, virtual_stage_id, total_num_stages)) - # Create interleaved 1F1B schedule - schedule = ScheduleInterleaved1F1B([pipeline_stage], micro_batch_num) + # Create interleaved 1F1B schedule with all stages for this rank + schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) # Get training config seq_length = mf_config.model.model_config.seq_length @@ -352,8 +366,9 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): labels = None loss_mask = None - # Zero gradients - stage_model.zero_grads() + # Zero gradients for all stage models + for stage_model in stage_models: + stage_model.zero_grads() # Run schedule - only first stage provides input if stage_index == 0: -- Gitee From 57658f05e488014a0d888744b4b79278dec625fe Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Sat, 7 Feb 2026 04:03:18 +0000 Subject: [PATCH 17/29] Fix bugs --- run_pynative_pp.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 98424f015..73c3b1aab 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -83,28 +83,29 @@ class PPStageModel(nn.Cell): if stage_id == total_num_stages - 1: self.end_layer = num_layers - # Extract components from full model - self.gpt_model = full_model.model + # Extract components from full model using local variable + # to avoid retaining the full model as a registered sub-module + gpt_model = full_model.model # Keep embedding only for first stage if self.is_first_stage: - self.embedding = self.gpt_model.embedding - self.casual_mask = self.gpt_model.casual_mask + self.embedding = gpt_model.embedding + self.casual_mask = gpt_model.casual_mask # Keep rotary embedding (needed for all stages) - if hasattr(self.gpt_model, 'rotary_pos_emb'): - self.rotary_pos_emb = self.gpt_model.rotary_pos_emb + if hasattr(gpt_model, 'rotary_pos_emb'): + self.rotary_pos_emb = gpt_model.rotary_pos_emb # Extract layers for this stage self.layers = nn.CellList() for i in range(self.start_layer, self.end_layer): - self.layers.append(self.gpt_model.decoder.layers[i]) + self.layers.append(gpt_model.decoder.layers[i]) # Keep final layernorm and output layer only for last stage if self.is_last_stage: - self.final_layernorm = self.gpt_model.decoder.final_layernorm - self.output_layer = self.gpt_model.output_layer - self.loss = self.gpt_model.loss + self.final_layernorm = gpt_model.decoder.final_layernorm + self.output_layer = gpt_model.output_layer + self.loss = gpt_model.loss logger.info( f"Stage {stage_id}/{total_num_stages}: layers [{self.start_layer}, {self.end_layer}), " @@ -157,9 +158,14 @@ class PPStageModel(nn.Cell): rotary_pos_emb = self.rotary_pos_emb(seq_len, position_ids=None) # Create causal attention mask for non-first stages - # Shape: [1, 1, seq_len, seq_len] - attention_mask = mint.ones((1, 1, seq_len, seq_len), dtype=hidden_states.dtype) - attention_mask = mint.tril(attention_mask) + # Must match CausalMaskGenerate output: 0=visible, 1=masked, uint8 + # Shape: [bs, 1, seq_len, seq_len] + bs = hidden_states.shape[0] + causal_mask = mint.ones((1, seq_len, seq_len), dtype=ms.float16) + causal_mask = mint.tril(causal_mask) + causal_mask = 1.0 - causal_mask + attention_mask = causal_mask.unsqueeze(1).expand(bs, 1, seq_len, seq_len) + attention_mask = attention_mask.astype(ms.uint8) # Process transformer layers for this stage for layer in self.layers: -- Gitee From 7f0fd1b78b5b0a396eb9cf906101654e058f4823 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Sat, 7 Feb 2026 04:06:10 +0000 Subject: [PATCH 18/29] Add log. --- run_pynative_pp.py | 47 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 73c3b1aab..2a2a3c4f4 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -224,9 +224,14 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): rank_id = get_rank() device_num = get_group_size() + logger.info("=" * 60) + logger.info(f"[Rank {rank_id}] Current rank_id: {rank_id}, " + f"total ranks (device_num): {device_num}") + logger.info("=" * 60) + # Load config mf_config = MindFormerConfig(config_path) - logger.info(f"Loaded config from: {config_path}") + logger.info(f"[Rank {rank_id}] Loaded config from: {config_path}") # PP config from yaml parallelism = getattr(mf_config, 'parallelism', None) @@ -243,11 +248,13 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): device_num_per_stage = device_num // num_pp_stages stage_index = rank_id // device_num_per_stage - logger.info( - f"PP Config: num_pp_stages={num_pp_stages}, num_virtual_stages={num_virtual_stages}, " - f"total_stages={total_num_stages}, micro_batch_num={micro_batch_num}, " - f"rank_id={rank_id}, stage_index={stage_index}" - ) + logger.info(f"[Rank {rank_id}] PP Config:") + logger.info(f"[Rank {rank_id}] num_pp_stages = {num_pp_stages}") + logger.info(f"[Rank {rank_id}] num_virtual_stages = {num_virtual_stages}") + logger.info(f"[Rank {rank_id}] total_num_stages = {total_num_stages}") + logger.info(f"[Rank {rank_id}] micro_batch_num = {micro_batch_num}") + logger.info(f"[Rank {rank_id}] device_num_per_stage = {device_num_per_stage}") + logger.info(f"[Rank {rank_id}] stage_index = {stage_index}") # Setup device mesh for data parallel within stage # Set dp=1 to disable data parallel, only use PP + VPP @@ -287,7 +294,17 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): num_layers = model.model.config.num_layers if v == 0: - logger.info(f"Built model with {num_layers} layers") + logger.info(f"[Rank {rank_id}] Model info:") + logger.info(f"[Rank {rank_id}] model type = {type(model).__name__}") + logger.info(f"[Rank {rank_id}] total num_layers = {num_layers}") + if hasattr(model.model, 'config'): + mc = model.model.config + logger.info(f"[Rank {rank_id}] hidden_size = {getattr(mc, 'hidden_size', 'N/A')}") + logger.info(f"[Rank {rank_id}] num_heads = {getattr(mc, 'num_heads', 'N/A')}") + logger.info(f"[Rank {rank_id}] seq_length = {getattr(mc, 'seq_length', 'N/A')}") + logger.info(f"[Rank {rank_id}] vocab_size = {getattr(mc, 'vocab_size', 'N/A')}") + all_layer_names = [name for name, _ in model.model.decoder.layers.name_cells().items()] + logger.info(f"[Rank {rank_id}] all layers = {all_layer_names}") stage_model = PPStageModel( full_model=model, @@ -296,6 +313,12 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): num_layers=num_layers, ) + # Log this virtual stage's layers + stage_layer_names = [type(l).__name__ for l in stage_model.layers] + logger.info(f"[Rank {rank_id}] Virtual stage {virtual_stage_id}: " + f"layers [{stage_model.start_layer}, {stage_model.end_layer}), " + f"types={stage_layer_names}") + # Always apply sharding (required by PP framework) shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) @@ -312,6 +335,14 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): seq_length = mf_config.model.model_config.seq_length local_batch_size = getattr(mf_config.training, 'local_batch_size', 1) global_batch_size = getattr(mf_config.training, 'global_batch_size', 8) + max_steps = getattr(mf_config.training, 'max_steps', 100) + + logger.info(f"[Rank {rank_id}] Training config:") + logger.info(f"[Rank {rank_id}] seq_length = {seq_length}") + logger.info(f"[Rank {rank_id}] local_batch_size = {local_batch_size}") + logger.info(f"[Rank {rank_id}] global_batch_size = {global_batch_size}") + logger.info(f"[Rank {rank_id}] max_steps = {max_steps}") + logger.info(f"[Rank {rank_id}] dp={dp}, mp={mp}") # Create dataset following Trainer pattern train_dataset_config = getattr(mf_config, 'train_dataset', None) @@ -337,8 +368,6 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): x_placements = (Replicate(), Replicate()) # No sharding when dp=1 # Training loop - max_steps = getattr(mf_config.training, 'max_steps', 100) - for step in range(max_steps): # Get batch data from dataset if dataset_iter is not None: -- Gitee From b896c84a073ff11fe6fdf678823e7bcf4765cf5a Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Sat, 7 Feb 2026 06:10:31 +0000 Subject: [PATCH 19/29] Update parameter. --- run_pynative_pp.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 2a2a3c4f4..747f34cbd 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -235,14 +235,11 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): # PP config from yaml parallelism = getattr(mf_config, 'parallelism', None) - if parallelism is None: - num_pp_stages = 4 # Number of physical PP stages (ranks) - num_virtual_stages = 2 # Each rank holds 2 virtual stages - micro_batch_num = 4 - else: - num_pp_stages = getattr(parallelism, 'pipeline_parallel', 4) - num_virtual_stages = getattr(parallelism, 'virtual_pipeline_parallel', 2) - micro_batch_num = getattr(mf_config.parallel_config, 'micro_batch_num', 4) + num_pp_stages = getattr(parallelism, 'pipeline_parallel', 4) + num_virtual_stages = getattr(parallelism, 'virtual_pipeline_parallel', 2) + micro_batch_num = getattr(mf_config.parallel_config, 'micro_batch_num', 4) + dp = getattr(parallelism, 'data_parallel', 1) + mp = getattr(parallelism, 'tensor_parallel', 1) total_num_stages = num_pp_stages * num_virtual_stages device_num_per_stage = device_num // num_pp_stages @@ -257,9 +254,6 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): logger.info(f"[Rank {rank_id}] stage_index = {stage_index}") # Setup device mesh for data parallel within stage - # Set dp=1 to disable data parallel, only use PP + VPP - dp = 1 - mp = 1 mesh = init_device_mesh(mesh_shape=(dp, mp), alias_name=("dp", "mp")) # Define placements based on dp -- Gitee From fede9b508ff3ea1bca4b7eb2e09e5cb1287b7441 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Sat, 7 Feb 2026 06:51:04 +0000 Subject: [PATCH 20/29] Update model. --- run_pynative_pp.py | 76 ++++++++++++++++++++++++++++++---------------- 1 file changed, 50 insertions(+), 26 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 747f34cbd..e82ce51b2 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -87,10 +87,16 @@ class PPStageModel(nn.Cell): # to avoid retaining the full model as a registered sub-module gpt_model = full_model.model - # Keep embedding only for first stage + # Keep causal mask for all stages (ensures consistent mask generation + # including pad token handling, dynamic seq, and mask compression) + self.casual_mask = gpt_model.casual_mask + + # First stage: keep embedding and preprocessing attributes + self.use_attn_mask_compression = gpt_model.use_attn_mask_compression if self.is_first_stage: self.embedding = gpt_model.embedding - self.casual_mask = gpt_model.casual_mask + self.pad_token_id = gpt_model.pad_token_id + self.ignore_token_id = gpt_model.ignore_token_id # Keep rotary embedding (needed for all stages) if hasattr(gpt_model, 'rotary_pos_emb'): @@ -106,28 +112,35 @@ class PPStageModel(nn.Cell): self.final_layernorm = gpt_model.decoder.final_layernorm self.output_layer = gpt_model.output_layer self.loss = gpt_model.loss + self.share_embeddings_and_output_weights = gpt_model.share_embeddings_and_output_weights + if self.share_embeddings_and_output_weights: + self.shared_embedding_or_output_weight = gpt_model.shared_embedding_or_output_weight logger.info( f"Stage {stage_id}/{total_num_stages}: layers [{self.start_layer}, {self.end_layer}), " f"first={self.is_first_stage}, last={self.is_last_stage}" ) - def construct(self, hidden_states_or_input_ids, labels=None, loss_mask=None): + def construct(self, hidden_states_or_input_ids, labels=None, loss_mask=None, input_ids=None): """ Forward pass for this pipeline stage. Pipeline data flow: - - First stage: receives (input_ids, labels, loss_mask), returns (hidden_states, labels, loss_mask) - - Middle stages: receives (hidden_states, labels, loss_mask), returns (hidden_states, labels, loss_mask) - - Last stage: receives (hidden_states, labels, loss_mask), returns loss + - First stage: receives (input_ids, labels, loss_mask), + returns (hidden_states, labels, loss_mask, input_ids) + - Middle stages: receives (hidden_states, labels, loss_mask, input_ids), + returns (hidden_states, labels, loss_mask, input_ids) + - Last stage: receives (hidden_states, labels, loss_mask, input_ids), returns loss Args: hidden_states_or_input_ids: input_ids for first stage, hidden_states for other stages labels: Labels for loss computation (passed through all stages) loss_mask: Loss mask tensor (passed through all stages) + input_ids: Original input token IDs (passed through for HashRoutedMoELayer) Returns: - Tuple of (hidden_states, labels, loss_mask) for non-last stages, or loss for last stage + Tuple of (hidden_states, labels, loss_mask, input_ids) for non-last stages, + or loss for last stage """ aux_loss = None attention_mask = None @@ -138,8 +151,18 @@ class PPStageModel(nn.Cell): input_ids = hidden_states_or_input_ids _, seq_len = input_ids.shape + # Preprocess labels and loss_mask (matches GPTModel._preprocess_input_labels_and_masks) + if loss_mask is None: + loss_mask = mint.not_equal(input_ids, self.pad_token_id).to(ms.float32) + if labels is not None: + label_mask = mint.not_equal(labels, self.ignore_token_id).to(ms.float32) + loss_mask = mint.mul(loss_mask, label_mask) + # Generate attention mask - attention_mask = self.casual_mask(input_ids) + if self.use_attn_mask_compression: + attention_mask = self.casual_mask() + else: + attention_mask = self.casual_mask(input_ids) # Embedding hidden_states = self.embedding(input_ids, position_ids=None) @@ -150,29 +173,26 @@ class PPStageModel(nn.Cell): else: # Non-first stages: input is hidden_states hidden_states = hidden_states_or_input_ids - _, seq_len, _ = hidden_states.shape + bs, seq_len, _ = hidden_states.shape - # Generate attention mask and rotary embedding for non-first stages - # Use seq_len to create proper masks + # Generate attention mask using the same CausalMaskGenerate as first stage + if self.use_attn_mask_compression: + attention_mask = self.casual_mask() + else: + dummy_tokens = mint.ones((bs, seq_len), dtype=ms.int32) + attention_mask = self.casual_mask(dummy_tokens) + + # Generate rotary position embedding if hasattr(self, 'rotary_pos_emb'): rotary_pos_emb = self.rotary_pos_emb(seq_len, position_ids=None) - # Create causal attention mask for non-first stages - # Must match CausalMaskGenerate output: 0=visible, 1=masked, uint8 - # Shape: [bs, 1, seq_len, seq_len] - bs = hidden_states.shape[0] - causal_mask = mint.ones((1, seq_len, seq_len), dtype=ms.float16) - causal_mask = mint.tril(causal_mask) - causal_mask = 1.0 - causal_mask - attention_mask = causal_mask.unsqueeze(1).expand(bs, 1, seq_len, seq_len) - attention_mask = attention_mask.astype(ms.uint8) - # Process transformer layers for this stage for layer in self.layers: hidden_states, _, layer_aux_loss = layer( hidden_states, attention_mask, rotary_pos_emb=rotary_pos_emb, + input_ids=input_ids, ) if layer_aux_loss is not None: aux_loss = layer_aux_loss if aux_loss is None else aux_loss + layer_aux_loss @@ -180,7 +200,10 @@ class PPStageModel(nn.Cell): # Last stage: process output and compute loss if self.is_last_stage: hidden_states = self.final_layernorm(hidden_states) - logits, _ = self.output_layer(hidden_states, weight=None) + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits, _ = self.output_layer(hidden_states, output_weight) if logits.ndim > 2: logits = mint.permute(logits, (1, 0, 2)) @@ -195,7 +218,7 @@ class PPStageModel(nn.Cell): return logits # Non-last stages: return tuple to pass to next stage - return hidden_states, labels, loss_mask + return hidden_states, labels, loss_mask, input_ids def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): @@ -266,10 +289,10 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): out_placements = (Replicate(), Replicate()) w_placements = (Replicate(), Replicate()) # Weights always replicated - # Input plan for 3 inputs: (hidden_states_or_input_ids, labels, loss_mask) + # Input plan for 4 inputs: (hidden_states_or_input_ids, labels, loss_mask, input_ids) model_stra = ShardingPlan( plan={"weight": w_placements}, - input_plan={"input": [in_placements, in_placements, in_placements]}, + input_plan={"input": [in_placements, in_placements, in_placements, in_placements]}, output_plan={"output": out_placements}, ) @@ -400,8 +423,9 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): stage_model.zero_grads() # Run schedule - only first stage provides input + # Pass input_ids as 4th arg so it flows through all stages for HashRoutedMoELayer if stage_index == 0: - loss = schedule.run(input_ids, labels, loss_mask) + loss = schedule.run(input_ids, labels, loss_mask, input_ids) else: loss = schedule.run() -- Gitee From 0e7a9dd687f1ca76d6350c732c343a522eb85163 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Sat, 7 Feb 2026 07:38:16 +0000 Subject: [PATCH 21/29] PP refactor. --- .../modeling_deepseek_v3_pynative.py | 19 +- .../pynative/base_models/gpt/gpt_model.py | 57 +++- .../transformers/transformer_block.py | 52 +++- run_pynative_pp.py | 255 +++++------------- 4 files changed, 172 insertions(+), 211 deletions(-) diff --git a/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py b/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py index ebf66c477..1ea056270 100644 --- a/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py +++ b/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py @@ -28,7 +28,16 @@ from .configuration_deepseek_v3 import DeepseekV3Config class DeepseekV3ForCausalLMPyNative(TrainModelMixin, DeepseekV3PreTrainedModel): """DeepseekV3 model for training""" - def __init__(self, config: DeepseekV3Config): + def __init__( + self, + config: DeepseekV3Config, + pre_process: bool = True, + post_process: bool = True, + pp_rank: int = 0, + pp_size: int = 1, + vp_stage: int = None, + num_vp_stages: int = None, + ): super().__init__(config, auto_prefix=False) transformer_config = self.convert_to_transformer_config(config, is_mla_model=True) if transformer_config.num_moe_experts: @@ -43,11 +52,17 @@ class DeepseekV3ForCausalLMPyNative(TrainModelMixin, DeepseekV3PreTrainedModel): transformer_layer_spec=transformer_layer_spec, vocab_size=transformer_config.vocab_size, max_sequence_length=transformer_config.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, position_embedding_type=transformer_config.position_embedding_type, rotary_percent=1.0, rotary_base=transformer_config.rotary_base, rope_scaling=False, - mtp_block_spec=mtp_block_spec + mtp_block_spec=mtp_block_spec, + pp_rank=pp_rank, + pp_size=pp_size, + vp_stage=vp_stage, + num_vp_stages=num_vp_stages, ) def construct( diff --git a/mindformers/pynative/base_models/gpt/gpt_model.py b/mindformers/pynative/base_models/gpt/gpt_model.py index 1b1ac69cd..26a63ef3f 100644 --- a/mindformers/pynative/base_models/gpt/gpt_model.py +++ b/mindformers/pynative/base_models/gpt/gpt_model.py @@ -87,6 +87,10 @@ class GPTModel(nn.Cell): rope_scaling_factor: float = 8.0, seq_len_interpolation_factor: Optional[float] = None, mtp_block_spec: ModuleSpec = None, + pp_rank: int = 0, + pp_size: int = 1, + vp_stage: int = None, + num_vp_stages: int = None, ): super().__init__() @@ -99,6 +103,12 @@ class GPTModel(nn.Cell): self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.use_attn_mask_compression = config.use_attn_mask_compression or config.use_eod_attn_mask_compression + # Pipeline parallelism parameters + self.pp_rank = pp_rank + self.pp_size = pp_size + self.vp_stage = vp_stage + self.num_vp_stages = num_vp_stages + if hasattr(self.config, 'position_embedding_type'): # By default, use the position_embedding_type configuration in TransformerConfig. self.position_embedding_type = self.config.position_embedding_type @@ -190,8 +200,10 @@ class GPTModel(nn.Cell): self.decoder = TransformerBlock( config=self.config, spec=transformer_layer_spec, - # The corresponding Megatron v0.12.0 module's forward pass has this logic disabled by default, - # so it won't cause significant impact. + pp_rank=self.pp_rank, + pp_size=self.pp_size, + vp_stage=self.vp_stage, + num_vp_stages=self.num_vp_stages, ) # Output @@ -250,7 +262,7 @@ class GPTModel(nn.Cell): """ if not self.config.use_eod_reset: position_ids = None - elif position_ids is None: + elif position_ids is None and self.pre_process: raise ValueError("When use eod_reset, position_ids should not be None.") if actual_seq_len is not None: actual_seq_len = self.reshape(actual_seq_len, (-1,)) @@ -259,8 +271,19 @@ class GPTModel(nn.Cell): # which indicates the partial seq_lens of eod sequences for compression mask. # Check mindformers.dataset.blended_datasets.gpt_dataset._get_eod_attention_mask() for implement details. - labels, attention_mask, loss_mask = self._preprocess_input_labels_and_masks( - input_ids, labels, attention_mask, loss_mask) + if self.pre_process: + labels, attention_mask, loss_mask = self._preprocess_input_labels_and_masks( + input_ids, labels, attention_mask, loss_mask) + else: + # Non-first stages: generate attention mask from decoder_input shape + if attention_mask is None: + if self.use_attn_mask_compression: + attention_mask = self.casual_mask() + elif decoder_input is not None: + bs = decoder_input.shape[1] if decoder_input.ndim == 3 else decoder_input.shape[0] + seq_len = decoder_input.shape[0] if decoder_input.ndim == 3 else decoder_input.shape[1] + dummy_tokens = mint.ones((bs, seq_len), dtype=dtype.int32) + attention_mask = self.casual_mask(dummy_tokens) hidden_states, _, aux_loss = self.language_model( input_ids, @@ -324,7 +347,12 @@ class GPTModel(nn.Cell): scenarios such as prefix tuning. The default value is None. actual_seq_len (Tensor, optional): Actual sequence length tensor. Default is None. """ - bs, seq_len = input_ids.shape + # Derive seq_len: from input_ids on first stage, from decoder_input on middle stages + if input_ids is not None: + bs, seq_len = input_ids.shape + else: + seq_len = decoder_input.shape[0] if decoder_input.ndim == 3 else decoder_input.shape[1] + # Encoder embedding if decoder_input is not None: pass @@ -346,7 +374,10 @@ class GPTModel(nn.Cell): f"And attn_mask.ndim should be 4, but got {attn_mask.ndim}") # prefix_key_values shape num_layers*(2, B, prefix_len, kv_num*kv_channel) - bs, seq_len = input_ids.shape + if input_ids is not None: + bs, seq_len = input_ids.shape + else: + bs = decoder_input.shape[1] if decoder_input.ndim == 3 else decoder_input.shape[0] prefix_length = prefix_keys_values[0].shape[2] prefix_mask = self.zeros((bs, 1, seq_len, prefix_length), attn_mask.dtype) # (B, 1, S, S) -> (B, 1, S, S+prefix_len) @@ -401,13 +432,15 @@ class GPTModel(nn.Cell): loss_mask: Tensor = None): """Preprocess input_ids and generate labels and masks if they are None. """ - if loss_mask is None: - loss_mask = self.cast(self.not_equal(input_ids, self.pad_token_id), dtype.float32) - label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), dtype.float32) - loss_mask = self.mul(loss_mask, label_mask) + if input_ids is not None: + if loss_mask is None: + loss_mask = self.cast(self.not_equal(input_ids, self.pad_token_id), dtype.float32) + if labels is not None: + label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), dtype.float32) + loss_mask = self.mul(loss_mask, label_mask) if self.use_attn_mask_compression: attention_mask = self.casual_mask() - elif attention_mask is None: + elif attention_mask is None and input_ids is not None: attention_mask = self.casual_mask(input_ids) return labels, attention_mask, loss_mask diff --git a/mindformers/pynative/transformers/transformer_block.py b/mindformers/pynative/transformers/transformer_block.py index 9de09968d..6709ca589 100644 --- a/mindformers/pynative/transformers/transformer_block.py +++ b/mindformers/pynative/transformers/transformer_block.py @@ -96,13 +96,37 @@ class TransformerBlock(nn.Cell): config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec], post_layer_norm: bool = True, + pp_rank: int = 0, + pp_size: int = 1, + vp_stage: int = None, + num_vp_stages: int = None, ): super().__init__() self.config = config self.submodules = _get_block_submodules(config, spec) self.post_layer_norm = post_layer_norm - self.num_layers = config.num_layers + + # Pipeline parallelism parameters + self.pp_rank = pp_rank + self.pp_size = pp_size + self.vp_stage = vp_stage if vp_stage is not None else 0 + self.num_vp_stages = num_vp_stages if num_vp_stages is not None else 1 + + # Compute layer count and offset for this PP stage + if pp_size <= 1: + self.num_layers = config.num_layers + self._layer_offset = 0 + else: + total_virtual_stages = pp_size * self.num_vp_stages + virtual_stage_id = pp_rank + self.vp_stage * pp_size + layers_per_stage = config.num_layers // total_virtual_stages + self._layer_offset = virtual_stage_id * layers_per_stage + # Last virtual stage gets any remaining layers + if virtual_stage_id == total_virtual_stages - 1: + self.num_layers = config.num_layers - self._layer_offset + else: + self.num_layers = layers_per_stage cp = config.context_parallel_size if config.context_parallel_size is not None else 1 if config.sequence_parallel and cp > 1: logger.warning("The context parallel way conflicts with sequence parallel way. " @@ -111,15 +135,29 @@ class TransformerBlock(nn.Cell): self._build_layers(config) + def _is_last_stage(self): + """Check if this is the last virtual pipeline stage.""" + if self.pp_size <= 1: + return True + total_virtual_stages = self.pp_size * self.num_vp_stages + virtual_stage_id = self.pp_rank + self.vp_stage * self.pp_size + return virtual_stage_id == total_virtual_stages - 1 + def _build_layers(self, config: TransformerConfig): - """build transformer layers.""" - # Transformer layers. + """Build transformer layers for this PP stage only.""" + # Transformer layers - build only layers for this stage self.layers = nn.CellList() - for layer_id in range(config.num_layers): - layer = build_module(self.submodules.layer_specs[layer_id], config=config, layer_number=layer_id) + for i in range(self.num_layers): + global_layer_id = self._layer_offset + i + layer = build_module( + self.submodules.layer_specs[global_layer_id], + config=config, + layer_number=global_layer_id, + ) self.layers.append(layer) - if self.post_layer_norm: + # Only build final_layernorm on the last stage + if self.post_layer_norm and self._is_last_stage(): self.final_layernorm = build_module(self.submodules.layer_norm, dim=config.hidden_size, eps=config.layernorm_epsilon, @@ -172,7 +210,7 @@ class TransformerBlock(nn.Cell): aux_loss = layer_aux_loss if aux_loss is None else aux_loss + layer_aux_loss # final layernorm. - if self.post_layer_norm: + if self.final_layernorm is not None: hidden_states = self.final_layernorm(hidden_states) return hidden_states, aux_loss diff --git a/run_pynative_pp.py b/run_pynative_pp.py index e82ce51b2..66c1dc5c9 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -29,7 +29,7 @@ sys.path.append(str(Path(__file__).parent.parent / 'hyper-parallel')) import numpy as np import mindspore as ms -from mindspore import nn, Tensor, mint +from mindspore import nn, Tensor from mindspore.communication.management import init, get_rank, get_group_size from hyper_parallel import PipelineStage, ScheduleInterleaved1F1B @@ -40,185 +40,49 @@ from hyper_parallel.core.placement_types import Shard, Replicate from hyper_parallel.core.shard.sharding_plan import ShardingPlan from mindformers.tools.register import MindFormerConfig -from mindformers.models.build_model import build_network +from mindformers.models.build_config import get_model_config +from mindformers.models.deepseek3.modeling_deepseek_v3_pynative import DeepseekV3ForCausalLMPyNative from mindformers.dataset.build_dataset import build_dataset from mindformers.tools.logger import logger -class PPStageModel(nn.Cell): +class PPStageAdapter(nn.Cell): """ - Pipeline Parallel Stage Model wrapper. - - This class wraps a portion of the full model for a specific pipeline stage. - Following the reference vpp_schedule.py pattern, each stage contains specific - transformer layers based on stage_id. - - Args: - full_model: The full model instance - stage_id: Current virtual stage ID (0-indexed) - total_num_stages: Total number of virtual pipeline stages - num_layers: Total number of transformer layers + Thin adapter that maps pipeline schedule positional args to GPTModel keyword args. + + Pipeline data flow: + - First stage: receives (input_ids, labels, loss_mask, input_ids), + returns (hidden_states, labels, loss_mask, input_ids) + - Middle stages: receives (hidden_states, labels, loss_mask, input_ids), + returns (hidden_states, labels, loss_mask, input_ids) + - Last stage: receives (hidden_states, labels, loss_mask, input_ids), returns loss """ - def __init__( - self, - full_model, - stage_id: int, - total_num_stages: int, - num_layers: int, - ): + def __init__(self, gpt_model, is_first_stage: bool, is_last_stage: bool): super().__init__() - self.stage_id = stage_id - self.total_num_stages = total_num_stages - self.num_layers = num_layers - - # Determine if this is first/last stage - self.is_first_stage = (stage_id == 0) - self.is_last_stage = (stage_id == total_num_stages - 1) - - # Calculate which layers this stage holds - layers_per_stage = num_layers // total_num_stages - self.start_layer = stage_id * layers_per_stage - self.end_layer = (stage_id + 1) * layers_per_stage - if stage_id == total_num_stages - 1: - self.end_layer = num_layers - - # Extract components from full model using local variable - # to avoid retaining the full model as a registered sub-module - gpt_model = full_model.model - - # Keep causal mask for all stages (ensures consistent mask generation - # including pad token handling, dynamic seq, and mask compression) - self.casual_mask = gpt_model.casual_mask - - # First stage: keep embedding and preprocessing attributes - self.use_attn_mask_compression = gpt_model.use_attn_mask_compression - if self.is_first_stage: - self.embedding = gpt_model.embedding - self.pad_token_id = gpt_model.pad_token_id - self.ignore_token_id = gpt_model.ignore_token_id - - # Keep rotary embedding (needed for all stages) - if hasattr(gpt_model, 'rotary_pos_emb'): - self.rotary_pos_emb = gpt_model.rotary_pos_emb - - # Extract layers for this stage - self.layers = nn.CellList() - for i in range(self.start_layer, self.end_layer): - self.layers.append(gpt_model.decoder.layers[i]) - - # Keep final layernorm and output layer only for last stage - if self.is_last_stage: - self.final_layernorm = gpt_model.decoder.final_layernorm - self.output_layer = gpt_model.output_layer - self.loss = gpt_model.loss - self.share_embeddings_and_output_weights = gpt_model.share_embeddings_and_output_weights - if self.share_embeddings_and_output_weights: - self.shared_embedding_or_output_weight = gpt_model.shared_embedding_or_output_weight - - logger.info( - f"Stage {stage_id}/{total_num_stages}: layers [{self.start_layer}, {self.end_layer}), " - f"first={self.is_first_stage}, last={self.is_last_stage}" - ) + self.gpt_model = gpt_model + self.is_first_stage = is_first_stage + self.is_last_stage = is_last_stage def construct(self, hidden_states_or_input_ids, labels=None, loss_mask=None, input_ids=None): - """ - Forward pass for this pipeline stage. - - Pipeline data flow: - - First stage: receives (input_ids, labels, loss_mask), - returns (hidden_states, labels, loss_mask, input_ids) - - Middle stages: receives (hidden_states, labels, loss_mask, input_ids), - returns (hidden_states, labels, loss_mask, input_ids) - - Last stage: receives (hidden_states, labels, loss_mask, input_ids), returns loss - - Args: - hidden_states_or_input_ids: input_ids for first stage, hidden_states for other stages - labels: Labels for loss computation (passed through all stages) - loss_mask: Loss mask tensor (passed through all stages) - input_ids: Original input token IDs (passed through for HashRoutedMoELayer) - - Returns: - Tuple of (hidden_states, labels, loss_mask, input_ids) for non-last stages, - or loss for last stage - """ - aux_loss = None - attention_mask = None - rotary_pos_emb = None - - # First stage: process embedding from input_ids if self.is_first_stage: input_ids = hidden_states_or_input_ids - _, seq_len = input_ids.shape - - # Preprocess labels and loss_mask (matches GPTModel._preprocess_input_labels_and_masks) - if loss_mask is None: - loss_mask = mint.not_equal(input_ids, self.pad_token_id).to(ms.float32) - if labels is not None: - label_mask = mint.not_equal(labels, self.ignore_token_id).to(ms.float32) - loss_mask = mint.mul(loss_mask, label_mask) - - # Generate attention mask - if self.use_attn_mask_compression: - attention_mask = self.casual_mask() - else: - attention_mask = self.casual_mask(input_ids) - - # Embedding - hidden_states = self.embedding(input_ids, position_ids=None) - - # Generate rotary position embedding - if hasattr(self, 'rotary_pos_emb'): - rotary_pos_emb = self.rotary_pos_emb(seq_len, position_ids=None) + output = self.gpt_model( + input_ids=input_ids, labels=labels, loss_mask=loss_mask, + ) else: - # Non-first stages: input is hidden_states - hidden_states = hidden_states_or_input_ids - bs, seq_len, _ = hidden_states.shape - - # Generate attention mask using the same CausalMaskGenerate as first stage - if self.use_attn_mask_compression: - attention_mask = self.casual_mask() - else: - dummy_tokens = mint.ones((bs, seq_len), dtype=ms.int32) - attention_mask = self.casual_mask(dummy_tokens) - - # Generate rotary position embedding - if hasattr(self, 'rotary_pos_emb'): - rotary_pos_emb = self.rotary_pos_emb(seq_len, position_ids=None) - - # Process transformer layers for this stage - for layer in self.layers: - hidden_states, _, layer_aux_loss = layer( - hidden_states, - attention_mask, - rotary_pos_emb=rotary_pos_emb, - input_ids=input_ids, + output = self.gpt_model( + input_ids=input_ids, decoder_input=hidden_states_or_input_ids, + labels=labels, loss_mask=loss_mask, ) - if layer_aux_loss is not None: - aux_loss = layer_aux_loss if aux_loss is None else aux_loss + layer_aux_loss - # Last stage: process output and compute loss if self.is_last_stage: - hidden_states = self.final_layernorm(hidden_states) - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - logits, _ = self.output_layer(hidden_states, output_weight) - - if logits.ndim > 2: - logits = mint.permute(logits, (1, 0, 2)) - logits = mint.reshape(logits, (-1, logits.shape[-1])) - logits = logits.astype(ms.float32) - - if labels is not None: - loss = self.loss(logits, labels, loss_mask) - if aux_loss is not None: - loss = loss + aux_loss - return loss - return logits + # post_process=True: GPTModel returns (loss, logits, hidden_states) + loss = output[0] + return loss - # Non-last stages: return tuple to pass to next stage - return hidden_states, labels, loss_mask, input_ids + # post_process=False: GPTModel returns hidden_states + return output, labels, loss_mask, input_ids def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): @@ -302,48 +166,59 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): stage_models = [] pipeline_stages = [] + # Build model config once (shared across virtual stages) + model_config = get_model_config(mf_config.model) + num_layers = model_config.num_hidden_layers + for v in range(num_virtual_stages): virtual_stage_id = stage_index + v * num_pp_stages - - # Build independent model for each virtual stage - model = build_network(mf_config.model) + is_first = (virtual_stage_id == 0) + is_last = (virtual_stage_id == total_num_stages - 1) + + # Build model with only the layers needed for this stage + model = DeepseekV3ForCausalLMPyNative( + config=model_config, + pre_process=is_first, + post_process=is_last, + pp_rank=stage_index, + pp_size=num_pp_stages, + vp_stage=v, + num_vp_stages=num_virtual_stages, + ) model.set_train(True) - num_layers = model.model.config.num_layers + # Log model info for first virtual stage if v == 0: logger.info(f"[Rank {rank_id}] Model info:") logger.info(f"[Rank {rank_id}] model type = {type(model).__name__}") logger.info(f"[Rank {rank_id}] total num_layers = {num_layers}") - if hasattr(model.model, 'config'): - mc = model.model.config - logger.info(f"[Rank {rank_id}] hidden_size = {getattr(mc, 'hidden_size', 'N/A')}") - logger.info(f"[Rank {rank_id}] num_heads = {getattr(mc, 'num_heads', 'N/A')}") - logger.info(f"[Rank {rank_id}] seq_length = {getattr(mc, 'seq_length', 'N/A')}") - logger.info(f"[Rank {rank_id}] vocab_size = {getattr(mc, 'vocab_size', 'N/A')}") - all_layer_names = [name for name, _ in model.model.decoder.layers.name_cells().items()] - logger.info(f"[Rank {rank_id}] all layers = {all_layer_names}") - - stage_model = PPStageModel( - full_model=model, - stage_id=virtual_stage_id, - total_num_stages=total_num_stages, - num_layers=num_layers, - ) + mc = model.model.config + logger.info(f"[Rank {rank_id}] hidden_size = {getattr(mc, 'hidden_size', 'N/A')}") + logger.info(f"[Rank {rank_id}] num_heads = {getattr(mc, 'num_heads', 'N/A')}") + logger.info(f"[Rank {rank_id}] seq_length = {getattr(mc, 'seq_length', 'N/A')}") + logger.info(f"[Rank {rank_id}] vocab_size = {getattr(mc, 'vocab_size', 'N/A')}") # Log this virtual stage's layers - stage_layer_names = [type(l).__name__ for l in stage_model.layers] - logger.info(f"[Rank {rank_id}] Virtual stage {virtual_stage_id}: " - f"layers [{stage_model.start_layer}, {stage_model.end_layer}), " - f"types={stage_layer_names}") + stage_num_layers = model.model.decoder.num_layers + layer_offset = model.model.decoder._layer_offset + stage_layer_names = [type(l).__name__ for l in model.model.decoder.layers] + logger.info( + f"[Rank {rank_id}] Virtual stage {virtual_stage_id}: " + f"layers [{layer_offset}, {layer_offset + stage_num_layers}), " + f"first={is_first}, last={is_last}, types={stage_layer_names}" + ) + + # Wrap with PPStageAdapter for pipeline schedule interface + adapter = PPStageAdapter(model.model, is_first, is_last) # Always apply sharding (required by PP framework) - shard_module(stage_model, device_mesh=mesh, sharding_plan=model_stra) + shard_module(adapter, device_mesh=mesh, sharding_plan=model_stra) # Always apply HSDP (provides zero_grads and gradient sync) - stage_model = hsdp(stage_model, dp, 0, "level1", True) + adapter = hsdp(adapter, dp, 0, "level1", True) - stage_models.append(stage_model) - pipeline_stages.append(PipelineStage(stage_model, virtual_stage_id, total_num_stages)) + stage_models.append(adapter) + pipeline_stages.append(PipelineStage(adapter, virtual_stage_id, total_num_stages)) # Create interleaved 1F1B schedule with all stages for this rank schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) -- Gitee From 580c5839334815b32915d12e85a2489587bd8fb4 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Mon, 9 Feb 2026 09:42:01 +0000 Subject: [PATCH 22/29] Update pp train. --- run_pynative_pp.py | 105 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 102 insertions(+), 3 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 66c1dc5c9..37be5db99 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -39,11 +39,17 @@ from hyper_parallel import shard_module from hyper_parallel.core.placement_types import Shard, Replicate from hyper_parallel.core.shard.sharding_plan import ShardingPlan +from mindspore import mint +from mindspore.ops import clip_by_global_norm + from mindformers.tools.register import MindFormerConfig from mindformers.models.build_config import get_model_config from mindformers.models.deepseek3.modeling_deepseek_v3_pynative import DeepseekV3ForCausalLMPyNative from mindformers.dataset.build_dataset import build_dataset from mindformers.tools.logger import logger +from mindformers.core import build_lr, build_optim, build_callback +from mindformers.pynative.callback.callback import CallbackHandler +from mindformers.pynative.trainer.train_state import TrainerState class PPStageAdapter(nn.Cell): @@ -220,6 +226,11 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): stage_models.append(adapter) pipeline_stages.append(PipelineStage(adapter, virtual_stage_id, total_num_stages)) + # Collect all trainable params across all virtual stages for this rank + all_trainable_params = [] + for stage_model in stage_models: + all_trainable_params.extend(stage_model.trainable_params()) + # Create interleaved 1F1B schedule with all stages for this rank schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) @@ -259,8 +270,64 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): else: x_placements = (Replicate(), Replicate()) # No sharding when dp=1 + # Build LR scheduler + lr_config = getattr(mf_config, 'lr_scheduler', None) + if lr_config is None: + lr_config = getattr(mf_config, 'lr_schedule', None) + lr_config.total_steps = max_steps + lr = build_lr(lr_config) + + # Build parameter groups (weight decay grouping) + optimizer_config = getattr(mf_config, 'optimizer', None) + weight_decay = getattr(optimizer_config, 'weight_decay', 0.0) + + decay_params = [] + no_decay_params = [] + for param in all_trainable_params: + if len(param.shape) == 1 or param.name.endswith(".bias"): + no_decay_params.append(param) + else: + decay_params.append(param) + + param_groups = [ + {"weight_decay": weight_decay, "params": decay_params}, + {"weight_decay": 0.0, "params": no_decay_params}, + ] + + # Build optimizer + optimizer = build_optim( + optimizer_config, + default_args={"params": param_groups, "learning_rate": lr} + ) + + # Build callbacks + callback_list = [] + callback_config = getattr(mf_config, 'callbacks', []) + for cb_cfg in callback_config: + callback_list.append(build_callback(cb_cfg)) + + callback_handler = CallbackHandler( + callbacks=callback_list, + model=stage_models[0], + train_dataset=dataset, + eval_dataset=None, + optimizer=optimizer, + lr_scheduler=lr, + ) + + # Initialize training state + state = TrainerState( + max_steps=max_steps, + save_steps=getattr(mf_config.training, 'save_steps', 2000), + global_batch_size=global_batch_size, + ) + + callback_handler.on_train_begin(mf_config, state) + # Training loop for step in range(max_steps): + callback_handler.on_step_begin(mf_config, state) + # Get batch data from dataset if dataset_iter is not None: try: @@ -304,11 +371,43 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): else: loss = schedule.run() - # Log loss from last stage + # Collect gradients (HSDP hooks store gradients in param.grad) + grads = [] + for param in all_trainable_params: + if param.grad is not None: + grads.append(param.grad) + else: + grads.append(mint.zeros_like(param)) + + # Overflow detection + overflow = False + for grad in grads: + if grad.isinf().any() or grad.isnan().any(): + overflow = True + break + if overflow: + logger.warning(f"Step {step}: gradient overflow detected, skipping update") + state.global_step += 1 + continue + + # Gradient clipping + grads = clip_by_global_norm(grads) + + # Optimizer step + optimizer(grads) + + # Loss logging (loss only available on last PP stage) if stage_index == num_pp_stages - 1: - logger.info(f"Step {step}, Loss: {loss}") + loss_value = loss[-1] if isinstance(loss, list) else loss + else: + loss_value = None + + # Update state and callbacks + state.global_step += 1 + callback_handler.on_step_end(mf_config, state, loss=loss_value, grad_norm=None) - logger.info("PP Training completed!") + callback_handler.on_train_end(mf_config, state) + logger.info(f"[Rank {rank_id}] PP Training completed! Total steps: {state.global_step}") if __name__ == "__main__": -- Gitee From 59b01aec9a089433cf377e5ba58f4e57217a5e42 Mon Sep 17 00:00:00 2001 From: wangchengzhao Date: Mon, 9 Feb 2026 18:14:36 +0800 Subject: [PATCH 23/29] Remove DTensor. --- run_pynative_pp.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 37be5db99..9ff9b8e5b 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -218,7 +218,7 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): adapter = PPStageAdapter(model.model, is_first, is_last) # Always apply sharding (required by PP framework) - shard_module(adapter, device_mesh=mesh, sharding_plan=model_stra) + # shard_module(adapter, device_mesh=mesh, sharding_plan=model_stra) # Always apply HSDP (provides zero_grads and gradient sync) adapter = hsdp(adapter, dp, 0, "level1", True) @@ -346,17 +346,15 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): loss_mask = batch[2] # Always convert to DTensor (required by PP framework) - input_ids = DTensor.from_local(input_ids, mesh, x_placements) - labels = DTensor.from_local(labels, mesh, x_placements) - loss_mask = DTensor.from_local( - loss_mask.astype(ms.float32), mesh, x_placements - ) + # input_ids = DTensor.from_local(input_ids, mesh, x_placements) + # labels = DTensor.from_local(labels, mesh, x_placements) + # loss_mask = DTensor.from_local( + # loss_mask.astype(ms.float32), mesh, x_placements + # ) else: # Fallback to dummy data if no dataset - input_ids = DTensor.from_local( - Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32), - mesh, x_placements - ) + input_ids = Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32) + # input_ids = DTensor.from_local(input_ids, mesh, x_placements) labels = None loss_mask = None -- Gitee From 9a6b3132b34d74045278dc2ff4d2cdec35808f81 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Mon, 9 Feb 2026 10:45:31 +0000 Subject: [PATCH 24/29] Update callback. --- run_pynative_pp.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 9ff9b8e5b..0d36c2123 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -47,8 +47,9 @@ from mindformers.models.build_config import get_model_config from mindformers.models.deepseek3.modeling_deepseek_v3_pynative import DeepseekV3ForCausalLMPyNative from mindformers.dataset.build_dataset import build_dataset from mindformers.tools.logger import logger -from mindformers.core import build_lr, build_optim, build_callback +from mindformers.core import build_lr, build_optim from mindformers.pynative.callback.callback import CallbackHandler +from mindformers.pynative.callback.loss_callback import LossCallback from mindformers.pynative.trainer.train_state import TrainerState @@ -300,11 +301,8 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): default_args={"params": param_groups, "learning_rate": lr} ) - # Build callbacks - callback_list = [] - callback_config = getattr(mf_config, 'callbacks', []) - for cb_cfg in callback_config: - callback_list.append(build_callback(cb_cfg)) + # Build callbacks (use pynative-compatible LossCallback instead of MFLossMonitor) + callback_list = [LossCallback()] callback_handler = CallbackHandler( callbacks=callback_list, -- Gitee From be4ae51f17f8aa1fa60c80c641a2735f646c9e07 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Mon, 9 Feb 2026 10:58:34 +0000 Subject: [PATCH 25/29] Fix optimizer. --- run_pynative_pp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 0d36c2123..eaf0ee959 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -367,9 +367,11 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): else: loss = schedule.run() - # Collect gradients (HSDP hooks store gradients in param.grad) + # Collect gradients in optimizer's internal parameter order + # (param_groups reorder params: decay group first, then no_decay group, + # which differs from the original model parameter order) grads = [] - for param in all_trainable_params: + for param in optimizer.parameters: if param.grad is not None: grads.append(param.grad) else: -- Gitee From bda50b00e0376acccbf84c4d03892077bca4d94b Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Mon, 9 Feb 2026 11:01:47 +0000 Subject: [PATCH 26/29] Add grad norm. --- run_pynative_pp.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index eaf0ee959..05a111e90 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -388,6 +388,12 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): state.global_step += 1 continue + # Compute grad norm (before clipping) + grad_norm = 0.0 + for grad in grads: + grad_norm += mint.sum(mint.square(grad)) + grad_norm = mint.sqrt(grad_norm) + # Gradient clipping grads = clip_by_global_norm(grads) @@ -402,7 +408,7 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): # Update state and callbacks state.global_step += 1 - callback_handler.on_step_end(mf_config, state, loss=loss_value, grad_norm=None) + callback_handler.on_step_end(mf_config, state, loss=loss_value, grad_norm=grad_norm) callback_handler.on_train_end(mf_config, state) logger.info(f"[Rank {rank_id}] PP Training completed! Total steps: {state.global_step}") -- Gitee From fc42107f506a8288c79266954233461e9d3bdb7c Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Tue, 10 Feb 2026 13:54:16 +0000 Subject: [PATCH 27/29] Add parameter size print. --- run_pynative_pp.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 05a111e90..947b22044 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -40,7 +40,7 @@ from hyper_parallel.core.placement_types import Shard, Replicate from hyper_parallel.core.shard.sharding_plan import ShardingPlan from mindspore import mint -from mindspore.ops import clip_by_global_norm +from mindspore.ops import clip_by_global_norm, AllReduce from mindformers.tools.register import MindFormerConfig from mindformers.models.build_config import get_model_config @@ -232,6 +232,24 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): for stage_model in stage_models: all_trainable_params.extend(stage_model.trainable_params()) + # Log total parameter count for this rank (across all virtual stages) + rank_total_params = sum(p.size for p in all_trainable_params) + logger.info( + f"[Rank {rank_id}] Total trainable params on this rank: {rank_total_params:,}" + ) + + # Compute total model parameter count across all PP stages via AllReduce + local_params_tensor = Tensor([rank_total_params], dtype=ms.float64) + global_params_tensor = AllReduce()(local_params_tensor) + # Each PP stage has device_num_per_stage replicas holding the same params, + # so divide by device_num_per_stage to get the actual total unique params. + total_model_params = int(global_params_tensor.asnumpy()[0]) // device_num_per_stage + if rank_id == 0: + logger.info( + f"Total model parameters (across all PP stages): {total_model_params:,} " + f"({total_model_params / 1e9:.2f}B)" + ) + # Create interleaved 1F1B schedule with all stages for this rank schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) -- Gitee From 825225c108be4cbd16c04a278af34b5b76db6efc Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Tue, 10 Feb 2026 14:19:32 +0000 Subject: [PATCH 28/29] Add model print. --- run_pynative_pp.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index 947b22044..bc6588747 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -227,6 +227,19 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): stage_models.append(adapter) pipeline_stages.append(PipelineStage(adapter, virtual_stage_id, total_num_stages)) + # Print model structure and parameter details for each virtual stage + for v, stage_model in enumerate(stage_models): + virtual_stage_id = stage_index + v * num_pp_stages + logger.info(f"[Rank {rank_id}] ===== Virtual stage {virtual_stage_id} model structure =====") + logger.info(f"\n{stage_model}") + logger.info(f"[Rank {rank_id}] ===== Virtual stage {virtual_stage_id} parameter details =====") + for param in stage_model.get_parameters(): + logger.info( + f"[Rank {rank_id}] name={param.name}, " + f"shape={param.shape}, " + f"trainable={param.requires_grad}" + ) + # Collect all trainable params across all virtual stages for this rank all_trainable_params = [] for stage_model in stage_models: @@ -239,7 +252,7 @@ def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): ) # Compute total model parameter count across all PP stages via AllReduce - local_params_tensor = Tensor([rank_total_params], dtype=ms.float64) + local_params_tensor = Tensor([rank_total_params], dtype=ms.float32) global_params_tensor = AllReduce()(local_params_tensor) # Each PP stage has device_num_per_stage replicas holding the same params, # so divide by device_num_per_stage to get the actual total unique params. -- Gitee From 4313ba7180b8224652330b86459fe5185ace2ce3 Mon Sep 17 00:00:00 2001 From: WangChengzhao Date: Tue, 10 Feb 2026 14:23:57 +0000 Subject: [PATCH 29/29] Add input output log. --- run_pynative_pp.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/run_pynative_pp.py b/run_pynative_pp.py index bc6588747..d56974022 100755 --- a/run_pynative_pp.py +++ b/run_pynative_pp.py @@ -71,7 +71,23 @@ class PPStageAdapter(nn.Cell): self.is_first_stage = is_first_stage self.is_last_stage = is_last_stage + def _log_tensor(self, name, t): + """Log shape and dtype of a tensor for debugging.""" + if t is None: + logger.debug(f" {name}: None") + elif isinstance(t, (tuple, list)): + for i, item in enumerate(t): + self._log_tensor(f"{name}[{i}]", item) + else: + logger.debug(f" {name}: shape={t.shape}, dtype={t.dtype}") + def construct(self, hidden_states_or_input_ids, labels=None, loss_mask=None, input_ids=None): + # Log inputs + logger.debug(f"PPStageAdapter input (first={self.is_first_stage}, last={self.is_last_stage}):") + self._log_tensor("hidden_states_or_input_ids", hidden_states_or_input_ids) + self._log_tensor("labels", labels) + self._log_tensor("loss_mask", loss_mask) + self._log_tensor("input_ids", input_ids) if self.is_first_stage: input_ids = hidden_states_or_input_ids output = self.gpt_model( @@ -83,9 +99,14 @@ class PPStageAdapter(nn.Cell): labels=labels, loss_mask=loss_mask, ) + # Log output + logger.debug(f"PPStageAdapter output (first={self.is_first_stage}, last={self.is_last_stage}):") + self._log_tensor("output", output) + if self.is_last_stage: # post_process=True: GPTModel returns (loss, logits, hidden_states) loss = output[0] + logger.debug(f" loss: shape={loss.shape}, dtype={loss.dtype}") return loss # post_process=False: GPTModel returns hidden_states -- Gitee