diff --git a/docs/run_pynative_pp_guide.md b/docs/run_pynative_pp_guide.md new file mode 100755 index 0000000000000000000000000000000000000000..2fee229b84c94a843dd47a7bfe6fc7205e63673b --- /dev/null +++ b/docs/run_pynative_pp_guide.md @@ -0,0 +1,230 @@ +# Pipeline Parallel (PP) 训练指南 + +本文档介绍如何使用 `run_pynative_pp.py` 脚本进行 Pipeline Parallel 训练。 + +## 目录 + +- [环境依赖](#环境依赖) +- [配置说明](#配置说明) +- [运行方式](#运行方式) +- [参数说明](#参数说明) +- [示例配置](#示例配置) + +## 环境依赖 + +### 必需依赖 + +```bash +# MindSpore +pip install mindspore + +# hyper_parallel (流水线并行调度库) +# 确保 hyper_parallel 库已安装,通常位于 ../hyper-parallel 目录 +export PYTHONPATH=$PYTHONPATH:/data/hyper-parallel + +# MindFormers +# 当前仓库 +``` + +### 硬件要求 + +- Ascend NPU 设备 +- 设备数量需要能被 `pipeline_parallel` 整除 +- 推荐配置:8卡或16卡 + +## 配置说明 + +### YAML 配置文件 + +在 YAML 配置文件(如 `ds_pynative.yaml`)中添加以下配置: + +```yaml +# 并行配置 +parallelism: + tensor_parallel: 1 # 张量并行度 + pipeline_parallel: 4 # 流水线并行 stage 数量 + context_parallel: 1 # 上下文并行度 + data_parallel: 2 # 数据并行度(自动计算:总卡数 / PP / TP / CP) + virtual_pipeline_parallel: 2 # VPP 虚拟 stage 数量(可选,仅 VPP 模式使用) + +# 并行配置(旧版兼容) +parallel_config: + pipeline_stage: 4 # 与 parallelism.pipeline_parallel 对应 + micro_batch_num: 4 # micro batch 数量,影响流水线效率 + +# 训练配置 +training: + max_steps: 2000 # 最大训练步数 + global_batch_size: 8 # 全局 batch size + local_batch_size: 1 # 每个设备的 batch size +``` + +### 关键参数说明 + +| 参数 | 说明 | 推荐值 | +|------|------|--------| +| `pipeline_parallel` | PP stage 数量 | 2, 4, 8 | +| `micro_batch_num` | micro batch 数量 | >= pipeline_parallel | +| `virtual_pipeline_parallel` | VPP 虚拟 stage 数 | 2 | + +### 约束条件 + +1. **设备数量约束**:`总设备数 = PP × DP × TP × CP` +2. **层数约束**:`num_hidden_layers` 需要能被 `pipeline_parallel` 整除 +3. **micro batch 约束**:`micro_batch_num >= pipeline_parallel` 以保证流水线效率 + +## 运行方式 + +### 单机多卡运行 + +```bash +# 标准 PP 模式(8卡,4个 stage) +msrun --worker_num=8 --local_worker_num=8 \ + python run_pynative_pp.py --config ds_pynative.yaml --mode pp + +# VPP 模式(8卡,4个物理 stage,每个 rank 2个虚拟 stage) +msrun --worker_num=8 --local_worker_num=8 \ + python run_pynative_pp.py --config ds_pynative.yaml --mode vpp +``` + +### 多机多卡运行 + +```bash +# 主节点 +msrun --worker_num=16 --local_worker_num=8 \ + --master_addr=192.168.1.1 --master_port=8118 \ + --node_rank=0 \ + python run_pynative_pp.py --config ds_pynative.yaml --mode pp + +# 从节点 +msrun --worker_num=16 --local_worker_num=8 \ + --master_addr=192.168.1.1 --master_port=8118 \ + --node_rank=1 \ + python run_pynative_pp.py --config ds_pynative.yaml --mode pp +``` + +### 使用 mpirun 运行 + +```bash +# 8卡运行 +mpirun -n 8 python run_pynative_pp.py --config ds_pynative.yaml --mode pp +``` + +## 参数说明 + +### 命令行参数 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `--config` | str | `ds_pynative.yaml` | 配置文件路径 | +| `--mode` | str | `pp` | 训练模式:`pp` 或 `vpp` | + +### 训练模式对比 + +| 模式 | 说明 | 适用场景 | +|------|------|----------| +| `pp` | 标准流水线并行 | 通用场景,简单易用 | +| `vpp` | 虚拟流水线并行 | 减少流水线气泡,提高效率 | + +## 示例配置 + +### 8卡 PP 配置示例 + +```yaml +# ds_pynative_pp8.yaml +parallelism: + tensor_parallel: 1 + pipeline_parallel: 4 + data_parallel: 2 + +parallel_config: + micro_batch_num: 8 + +model: + model_config: + num_hidden_layers: 8 # 需要能被 pipeline_parallel 整除 + hidden_size: 512 + seq_length: 4096 + +training: + max_steps: 1000 + global_batch_size: 8 + local_batch_size: 1 +``` + +### 16卡 VPP 配置示例 + +```yaml +# ds_pynative_vpp16.yaml +parallelism: + tensor_parallel: 1 + pipeline_parallel: 4 + data_parallel: 4 + virtual_pipeline_parallel: 2 + +parallel_config: + micro_batch_num: 16 + +model: + model_config: + num_hidden_layers: 16 # 需要能被 (PP × VPP) 整除 + hidden_size: 1024 + seq_length: 4096 + +training: + max_steps: 2000 + global_batch_size: 16 + local_batch_size: 1 +``` + +## 常见问题 + +### Q1: 报错 "Unable to import hyper_parallel" + +**解决方案**:确保 hyper_parallel 库在 PYTHONPATH 中 + +```bash +export PYTHONPATH=$PYTHONPATH:/data/hyper-parallel +``` + +### Q2: 层数无法整除 stage 数量 + +**解决方案**:调整 `num_hidden_layers` 或 `pipeline_parallel` 使其能整除 + +### Q3: 流水线气泡过大 + +**解决方案**: +1. 增加 `micro_batch_num` +2. 使用 VPP 模式 +3. 减少 `pipeline_parallel` 数量 + +### Q4: 显存不足 + +**解决方案**: +1. 减小 `local_batch_size` +2. 增加 `pipeline_parallel` 数量 +3. 启用梯度检查点(recompute) + +## 架构说明 + +### PP 模式架构 + +``` +Stage 0 (Rank 0-1) Stage 1 (Rank 2-3) Stage 2 (Rank 4-5) Stage 3 (Rank 6-7) +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ Embedding │ │ Layer 2-3 │ │ Layer 4-5 │ │ Layer 6-7 │ +│ Layer 0-1 │──▶│ │──▶│ │──▶│ Output Layer │ +│ │ │ │ │ │ │ Loss │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ └─────────────────┘ +``` + +### VPP 模式架构 + +``` +Rank 0-1 holds: Stage 0, Stage 4 +Rank 2-3 holds: Stage 1, Stage 5 +Rank 4-5 holds: Stage 2, Stage 6 +Rank 6-7 holds: Stage 3, Stage 7 + +Interleaved 1F1B Schedule reduces pipeline bubbles +``` diff --git a/ds_pynative.yaml b/ds_pynative.yaml index 15461272e37878dd30e7265018c60fc5b0a42488..20b44a8f1dbcd7ea5603460a268e5f80c326ba5e 100644 --- a/ds_pynative.yaml +++ b/ds_pynative.yaml @@ -6,7 +6,7 @@ lr_scheduler: total_steps: -1 # -1 means it will load the total steps of the dataset checkpoint: - save_path: "../ds3_1dense1moe1mtp_concat_pynative" + save_path: "checkpoints/ds3_1dense1moe1mtp_concat_pynative" save_max: 1 save_interleaved_steps: 100000 # 500 no_save_optim: False @@ -26,7 +26,7 @@ parallelism: data_parallel: 1 training: - max_steps: 2000 + max_steps: 20000 global_batch_size: 1 local_batch_size: 1 save_steps: 2000 @@ -35,7 +35,7 @@ training: seed: 1234 -output_dir: './output' # path to save checkpoint and strategy +output_dir: './output_pynative' # path to save checkpoint and strategy load_checkpoint: '' load_ckpt_format: 'safetensors' # format of checkpoint files to load src_strategy_path_or_dir: '' @@ -91,7 +91,7 @@ train_dataset: &train_dataset type: BlendedMegatronDatasetDataLoader datasets_type: "GPTDataset" sizes: - - 100 # Number of training set data samples. + - 20000 # Number of training set data samples. - 0 # Number of test set data samples. Currently, configuration is not supported. - 0 # Number of eval set data samples. Currently, configuration is not supported. config: # GPTDataset Configs @@ -108,7 +108,7 @@ train_dataset: &train_dataset pad: -1 # The token id of `pad` in the dataset. data_path: # Megatron dataset sampling ratio and path. - '1' - - "/home/w00932055/dsv4/deepseek-datasets/mmap_deepseekv3_datasets_text_document" + - "/home/w00932055/deepseek-datasets/fineweb_edu_10BT_text_document" input_columns: ["input_ids", "labels", "loss_mask", "position_ids"] construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids"] num_parallel_workers: 8 @@ -117,13 +117,14 @@ train_dataset: &train_dataset numa_enable: False prefetch_size: 1 seed: 42 + train_dataset_task: type: CausalLanguageModelDataset dataset_config: *train_dataset # mindspore context init config context: - mode: 0 # 0--Graph Mode; 1--Pynative Mode + mode: 1 # 0--Graph Mode; 1--Pynative Mode device_target: "Ascend" max_call_depth: 10000 max_device_memory: "58GB" @@ -177,13 +178,14 @@ model: seq_length: 4096 hidden_size: 512 intermediate_size: 3072 - num_hidden_layers: 8 + num_hidden_layers: 12 max_position_embeddings: 163840 hidden_act: 'silu' # 'fusedswiglu' - num_attention_heads: 4 + num_attention_heads: 12 rms_norm_eps: 1.e-6 add_bias_linear: False use_flash_attention: True + # MLA 配置 multi_latent_attention: True mla_qkv_concat: True kv_lora_rank: 512 @@ -198,7 +200,7 @@ model: # # DSA (DeepSeek Sparse Attention) Configuration # experimental_attention_variant: 'dsa' # dsa_indexer_n_heads: 4 # Same as num_attention_heads - # dsa_indexer_head_dim: 192 # qk_rope_head_dim + qk_nope_head_dim + # dsa_indexer_head_dim: 80 # should greater than qk_rope_head_dim # dsa_indexer_topk: 256 # dsa_indexer_loss_coeff: 0.001 # dsa_indexer_use_sparse_loss: False diff --git a/ds_pynative_pp.yaml b/ds_pynative_pp.yaml new file mode 100755 index 0000000000000000000000000000000000000000..8a72fe7b594def373bb72cdd5be11a2c0e9a47bd --- /dev/null +++ b/ds_pynative_pp.yaml @@ -0,0 +1,287 @@ +# add for pynative +# 可复用变量定义 +vars: + pp_stage: &pp_stage 2 + +lr_scheduler: + type: ConstantWarmUpLR + learning_rate: 1.e-5 + warmup_steps: 0 + total_steps: -1 # -1 means it will load the total steps of the dataset + +checkpoint: + save_path: "checkpoints/ds3_pp_checkpoint" + save_max: 1 + save_interleaved_steps: 100000 # 500 + no_save_optim: False + async_save: False + prefix: "deepseekv3_pp" + remove_redundancy: False + load_balanced: False + no_load_optim: True + load_worker_number: 1 + +parallelism: + tensor_parallel: 1 + pipeline_parallel: *pp_stage # PP stage 数量,需要能整除 num_hidden_layers + context_parallel: 1 + data_parallel: 1 # 每个 stage 内的数据并行度 + virtual_pipeline_parallel: 2 # VPP 虚拟 stage 数量(仅 VPP 模式使用) + # HSDP 配置 + hsdp: + shard_size: 2 + shard_dim: 0 + level: "level1" + enable_hierarchical_allgather: True + +training: + max_steps: 20000 + global_batch_size: 8 # 全局 batch size + local_batch_size: 1 # 每个设备的 batch size + save_steps: 2000 + +# original + + +seed: 1234 +output_dir: './output_pynative_pp' # path to save checkpoint and strategy +load_checkpoint: '' +load_ckpt_format: 'safetensors' # format of checkpoint files to load +src_strategy_path_or_dir: '' +auto_trans_ckpt: True # If True, auto transform `load_checkpoint` to load in distributed model +only_save_strategy: False +resume_training: False +use_parallel: True +run_mode: 'train' +print_separate_loss: True + +# trainer config +trainer: + type: CausalLanguageModelingTrainer + model_name: 'deepseekV3' + +# runner config +runner_config: + epochs: 1 + batch_size: 1 + sink_mode: True + sink_size: 1 + +# optimizer +optimizer: + type: AdamW + betas: [0.9, 0.95] + eps: 1.e-8 + weight_decay: 0.0 + +# Muon optimizer configuration example (uncomment to use): +# optimizer: +# type: Muon +# learning_rate: 2.e-2 +# weight_decay: 0.1 +# matched_adamw_rms: 0.2 +# momentum: 0.95 +# nesterov: True +# ns_steps: 5 +# adamw_betas: [0.95, 0.95] +# adamw_eps: 1.e-8 +# qk_clip_threshold: 100 + +# lr schedule +lr_schedule: + type: ConstantWarmUpLR + learning_rate: 1.e-5 + warmup_steps: 0 + total_steps: -1 # -1 means it will load the total steps of the dataset + +# Dataset configuration +train_dataset: &train_dataset + data_loader: + type: BlendedMegatronDatasetDataLoader + datasets_type: "GPTDataset" + sizes: + - 20000 # Number of training set data samples. + - 0 # Number of test set data samples. Currently, configuration is not supported. + - 0 # Number of eval set data samples. Currently, configuration is not supported. + config: # GPTDataset Configs + seed: 42 # Data sampling random seed + split: "1, 0, 0" # The usage ratio of training, testing and evaluation sets. Currently, configuration is not supported. + seq_length: 4096 # The sequence length of the data set returned. + eod_mask_loss: False # Whether to calculate loss at eod. + reset_position_ids: False # Whether to reset position_ids at eod. + create_attention_mask: False # Whether to return `attention_mask`. + reset_attention_mask: False # Whether to reset the `attention_mask` at eod and return a stepped `attention_mask`. + create_compressed_eod_mask: False # Whether to return the compressed `attention_mask`. + eod_pad_length: 128 # Set the length of the compressed `attention_mask`. + eod: 1 # The token id of `eod` in the dataset. + pad: -1 # The token id of `pad` in the dataset. + data_path: # Megatron dataset sampling ratio and path. + - '1' + - "/home/w00932055/deepseek-datasets/fineweb_edu_10BT_text_document" + input_columns: ["input_ids", "labels", "loss_mask", "position_ids"] + construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids"] + num_parallel_workers: 8 + python_multiprocessing: False + drop_remainder: True + numa_enable: False + prefetch_size: 1 + seed: 42 + +train_dataset_task: + type: CausalLanguageModelDataset + dataset_config: *train_dataset + +# mindspore context init config +context: + mode: 1 # 0--Graph Mode; 1--Pynative Mode + device_target: "Ascend" + max_call_depth: 10000 + max_device_memory: "58GB" + save_graphs: False + save_graphs_path: "./graph" + jit_config: + jit_level: "O0" + +# parallel config +parallel_config: + data_parallel: 1 + model_parallel: 1 + pipeline_stage: *pp_stage # PP stage 数量 + expert_parallel: 1 + micro_batch_num: 8 # micro batch 数量,建议 >= pipeline_stage * 2 + use_seq_parallel: False + gradient_aggregation_group: 4 +# when model parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process. +micro_batch_interleave_num: 1 + +# parallel context config +parallel: + parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel + gradients_mean: False + enable_alltoall: True + search_mode: "sharding_propagation" + enable_parallel_optimizer: False + strategy_ckpt_config: + save_file: "./ckpt_strategy_pp.ckpt" + only_trainable_params: False + parallel_optimizer_config: + gradient_accumulation_shard: False + parallel_optimizer_threshold: 8 + +# recompute config +recompute_config: + recompute: True + select_recompute: False + mp_comm_recompute: False + +# model config +use_legacy: False +model: + model_config: + topk_group: 0 + n_group: 0 + model_type: deepseek_v3 + architectures: DeepseekV3ForCausalLM + offset: 0 + vocab_size: 129280 + seq_length: 4096 + hidden_size: 512 + intermediate_size: 3072 + num_hidden_layers: 12 # 需要能被 pipeline_parallel 整除 + max_position_embeddings: 163840 + hidden_act: 'silu' # 'fusedswiglu' + num_attention_heads: 12 + rms_norm_eps: 1.e-6 + add_bias_linear: False + use_flash_attention: True + # MLA 配置 + multi_latent_attention: True + mla_qkv_concat: True + kv_lora_rank: 512 + q_lora_rank: 1536 + qk_rope_head_dim: 64 + v_head_dim: 192 + qk_nope_head_dim: 128 + qk_layernorm: True + # QK-clip scaling for Muon optimizer (tracks max attention logit per head) + # Enable this when using Muon optimizer with qk_clip_threshold + # track_max_attention_logit: True + # # DSA (DeepSeek Sparse Attention) Configuration + # experimental_attention_variant: 'dsa' + # dsa_indexer_n_heads: 4 # Same as num_attention_heads + # dsa_indexer_head_dim: 80 # should greater than qk_rope_head_dim + # dsa_indexer_topk: 256 + # dsa_indexer_loss_coeff: 0.001 + # dsa_indexer_use_sparse_loss: False + # dsa_use_fused_ops: True # Use fused DSA operators (lightning_indexer + sparse_flash_attention) + attention_dropout: 0.0 + hidden_dropout: 0.0 + params_dtype: "float32" + compute_dtype: "bfloat16" + layernorm_compute_dtype: "float32" + softmax_compute_dtype: "float32" + rotary_dtype: "float32" + initializer_range: 0.01 + num_nextn_predict_layers: 0 # 1 + mtp_loss_scaling_factor: 0.3 + position_embedding_type: "yarn" + scaling_factor: 40 + beta_fast: 32 + beta_slow: 1 + mscale: 1 + mscale_all_dim: 1 + rope_theta: 10000 + # mhc + mhc_residual: False + mhc_expansion_rate: 4 + mhc_sk_iter: 20 + mhc_gating_factor_init: 0.01 + # moe config + hash_routed_layer: 0 + hash_size: 3 # feasible when hash_router_layer > 0 + router_dense_type: "float32" + gated_linear_unit: True + moe_intermediate_size: 2048 + routed_scaling_factor: 1.5 + first_k_dense_replace: 1 + n_routed_experts: 16 + num_experts_per_tok: 8 + n_shared_experts: 1 + num_copy_experts: 0 + use_topk_router_with_load_balancing: False + moe_expected_ffn_experts: 2.0 # Default best value: top-k * FFN/(FFN + COPY) + moe_router_bias_update_rate: 0.001 + moe_shared_expert_intermediate_size: 2048 + moe_grouped_gemm: True + moe_router_load_balancing_type: 'seq_aux_loss' + moe_aux_loss_coeff: 0.001 + scoring_func: 'sigmoid' + norm_topk_prob: True + moe_token_drop_policy: probs + moe_router_enable_expert_bias: True + +# callbacks +callbacks: + - type: MFLossMonitor + # balance topk bias with callback + # - type: TopkBiasBalanceCallback + # - type: CheckpointMonitor + # prefix: "deepseekv3" + # save_checkpoint_steps: 100000 + # keep_checkpoint_max: 1 + # integrated_save: False + # async_save: False + # checkpoint_format: "safetensors" # format of checkpoint files to save + +# wrapper cell config +runner_wrapper: + type: MFTrainOneStepCell + scale_sense: 1.0 + use_clip_grad: True + +profile: False +profile_start_step: 5 +profile_stop_step: 7 +init_start_profile: False +profile_communication: False +profile_memory: True diff --git a/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py b/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py index ebf66c4776c6cdd81586549368c11d0164de795f..1ea0562708036c5c4e84d523920f31622d33af16 100644 --- a/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py +++ b/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py @@ -28,7 +28,16 @@ from .configuration_deepseek_v3 import DeepseekV3Config class DeepseekV3ForCausalLMPyNative(TrainModelMixin, DeepseekV3PreTrainedModel): """DeepseekV3 model for training""" - def __init__(self, config: DeepseekV3Config): + def __init__( + self, + config: DeepseekV3Config, + pre_process: bool = True, + post_process: bool = True, + pp_rank: int = 0, + pp_size: int = 1, + vp_stage: int = None, + num_vp_stages: int = None, + ): super().__init__(config, auto_prefix=False) transformer_config = self.convert_to_transformer_config(config, is_mla_model=True) if transformer_config.num_moe_experts: @@ -43,11 +52,17 @@ class DeepseekV3ForCausalLMPyNative(TrainModelMixin, DeepseekV3PreTrainedModel): transformer_layer_spec=transformer_layer_spec, vocab_size=transformer_config.vocab_size, max_sequence_length=transformer_config.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, position_embedding_type=transformer_config.position_embedding_type, rotary_percent=1.0, rotary_base=transformer_config.rotary_base, rope_scaling=False, - mtp_block_spec=mtp_block_spec + mtp_block_spec=mtp_block_spec, + pp_rank=pp_rank, + pp_size=pp_size, + vp_stage=vp_stage, + num_vp_stages=num_vp_stages, ) def construct( diff --git a/mindformers/pynative/base_models/gpt/gpt_model.py b/mindformers/pynative/base_models/gpt/gpt_model.py index 1b1ac69cd04d273c4f4e30d3a543bf67a65d25ce..26a63ef3f0e5f20b1d5f6940897132a72c2a47a7 100644 --- a/mindformers/pynative/base_models/gpt/gpt_model.py +++ b/mindformers/pynative/base_models/gpt/gpt_model.py @@ -87,6 +87,10 @@ class GPTModel(nn.Cell): rope_scaling_factor: float = 8.0, seq_len_interpolation_factor: Optional[float] = None, mtp_block_spec: ModuleSpec = None, + pp_rank: int = 0, + pp_size: int = 1, + vp_stage: int = None, + num_vp_stages: int = None, ): super().__init__() @@ -99,6 +103,12 @@ class GPTModel(nn.Cell): self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.use_attn_mask_compression = config.use_attn_mask_compression or config.use_eod_attn_mask_compression + # Pipeline parallelism parameters + self.pp_rank = pp_rank + self.pp_size = pp_size + self.vp_stage = vp_stage + self.num_vp_stages = num_vp_stages + if hasattr(self.config, 'position_embedding_type'): # By default, use the position_embedding_type configuration in TransformerConfig. self.position_embedding_type = self.config.position_embedding_type @@ -190,8 +200,10 @@ class GPTModel(nn.Cell): self.decoder = TransformerBlock( config=self.config, spec=transformer_layer_spec, - # The corresponding Megatron v0.12.0 module's forward pass has this logic disabled by default, - # so it won't cause significant impact. + pp_rank=self.pp_rank, + pp_size=self.pp_size, + vp_stage=self.vp_stage, + num_vp_stages=self.num_vp_stages, ) # Output @@ -250,7 +262,7 @@ class GPTModel(nn.Cell): """ if not self.config.use_eod_reset: position_ids = None - elif position_ids is None: + elif position_ids is None and self.pre_process: raise ValueError("When use eod_reset, position_ids should not be None.") if actual_seq_len is not None: actual_seq_len = self.reshape(actual_seq_len, (-1,)) @@ -259,8 +271,19 @@ class GPTModel(nn.Cell): # which indicates the partial seq_lens of eod sequences for compression mask. # Check mindformers.dataset.blended_datasets.gpt_dataset._get_eod_attention_mask() for implement details. - labels, attention_mask, loss_mask = self._preprocess_input_labels_and_masks( - input_ids, labels, attention_mask, loss_mask) + if self.pre_process: + labels, attention_mask, loss_mask = self._preprocess_input_labels_and_masks( + input_ids, labels, attention_mask, loss_mask) + else: + # Non-first stages: generate attention mask from decoder_input shape + if attention_mask is None: + if self.use_attn_mask_compression: + attention_mask = self.casual_mask() + elif decoder_input is not None: + bs = decoder_input.shape[1] if decoder_input.ndim == 3 else decoder_input.shape[0] + seq_len = decoder_input.shape[0] if decoder_input.ndim == 3 else decoder_input.shape[1] + dummy_tokens = mint.ones((bs, seq_len), dtype=dtype.int32) + attention_mask = self.casual_mask(dummy_tokens) hidden_states, _, aux_loss = self.language_model( input_ids, @@ -324,7 +347,12 @@ class GPTModel(nn.Cell): scenarios such as prefix tuning. The default value is None. actual_seq_len (Tensor, optional): Actual sequence length tensor. Default is None. """ - bs, seq_len = input_ids.shape + # Derive seq_len: from input_ids on first stage, from decoder_input on middle stages + if input_ids is not None: + bs, seq_len = input_ids.shape + else: + seq_len = decoder_input.shape[0] if decoder_input.ndim == 3 else decoder_input.shape[1] + # Encoder embedding if decoder_input is not None: pass @@ -346,7 +374,10 @@ class GPTModel(nn.Cell): f"And attn_mask.ndim should be 4, but got {attn_mask.ndim}") # prefix_key_values shape num_layers*(2, B, prefix_len, kv_num*kv_channel) - bs, seq_len = input_ids.shape + if input_ids is not None: + bs, seq_len = input_ids.shape + else: + bs = decoder_input.shape[1] if decoder_input.ndim == 3 else decoder_input.shape[0] prefix_length = prefix_keys_values[0].shape[2] prefix_mask = self.zeros((bs, 1, seq_len, prefix_length), attn_mask.dtype) # (B, 1, S, S) -> (B, 1, S, S+prefix_len) @@ -401,13 +432,15 @@ class GPTModel(nn.Cell): loss_mask: Tensor = None): """Preprocess input_ids and generate labels and masks if they are None. """ - if loss_mask is None: - loss_mask = self.cast(self.not_equal(input_ids, self.pad_token_id), dtype.float32) - label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), dtype.float32) - loss_mask = self.mul(loss_mask, label_mask) + if input_ids is not None: + if loss_mask is None: + loss_mask = self.cast(self.not_equal(input_ids, self.pad_token_id), dtype.float32) + if labels is not None: + label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), dtype.float32) + loss_mask = self.mul(loss_mask, label_mask) if self.use_attn_mask_compression: attention_mask = self.casual_mask() - elif attention_mask is None: + elif attention_mask is None and input_ids is not None: attention_mask = self.casual_mask(input_ids) return labels, attention_mask, loss_mask diff --git a/mindformers/pynative/transformers/transformer_block.py b/mindformers/pynative/transformers/transformer_block.py index 9de09968dc9e0318445d04cee0ef8ac265972fef..6709ca5891a9f6d74651e9c8961aafa2c2a4e3ad 100644 --- a/mindformers/pynative/transformers/transformer_block.py +++ b/mindformers/pynative/transformers/transformer_block.py @@ -96,13 +96,37 @@ class TransformerBlock(nn.Cell): config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec], post_layer_norm: bool = True, + pp_rank: int = 0, + pp_size: int = 1, + vp_stage: int = None, + num_vp_stages: int = None, ): super().__init__() self.config = config self.submodules = _get_block_submodules(config, spec) self.post_layer_norm = post_layer_norm - self.num_layers = config.num_layers + + # Pipeline parallelism parameters + self.pp_rank = pp_rank + self.pp_size = pp_size + self.vp_stage = vp_stage if vp_stage is not None else 0 + self.num_vp_stages = num_vp_stages if num_vp_stages is not None else 1 + + # Compute layer count and offset for this PP stage + if pp_size <= 1: + self.num_layers = config.num_layers + self._layer_offset = 0 + else: + total_virtual_stages = pp_size * self.num_vp_stages + virtual_stage_id = pp_rank + self.vp_stage * pp_size + layers_per_stage = config.num_layers // total_virtual_stages + self._layer_offset = virtual_stage_id * layers_per_stage + # Last virtual stage gets any remaining layers + if virtual_stage_id == total_virtual_stages - 1: + self.num_layers = config.num_layers - self._layer_offset + else: + self.num_layers = layers_per_stage cp = config.context_parallel_size if config.context_parallel_size is not None else 1 if config.sequence_parallel and cp > 1: logger.warning("The context parallel way conflicts with sequence parallel way. " @@ -111,15 +135,29 @@ class TransformerBlock(nn.Cell): self._build_layers(config) + def _is_last_stage(self): + """Check if this is the last virtual pipeline stage.""" + if self.pp_size <= 1: + return True + total_virtual_stages = self.pp_size * self.num_vp_stages + virtual_stage_id = self.pp_rank + self.vp_stage * self.pp_size + return virtual_stage_id == total_virtual_stages - 1 + def _build_layers(self, config: TransformerConfig): - """build transformer layers.""" - # Transformer layers. + """Build transformer layers for this PP stage only.""" + # Transformer layers - build only layers for this stage self.layers = nn.CellList() - for layer_id in range(config.num_layers): - layer = build_module(self.submodules.layer_specs[layer_id], config=config, layer_number=layer_id) + for i in range(self.num_layers): + global_layer_id = self._layer_offset + i + layer = build_module( + self.submodules.layer_specs[global_layer_id], + config=config, + layer_number=global_layer_id, + ) self.layers.append(layer) - if self.post_layer_norm: + # Only build final_layernorm on the last stage + if self.post_layer_norm and self._is_last_stage(): self.final_layernorm = build_module(self.submodules.layer_norm, dim=config.hidden_size, eps=config.layernorm_epsilon, @@ -172,7 +210,7 @@ class TransformerBlock(nn.Cell): aux_loss = layer_aux_loss if aux_loss is None else aux_loss + layer_aux_loss # final layernorm. - if self.post_layer_norm: + if self.final_layernorm is not None: hidden_states = self.final_layernorm(hidden_states) return hidden_states, aux_loss diff --git a/run_pynative_pp.py b/run_pynative_pp.py new file mode 100755 index 0000000000000000000000000000000000000000..d5697402289adc3fef73ebf12a51f3cfff8ff6db --- /dev/null +++ b/run_pynative_pp.py @@ -0,0 +1,479 @@ +# Copyright 2026 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Pipeline Parallel Training Script for MindFormers PyNative Mode. + +This script implements Virtual Pipeline Parallel (VPP) training following +the pattern from hyper-parallel/tests/mindspore/st/pipeline_parallel/vpp_schedule.py. + +Key features: +- Each rank holds multiple non-consecutive virtual stages +- Uses ScheduleInterleaved1F1B for efficient pipeline scheduling +- Supports HSDP + SHARD + PP combination +""" + +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).parent.parent / 'hyper-parallel')) + +import numpy as np +import mindspore as ms +from mindspore import nn, Tensor +from mindspore.communication.management import init, get_rank, get_group_size + +from hyper_parallel import PipelineStage, ScheduleInterleaved1F1B +from hyper_parallel import DTensor, init_device_mesh +from hyper_parallel.core.hsdp import hsdp +from hyper_parallel import shard_module +from hyper_parallel.core.placement_types import Shard, Replicate +from hyper_parallel.core.shard.sharding_plan import ShardingPlan + +from mindspore import mint +from mindspore.ops import clip_by_global_norm, AllReduce + +from mindformers.tools.register import MindFormerConfig +from mindformers.models.build_config import get_model_config +from mindformers.models.deepseek3.modeling_deepseek_v3_pynative import DeepseekV3ForCausalLMPyNative +from mindformers.dataset.build_dataset import build_dataset +from mindformers.tools.logger import logger +from mindformers.core import build_lr, build_optim +from mindformers.pynative.callback.callback import CallbackHandler +from mindformers.pynative.callback.loss_callback import LossCallback +from mindformers.pynative.trainer.train_state import TrainerState + + +class PPStageAdapter(nn.Cell): + """ + Thin adapter that maps pipeline schedule positional args to GPTModel keyword args. + + Pipeline data flow: + - First stage: receives (input_ids, labels, loss_mask, input_ids), + returns (hidden_states, labels, loss_mask, input_ids) + - Middle stages: receives (hidden_states, labels, loss_mask, input_ids), + returns (hidden_states, labels, loss_mask, input_ids) + - Last stage: receives (hidden_states, labels, loss_mask, input_ids), returns loss + """ + + def __init__(self, gpt_model, is_first_stage: bool, is_last_stage: bool): + super().__init__() + self.gpt_model = gpt_model + self.is_first_stage = is_first_stage + self.is_last_stage = is_last_stage + + def _log_tensor(self, name, t): + """Log shape and dtype of a tensor for debugging.""" + if t is None: + logger.debug(f" {name}: None") + elif isinstance(t, (tuple, list)): + for i, item in enumerate(t): + self._log_tensor(f"{name}[{i}]", item) + else: + logger.debug(f" {name}: shape={t.shape}, dtype={t.dtype}") + + def construct(self, hidden_states_or_input_ids, labels=None, loss_mask=None, input_ids=None): + # Log inputs + logger.debug(f"PPStageAdapter input (first={self.is_first_stage}, last={self.is_last_stage}):") + self._log_tensor("hidden_states_or_input_ids", hidden_states_or_input_ids) + self._log_tensor("labels", labels) + self._log_tensor("loss_mask", loss_mask) + self._log_tensor("input_ids", input_ids) + if self.is_first_stage: + input_ids = hidden_states_or_input_ids + output = self.gpt_model( + input_ids=input_ids, labels=labels, loss_mask=loss_mask, + ) + else: + output = self.gpt_model( + input_ids=input_ids, decoder_input=hidden_states_or_input_ids, + labels=labels, loss_mask=loss_mask, + ) + + # Log output + logger.debug(f"PPStageAdapter output (first={self.is_first_stage}, last={self.is_last_stage}):") + self._log_tensor("output", output) + + if self.is_last_stage: + # post_process=True: GPTModel returns (loss, logits, hidden_states) + loss = output[0] + logger.debug(f" loss: shape={loss.shape}, dtype={loss.dtype}") + return loss + + # post_process=False: GPTModel returns hidden_states + return output, labels, loss_mask, input_ids + + +def run_pp_training(config_path: str = 'ds_pynative_pp.yaml'): + """ + Run Virtual Pipeline Parallel (VPP) training. + + Following the reference vpp_schedule.py pattern: + - num_pp_stages physical ranks for pipeline parallel + - Each rank holds num_virtual_stages virtual stages + - Total virtual stages = num_pp_stages * num_virtual_stages + - rank i holds stages: [stage_index, stage_index + num_pp_stages, ...] + + Example with num_pp_stages=4, num_virtual_stages=2: + rank 0 holds: stage 0, stage 4 + rank 1 holds: stage 1, stage 5 + rank 2 holds: stage 2, stage 6 + rank 3 holds: stage 3, stage 7 + + Args: + config_path: Path to the configuration yaml file + """ + # Initialize + ms.set_seed(0) + init("hccl") + + rank_id = get_rank() + device_num = get_group_size() + + logger.info("=" * 60) + logger.info(f"[Rank {rank_id}] Current rank_id: {rank_id}, " + f"total ranks (device_num): {device_num}") + logger.info("=" * 60) + + # Load config + mf_config = MindFormerConfig(config_path) + logger.info(f"[Rank {rank_id}] Loaded config from: {config_path}") + + # PP config from yaml + parallelism = getattr(mf_config, 'parallelism', None) + num_pp_stages = getattr(parallelism, 'pipeline_parallel', 4) + num_virtual_stages = getattr(parallelism, 'virtual_pipeline_parallel', 2) + micro_batch_num = getattr(mf_config.parallel_config, 'micro_batch_num', 4) + dp = getattr(parallelism, 'data_parallel', 1) + mp = getattr(parallelism, 'tensor_parallel', 1) + + total_num_stages = num_pp_stages * num_virtual_stages + device_num_per_stage = device_num // num_pp_stages + stage_index = rank_id // device_num_per_stage + + logger.info(f"[Rank {rank_id}] PP Config:") + logger.info(f"[Rank {rank_id}] num_pp_stages = {num_pp_stages}") + logger.info(f"[Rank {rank_id}] num_virtual_stages = {num_virtual_stages}") + logger.info(f"[Rank {rank_id}] total_num_stages = {total_num_stages}") + logger.info(f"[Rank {rank_id}] micro_batch_num = {micro_batch_num}") + logger.info(f"[Rank {rank_id}] device_num_per_stage = {device_num_per_stage}") + logger.info(f"[Rank {rank_id}] stage_index = {stage_index}") + + # Setup device mesh for data parallel within stage + mesh = init_device_mesh(mesh_shape=(dp, mp), alias_name=("dp", "mp")) + + # Define placements based on dp + # When dp=1, use Replicate; when dp>1, use Shard for data parallel + if dp > 1: + in_placements = (Shard(0), Replicate()) # Shard on batch dim + out_placements = (Shard(0), Replicate()) + else: + in_placements = (Replicate(), Replicate()) # No sharding + out_placements = (Replicate(), Replicate()) + w_placements = (Replicate(), Replicate()) # Weights always replicated + + # Input plan for 4 inputs: (hidden_states_or_input_ids, labels, loss_mask, input_ids) + model_stra = ShardingPlan( + plan={"weight": w_placements}, + input_plan={"input": [in_placements, in_placements, in_placements, in_placements]}, + output_plan={"output": out_placements}, + ) + + # Create stage models for this rank + # Each rank holds num_virtual_stages models + # IMPORTANT: Each stage needs its own model instance to avoid shared parameter issues + stage_models = [] + pipeline_stages = [] + + # Build model config once (shared across virtual stages) + model_config = get_model_config(mf_config.model) + num_layers = model_config.num_hidden_layers + + for v in range(num_virtual_stages): + virtual_stage_id = stage_index + v * num_pp_stages + is_first = (virtual_stage_id == 0) + is_last = (virtual_stage_id == total_num_stages - 1) + + # Build model with only the layers needed for this stage + model = DeepseekV3ForCausalLMPyNative( + config=model_config, + pre_process=is_first, + post_process=is_last, + pp_rank=stage_index, + pp_size=num_pp_stages, + vp_stage=v, + num_vp_stages=num_virtual_stages, + ) + model.set_train(True) + + # Log model info for first virtual stage + if v == 0: + logger.info(f"[Rank {rank_id}] Model info:") + logger.info(f"[Rank {rank_id}] model type = {type(model).__name__}") + logger.info(f"[Rank {rank_id}] total num_layers = {num_layers}") + mc = model.model.config + logger.info(f"[Rank {rank_id}] hidden_size = {getattr(mc, 'hidden_size', 'N/A')}") + logger.info(f"[Rank {rank_id}] num_heads = {getattr(mc, 'num_heads', 'N/A')}") + logger.info(f"[Rank {rank_id}] seq_length = {getattr(mc, 'seq_length', 'N/A')}") + logger.info(f"[Rank {rank_id}] vocab_size = {getattr(mc, 'vocab_size', 'N/A')}") + + # Log this virtual stage's layers + stage_num_layers = model.model.decoder.num_layers + layer_offset = model.model.decoder._layer_offset + stage_layer_names = [type(l).__name__ for l in model.model.decoder.layers] + logger.info( + f"[Rank {rank_id}] Virtual stage {virtual_stage_id}: " + f"layers [{layer_offset}, {layer_offset + stage_num_layers}), " + f"first={is_first}, last={is_last}, types={stage_layer_names}" + ) + + # Wrap with PPStageAdapter for pipeline schedule interface + adapter = PPStageAdapter(model.model, is_first, is_last) + + # Always apply sharding (required by PP framework) + # shard_module(adapter, device_mesh=mesh, sharding_plan=model_stra) + + # Always apply HSDP (provides zero_grads and gradient sync) + adapter = hsdp(adapter, dp, 0, "level1", True) + + stage_models.append(adapter) + pipeline_stages.append(PipelineStage(adapter, virtual_stage_id, total_num_stages)) + + # Print model structure and parameter details for each virtual stage + for v, stage_model in enumerate(stage_models): + virtual_stage_id = stage_index + v * num_pp_stages + logger.info(f"[Rank {rank_id}] ===== Virtual stage {virtual_stage_id} model structure =====") + logger.info(f"\n{stage_model}") + logger.info(f"[Rank {rank_id}] ===== Virtual stage {virtual_stage_id} parameter details =====") + for param in stage_model.get_parameters(): + logger.info( + f"[Rank {rank_id}] name={param.name}, " + f"shape={param.shape}, " + f"trainable={param.requires_grad}" + ) + + # Collect all trainable params across all virtual stages for this rank + all_trainable_params = [] + for stage_model in stage_models: + all_trainable_params.extend(stage_model.trainable_params()) + + # Log total parameter count for this rank (across all virtual stages) + rank_total_params = sum(p.size for p in all_trainable_params) + logger.info( + f"[Rank {rank_id}] Total trainable params on this rank: {rank_total_params:,}" + ) + + # Compute total model parameter count across all PP stages via AllReduce + local_params_tensor = Tensor([rank_total_params], dtype=ms.float32) + global_params_tensor = AllReduce()(local_params_tensor) + # Each PP stage has device_num_per_stage replicas holding the same params, + # so divide by device_num_per_stage to get the actual total unique params. + total_model_params = int(global_params_tensor.asnumpy()[0]) // device_num_per_stage + if rank_id == 0: + logger.info( + f"Total model parameters (across all PP stages): {total_model_params:,} " + f"({total_model_params / 1e9:.2f}B)" + ) + + # Create interleaved 1F1B schedule with all stages for this rank + schedule = ScheduleInterleaved1F1B(pipeline_stages, micro_batch_num) + + # Get training config + seq_length = mf_config.model.model_config.seq_length + local_batch_size = getattr(mf_config.training, 'local_batch_size', 1) + global_batch_size = getattr(mf_config.training, 'global_batch_size', 8) + max_steps = getattr(mf_config.training, 'max_steps', 100) + + logger.info(f"[Rank {rank_id}] Training config:") + logger.info(f"[Rank {rank_id}] seq_length = {seq_length}") + logger.info(f"[Rank {rank_id}] local_batch_size = {local_batch_size}") + logger.info(f"[Rank {rank_id}] global_batch_size = {global_batch_size}") + logger.info(f"[Rank {rank_id}] max_steps = {max_steps}") + logger.info(f"[Rank {rank_id}] dp={dp}, mp={mp}") + + # Create dataset following Trainer pattern + train_dataset_config = getattr(mf_config, 'train_dataset', None) + if train_dataset_config is not None: + dataset_config = MindFormerConfig() + dataset_config.type = 'CausalLanguageModelDataset' + dataset_config.dataset_config = train_dataset_config + dataset_config.dataset_config.batch_size = global_batch_size // dp + dataset_config.dataset_config.seed = getattr(mf_config, 'seed', 42) + dataset = build_dataset(dataset_config) + dataset_iter = dataset.create_tuple_iterator() + logger.info(f"Created dataset with batch_size={global_batch_size // dp}") + else: + dataset = None + dataset_iter = None + logger.warning("No dataset config found, using dummy data") + + # DTensor placements for input data + # When dp=1, use Replicate; when dp>1, use Shard for data parallel + if dp > 1: + x_placements = (Shard(0), Replicate()) # Shard on batch dim for dp + else: + x_placements = (Replicate(), Replicate()) # No sharding when dp=1 + + # Build LR scheduler + lr_config = getattr(mf_config, 'lr_scheduler', None) + if lr_config is None: + lr_config = getattr(mf_config, 'lr_schedule', None) + lr_config.total_steps = max_steps + lr = build_lr(lr_config) + + # Build parameter groups (weight decay grouping) + optimizer_config = getattr(mf_config, 'optimizer', None) + weight_decay = getattr(optimizer_config, 'weight_decay', 0.0) + + decay_params = [] + no_decay_params = [] + for param in all_trainable_params: + if len(param.shape) == 1 or param.name.endswith(".bias"): + no_decay_params.append(param) + else: + decay_params.append(param) + + param_groups = [ + {"weight_decay": weight_decay, "params": decay_params}, + {"weight_decay": 0.0, "params": no_decay_params}, + ] + + # Build optimizer + optimizer = build_optim( + optimizer_config, + default_args={"params": param_groups, "learning_rate": lr} + ) + + # Build callbacks (use pynative-compatible LossCallback instead of MFLossMonitor) + callback_list = [LossCallback()] + + callback_handler = CallbackHandler( + callbacks=callback_list, + model=stage_models[0], + train_dataset=dataset, + eval_dataset=None, + optimizer=optimizer, + lr_scheduler=lr, + ) + + # Initialize training state + state = TrainerState( + max_steps=max_steps, + save_steps=getattr(mf_config.training, 'save_steps', 2000), + global_batch_size=global_batch_size, + ) + + callback_handler.on_train_begin(mf_config, state) + + # Training loop + for step in range(max_steps): + callback_handler.on_step_begin(mf_config, state) + + # Get batch data from dataset + if dataset_iter is not None: + try: + batch = next(dataset_iter) + # Dataset returns: (input_ids, labels, loss_mask, position_ids) + input_ids = batch[0] + labels = batch[1] + loss_mask = batch[2] + # position_ids = batch[3] # Not used currently + except StopIteration: + # Reset iterator when exhausted + dataset_iter = dataset.create_tuple_iterator() + batch = next(dataset_iter) + input_ids = batch[0] + labels = batch[1] + loss_mask = batch[2] + + # Always convert to DTensor (required by PP framework) + # input_ids = DTensor.from_local(input_ids, mesh, x_placements) + # labels = DTensor.from_local(labels, mesh, x_placements) + # loss_mask = DTensor.from_local( + # loss_mask.astype(ms.float32), mesh, x_placements + # ) + else: + # Fallback to dummy data if no dataset + input_ids = Tensor(np.ones((local_batch_size, seq_length)), dtype=ms.int32) + # input_ids = DTensor.from_local(input_ids, mesh, x_placements) + labels = None + loss_mask = None + + # Zero gradients for all stage models + for stage_model in stage_models: + stage_model.zero_grads() + + # Run schedule - only first stage provides input + # Pass input_ids as 4th arg so it flows through all stages for HashRoutedMoELayer + if stage_index == 0: + loss = schedule.run(input_ids, labels, loss_mask, input_ids) + else: + loss = schedule.run() + + # Collect gradients in optimizer's internal parameter order + # (param_groups reorder params: decay group first, then no_decay group, + # which differs from the original model parameter order) + grads = [] + for param in optimizer.parameters: + if param.grad is not None: + grads.append(param.grad) + else: + grads.append(mint.zeros_like(param)) + + # Overflow detection + overflow = False + for grad in grads: + if grad.isinf().any() or grad.isnan().any(): + overflow = True + break + if overflow: + logger.warning(f"Step {step}: gradient overflow detected, skipping update") + state.global_step += 1 + continue + + # Compute grad norm (before clipping) + grad_norm = 0.0 + for grad in grads: + grad_norm += mint.sum(mint.square(grad)) + grad_norm = mint.sqrt(grad_norm) + + # Gradient clipping + grads = clip_by_global_norm(grads) + + # Optimizer step + optimizer(grads) + + # Loss logging (loss only available on last PP stage) + if stage_index == num_pp_stages - 1: + loss_value = loss[-1] if isinstance(loss, list) else loss + else: + loss_value = None + + # Update state and callbacks + state.global_step += 1 + callback_handler.on_step_end(mf_config, state, loss=loss_value, grad_norm=grad_norm) + + callback_handler.on_train_end(mf_config, state) + logger.info(f"[Rank {rank_id}] PP Training completed! Total steps: {state.global_step}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Pipeline Parallel Training") + parser.add_argument( + "--config", type=str, default="ds_pynative_pp.yaml", + help="Path to config yaml file" + ) + args = parser.parse_args() + + run_pp_training(args.config)