diff --git a/ds_pynative.yaml b/ds_pynative.yaml index ed512c1ed2e5251ba83eb21ea7e131b8635bf1e7..08422ab837a978d1ee9d3237bb00d23947917e65 100644 --- a/ds_pynative.yaml +++ b/ds_pynative.yaml @@ -31,6 +31,13 @@ training: local_batch_size: 1 save_steps: 2000 +# memory swap config (optional, uncomment to enable) +# memory_swap: +# prefetch_interval: 1 # Number of iterations ahead to prefetch +# min_tensor_size: 1048576 # Minimum tensor size in bytes to swap (1MB) +# enable_stats: True # Whether to collect and print statistics +# enable_profiling: False # Whether to add profiling markers + # original diff --git a/mindformers/pynative/memory/__init__.py b/mindformers/pynative/memory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..152b582d65c39e1b893badc6a7bb358afefa8cd9 --- /dev/null +++ b/mindformers/pynative/memory/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Memory management utilities for reducing NPU memory usage.""" + +from .swap_config import SwapConfig +from .swap_manager import SwapManager +from .swap_tensor import SwapTensor + +__all__ = ['SwapConfig', 'SwapManager', 'SwapTensor'] diff --git a/mindformers/pynative/memory/swap_config.py b/mindformers/pynative/memory/swap_config.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8a8c8ffb50f3d3fe8cefd39d8c6fb074ab15cf --- /dev/null +++ b/mindformers/pynative/memory/swap_config.py @@ -0,0 +1,68 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Swap configuration for memory management in iterative solvers.""" + + +class SwapConfig: + """Configuration for memory swapping in iterative solvers. + + When passed to a solver (e.g., CBS), this enables swapping of activation + tensors between device and host memory to reduce peak NPU memory usage + during training. + + Args: + prefetch_interval (int, optional): Number of iterations ahead to prefetch + during backward pass. Default: ``1`` . + min_tensor_size (int, optional): Minimum tensor size in bytes to swap. + Default: ``1048576`` . + enable_stats (bool, optional): Whether to collect and print swap + statistics. Default: ``False`` . + enable_profiling (bool, optional): Whether to add profiling markers for + MindSpore profiler. Requires user to enable profiler externally with + mstx=True. Default: ``False`` . + + Examples: + >>> from mindformers.pynative.memory import SwapConfig + >>> swap_config = SwapConfig(prefetch_interval=2) + >>> trainer = Trainer(model=model, swap_config=swap_config) + """ + + def __init__(self, + prefetch_interval=1, + min_tensor_size=1*1024*1024, + enable_stats=False, + enable_profiling=False): + self.prefetch_interval = prefetch_interval + self.min_tensor_size = min_tensor_size + self.enable_stats = enable_stats + self.enable_profiling = enable_profiling + + def create_manager(self, n_iter): + """Create a SwapManager from this configuration. + + Args: + n_iter (int): Total number of iterations. + + Returns: + SwapManager: A configured swap manager instance. + """ + from .swap_manager import SwapManager + return SwapManager( + n_iter=n_iter, + prefetch_interval=self.prefetch_interval, + min_tensor_size=self.min_tensor_size, + enable_stats=self.enable_stats, + enable_profiling=self.enable_profiling, + ) diff --git a/mindformers/pynative/memory/swap_manager.py b/mindformers/pynative/memory/swap_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a1c730b3f70afaa0455d5b3dc59e07dbaec41437 --- /dev/null +++ b/mindformers/pynative/memory/swap_manager.py @@ -0,0 +1,532 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Memory swap manager for reducing NPU memory usage during training. + +This module provides a framework for swapping activation between device +(GPU/NPU) and host (CPU) memory to reduce peak memory usage during training. +It is designed to work with iterative solvers like CBS where activations from +each iteration need to be saved for backward propagation. +""" +import time +import logging +from contextlib import contextmanager +from typing import Generator, Union, Optional +import psutil +import mindspore as ms +from mindspore import Tensor +from mindspore import runtime +from mindspore.profiler import mstx + +from .swap_tensor import SwapTensor + +from mindformers.tools.logger import logger +process = psutil.Process() + + +class SwapManager: + """Memory swap manager for coordinating tensor swapping across iterations. + + This class manages the swapping of activation tensors between device and host + memory during training. It uses saved_tensors_hooks to intercept tensor saving + during forward pass and restore them during backward pass. + + Args: + n_iter (int): Total number of iterations. + prefetch_interval (int): Number of iterations ahead to prefetch. Default: 1. + min_tensor_size (int): Minimum tensor size in bytes to swap. Default: 1MB. + enable_stats (bool): Whether to collect and print statistics. Default: False. + enable_profiling (bool): Whether to add profiling markers. Default: False. + + Examples: + >>> swap_manager = SwapManager(n_iter=100, prefetch_interval=2) + >>> for i in range(n_iter): + ... swap_manager.set_iteration(i) + ... with ms.saved_tensors_hooks(swap_manager.pack_hook, swap_manager.unpack_hook): + ... output = model(input) + ... swap_manager.end_iteration() + """ + + def __init__(self, n_iter: int, prefetch_interval: int = 1, + min_tensor_size: int = 1*1024*1024, + enable_stats: bool = False, enable_profiling: bool = False): + self.n_iter = n_iter + self.prefetch_interval = prefetch_interval + self.min_tensor_size = min_tensor_size + self.enable_stats = enable_stats + self.enable_profiling = enable_profiling + + self.transfer_stream = runtime.Stream() + + self.current_iter = 0 + + self.prefetch_list = {} + self.prefetched_iters = set() + + self.current_swap_tensors = [] + self.current_iter_tensor_ptrs = set() + self.last_prefetched_iter = None + self.last_d2h_iter = None + + self.iteration_h2d_events = {} + + if self.enable_stats: + self._reset_stats() + + def _mstx_mark(self, message: str) -> None: + """Mark a profiling point if profiling is enabled. + + Args: + message: The profiling message to mark. + """ + if self.enable_profiling: + mstx.mark(message, domain="memory_swap") + + def _mstx_range_start(self, message: str, stream=None) -> Optional[int]: + """Start a profiling range if profiling is enabled. + + Args: + message: The profiling message for the range. + stream: Optional stream to associate with the range. + + Returns: + Range ID if profiling is enabled, None otherwise. + """ + if self.enable_profiling: + return mstx.range_start(message, stream, "memory_swap") + return None + + def _mstx_range_end(self, range_id: Optional[int]) -> None: + """End a profiling range if profiling is enabled. + + Args: + range_id: The range ID returned by _mstx_range_start. + """ + if self.enable_profiling and range_id is not None: + mstx.range_end(range_id, "memory_swap") + + @contextmanager + def _mstx_range(self, message: str, stream=None) -> Generator[None, None, None]: + """Context manager for profiling range. + + Args: + message: The profiling message for the range. + stream: Optional stream to associate with the range. + + Yields: + None + """ + range_id = self._mstx_range_start(message, stream) + try: + yield + finally: + self._mstx_range_end(range_id) + + def should_swap(self, tensor: Tensor) -> bool: + """Determine whether a tensor should be swapped. + + Args: + tensor (Tensor): The tensor to check. + + Returns: + bool: True if the tensor should be swapped, False otherwise. + """ + tensor_size = tensor.numel() * tensor.itemsize + if tensor_size < self.min_tensor_size: + return False + + return True + + def pack_hook(self, tensor: Tensor) -> Union['SwapTensor', Tensor]: + """Pack hook for forward pass tensor saving. + + This hook is called by saved_tensors_hooks during forward pass. + It creates a SwapTensor and records forward event for later batch D2H. + + Args: + tensor (Tensor): The tensor to be saved. + + Returns: + SwapTensor or Tensor: SwapTensor if swapping, original tensor otherwise. + """ + if self.enable_stats: + start = time.perf_counter() + + if not self.should_swap(tensor): + if self.enable_stats: + self.skipped_count += 1 + return tensor + + tensor_ptr = tensor.data_ptr() + if tensor_ptr in self.current_iter_tensor_ptrs: + for swap_tensor in self.current_swap_tensors: + if swap_tensor.tensor.data_ptr() == tensor_ptr: + if self.enable_stats: + self.duplicate_tensor_count += 1 + if self.enable_profiling: + mstx.mark( + f"MemSwap/Iter{self.current_iter}/Pack_Duplicate_shape{list(tensor.shape)}", + domain="memory_swap" + ) + return swap_tensor + + self.current_iter_tensor_ptrs.add(tensor_ptr) + + tensor_idx = len(self.current_swap_tensors) + self._mstx_mark( + f"MemSwap/Iter{self.current_iter}/Pack_T{tensor_idx}_shape{list(tensor.shape)}_size{tensor.numel() * tensor.itemsize}" + ) + + swap_tensor = SwapTensor(tensor, self.current_iter, self.transfer_stream) + + if self.current_iter < self.n_iter - self.prefetch_interval: + swap_tensor.record_forward_event() + self.current_swap_tensors.append(swap_tensor) + + if self.enable_stats: + self.pack_count += 1 + self.total_swap_bytes += tensor.numel() * tensor.itemsize + self.pack_time += time.perf_counter() - start + + return swap_tensor + + def unpack_hook(self, packed: Union['SwapTensor', Tensor]) -> Tensor: + """Unpack hook for backward pass tensor restoration. + + This hook is called by saved_tensors_hooks during backward pass. + It triggers prefetching and waits for H2D transfer to complete. + + Args: + packed: SwapTensor or original Tensor. + + Returns: + Tensor: The restored device tensor. + """ + if isinstance(packed, Tensor): + return packed + + if self.enable_stats: + start = time.perf_counter() + swap_tensor = packed + + if self.enable_profiling: + tensor_shape = list(swap_tensor.shape) + status = swap_tensor.stat + mstx.mark( + f"MemSwap/Iter{swap_tensor.iteration_id}/Unpack_shape{tensor_shape}_stat{status}", + domain="memory_swap" + ) + + if swap_tensor.stat == "host": + swap_tensor.tensor = swap_tensor.tensor_cpu.to('Ascend') + swap_tensor.tensor_cpu = None + swap_tensor.stat = "device" + + if self.enable_stats: + self.prefetch_misses += 1 + if self.enable_profiling: + mstx.mark( + f"MemSwap/Iter{swap_tensor.iteration_id}/Prefetch_Miss", + domain="memory_swap" + ) + elif self.enable_stats: + self.prefetch_hits += 1 + + iter_id = swap_tensor.iteration_id + if iter_id in self.iteration_h2d_events: + if self.enable_stats: + wait_start = time.perf_counter() + + self._mstx_mark(f"MemSwap/Iter{iter_id}/UnpackWait") + + runtime.current_stream().wait_event(self.iteration_h2d_events[iter_id]) + + if self.enable_stats: + elapsed = time.perf_counter() - wait_start + self.h2d_wait_count += 1 + self.h2d_wait_time += elapsed + + self.iteration_h2d_events.pop(iter_id, None) + if iter_id in self.prefetch_list: + for st in self.prefetch_list[iter_id]: + if st.stat == "h2d": + st.stat = "device" + + del self.prefetch_list[iter_id] + + prefetch_target = swap_tensor.iteration_id - self.prefetch_interval + if prefetch_target >= 0 and prefetch_target in self.prefetch_list: + if prefetch_target not in self.prefetched_iters: + self.prefetch_h2d(prefetch_target) + self.prefetched_iters.add(prefetch_target) + + if self.enable_stats: + self.unpack_count += 1 + self.unpack_time += time.perf_counter() - start + + return swap_tensor.get_tensor() + + def set_iteration(self, iter_id: int) -> None: + """Set the current iteration ID. + + Args: + iter_id (int): The current iteration index. + """ + self.current_iter = iter_id + self.current_swap_tensors = [] + self.current_iter_tensor_ptrs.clear() + + def end_iteration(self) -> None: + """End the current iteration and launch D2H transfers. + + This method waits for the previous iteration's D2H to complete first, + then launches D2H transfers for the current iteration. This allows the + previous iteration's D2H to overlap with the current iteration's computation. + """ + + if self.last_d2h_iter is not None and self.last_d2h_iter in self.prefetch_list: + prev_tensors = self.prefetch_list[self.last_d2h_iter] + if prev_tensors and prev_tensors[0].stat == "d2h": + wait_range_id = None + if self.enable_profiling: + wait_range_id = mstx.range_start( + f"MemSwap/Iter{self.last_d2h_iter}/D2H_Wait", + self.transfer_stream, + "memory_swap" + ) + + wait_start = time.perf_counter() if self.enable_stats else 0 + runtime.current_stream().wait_stream(self.transfer_stream) + if self.enable_stats: + elapsed = time.perf_counter() - wait_start + self.d2h_wait_count += 1 + self.d2h_wait_time += elapsed + + if self.enable_profiling and wait_range_id is not None: + mstx.range_end(wait_range_id, "memory_swap") + + for swap_tensor in prev_tensors: + swap_tensor.wait_d2h_finished(need_wait=False) + + self.last_d2h_iter = None + + if not self.current_swap_tensors: + if self.enable_stats: + self.iter_swap_counts.append(0) + self.current_iter_tensor_ptrs.clear() + return + + with self._mstx_range(f"MemSwap/Iter{self.current_iter}/D2H_Batch", self.transfer_stream): + with runtime.StreamCtx(self.transfer_stream): + for swap_tensor in self.current_swap_tensors: + if swap_tensor.stat == "event_recorded": + swap_tensor.execute_d2h_copy() + + if self.current_swap_tensors: + self.last_d2h_iter = self.current_iter + + n_swapped = len(self.current_swap_tensors) + if self.enable_stats: + self.iter_swap_counts.append(n_swapped) + swapped_bytes = sum(t.size for t in self.current_swap_tensors) + mem_info = psutil.virtual_memory() + logger.debug( + f"Iteration {self.current_iter}: swapped {n_swapped} tensors ({swapped_bytes / (1024 * 1024):.2f}MB), " + f"device={ms.runtime.memory_allocated() / (1024 * 1024 * 1024):.2f}GB, " + f"cpu={process.memory_info().rss / (1024 * 1024 * 1024):.2f}GB, " + f"system_free={mem_info.free / (1024 * 1024 * 1024):.2f}GB" + ) + + self.prefetch_list[self.current_iter] = self.current_swap_tensors + self.current_swap_tensors = [] + self.current_iter_tensor_ptrs.clear() + + def prefetch_h2d(self, target_iter: int) -> None: + """Prefetch tensors from a target iteration with layer-wise waiting. + + This method waits for the previous layer's H2D to complete before + starting the current layer's H2D, enabling timely memory release. + + Args: + target_iter (int): The iteration index to prefetch. + """ + if self.last_prefetched_iter is not None and self.last_prefetched_iter in self.prefetch_list: + prev_tensors = self.prefetch_list[self.last_prefetched_iter] + if prev_tensors and prev_tensors[0].stat == "h2d": + wait_start = time.perf_counter() if self.enable_stats else 0 + runtime.current_stream().wait_stream(self.transfer_stream) + if self.enable_stats: + elapsed = time.perf_counter() - wait_start + self.h2d_wait_count += 1 + self.h2d_wait_time += elapsed + + for swap_tensor in prev_tensors: + swap_tensor.wait_h2d_finished(need_wait=False) + self.last_prefetched_iter = None + + if target_iter not in self.prefetch_list: + return + + swap_tensors = self.prefetch_list[target_iter] + + backward_event = runtime.current_stream().record_event() + with self._mstx_range(f"MemSwap/Iter{target_iter}/H2D_Prefetch", self.transfer_stream): + with runtime.StreamCtx(self.transfer_stream): + self.transfer_stream.wait_event(backward_event) + for swap_tensor in swap_tensors: + if swap_tensor.stat == "host": + swap_tensor.execute_h2d_copy() + self.iteration_h2d_events[target_iter] = self.transfer_stream.record_event() + + self.last_prefetched_iter = target_iter + + def finish(self) -> None: + """Finish swap management and print statistics. + + Waits for any pending D2H operations before cleanup. + Prints swap statistics and clears all stored swap tensors, + resetting the state for potential reuse. + """ + # Wait for pending transfers before cleanup + if self.last_d2h_iter is not None and self.last_d2h_iter in self.prefetch_list: + prev_tensors = self.prefetch_list[self.last_d2h_iter] + if prev_tensors and prev_tensors[0].stat == "d2h": + runtime.current_stream().wait_stream(self.transfer_stream) + for swap_tensor in prev_tensors: + swap_tensor.wait_d2h_finished(need_wait=False) + + if self.last_prefetched_iter is not None and self.last_prefetched_iter in self.prefetch_list: + prev_tensors = self.prefetch_list[self.last_prefetched_iter] + if prev_tensors and prev_tensors[0].stat == "h2d": + runtime.current_stream().wait_stream(self.transfer_stream) + for swap_tensor in prev_tensors: + swap_tensor.wait_h2d_finished(need_wait=False) + + self.print_stats() + + # Explicitly release all tensor and event references + for iter_tensors in self.prefetch_list.values(): + for swap_tensor in iter_tensors: + swap_tensor.tensor = None + swap_tensor.tensor_cpu = None + swap_tensor.forward_event = None + + self.prefetch_list.clear() + self.prefetched_iters.clear() + self.iteration_h2d_events.clear() + self.current_swap_tensors = [] + self.current_iter_tensor_ptrs.clear() + self.last_prefetched_iter = None + self.last_d2h_iter = None + self.current_iter = 0 + if self.enable_stats: + self._reset_stats() + + def _reset_stats(self) -> None: + """Reset all statistics counters.""" + self.pack_count = 0 + self.unpack_count = 0 + self.pack_time = 0.0 + self.unpack_time = 0.0 + self.skipped_count = 0 + self.total_swap_bytes = 0 + self.prefetch_hits = 0 + self.prefetch_misses = 0 + self.d2h_wait_count = 0 + self.d2h_wait_time = 0.0 + self.h2d_wait_count = 0 + self.h2d_wait_time = 0.0 + self.iter_times = [] + self.iter_swap_counts = [] + self.duplicate_tensor_count = 0 + + @contextmanager + def iteration(self, iter_id: int) -> Generator[None, None, None]: + """Context manager that wraps set_iteration, saved_tensors_hooks, and end_iteration. + + This provides a convenient way to use the swap manager in a for loop, + and centralizes the hook setup so that logging and statistics can be + added in one place. + + Args: + iter_id (int): The current iteration index. + + Yields: + None + + Examples: + >>> swap_manager = SwapManager(n_iter=100) + >>> for i in range(n_iter): + ... with swap_manager.iteration(i): + ... output = model(input) + """ + self.set_iteration(iter_id) + iter_start = time.perf_counter() if self.enable_stats else 0 + with ms.saved_tensors_hooks(self.pack_hook, self.unpack_hook): + yield + self.end_iteration() + if self.enable_stats: + self.iter_times.append(time.perf_counter() - iter_start) + + def print_stats(self) -> None: + """Print a summary of swap statistics to logger. + + Logs hook timing, tensor swap counts, data volume, + prefetch hit rate, and per-iteration breakdown. + Does nothing if enable_stats is False. + """ + if not self.enable_stats: + return + + total_hook_time = self.pack_time + self.unpack_time + total_iter_time = sum(self.iter_times) if self.iter_times else 0.0 + total_prefetch = self.prefetch_hits + self.prefetch_misses + prefetch_hit_rate = ( + self.prefetch_hits / total_prefetch if total_prefetch > 0 else 0.0 + ) + overhead_pct = (total_hook_time / total_iter_time * 100) if total_iter_time > 0 else 0.0 + + mem_info = psutil.virtual_memory() + device_gb = ms.runtime.memory_allocated() / (1024 * 1024 * 1024) + reserved_gb = ms.runtime.memory_reserved() / (1024 * 1024 * 1024) + cpu_gb = process.memory_info().rss / (1024 * 1024 * 1024) + system_total_gb = mem_info.total / (1024 * 1024 * 1024) + system_free_gb = mem_info.free / (1024 * 1024 * 1024) + + logger.info("=== SwapManager Statistics ===") + logger.info( + f"Hook timing: pack={self.pack_time * 1000:.1f}ms ({self.pack_count} calls), " + f"unpack={self.unpack_time * 1000:.1f}ms ({self.unpack_count} calls), " + f"total={total_hook_time * 1000:.1f}ms ({overhead_pct:.1f}% overhead)" + ) + logger.info( + f"Tensors: {self.pack_count} swapped, {self.skipped_count} skipped, " + f"{self.duplicate_tensor_count} duplicates, " + f"{self.total_swap_bytes / (1024 * 1024 * 1024):.2f} GB total" + ) + logger.info( + f"Prefetch: {self.prefetch_hits} hits, {self.prefetch_misses} misses, " + f"hit_rate={prefetch_hit_rate * 100:.1f}%" + ) + logger.info( + f"Iterations: {len(self.iter_times)}, total_time={total_iter_time:.3f}s" + ) + logger.info( + f"Stream sync: D2H wait={self.d2h_wait_count} times {self.d2h_wait_time * 1000:.1f}ms, " + f"H2D wait={self.h2d_wait_count} times {self.h2d_wait_time * 1000:.1f}ms" + ) + logger.info( + f"Memory: device={device_gb:.2f}GB, reserved={reserved_gb:.2f}GB, cpu={cpu_gb:.2f}GB, " + f"system_total={system_total_gb:.2f}GB, system_free={system_free_gb:.2f}GB" + ) diff --git a/mindformers/pynative/memory/swap_tensor.py b/mindformers/pynative/memory/swap_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..7a8a7feba6509e767db7b2f88f08f8e0c4333700 --- /dev/null +++ b/mindformers/pynative/memory/swap_tensor.py @@ -0,0 +1,167 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SwapTensor for managing tensor transmission between device and host.""" +import logging +import psutil +import mindspore as ms +from mindspore import Tensor +from mindspore import runtime + +from mindformers.tools.logger import logger +process = psutil.Process() + + +class SwapTensor: + """Manages the transmission of a single tensor between device and host. + + This class handles the lifecycle of a tensor as it moves between device + memory and host memory. It tracks the tensor's state and provides methods + for asynchronous data transfer. + + State machine: + device -> d2h -> host -> h2d -> device + + Args: + tensor (Tensor): The original device tensor to be swapped. + iteration_id (int): The iteration index this tensor belongs to. + stream (mindspore.runtime.Stream): The stream used for asynchronous data transfer. + + Attributes: + tensor (Tensor): Reference to the original device tensor. + tensor_cpu (Tensor): CPU buffer for the tensor data. + shape (tuple): Shape of the tensor. + iteration_id (int): Iteration index this tensor belongs to. + stream (Stream): Transfer stream for async operations. + stat (str): Current state of the tensor. + """ + + def __init__(self, tensor: Tensor, iteration_id: int, stream: runtime.Stream): + """Initialize SwapTensor. + + Args: + tensor (Tensor): The original device tensor to be swapped. + iteration_id (int): The iteration index this tensor belongs to. + stream (mindspore.runtime.Stream): The stream used for asynchronous data transfer. + """ + self.tensor = tensor + self.shape = tensor.shape + self.size = tensor.numel() * tensor.itemsize + self.iteration_id = iteration_id + self.stream = stream + + self.tensor_cpu = None + + self.stat = "device" + self.forward_event = None + + def record_forward_event(self) -> None: + """Record forward event without switching stream. + + This method only records the event on the current stream, deferring + the actual D2H copy to be executed later in batch mode. + """ + self.forward_event = runtime.current_stream().record_event() + self.stat = "event_recorded" + + def execute_d2h_copy(self) -> None: + """Execute D2H copy in the transfer stream. + + This method must be called within a StreamCtx(transfer_stream) context. + It waits for the forward event and performs the actual D2H copy. + """ + if self.forward_event is not None: + self.stream.wait_event(self.forward_event) + self.forward_event = None + + self.tensor_cpu = self.tensor.to('CPU', non_blocking=True) + self.tensor = None + self.stat = "d2h" + + def launch_d2h(self) -> None: + """Launch asynchronous device-to-host transfer (legacy method). + + This method is kept for backward compatibility but internally uses + the new batch processing approach. + """ + if self.stat != "device": + return + + self.record_forward_event() + with runtime.StreamCtx(self.stream): + self.execute_d2h_copy() + + def wait_d2h_finished(self, need_wait: bool = False) -> None: + """Wait for D2H transfer to complete and release device memory. + + Args: + need_wait (bool): Whether to synchronize with the transfer stream. + Only the first tensor in each iteration needs to wait. + """ + if self.stat != "d2h": + return + + if need_wait: + runtime.current_stream().wait_stream(self.stream) + + self.tensor = None + self.stat = "host" + + def execute_h2d_copy(self) -> None: + """Execute H2D copy in the transfer stream. + + This method must be called within a StreamCtx(transfer_stream) context. + It waits for the backward event and performs the actual H2D copy. + """ + self.tensor = self.tensor_cpu.to('Ascend', non_blocking=True) + self.tensor_cpu = None + self.stat = "h2d" + + def launch_h2d(self) -> None: + """Launch asynchronous host-to-device transfer (legacy method). + + This method is kept for backward compatibility but internally uses + the new batch processing approach. + """ + if self.stat != "host": + return + + backward_event = runtime.current_stream().record_event() + with runtime.StreamCtx(self.stream): + self.stream.wait_event(backward_event) + self.execute_h2d_copy() + + def wait_h2d_finished(self, need_wait: bool = False) -> None: + """Wait for H2D transfer to complete. + + Args: + need_wait (bool): Whether to synchronize with the transfer stream. + Only the last tensor in each iteration needs to wait. + """ + if self.stat != "h2d": + return + + if need_wait: + runtime.current_stream().wait_stream(self.stream) + + self.tensor_cpu = None + self.stat = "device" + + def get_tensor(self) -> Tensor: + """Get the restored device tensor. + + Returns: + Tensor: The tensor on device memory. + """ + return self.tensor diff --git a/mindformers/pynative/trainer/trainer.py b/mindformers/pynative/trainer/trainer.py index 274e968a91ce6c893860ee6557718b5c3f07f00b..c3c9be396f9973bbda50eecf2299ef1ba1e85b55 100644 --- a/mindformers/pynative/trainer/trainer.py +++ b/mindformers/pynative/trainer/trainer.py @@ -18,6 +18,7 @@ import enum from typing import Optional, Callable, List, Dict, Any import numpy as np +import mindspore as ms from mindspore import manual_seed, value_and_grad, mint from mindspore.dataset import Dataset from mindspore.common import set_seed @@ -87,6 +88,7 @@ class Trainer: lr_scheduler: Optional[Any] = None, compute_loss_func: Optional[Callable] = None, callbacks: Optional[List] = None, + swap_config: Optional['SwapConfig'] = None, ): """ Initialize the Trainer. @@ -155,6 +157,10 @@ class Trainer: self.grad_fn = None self.clip_grad = None + # Memory swap configuration + self.swap_config = swap_config + self.swap_manager = None + def _init_config(self, config: str) -> MindFormerConfig: """ Initialize trainer config from yaml file or MindFormerConfig instance. @@ -459,6 +465,19 @@ class Trainer: global_batch_size=self.global_batch_size, ) + # Initialize swap manager if configured + if self.swap_config is not None: + # Get number of layers from model + num_layers = getattr(self.model.model.decoder, 'num_layers', 1) + # n_iter = num_layers (each training step is a complete swap cycle) + self.swap_manager = self.swap_config.create_manager(n_iter=num_layers) + + # Attach swap_manager to TransformerBlock + decoder = self.model.model.decoder + decoder.swap_manager = self.swap_manager + + logger.info(f"Memory swap enabled: {num_layers} layers, prefetch_interval={self.swap_config.prefetch_interval}") + # Load checkpoint if checkpoint_path: self._load_checkpoint(checkpoint_path) @@ -767,6 +786,10 @@ class Trainer: # Optimizer step self.optimizer(grads) + # Finish swap cycle for this training step + if self.swap_manager is not None: + self.swap_manager.finish() + return loss, grad_norm def _get_data_parallel(self): diff --git a/mindformers/pynative/transformers/transformer_block.py b/mindformers/pynative/transformers/transformer_block.py index 9de09968dc9e0318445d04cee0ef8ac265972fef..147dead60d56fd10fbe40f2ac6aa22c92d21deec 100644 --- a/mindformers/pynative/transformers/transformer_block.py +++ b/mindformers/pynative/transformers/transformer_block.py @@ -157,17 +157,32 @@ class TransformerBlock(nn.Cell): - aux_loss (Tensor | None): Optional auxiliary loss accumulated over layers. """ aux_loss = None + swap_manager = getattr(self, 'swap_manager', None) + for index in range(self.num_layers): layer = self._get_layer(index) prefix_kv = prefix_keys_values[index] if prefix_keys_values is not None else None - hidden_states, _, layer_aux_loss = layer( - hidden_states, - attention_mask, - rotary_pos_emb=rotary_pos_emb, - prefix_keys_values=prefix_kv, - actual_seq_len=actual_seq_len, - input_ids=input_ids, - ) + + if swap_manager is not None: + with swap_manager.iteration(index): + hidden_states, _, layer_aux_loss = layer( + hidden_states, + attention_mask, + rotary_pos_emb=rotary_pos_emb, + prefix_keys_values=prefix_kv, + actual_seq_len=actual_seq_len, + input_ids=input_ids, + ) + else: + hidden_states, _, layer_aux_loss = layer( + hidden_states, + attention_mask, + rotary_pos_emb=rotary_pos_emb, + prefix_keys_values=prefix_kv, + actual_seq_len=actual_seq_len, + input_ids=input_ids, + ) + if layer_aux_loss is not None: aux_loss = layer_aux_loss if aux_loss is None else aux_loss + layer_aux_loss diff --git a/mindformers/pynative/transformers/transformer_layer.py b/mindformers/pynative/transformers/transformer_layer.py index 40bb092ab095e66c6c3e0303b4c4097f2a1feb5f..bde32cfac22291236d8425a4d7066722aa0f4e3d 100644 --- a/mindformers/pynative/transformers/transformer_layer.py +++ b/mindformers/pynative/transformers/transformer_layer.py @@ -218,8 +218,8 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): def expand_hidden_states(hidden_states): hidden_states = hidden_states.unsqueeze(axis=2) return hidden_states.repeat(1, 1, self.mhc_expansion_rate, 1).astype(float32) - print(f"-------------- Transformer Layer {self.layer_number} ----------------") - print(f"@ Original hidden_states AMax: {ops.abs(hidden_states).max()} {ops.abs(hidden_states).mean()}") + # print(f"-------------- Transformer Layer {self.layer_number} ----------------") + # print(f"@ Original hidden_states AMax: {ops.abs(hidden_states).max()} {ops.abs(hidden_states).mean()}") # Before layernorm # if self.apply_manifold_constrained_hyper_connections: @@ -331,5 +331,5 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): output = self.add(residual, dropout_output) # Note: context parameter is returned for API compatibility but currently unused. # It may be deprecated in future versions. - print(f"@ transformer layer output AMax: {ops.abs(output).max()} {ops.abs(output).mean()}") + # print(f"@ transformer layer output AMax: {ops.abs(output).max()} {ops.abs(output).mean()}") return output, context, aux_loss diff --git a/run_pynative.py b/run_pynative.py index 1bdf0de366215c95baf302094bf17b95ac30b3b9..0a96cc4616f84efc56a0088676777244069f846b 100644 --- a/run_pynative.py +++ b/run_pynative.py @@ -2,6 +2,7 @@ import mindspore as ms from mindspore import nn, mint from mindspore.mint.distributed import init_process_group, destroy_process_group from regex import F +import yaml from mindformers.pynative.trainer.trainer import Trainer from mindformers.tools.register import MindFormerRegister, MindFormerModuleType, MindFormerConfig @@ -32,9 +33,18 @@ loss = mint.nn.functional.cross_entropy # format='safetensors' # ) +# Load memory swap config directly from YAML +swap_config = None +with open('ds_pynative.yaml', 'r') as f: + yaml_config = yaml.safe_load(f) + if 'memory_swap' in yaml_config: + from mindformers.pynative.memory import SwapConfig + swap_config = SwapConfig(**yaml_config['memory_swap']) + trainer = Trainer( model=model, config='ds_pynative.yaml', + swap_config=swap_config, # compute_loss_func=loss )