From 94152fea525f141eb95095a67cd5a00c11421a51 Mon Sep 17 00:00:00 2001 From: yanglong_unimelb Date: Tue, 17 Mar 2026 15:41:57 +0800 Subject: [PATCH] [bugfix] 3B hsdp & mHC fix --- ds_pynative.yaml | 30 ++++---- .../modeling_deepseek_v3_pynative.py | 2 + .../base_models/gpt/gpt_layer_specs.py | 16 ++-- mindformers/pynative/trainer/trainer.py | 51 +++++++------ .../transformers/transformer_block.py | 2 + .../transformers/transformer_layer.py | 75 ++++++++++--------- 6 files changed, 96 insertions(+), 80 deletions(-) diff --git a/ds_pynative.yaml b/ds_pynative.yaml index ed512c1ed..c68c7c0ff 100644 --- a/ds_pynative.yaml +++ b/ds_pynative.yaml @@ -24,9 +24,13 @@ parallelism: pipeline_parallel: 1 context_parallel: 1 data_parallel: 1 + hsdp: + shard_size: 4 + optimizer_level: "level3" + threshold: 0 training: - max_steps: 20000 + max_steps: 10000 global_batch_size: 1 local_batch_size: 1 save_steps: 2000 @@ -175,22 +179,22 @@ model: offset: 0 vocab_size: 129280 seq_length: 4096 - hidden_size: 512 - intermediate_size: 3072 - num_hidden_layers: 8 + hidden_size: 1280 + intermediate_size: 896 + num_hidden_layers: 12 max_position_embeddings: 163840 hidden_act: 'silu' # 'fusedswiglu' - num_attention_heads: 8 + num_attention_heads: 16 rms_norm_eps: 1.e-6 add_bias_linear: False use_flash_attention: True multi_latent_attention: True mla_qkv_concat: True kv_lora_rank: 512 - q_lora_rank: 1536 + q_lora_rank: 2048 qk_rope_head_dim: 64 - v_head_dim: 192 - qk_nope_head_dim: 128 + v_head_dim: 128 + qk_nope_head_dim: 64 qk_layernorm: True # QK-clip scaling for Muon optimizer (tracks max attention logit per head) # Enable this when using Muon optimizer with qk_clip_threshold @@ -243,17 +247,17 @@ model: hash_size: 3 # feasible when hash_router_layer > 0 router_dense_type: "float32" gated_linear_unit: True - moe_intermediate_size: 2048 + moe_intermediate_size: 896 routed_scaling_factor: 1.5 first_k_dense_replace: 1 - n_routed_experts: 16 # FFN + COPY - num_experts_per_tok: 8 - n_shared_experts: 1 + n_routed_experts: 64 # FFN + COPY + num_experts_per_tok: 6 + n_shared_experts: 2 num_copy_experts: 0 use_topk_router_with_load_balancing: False moe_expected_ffn_experts: 8.0 # Default best value: top-k * FFN/(FFN + COPY) moe_router_bias_update_rate: 0.001 - moe_shared_expert_intermediate_size: 2048 + moe_shared_expert_intermediate_size: 896 moe_grouped_gemm: True moe_router_load_balancing_type: 'seq_aux_loss' moe_aux_loss_coeff: 0.001 # 0.001 diff --git a/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py b/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py index ebf66c477..4026b5583 100644 --- a/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py +++ b/mindformers/models/deepseek3/modeling_deepseek_v3_pynative.py @@ -21,6 +21,7 @@ from mindformers.pynative.base_models.gpt.gpt_layer_specs import get_gpt_decoder get_gpt_layer_local_spec # get_gpt_mtp_block_spec, get_gpt_layer_local_spec from mindformers.parallel_core.utils.model_mixin import TrainModelMixin +from hyper_parallel import hsdp from .configuration_deepseek_v3 import DeepseekV3Config @@ -49,6 +50,7 @@ class DeepseekV3ForCausalLMPyNative(TrainModelMixin, DeepseekV3PreTrainedModel): rope_scaling=False, mtp_block_spec=mtp_block_spec ) + hsdp(self.model, shard_size=4, threshold=0, optimizer_level="level3") def construct( self, diff --git a/mindformers/pynative/base_models/gpt/gpt_layer_specs.py b/mindformers/pynative/base_models/gpt/gpt_layer_specs.py index 4f83f85f0..50edda3a0 100644 --- a/mindformers/pynative/base_models/gpt/gpt_layer_specs.py +++ b/mindformers/pynative/base_models/gpt/gpt_layer_specs.py @@ -153,12 +153,12 @@ def get_gpt_layer_local_spec( self_attention=self_attention, pre_mlp_layernorm=get_norm_cls(normalization, fused_norm), mlp=mlp, - # mhc_rmsnorm_mla=get_norm_cls("RMSNorm", fused_norm), + mhc_rmsnorm_mla=get_norm_cls("RMSNorm", fused_norm), mhc_rmsnorm_moe=get_norm_cls("RMSNorm", fused_norm), mhc_reduce_rmsnorm=get_norm_cls("RMSNorm", fused_norm), - # mhc_res_mapping_mla=get_mhc_module("res"), - # mhc_pre_mapping_mla=get_mhc_module("pre"), - # mhc_post_mapping_mla=get_mhc_module("post"), + mhc_res_mapping_mla=get_mhc_module("res"), + mhc_pre_mapping_mla=get_mhc_module("pre"), + mhc_post_mapping_mla=get_mhc_module("post"), mhc_res_mapping_moe=get_mhc_module("res"), mhc_pre_mapping_moe=get_mhc_module("pre"), mhc_post_mapping_moe=get_mhc_module("post"), @@ -181,12 +181,12 @@ def get_gpt_layer_local_spec( ), pre_mlp_layernorm=get_norm_cls(normalization, fused_norm), mlp=mlp, - # mhc_rmsnorm_mla=get_norm_cls("RMSNorm", fused_norm), + mhc_rmsnorm_mla=get_norm_cls("RMSNorm", fused_norm), mhc_rmsnorm_moe=get_norm_cls("RMSNorm", fused_norm), mhc_reduce_rmsnorm=get_norm_cls("RMSNorm", fused_norm), - # mhc_res_mapping_mla=get_mhc_module("res"), - # mhc_pre_mapping_mla=get_mhc_module("pre"), - # mhc_post_mapping_mla=get_mhc_module("post"), + mhc_res_mapping_mla=get_mhc_module("res"), + mhc_pre_mapping_mla=get_mhc_module("pre"), + mhc_post_mapping_mla=get_mhc_module("post"), mhc_res_mapping_moe=get_mhc_module("res"), mhc_pre_mapping_moe=get_mhc_module("pre"), mhc_post_mapping_moe=get_mhc_module("post"), diff --git a/mindformers/pynative/trainer/trainer.py b/mindformers/pynative/trainer/trainer.py index 274e968a9..9ff9e7b9a 100644 --- a/mindformers/pynative/trainer/trainer.py +++ b/mindformers/pynative/trainer/trainer.py @@ -127,6 +127,12 @@ class Trainer: self.model = self._create_model(model, getattr(self.config, "model_config", None)) self._compute_parameters() + # Initialize parallel config and wrappers + if self.use_parallel: + self.model = self._init_parallel_config( + self.model, getattr(self.config, 'parallelism', None) + ) + # Create datasets self.train_dataset = self._create_dataset( train_dataset, getattr(self.config, "train_dataset", None) @@ -222,9 +228,9 @@ class Trainer: if model_config is None: raise ValueError("Either model instance or model_config must be provided.") - self.config.model = MindFormerConfig() - self.config.model.model_config = self.config.model_config - del self.config.model_config + src_model_config = model_config + model_config = MindFormerConfig() + model_config.model_config = src_model_config logger.info("Building model from config...") model = build_model(model_config) @@ -294,7 +300,8 @@ class Trainer: data_parallel * self.config.training.local_batch_size ) logger.info( - f"Calculate gradient_accumulation_steps={self.gradient_accumulation_steps}" + f"Calculate global_batch_size={self.global_batch_size}, " + f"gradient_accumulation_steps={self.gradient_accumulation_steps}." ) def _create_dataset(self, dataset, dataset_config: Optional[Dict]) -> Optional[Any]: @@ -445,12 +452,6 @@ class Trainer: if checkpoint_path is None: checkpoint_path = getattr(self.config.checkpoint, "load_path", None) - # Initialize parallel config and wrappers - if self.use_parallel: - self.model = self._init_parallel_config( - self.model, getattr(self.config, 'parallelism', None) - ) - # Initialize training state self.state = TrainerState( max_steps=getattr(self.config.training, "max_steps", 1000), @@ -490,7 +491,7 @@ class Trainer: return model logger.info("Initializing parallel config...") - return get_hsdp_model(self.model, parallelism.hsdp) + return get_hsdp_model(self.model, parallelism.hsdp, data_parallel=self._get_data_parallel()) def _load_checkpoint(self, checkpoint_path: str): """ @@ -586,8 +587,12 @@ class Trainer: self.compute_loss, None, self.optimizer.parameters, - has_aux=True + has_aux=False ) + + # TODO: IF remove + if self.compute_loss_func is None: + self.model.set_train(True) # Training loop step = self.state.global_step @@ -718,7 +723,8 @@ class Trainer: # Assume first element is loss loss = outputs[0] - return loss, outputs + # return value must be Tensor or DTensor + return loss def training_step(self, model, inputs: Dict[str, Any]): """ @@ -731,16 +737,15 @@ class Trainer: Returns: Loss value """ - loss_sync_allreduce = AllReduce(ReduceOp.SUM) - # Forward and compute loss - (loss, _), grads = self.grad_fn(model, inputs) + loss, grads = self.grad_fn(model, inputs) if self.use_parallel: - print(f"rank {get_rank()} local loss: {loss}") + loss_sync_allreduce = AllReduce(ReduceOp.SUM) + print(f"rank {get_rank()} local loss_: {loss}") loss = loss_sync_allreduce(loss) # loss = all_reduce(loss) # TODO: loss / self.dp_size - loss = loss / self.world_size + loss /= self.world_size # Check Overflow overflow = False # all_finite(grads) @@ -751,6 +756,7 @@ class Trainer: break grad_norm += mint.sum(mint.square(grad)) grad_norm = mint.sqrt(grad_norm) + print(f"local norm_: {grad_norm}") if overflow: raise RuntimeError("train process overflow.") @@ -761,9 +767,6 @@ class Trainer: # function 2: legacy # grads, grad_norm = self.clip_grad(grads) - # Backward pass - # In real implementation with MindSpore: - # Optimizer step self.optimizer(grads) @@ -773,6 +776,10 @@ class Trainer: return self.config.parallelism.data_parallel -def get_hsdp_model(model, hsdp_args): +def get_hsdp_model(model, hsdp_args, data_parallel): + shard_size = hsdp_args.get('shard_size', 1) + if shard_size == 1: + grad_scale = 1 / data_parallel + hsdp_args['grad_scale'] = grad_scale model = hsdp(model, **hsdp_args) return model diff --git a/mindformers/pynative/transformers/transformer_block.py b/mindformers/pynative/transformers/transformer_block.py index 9de09968d..2bb3b127f 100644 --- a/mindformers/pynative/transformers/transformer_block.py +++ b/mindformers/pynative/transformers/transformer_block.py @@ -12,6 +12,7 @@ from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.parallel_core.utils.spec_utils import ModuleSpec, build_module from mindformers.pynative.transformers.transformer_layer import BaseTransformerLayer from mindformers.tools.logger import logger +from hyper_parallel import hsdp @dataclass @@ -117,6 +118,7 @@ class TransformerBlock(nn.Cell): self.layers = nn.CellList() for layer_id in range(config.num_layers): layer = build_module(self.submodules.layer_specs[layer_id], config=config, layer_number=layer_id) + hsdp(layer, shard_size=4, threshold=0, optimizer_level="level3") self.layers.append(layer) if self.post_layer_norm: diff --git a/mindformers/pynative/transformers/transformer_layer.py b/mindformers/pynative/transformers/transformer_layer.py index 40bb092ab..fb9229b34 100644 --- a/mindformers/pynative/transformers/transformer_layer.py +++ b/mindformers/pynative/transformers/transformer_layer.py @@ -44,12 +44,12 @@ class TransformerLayerSubmodules: pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp mlp: Union[ModuleSpec, type] = IdentityOp - # mhc_rmsnorm_mla: Union[ModuleSpec, type] = IdentityOp + mhc_rmsnorm_mla: Union[ModuleSpec, type] = IdentityOp mhc_rmsnorm_moe: Union[ModuleSpec, type] = IdentityOp mhc_reduce_rmsnorm: Union[ModuleSpec, type] = IdentityOp - # mhc_res_mapping_mla: Union[ModuleSpec, type] = IdentityOp - # mhc_pre_mapping_mla: Union[ModuleSpec, type] = IdentityOp - # mhc_post_mapping_mla: Union[ModuleSpec, type] = IdentityOp + mhc_res_mapping_mla: Union[ModuleSpec, type] = IdentityOp + mhc_pre_mapping_mla: Union[ModuleSpec, type] = IdentityOp + mhc_post_mapping_mla: Union[ModuleSpec, type] = IdentityOp mhc_res_mapping_moe: Union[ModuleSpec, type] = IdentityOp mhc_pre_mapping_moe: Union[ModuleSpec, type] = IdentityOp mhc_post_mapping_moe: Union[ModuleSpec, type] = IdentityOp @@ -150,11 +150,11 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): if self.apply_manifold_constrained_hyper_connections: # RMSNorm dimension should be hidden_size * expansion_rate because input is expanded - # self.mhc_rmsnorm_mla = build_module( - # submodules.mhc_rmsnorm_mla, - # dim=config.hidden_size, - # eps=config.layernorm_epsilon - # ) + self.mhc_rmsnorm_mla = build_module( + submodules.mhc_rmsnorm_mla, + dim=config.hidden_size, + eps=config.layernorm_epsilon + ) self.mhc_rmsnorm_moe = build_module( submodules.mhc_rmsnorm_moe, dim=config.hidden_size, @@ -165,9 +165,9 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): dim=config.hidden_size, eps=config.layernorm_epsilon ) - # self.mhc_res_mapping_mla = build_module(submodules.mhc_res_mapping_mla, config=self.config) - # self.mhc_pre_mapping_mla = build_module(submodules.mhc_pre_mapping_mla, config=self.config) - # self.mhc_post_mapping_mla = build_module(submodules.mhc_post_mapping_mla, config=self.config) + self.mhc_res_mapping_mla = build_module(submodules.mhc_res_mapping_mla, config=self.config) + self.mhc_pre_mapping_mla = build_module(submodules.mhc_pre_mapping_mla, config=self.config) + self.mhc_post_mapping_mla = build_module(submodules.mhc_post_mapping_mla, config=self.config) self.mhc_res_mapping_moe = build_module(submodules.mhc_res_mapping_moe, config=self.config) self.mhc_pre_mapping_moe = build_module(submodules.mhc_pre_mapping_moe, config=self.config) self.mhc_post_mapping_moe = build_module(submodules.mhc_post_mapping_moe, config=self.config) @@ -218,22 +218,23 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): def expand_hidden_states(hidden_states): hidden_states = hidden_states.unsqueeze(axis=2) return hidden_states.repeat(1, 1, self.mhc_expansion_rate, 1).astype(float32) - print(f"-------------- Transformer Layer {self.layer_number} ----------------") - print(f"@ Original hidden_states AMax: {ops.abs(hidden_states).max()} {ops.abs(hidden_states).mean()}") # Before layernorm - # if self.apply_manifold_constrained_hyper_connections: - # mhc_hidden_states = expand_hidden_states(hidden_states).astype(float32) # (S, B, H) -> (S, B, N, H) - # mhc_hidden_states = self.mhc_rmsnorm_mla(mhc_hidden_states) + if self.apply_manifold_constrained_hyper_connections: + if self.layer_number == 0: + mhc_hidden_states = expand_hidden_states(hidden_states) # (S, B, H) -> (S, B, N, H) + else: + mhc_hidden_states = hidden_states.astype(float32) # (S, B, N, H) + mhc_hidden_states = self.mhc_rmsnorm_mla(mhc_hidden_states) - # mhc_res = self.mhc_res_mapping_mla(mhc_hidden_states) # (S, B, N, N) - # mhc_pre = self.mhc_pre_mapping_mla(mhc_hidden_states) # (S, B, 1, N) - # mhc_post = self.mhc_post_mapping_mla(mhc_hidden_states) # (S, B, 1, N) + mhc_res = self.mhc_res_mapping_mla(mhc_hidden_states) # (S, B, N, N) + mhc_pre = self.mhc_pre_mapping_mla(mhc_hidden_states) # (S, B, 1, N) + mhc_post = self.mhc_post_mapping_mla(mhc_hidden_states) # (S, B, 1, N) - # mhc_res_hidden = self.bmm(mhc_res, mhc_hidden_states) # (S, B, N, N) @ (S, B, N, H) -> (S, B, N, H) - # mhc_pre_hidden = self.bmm(mhc_pre, mhc_hidden_states) # (S, B, 1, N) @ (S, B, N, H) -> (S, B, 1, H) - # mhc_post = mhc_post.squeeze(axis=2).unsqueeze(-1) # (S, B, 1, N) -> (S, B, N) -> (S, B, N, 1) - # hidden_states = mhc_pre_hidden.squeeze(axis=2) # (S, B, 1, H) -> (S, B, H) + mhc_res_hidden = self.bmm(mhc_res, mhc_hidden_states) # (S, B, N, N) @ (S, B, N, H) -> (S, B, N, H) + mhc_pre_hidden = self.bmm(mhc_pre, mhc_hidden_states) # (S, B, 1, N) @ (S, B, N, H) -> (S, B, 1, H) + mhc_post = mhc_post.squeeze(axis=2).unsqueeze(-1) # (S, B, 1, N) -> (S, B, N) -> (S, B, N, 1) + hidden_states = mhc_pre_hidden.squeeze(axis=2) # (S, B, 1, H) -> (S, B, H) # Layer norm at the beginning input_layernorm_output = self.input_layernorm(hidden_states) @@ -241,8 +242,8 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): # Residual connection if self.apply_residual_connection_post_norm: residual = input_layernorm_output - # elif self.apply_manifold_constrained_hyper_connections: - # residual = mhc_res_hidden + elif self.apply_manifold_constrained_hyper_connections: + residual = mhc_res_hidden else: residual = hidden_states @@ -268,20 +269,17 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): dropout_output = self.hidden_states_dropout(attention_output) # # Add residual - # if self.apply_manifold_constrained_hyper_connections: - # dropout_expand = dropout_output.unsqueeze(axis=2) # (S, B, H) -> (S, B, 1, H) - # mhc_post_hidden = self.bmm(mhc_post, dropout_expand) # (S, B, N, 1) @ (S, B, 1, H) -> (S, B, N, H) - # norm_input = self.add(residual, mhc_post_hidden) - # norm_input = norm_input.sum(dim=2) # (S, B, N, H) -> (S, B, H) - # else: - # norm_input = self.add(residual, dropout_output) - - norm_input = self.add(residual, dropout_output) + if self.apply_manifold_constrained_hyper_connections: + dropout_expand = dropout_output.unsqueeze(axis=2) # (S, B, H) -> (S, B, 1, H) + mhc_post_hidden = self.bmm(mhc_post, dropout_expand) # (S, B, N, 1) @ (S, B, 1, H) -> (S, B, N, H) + norm_input = self.add(residual, mhc_post_hidden) + else: + norm_input = self.add(residual, dropout_output) # Before layernorm if self.apply_manifold_constrained_hyper_connections: - mhc_hidden_states = expand_hidden_states(norm_input) + mhc_hidden_states = norm_input mhc_hidden_states = self.mhc_rmsnorm_moe(mhc_hidden_states) mhc_res = self.mhc_res_mapping_moe(mhc_hidden_states) # (S, B, N, N) @@ -326,7 +324,10 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): dropout_expand = dropout_output.unsqueeze(axis=2) # (S, B, H) -> (S, B, 1, H) mhc_post_hidden = mhc_post @ dropout_expand # (S, B, N, 1) @ (S, B, 1, H) -> (S, B, N, H) output = self.add(residual, mhc_post_hidden) - output = self.mhc_reduce_rmsnorm(output.sum(dim=2)) # (S, B, N, H) -> (S, B, H) + if self.layer_number == self.config.num_layers - 1: + output = self.mhc_reduce_rmsnorm(output.sum(dim=2)) # (S, B, N, H) -> (S, B, H) + else: + output = self.mhc_reduce_rmsnorm(output) else: output = self.add(residual, dropout_output) # Note: context parameter is returned for API compatibility but currently unused. -- Gitee