From cc4645f0c87d20d28566b4789089c76f7f471407 Mon Sep 17 00:00:00 2001 From: niujunhao Date: Thu, 22 Jan 2026 15:48:04 +0800 Subject: [PATCH] Update hsdp. --- mindformers/pynative/config/__init__.py | 0 mindformers/pynative/config/config.py | 299 ++++++++++++++++++++++++ mindformers/pynative/trainer/trainer.py | 31 ++- 3 files changed, 318 insertions(+), 12 deletions(-) create mode 100644 mindformers/pynative/config/__init__.py create mode 100644 mindformers/pynative/config/config.py diff --git a/mindformers/pynative/config/__init__.py b/mindformers/pynative/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindformers/pynative/config/config.py b/mindformers/pynative/config/config.py new file mode 100644 index 000000000..0d1c0caa3 --- /dev/null +++ b/mindformers/pynative/config/config.py @@ -0,0 +1,299 @@ +# Copyright 202 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Pynative Config.""" +import yaml +import os +from dataclasses import dataclass, field, fields, is_dataclass +from typing import Optional, List, Any, Dict, Union, get_type_hints + +# Helper for defining fields with metadata +def config_field(default=None, default_factory=None, help_text="", choices=None): + metadata = {"help": help_text} + if choices is not None: + metadata["choices"] = choices + + if default_factory is not None: + return field(default_factory=default_factory, metadata=metadata) + return field(default=default, metadata=metadata) + +class DictConfig: + """Helper for custom configurations to allow dot notation""" + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, dict): + setattr(self, k, DictConfig(**v)) + else: + setattr(self, k, v) + + @classmethod + def from_value(cls, value): + if isinstance(value, dict): + return cls(**value) + elif isinstance(value, list): + return [cls.from_value(v) for v in value] + else: + return value + + def __repr__(self): + return str(self.__dict__) + +class BaseConfig: + @classmethod + def load(cls, path: str): + if not os.path.exists(path): + raise FileNotFoundError(f"Config file not found: {path}") + + with open(path, 'r', encoding='utf-8') as f: + data = yaml.safe_load(f) + + return cls.from_dict(data) + + @classmethod + def from_dict(cls, data: Dict[str, Any]): + if not is_dataclass(cls): + raise TypeError(f"{cls.__name__} must be a dataclass") + + field_names = {f.name for f in fields(cls)} + known_args = {} + unknown_args = {} + + for k, v in data.items(): + if k in field_names: + known_args[k] = v + else: + unknown_args[k] = v + + init_args = {} + type_hints = get_type_hints(cls) + + for k, v in known_args.items(): + field_type = type_hints.get(k) + # Handle Optional[T] by stripping Optional (Union[T, None]) + # Simplified check for our specific use case + if hasattr(field_type, '__origin__') and field_type.__origin__ is Union: + args = field_type.__args__ + # Assume non-None type is what we want + valid_args = [a for a in args if a is not type(None)] + if valid_args: + field_type = valid_args[0] + + if is_dataclass(field_type): + if isinstance(v, dict): + init_args[k] = field_type.from_dict(v) + else: + init_args[k] = v + elif hasattr(field_type, '__origin__') and field_type.__origin__ is list: + # Handle List[Dataclass] + args = field_type.__args__ + if args and is_dataclass(args[0]) and isinstance(v, list): + init_args[k] = [args[0].from_dict(item) if isinstance(item, dict) else item for item in v] + else: + init_args[k] = v + else: + init_args[k] = v + + instance = cls(**init_args) + + # Handle custom fields for classes that support it (indicated by _custom_config attribute existence) + # Note: _custom_config is initialized in __post_init__ + if hasattr(instance, '_custom_config'): + for k, v in unknown_args.items(): + val = DictConfig.from_value(v) + instance._custom_config[k] = val + # Monkey-patch for direct access if no conflict + if not hasattr(instance, k): + setattr(instance, k, val) + + instance.validate() + return instance + + def validate(self): + for f in fields(self): + value = getattr(self, f.name) + + # Check choices + choices = f.metadata.get("choices") + if choices is not None and value not in choices: + raise ValueError(f"Config validation error in '{self.__class__.__name__}.{f.name}': Value '{value}' is not in allowed choices {choices}") + + # Recursive validation + if isinstance(value, BaseConfig): + value.validate() + elif isinstance(value, list): + for item in value: + if isinstance(item, BaseConfig): + item.validate() + +@dataclass +class CheckpointConfig(BaseConfig): + save_path: str = config_field(default="/path/checkpoint", help_text="directory to save checkpoints") + save_max: int = config_field(default=5, help_text="maximum number of checkpoints to retain") + save_interleaved_steps: int = config_field(default=100, help_text="number of training steps between checkpoint saves") + no_save_optim: bool = config_field(default=False, help_text="whether to skip saving optimizer states") + async_save: bool = config_field(default=False, help_text="enable asynchronous checkpoint saving") + prefix: Optional[str] = config_field(default=None, help_text="filename prefix for saved checkpoints") + remove_redundancy: bool = config_field(default=True, help_text="remove redundant data when saving checkpoints") + + load_path: str = config_field(default="", help_text="directory to load checkpoints from") + load_balanced: bool = config_field(default=False, help_text="enable balanced loading across ranks/devices") + no_load_optim: bool = config_field(default=True, help_text="whether to skip loading optimizer states") + load_worker_number: int = config_field(default=1, help_text="number of worker threads used for checkpoint loading") + +@dataclass +class TrainingConfig(BaseConfig): + steps: int = config_field(default=100, help_text="total number of training steps") + local_batch_size: int = config_field(default=2, help_text="per-rank batch size") + global_batch_size: int = config_field(default=4, help_text="total number of samples processed per step") + max_norm: float = config_field(default=1.0, help_text="enable gradient clipping if value > 0") + seed: int = config_field(default=42, help_text="random seed for training") + deterministic: int = config_field(default=1, help_text="enable deterministic training behavior") + +@dataclass +class ParallelismConfig(BaseConfig): + tensor_parallel: int = config_field(default=1, help_text="tensor parallelism degree") + context_parallel: int = config_field(default=1, help_text="context parallelism degree") + context_parallel_method: str = config_field(default="colossal", help_text="implementation method for context parallelism") + pipeline_parallel: int = config_field(default=2, help_text="pipeline parallelism degree") + pipeline_parallel_schedule: str = config_field(default="1f1b", help_text="pipeline execution schedule") + pipeline_parallel_microbatch_size: int = config_field(default=1, help_text="number of micro-batches per pipeline step") + pipeline_parallel_interleave_num: int = config_field(default=2, help_text="number of interleaved model chunks") + pipeline_parallel_layers_per_stage: List[List[int]] = config_field(default_factory=list, help_text="layers assigned to each pipeline stage") + hsdp_shard_size: int = config_field(default=2, help_text="HSDP sharding group size") + hsdp_optimizer_level: str = config_field(default="level1", help_text="optimizer state sharding level for HSDP") + hsdp_threshold: int = config_field(default=64, help_text="parameter size threshold (in MB)") + sequence_parallel: bool = config_field(default=False, help_text="enable sequence parallelism") + +@dataclass +class OptimizerConfig(BaseConfig): + type: str = config_field(default="AdamW", help_text="Optimizer type") + betas: List[float] = config_field(default_factory=lambda: [0.9, 0.95], help_text="Optimizer betas") + eps: float = config_field(default=1e-8, help_text="Optimizer epsilon") + weight_decay: float = config_field(default=0.0, help_text="Weight decay") + +@dataclass +class LrSchedulerConfig(BaseConfig): + type: str = config_field(default="ConstantWarmUpLR", help_text="Scheduler type") + learning_rate: float = config_field(default=1e-6, help_text="Learning rate") + warmup_lr_init: float = config_field(default=0.0, help_text="Initial warmup LR") + warmup_steps: int = config_field(default=0, help_text="Warmup steps") + +@dataclass +class DataLoaderConfig(BaseConfig): + type: str = config_field(default="BlendedMegatronDatasetDataLoader", help_text="DataLoader type") + +@dataclass +class TrainDatasetConfig(BaseConfig): + dataloader: DataLoaderConfig = config_field(default_factory=DataLoaderConfig, help_text="Dataloader config") + column_names: List[str] = config_field(default_factory=lambda: ['input_ids'], help_text="Input feature column names") + num_parallel_workers: int = config_field(default=8, help_text="Number of parallel data loading workers") + shuffle: bool = config_field(default=False, help_text="Whether to shuffle the dataset") + prefetch_size: int = config_field(default=1, help_text="Number of prefetched batches") + numa_enable: bool = config_field(default=False, help_text="Enable NUMA-aware data loading") + +@dataclass +class ModelConfig(BaseConfig): + model_type: str = config_field(default="qwen3", help_text="Model type") + architectures: str = config_field(default="Qwen3ForCausalLM", help_text="Model architectures") + +@dataclass +class HealthCheckpointConfig(BaseConfig): + embedding_loacl_norm_threshold: float = config_field(default=100.0, help_text="Threshold for embedding parameter local norm") + global_norm_skip_threshold: float = config_field(default=100.0, help_text="Global norm threshold for skipping updates") + global_norm_skip_time: int = config_field(default=10, help_text="Number of consecutive skips allowed") + +@dataclass +class TrainStateConfig(BaseConfig): + local_norm: bool = config_field(default=True, help_text="Monitor local gradient/parameter norm") + local_loss: bool = config_field(default=True, help_text="Monitor local training loss") + device_norm: bool = config_field(default=True, help_text="Monitor device-level norm statistics") + device_loss: bool = config_field(default=True, help_text="Monitor device-level loss statistics") + +@dataclass +class MonitorConfig(BaseConfig): + health_checkpoint: HealthCheckpointConfig = config_field(default_factory=HealthCheckpointConfig, help_text="Health monitoring") + train_state: TrainStateConfig = config_field(default_factory=TrainStateConfig, help_text="Training state monitoring") + +@dataclass +class CallbackConfig(BaseConfig): + type: str = config_field(default="", help_text="Callback type") + param: int = config_field(default=0, help_text="Callback parameter") + +@dataclass +class TensorBoardConfig(BaseConfig): + tensorboard_dir: str = config_field(default="", help_text="Directory to store TensorBoard logs") + tensorboard_log_interval: int = config_field(default=1, help_text="Logging interval") + tensorboard_queue_size: int = config_field(default=1000, help_text="Maximum size of the event queue") + log_xxx_to_tensorboard: bool = config_field(default=True, help_text="Enable logging of extended metrics") + +@dataclass +class ContextConfig(BaseConfig): + mode: int = config_field(default=1, help_text="MindSpore context mode", choices=[0, 1]) + max_device_memory: str = config_field(default="59GB", help_text="Maximum device memory usage") + device_target: str = config_field(default="Ascend", help_text="Target device type", choices=["Ascend", "GPU", "CPU"]) + +@dataclass +class ProfilerConfig(BaseConfig): + start_step: int = config_field(default=10, help_text="Profiling start step") + stop_step: int = config_field(default=20, help_text="Profiling stop step") + output_path: str = config_field(default="", help_text="Profiler output directory") + profiler_level: str = config_field(default="Level0", help_text="Profiling detail level") + profile_memory: bool = config_field(default=False, help_text="Enable memory profiling") + +@dataclass +class RecomputeConfig(BaseConfig): + recompute: bool = config_field(default=True, help_text="Enable activation recomputation") + select_recompute: bool = config_field(default=False, help_text="Enable selective recomputation") + +@dataclass +class CpuOffloadConfig(BaseConfig): + optimizer_offload: bool = config_field(default=True, help_text="Offload optimizer states to CPU") + ops_offload: bool = config_field(default=True, help_text="Offload computation ops to CPU") + layers_offload: bool = config_field(default=True, help_text="Offload model layers to CPU") + +@dataclass +class PynativeConfig(BaseConfig): + checkpoint: CheckpointConfig = config_field(default_factory=CheckpointConfig) + training: TrainingConfig = config_field(default_factory=TrainingConfig) + parallelism: ParallelismConfig = config_field(default_factory=ParallelismConfig) + optimizer: OptimizerConfig = config_field(default_factory=OptimizerConfig) + lr_scheduler: LrSchedulerConfig = config_field(default_factory=LrSchedulerConfig) + train_dataset: TrainDatasetConfig = config_field(default_factory=TrainDatasetConfig) + model: ModelConfig = config_field(default_factory=ModelConfig) + monitor: MonitorConfig = config_field(default_factory=MonitorConfig) + callbacks: List[CallbackConfig] = config_field(default_factory=list) + tensorboard: TensorBoardConfig = config_field(default_factory=TensorBoardConfig) + context: ContextConfig = config_field(default_factory=ContextConfig) + profiler: ProfilerConfig = config_field(default_factory=ProfilerConfig) + recompute: RecomputeConfig = config_field(default_factory=RecomputeConfig) + cpu_offload: CpuOffloadConfig = config_field(default_factory=CpuOffloadConfig) + + def __post_init__(self): + # Initialize custom config container + self._custom_config = {} + + def show_config_status(self): + builtin_keys = {f.name for f in fields(self)} + custom_keys = [k for k in self.__dict__ if k not in builtin_keys and k != "_custom_config"] + + print("Built-in Configurations:") + for k in sorted(builtin_keys): + print(f" - {k}") + + print("\nCustom Configurations:") + if not custom_keys: + print(" (None)") + else: + for k in sorted(custom_keys): + print(f" - {k}") diff --git a/mindformers/pynative/trainer/trainer.py b/mindformers/pynative/trainer/trainer.py index 274e968a9..8f52d9a7d 100644 --- a/mindformers/pynative/trainer/trainer.py +++ b/mindformers/pynative/trainer/trainer.py @@ -18,6 +18,7 @@ import enum from typing import Optional, Callable, List, Dict, Any import numpy as np +import mindspore as ms from mindspore import manual_seed, value_and_grad, mint from mindspore.dataset import Dataset from mindspore.common import set_seed @@ -126,6 +127,11 @@ class Trainer: # Create model self.model = self._create_model(model, getattr(self.config, "model_config", None)) self._compute_parameters() + # Initialize parallel config and wrappers + if self.use_parallel: + self.model = self._init_parallel_config( + self.model, getattr(self.config, 'parallelism', None) + ) # Create datasets self.train_dataset = self._create_dataset( @@ -226,6 +232,10 @@ class Trainer: self.config.model.model_config = self.config.model_config del self.config.model_config + src_model_config = model_config + model_config = MindFormerConfig() + model_config.model_config = src_model_config + logger.info("Building model from config...") model = build_model(model_config) @@ -294,7 +304,8 @@ class Trainer: data_parallel * self.config.training.local_batch_size ) logger.info( - f"Calculate gradient_accumulation_steps={self.gradient_accumulation_steps}" + f"Calculate global_batch_size={self.global_batch_size}, " + f"gradient_accumulation_steps={self.gradient_accumulation_steps}." ) def _create_dataset(self, dataset, dataset_config: Optional[Dict]) -> Optional[Any]: @@ -384,6 +395,7 @@ class Trainer: default_args = {"params": grouped_params, "learning_rate": lr} optimizer = build_optim(optimizer_config, default_args=default_args) + # optimizer = optim.AdamW(self.model.trainable_params(), lr=1.e-6) return optimizer, lr @@ -445,12 +457,6 @@ class Trainer: if checkpoint_path is None: checkpoint_path = getattr(self.config.checkpoint, "load_path", None) - # Initialize parallel config and wrappers - if self.use_parallel: - self.model = self._init_parallel_config( - self.model, getattr(self.config, 'parallelism', None) - ) - # Initialize training state self.state = TrainerState( max_steps=getattr(self.config.training, "max_steps", 1000), @@ -490,7 +496,7 @@ class Trainer: return model logger.info("Initializing parallel config...") - return get_hsdp_model(self.model, parallelism.hsdp) + return get_hsdp_model(self.model, parallelism.hsdp, data_parallel=self._get_data_parallel()) def _load_checkpoint(self, checkpoint_path: str): """ @@ -761,9 +767,6 @@ class Trainer: # function 2: legacy # grads, grad_norm = self.clip_grad(grads) - # Backward pass - # In real implementation with MindSpore: - # Optimizer step self.optimizer(grads) @@ -773,6 +776,10 @@ class Trainer: return self.config.parallelism.data_parallel -def get_hsdp_model(model, hsdp_args): +def get_hsdp_model(model, hsdp_args, data_parallel): + shard_size = hsdp_args.get('shard_size', 1) + if shard_size == 1: + grad_scale = 1 / data_parallel + hsdp_args['grad_scale'] = grad_scale model = hsdp(model, **hsdp_args) return model -- Gitee