From b1bbddac93b03b17585a9bdeebe47f41dd883c53 Mon Sep 17 00:00:00 2001 From: userA Date: Tue, 10 Mar 2026 11:56:42 +0800 Subject: [PATCH 1/8] Add swap. --- ds_pynative.yaml | 7 + mindformers/pynative/memory/__init__.py | 21 + mindformers/pynative/memory/swap_config.py | 68 +++ mindformers/pynative/memory/swap_manager.py | 515 ++++++++++++++++++ mindformers/pynative/memory/swap_tensor.py | 167 ++++++ mindformers/pynative/trainer/trainer.py | 31 ++ .../transformers/transformer_block.py | 11 + run_pynative.py | 7 + 8 files changed, 827 insertions(+) create mode 100644 mindformers/pynative/memory/__init__.py create mode 100644 mindformers/pynative/memory/swap_config.py create mode 100644 mindformers/pynative/memory/swap_manager.py create mode 100644 mindformers/pynative/memory/swap_tensor.py diff --git a/ds_pynative.yaml b/ds_pynative.yaml index ed512c1ed..08422ab83 100644 --- a/ds_pynative.yaml +++ b/ds_pynative.yaml @@ -31,6 +31,13 @@ training: local_batch_size: 1 save_steps: 2000 +# memory swap config (optional, uncomment to enable) +# memory_swap: +# prefetch_interval: 1 # Number of iterations ahead to prefetch +# min_tensor_size: 1048576 # Minimum tensor size in bytes to swap (1MB) +# enable_stats: True # Whether to collect and print statistics +# enable_profiling: False # Whether to add profiling markers + # original diff --git a/mindformers/pynative/memory/__init__.py b/mindformers/pynative/memory/__init__.py new file mode 100644 index 000000000..152b582d6 --- /dev/null +++ b/mindformers/pynative/memory/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Memory management utilities for reducing NPU memory usage.""" + +from .swap_config import SwapConfig +from .swap_manager import SwapManager +from .swap_tensor import SwapTensor + +__all__ = ['SwapConfig', 'SwapManager', 'SwapTensor'] diff --git a/mindformers/pynative/memory/swap_config.py b/mindformers/pynative/memory/swap_config.py new file mode 100644 index 000000000..bc8a8c8ff --- /dev/null +++ b/mindformers/pynative/memory/swap_config.py @@ -0,0 +1,68 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Swap configuration for memory management in iterative solvers.""" + + +class SwapConfig: + """Configuration for memory swapping in iterative solvers. + + When passed to a solver (e.g., CBS), this enables swapping of activation + tensors between device and host memory to reduce peak NPU memory usage + during training. + + Args: + prefetch_interval (int, optional): Number of iterations ahead to prefetch + during backward pass. Default: ``1`` . + min_tensor_size (int, optional): Minimum tensor size in bytes to swap. + Default: ``1048576`` . + enable_stats (bool, optional): Whether to collect and print swap + statistics. Default: ``False`` . + enable_profiling (bool, optional): Whether to add profiling markers for + MindSpore profiler. Requires user to enable profiler externally with + mstx=True. Default: ``False`` . + + Examples: + >>> from mindformers.pynative.memory import SwapConfig + >>> swap_config = SwapConfig(prefetch_interval=2) + >>> trainer = Trainer(model=model, swap_config=swap_config) + """ + + def __init__(self, + prefetch_interval=1, + min_tensor_size=1*1024*1024, + enable_stats=False, + enable_profiling=False): + self.prefetch_interval = prefetch_interval + self.min_tensor_size = min_tensor_size + self.enable_stats = enable_stats + self.enable_profiling = enable_profiling + + def create_manager(self, n_iter): + """Create a SwapManager from this configuration. + + Args: + n_iter (int): Total number of iterations. + + Returns: + SwapManager: A configured swap manager instance. + """ + from .swap_manager import SwapManager + return SwapManager( + n_iter=n_iter, + prefetch_interval=self.prefetch_interval, + min_tensor_size=self.min_tensor_size, + enable_stats=self.enable_stats, + enable_profiling=self.enable_profiling, + ) diff --git a/mindformers/pynative/memory/swap_manager.py b/mindformers/pynative/memory/swap_manager.py new file mode 100644 index 000000000..79759ec14 --- /dev/null +++ b/mindformers/pynative/memory/swap_manager.py @@ -0,0 +1,515 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Memory swap manager for reducing NPU memory usage during training. + +This module provides a framework for swapping activation between device +(GPU/NPU) and host (CPU) memory to reduce peak memory usage during training. +It is designed to work with iterative solvers like CBS where activations from +each iteration need to be saved for backward propagation. +""" +import time +import logging +from contextlib import contextmanager +from typing import Generator, Union, Optional +import psutil +import mindspore as ms +from mindspore import Tensor +from mindspore import runtime +from mindspore.profiler import mstx + +from .swap_tensor import SwapTensor + +logger = logging.getLogger(__name__) +process = psutil.Process() + + +class SwapManager: + """Memory swap manager for coordinating tensor swapping across iterations. + + This class manages the swapping of activation tensors between device and host + memory during training. It uses saved_tensors_hooks to intercept tensor saving + during forward pass and restore them during backward pass. + + Args: + n_iter (int): Total number of iterations. + prefetch_interval (int): Number of iterations ahead to prefetch. Default: 1. + min_tensor_size (int): Minimum tensor size in bytes to swap. Default: 1MB. + enable_stats (bool): Whether to collect and print statistics. Default: False. + enable_profiling (bool): Whether to add profiling markers. Default: False. + + Examples: + >>> swap_manager = SwapManager(n_iter=100, prefetch_interval=2) + >>> for i in range(n_iter): + ... swap_manager.set_iteration(i) + ... with ms.saved_tensors_hooks(swap_manager.pack_hook, swap_manager.unpack_hook): + ... output = model(input) + ... swap_manager.end_iteration() + """ + + def __init__(self, n_iter: int, prefetch_interval: int = 1, + min_tensor_size: int = 1*1024*1024, + enable_stats: bool = False, enable_profiling: bool = False): + self.n_iter = n_iter + self.prefetch_interval = prefetch_interval + self.min_tensor_size = min_tensor_size + self.enable_stats = enable_stats + self.enable_profiling = enable_profiling + + self.transfer_stream = runtime.Stream() + + self.current_iter = 0 + + self.prefetch_list = {} + self.prefetched_iters = set() + + self.current_swap_tensors = [] + self.current_iter_tensor_ptrs = set() + self.last_prefetched_iter = None + self.last_d2h_iter = None + + self.iteration_h2d_events = {} + + if self.enable_stats: + self._reset_stats() + + def _mstx_mark(self, message: str) -> None: + """Mark a profiling point if profiling is enabled. + + Args: + message: The profiling message to mark. + """ + if self.enable_profiling: + mstx.mark(message, domain="memory_swap") + + def _mstx_range_start(self, message: str, stream=None) -> Optional[int]: + """Start a profiling range if profiling is enabled. + + Args: + message: The profiling message for the range. + stream: Optional stream to associate with the range. + + Returns: + Range ID if profiling is enabled, None otherwise. + """ + if self.enable_profiling: + return mstx.range_start(message, stream, "memory_swap") + return None + + def _mstx_range_end(self, range_id: Optional[int]) -> None: + """End a profiling range if profiling is enabled. + + Args: + range_id: The range ID returned by _mstx_range_start. + """ + if self.enable_profiling and range_id is not None: + mstx.range_end(range_id, "memory_swap") + + @contextmanager + def _mstx_range(self, message: str, stream=None) -> Generator[None, None, None]: + """Context manager for profiling range. + + Args: + message: The profiling message for the range. + stream: Optional stream to associate with the range. + + Yields: + None + """ + range_id = self._mstx_range_start(message, stream) + try: + yield + finally: + self._mstx_range_end(range_id) + + def should_swap(self, tensor: Tensor) -> bool: + """Determine whether a tensor should be swapped. + + Args: + tensor (Tensor): The tensor to check. + + Returns: + bool: True if the tensor should be swapped, False otherwise. + """ + tensor_size = tensor.numel() * tensor.itemsize + if tensor_size < self.min_tensor_size: + return False + + return True + + def pack_hook(self, tensor: Tensor) -> Union['SwapTensor', Tensor]: + """Pack hook for forward pass tensor saving. + + This hook is called by saved_tensors_hooks during forward pass. + It creates a SwapTensor and records forward event for later batch D2H. + + Args: + tensor (Tensor): The tensor to be saved. + + Returns: + SwapTensor or Tensor: SwapTensor if swapping, original tensor otherwise. + """ + if self.enable_stats: + start = time.perf_counter() + + if not self.should_swap(tensor): + if self.enable_stats: + self.skipped_count += 1 + return tensor + + tensor_ptr = tensor.data_ptr() + if tensor_ptr in self.current_iter_tensor_ptrs: + for swap_tensor in self.current_swap_tensors: + if swap_tensor.tensor.data_ptr() == tensor_ptr: + if self.enable_stats: + self.duplicate_tensor_count += 1 + if self.enable_profiling: + mstx.mark( + f"MemSwap/Iter{self.current_iter}/Pack_Duplicate_shape{list(tensor.shape)}", + domain="memory_swap" + ) + return swap_tensor + + self.current_iter_tensor_ptrs.add(tensor_ptr) + + tensor_idx = len(self.current_swap_tensors) + self._mstx_mark( + f"MemSwap/Iter{self.current_iter}/Pack_T{tensor_idx}_shape{list(tensor.shape)}_size{tensor.numel() * tensor.itemsize}" + ) + + swap_tensor = SwapTensor(tensor, self.current_iter, self.transfer_stream) + + if self.current_iter < self.n_iter - self.prefetch_interval: + swap_tensor.record_forward_event() + self.current_swap_tensors.append(swap_tensor) + + if self.enable_stats: + self.pack_count += 1 + self.total_swap_bytes += tensor.numel() * tensor.itemsize + self.pack_time += time.perf_counter() - start + + return swap_tensor + + def unpack_hook(self, packed: Union['SwapTensor', Tensor]) -> Tensor: + """Unpack hook for backward pass tensor restoration. + + This hook is called by saved_tensors_hooks during backward pass. + It triggers prefetching and waits for H2D transfer to complete. + + Args: + packed: SwapTensor or original Tensor. + + Returns: + Tensor: The restored device tensor. + """ + if isinstance(packed, Tensor): + return packed + + if self.enable_stats: + start = time.perf_counter() + swap_tensor = packed + + if self.enable_profiling: + tensor_shape = list(swap_tensor.shape) + status = swap_tensor.stat + mstx.mark( + f"MemSwap/Iter{swap_tensor.iteration_id}/Unpack_shape{tensor_shape}_stat{status}", + domain="memory_swap" + ) + + if swap_tensor.stat == "host": + swap_tensor.tensor = swap_tensor.tensor_cpu.to('Ascend') + swap_tensor.tensor_cpu = None + swap_tensor.stat = "device" + + if self.enable_stats: + self.prefetch_misses += 1 + if self.enable_profiling: + mstx.mark( + f"MemSwap/Iter{swap_tensor.iteration_id}/Prefetch_Miss", + domain="memory_swap" + ) + elif self.enable_stats: + self.prefetch_hits += 1 + + iter_id = swap_tensor.iteration_id + if iter_id in self.iteration_h2d_events: + if self.enable_stats: + wait_start = time.perf_counter() + + self._mstx_mark(f"MemSwap/Iter{iter_id}/UnpackWait") + + runtime.current_stream().wait_event(self.iteration_h2d_events[iter_id]) + + if self.enable_stats: + elapsed = time.perf_counter() - wait_start + self.h2d_wait_count += 1 + self.h2d_wait_time += elapsed + + self.iteration_h2d_events.pop(iter_id, None) + if iter_id in self.prefetch_list: + for st in self.prefetch_list[iter_id]: + if st.stat == "h2d": + st.stat = "device" + + del self.prefetch_list[iter_id] + + prefetch_target = swap_tensor.iteration_id - self.prefetch_interval + if prefetch_target >= 0 and prefetch_target in self.prefetch_list: + if prefetch_target not in self.prefetched_iters: + self.prefetch_h2d(prefetch_target) + self.prefetched_iters.add(prefetch_target) + + if self.enable_stats: + self.unpack_count += 1 + self.unpack_time += time.perf_counter() - start + + return swap_tensor.get_tensor() + + def set_iteration(self, iter_id: int) -> None: + """Set the current iteration ID. + + Args: + iter_id (int): The current iteration index. + """ + self.current_iter = iter_id + self.current_swap_tensors = [] + self.current_iter_tensor_ptrs.clear() + + def end_iteration(self) -> None: + """End the current iteration and launch D2H transfers. + + This method waits for the previous iteration's D2H to complete first, + then launches D2H transfers for the current iteration. This allows the + previous iteration's D2H to overlap with the current iteration's computation. + """ + + if self.last_d2h_iter is not None and self.last_d2h_iter in self.prefetch_list: + prev_tensors = self.prefetch_list[self.last_d2h_iter] + if prev_tensors and prev_tensors[0].stat == "d2h": + wait_range_id = None + if self.enable_profiling: + wait_range_id = mstx.range_start( + f"MemSwap/Iter{self.last_d2h_iter}/D2H_Wait", + self.transfer_stream, + "memory_swap" + ) + + wait_start = time.perf_counter() if self.enable_stats else 0 + runtime.current_stream().wait_stream(self.transfer_stream) + if self.enable_stats: + elapsed = time.perf_counter() - wait_start + self.d2h_wait_count += 1 + self.d2h_wait_time += elapsed + + if self.enable_profiling and wait_range_id is not None: + mstx.range_end(wait_range_id, "memory_swap") + + for swap_tensor in prev_tensors: + swap_tensor.wait_d2h_finished(need_wait=False) + + self.last_d2h_iter = None + + if not self.current_swap_tensors: + if self.enable_stats: + self.iter_swap_counts.append(0) + self.current_iter_tensor_ptrs.clear() + return + + with self._mstx_range(f"MemSwap/Iter{self.current_iter}/D2H_Batch", self.transfer_stream): + with runtime.StreamCtx(self.transfer_stream): + for swap_tensor in self.current_swap_tensors: + if swap_tensor.stat == "event_recorded": + swap_tensor.execute_d2h_copy() + + if self.current_swap_tensors: + self.last_d2h_iter = self.current_iter + + n_swapped = len(self.current_swap_tensors) + if self.enable_stats: + self.iter_swap_counts.append(n_swapped) + swapped_bytes = sum(t.size for t in self.current_swap_tensors) + mem_info = psutil.virtual_memory() + logger.debug( + f"Iteration {self.current_iter}: swapped {n_swapped} tensors ({swapped_bytes / (1024 * 1024):.2f}MB), " + f"device={ms.runtime.memory_allocated() / (1024 * 1024 * 1024):.2f}GB, " + f"cpu={process.memory_info().rss / (1024 * 1024 * 1024):.2f}GB, " + f"system_free={mem_info.free / (1024 * 1024 * 1024):.2f}GB" + ) + + self.prefetch_list[self.current_iter] = self.current_swap_tensors + self.current_swap_tensors = [] + self.current_iter_tensor_ptrs.clear() + + def prefetch_h2d(self, target_iter: int) -> None: + """Prefetch tensors from a target iteration with layer-wise waiting. + + This method waits for the previous layer's H2D to complete before + starting the current layer's H2D, enabling timely memory release. + + Args: + target_iter (int): The iteration index to prefetch. + """ + if self.last_prefetched_iter is not None and self.last_prefetched_iter in self.prefetch_list: + prev_tensors = self.prefetch_list[self.last_prefetched_iter] + if prev_tensors and prev_tensors[0].stat == "h2d": + wait_start = time.perf_counter() if self.enable_stats else 0 + runtime.current_stream().wait_stream(self.transfer_stream) + if self.enable_stats: + elapsed = time.perf_counter() - wait_start + self.h2d_wait_count += 1 + self.h2d_wait_time += elapsed + + for swap_tensor in prev_tensors: + swap_tensor.wait_h2d_finished(need_wait=False) + self.last_prefetched_iter = None + + if target_iter not in self.prefetch_list: + return + + swap_tensors = self.prefetch_list[target_iter] + + backward_event = runtime.current_stream().record_event() + with self._mstx_range(f"MemSwap/Iter{target_iter}/H2D_Prefetch", self.transfer_stream): + with runtime.StreamCtx(self.transfer_stream): + self.transfer_stream.wait_event(backward_event) + for swap_tensor in swap_tensors: + if swap_tensor.stat == "host": + swap_tensor.execute_h2d_copy() + self.iteration_h2d_events[target_iter] = self.transfer_stream.record_event() + + self.last_prefetched_iter = target_iter + + def finish(self) -> None: + """Finish swap management and print statistics. + + Waits for any pending D2H operations before cleanup. + Prints swap statistics and clears all stored swap tensors, + resetting the state for potential reuse. + """ + self.print_stats() + + # Explicitly release all tensor references in SwapTensor objects + for iter_tensors in self.prefetch_list.values(): + for swap_tensor in iter_tensors: + swap_tensor.tensor = None + swap_tensor.tensor_cpu = None + + self.prefetch_list.clear() + self.prefetched_iters.clear() + self.iteration_h2d_events.clear() + self.current_swap_tensors = [] + self.current_iter_tensor_ptrs.clear() + self.last_prefetched_iter = None + self.last_d2h_iter = None + self.current_iter = 0 + if self.enable_stats: + self._reset_stats() + + def _reset_stats(self) -> None: + """Reset all statistics counters.""" + self.pack_count = 0 + self.unpack_count = 0 + self.pack_time = 0.0 + self.unpack_time = 0.0 + self.skipped_count = 0 + self.total_swap_bytes = 0 + self.prefetch_hits = 0 + self.prefetch_misses = 0 + self.d2h_wait_count = 0 + self.d2h_wait_time = 0.0 + self.h2d_wait_count = 0 + self.h2d_wait_time = 0.0 + self.iter_times = [] + self.iter_swap_counts = [] + self.duplicate_tensor_count = 0 + + @contextmanager + def iteration(self, iter_id: int) -> Generator[None, None, None]: + """Context manager that wraps set_iteration, saved_tensors_hooks, and end_iteration. + + This provides a convenient way to use the swap manager in a for loop, + and centralizes the hook setup so that logging and statistics can be + added in one place. + + Args: + iter_id (int): The current iteration index. + + Yields: + None + + Examples: + >>> swap_manager = SwapManager(n_iter=100) + >>> for i in range(n_iter): + ... with swap_manager.iteration(i): + ... output = model(input) + """ + self.set_iteration(iter_id) + iter_start = time.perf_counter() if self.enable_stats else 0 + with ms.saved_tensors_hooks(self.pack_hook, self.unpack_hook): + yield + self.end_iteration() + if self.enable_stats: + self.iter_times.append(time.perf_counter() - iter_start) + + def print_stats(self) -> None: + """Print a summary of swap statistics to logger. + + Logs hook timing, tensor swap counts, data volume, + prefetch hit rate, and per-iteration breakdown. + Does nothing if enable_stats is False. + """ + if not self.enable_stats: + return + + total_hook_time = self.pack_time + self.unpack_time + total_iter_time = sum(self.iter_times) if self.iter_times else 0.0 + total_prefetch = self.prefetch_hits + self.prefetch_misses + prefetch_hit_rate = ( + self.prefetch_hits / total_prefetch if total_prefetch > 0 else 0.0 + ) + overhead_pct = (total_hook_time / total_iter_time * 100) if total_iter_time > 0 else 0.0 + + mem_info = psutil.virtual_memory() + device_gb = ms.runtime.memory_allocated() / (1024 * 1024 * 1024) + cpu_gb = process.memory_info().rss / (1024 * 1024 * 1024) + system_total_gb = mem_info.total / (1024 * 1024 * 1024) + system_free_gb = mem_info.free / (1024 * 1024 * 1024) + + logger.info("=== SwapManager Statistics ===") + logger.info( + f"Hook timing: pack={self.pack_time * 1000:.1f}ms ({self.pack_count} calls), " + f"unpack={self.unpack_time * 1000:.1f}ms ({self.unpack_count} calls), " + f"total={total_hook_time * 1000:.1f}ms ({overhead_pct:.1f}% overhead)" + ) + logger.info( + f"Tensors: {self.pack_count} swapped, {self.skipped_count} skipped, " + f"{self.duplicate_tensor_count} duplicates, " + f"{self.total_swap_bytes / (1024 * 1024 * 1024):.2f} GB total" + ) + logger.info( + f"Prefetch: {self.prefetch_hits} hits, {self.prefetch_misses} misses, " + f"hit_rate={prefetch_hit_rate * 100:.1f}%" + ) + logger.info( + f"Iterations: {len(self.iter_times)}, total_time={total_iter_time:.3f}s" + ) + logger.info( + f"Stream sync: D2H wait={self.d2h_wait_count} times {self.d2h_wait_time * 1000:.1f}ms, " + f"H2D wait={self.h2d_wait_count} times {self.h2d_wait_time * 1000:.1f}ms" + ) + logger.info( + f"Memory: device={device_gb:.2f}GB, cpu={cpu_gb:.2f}GB, " + f"system_total={system_total_gb:.2f}GB, system_free={system_free_gb:.2f}GB" + ) diff --git a/mindformers/pynative/memory/swap_tensor.py b/mindformers/pynative/memory/swap_tensor.py new file mode 100644 index 000000000..95fc2a471 --- /dev/null +++ b/mindformers/pynative/memory/swap_tensor.py @@ -0,0 +1,167 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SwapTensor for managing tensor transmission between device and host.""" +import logging +import psutil +import mindspore as ms +from mindspore import Tensor +from mindspore import runtime + +logger = logging.getLogger(__name__) +process = psutil.Process() + + +class SwapTensor: + """Manages the transmission of a single tensor between device and host. + + This class handles the lifecycle of a tensor as it moves between device + memory and host memory. It tracks the tensor's state and provides methods + for asynchronous data transfer. + + State machine: + device -> d2h -> host -> h2d -> device + + Args: + tensor (Tensor): The original device tensor to be swapped. + iteration_id (int): The iteration index this tensor belongs to. + stream (mindspore.runtime.Stream): The stream used for asynchronous data transfer. + + Attributes: + tensor (Tensor): Reference to the original device tensor. + tensor_cpu (Tensor): CPU buffer for the tensor data. + shape (tuple): Shape of the tensor. + iteration_id (int): Iteration index this tensor belongs to. + stream (Stream): Transfer stream for async operations. + stat (str): Current state of the tensor. + """ + + def __init__(self, tensor: Tensor, iteration_id: int, stream: runtime.Stream): + """Initialize SwapTensor. + + Args: + tensor (Tensor): The original device tensor to be swapped. + iteration_id (int): The iteration index this tensor belongs to. + stream (mindspore.runtime.Stream): The stream used for asynchronous data transfer. + """ + self.tensor = tensor + self.shape = tensor.shape + self.size = tensor.numel() * tensor.itemsize + self.iteration_id = iteration_id + self.stream = stream + + self.tensor_cpu = None + + self.stat = "device" + self.forward_event = None + + def record_forward_event(self) -> None: + """Record forward event without switching stream. + + This method only records the event on the current stream, deferring + the actual D2H copy to be executed later in batch mode. + """ + self.forward_event = runtime.current_stream().record_event() + self.stat = "event_recorded" + + def execute_d2h_copy(self) -> None: + """Execute D2H copy in the transfer stream. + + This method must be called within a StreamCtx(transfer_stream) context. + It waits for the forward event and performs the actual D2H copy. + """ + if self.forward_event is not None: + self.stream.wait_event(self.forward_event) + self.forward_event = None + + self.tensor_cpu = self.tensor.to('CPU', non_blocking=True) + self.tensor = None + self.stat = "d2h" + + def launch_d2h(self) -> None: + """Launch asynchronous device-to-host transfer (legacy method). + + This method is kept for backward compatibility but internally uses + the new batch processing approach. + """ + if self.stat != "device": + return + + self.record_forward_event() + with runtime.StreamCtx(self.stream): + self.execute_d2h_copy() + + def wait_d2h_finished(self, need_wait: bool = False) -> None: + """Wait for D2H transfer to complete and release device memory. + + Args: + need_wait (bool): Whether to synchronize with the transfer stream. + Only the first tensor in each iteration needs to wait. + """ + if self.stat != "d2h": + return + + if need_wait: + runtime.current_stream().wait_stream(self.stream) + + self.tensor = None + self.stat = "host" + + def execute_h2d_copy(self) -> None: + """Execute H2D copy in the transfer stream. + + This method must be called within a StreamCtx(transfer_stream) context. + It waits for the backward event and performs the actual H2D copy. + """ + self.tensor = self.tensor_cpu.to('Ascend', non_blocking=True) + self.tensor_cpu = None + self.stat = "h2d" + + def launch_h2d(self) -> None: + """Launch asynchronous host-to-device transfer (legacy method). + + This method is kept for backward compatibility but internally uses + the new batch processing approach. + """ + if self.stat != "host": + return + + backward_event = runtime.current_stream().record_event() + with runtime.StreamCtx(self.stream): + self.stream.wait_event(backward_event) + self.execute_h2d_copy() + + def wait_h2d_finished(self, need_wait: bool = False) -> None: + """Wait for H2D transfer to complete. + + Args: + need_wait (bool): Whether to synchronize with the transfer stream. + Only the last tensor in each iteration needs to wait. + """ + if self.stat != "h2d": + return + + if need_wait: + runtime.current_stream().wait_stream(self.stream) + + self.tensor_cpu = None + self.stat = "device" + + def get_tensor(self) -> Tensor: + """Get the restored device tensor. + + Returns: + Tensor: The tensor on device memory. + """ + return self.tensor diff --git a/mindformers/pynative/trainer/trainer.py b/mindformers/pynative/trainer/trainer.py index 274e968a9..70a07236a 100644 --- a/mindformers/pynative/trainer/trainer.py +++ b/mindformers/pynative/trainer/trainer.py @@ -18,6 +18,7 @@ import enum from typing import Optional, Callable, List, Dict, Any import numpy as np +import mindspore as ms from mindspore import manual_seed, value_and_grad, mint from mindspore.dataset import Dataset from mindspore.common import set_seed @@ -87,6 +88,7 @@ class Trainer: lr_scheduler: Optional[Any] = None, compute_loss_func: Optional[Callable] = None, callbacks: Optional[List] = None, + swap_config: Optional['SwapConfig'] = None, ): """ Initialize the Trainer. @@ -155,6 +157,10 @@ class Trainer: self.grad_fn = None self.clip_grad = None + # Memory swap configuration + self.swap_config = swap_config + self.swap_manager = None + def _init_config(self, config: str) -> MindFormerConfig: """ Initialize trainer config from yaml file or MindFormerConfig instance. @@ -459,6 +465,23 @@ class Trainer: global_batch_size=self.global_batch_size, ) + # Initialize swap manager if configured + if self.swap_config is not None: + from mindformers.pynative.memory import SwapManager + # Get number of layers from model + num_layers = getattr(self.model.model.decoder, 'num_layers', 1) + n_iter = num_layers * self.state.max_steps + self.swap_manager = self.swap_config.create_manager(n_iter=n_iter) + self.swap_manager.current_step = 0 + + # Attach swap_manager to TransformerBlock and register hooks for each layer + decoder = self.model.model.decoder + decoder.swap_manager = self.swap_manager + for layer_idx, layer in enumerate(decoder.layers): + layer.register_saved_tensors_hooks(self.swap_manager.pack_hook, self.swap_manager.unpack_hook) + + logger.info(f"Memory swap enabled: {num_layers} layers, prefetch_interval={self.swap_config.prefetch_interval}") + # Load checkpoint if checkpoint_path: self._load_checkpoint(checkpoint_path) @@ -476,6 +499,10 @@ class Trainer: # Execute training loop self._inner_train_loop() + # Cleanup swap manager + if self.swap_manager is not None: + self.swap_manager.finish() + # Call train end callback if self.callback_handler is not None: self.callback_handler.on_train_end(self.config, self.state) @@ -733,6 +760,10 @@ class Trainer: """ loss_sync_allreduce = AllReduce(ReduceOp.SUM) + # Update swap manager current step if enabled + if self.swap_manager is not None: + self.swap_manager.current_step = self.state.global_step + # Forward and compute loss (loss, _), grads = self.grad_fn(model, inputs) if self.use_parallel: diff --git a/mindformers/pynative/transformers/transformer_block.py b/mindformers/pynative/transformers/transformer_block.py index 9de09968d..a12f1fb8c 100644 --- a/mindformers/pynative/transformers/transformer_block.py +++ b/mindformers/pynative/transformers/transformer_block.py @@ -157,9 +157,16 @@ class TransformerBlock(nn.Cell): - aux_loss (Tensor | None): Optional auxiliary loss accumulated over layers. """ aux_loss = None + swap_manager = getattr(self, 'swap_manager', None) + for index in range(self.num_layers): layer = self._get_layer(index) prefix_kv = prefix_keys_values[index] if prefix_keys_values is not None else None + + if swap_manager is not None: + iter_id = swap_manager.current_step * self.num_layers + index + swap_manager.set_iteration(iter_id) + hidden_states, _, layer_aux_loss = layer( hidden_states, attention_mask, @@ -168,6 +175,10 @@ class TransformerBlock(nn.Cell): actual_seq_len=actual_seq_len, input_ids=input_ids, ) + + if swap_manager is not None: + swap_manager.end_iteration() + if layer_aux_loss is not None: aux_loss = layer_aux_loss if aux_loss is None else aux_loss + layer_aux_loss diff --git a/run_pynative.py b/run_pynative.py index 1bdf0de36..6c1bd17e4 100644 --- a/run_pynative.py +++ b/run_pynative.py @@ -32,9 +32,16 @@ loss = mint.nn.functional.cross_entropy # format='safetensors' # ) +# Load memory swap config if present +swap_config = None +if hasattr(mf_config, 'memory_swap') and mf_config.memory_swap: + from mindformers.pynative.memory import SwapConfig + swap_config = SwapConfig(**mf_config.memory_swap.to_dict()) + trainer = Trainer( model=model, config='ds_pynative.yaml', + swap_config=swap_config, # compute_loss_func=loss ) -- Gitee From a68d48f91fc47d8396487462d5e4ac0711236f33 Mon Sep 17 00:00:00 2001 From: userA Date: Tue, 10 Mar 2026 14:23:30 +0800 Subject: [PATCH 2/8] Update memory swap config. --- run_pynative.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/run_pynative.py b/run_pynative.py index 6c1bd17e4..0a96cc461 100644 --- a/run_pynative.py +++ b/run_pynative.py @@ -2,6 +2,7 @@ import mindspore as ms from mindspore import nn, mint from mindspore.mint.distributed import init_process_group, destroy_process_group from regex import F +import yaml from mindformers.pynative.trainer.trainer import Trainer from mindformers.tools.register import MindFormerRegister, MindFormerModuleType, MindFormerConfig @@ -32,11 +33,13 @@ loss = mint.nn.functional.cross_entropy # format='safetensors' # ) -# Load memory swap config if present +# Load memory swap config directly from YAML swap_config = None -if hasattr(mf_config, 'memory_swap') and mf_config.memory_swap: - from mindformers.pynative.memory import SwapConfig - swap_config = SwapConfig(**mf_config.memory_swap.to_dict()) +with open('ds_pynative.yaml', 'r') as f: + yaml_config = yaml.safe_load(f) + if 'memory_swap' in yaml_config: + from mindformers.pynative.memory import SwapConfig + swap_config = SwapConfig(**yaml_config['memory_swap']) trainer = Trainer( model=model, -- Gitee From db9d9d5490da4ce3731379bd0434f2a0e39c0227 Mon Sep 17 00:00:00 2001 From: userA Date: Tue, 10 Mar 2026 14:37:20 +0800 Subject: [PATCH 3/8] Fix iter bug --- mindformers/pynative/trainer/trainer.py | 17 ++++++----------- .../pynative/transformers/transformer_block.py | 3 +-- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/mindformers/pynative/trainer/trainer.py b/mindformers/pynative/trainer/trainer.py index 70a07236a..94cca8570 100644 --- a/mindformers/pynative/trainer/trainer.py +++ b/mindformers/pynative/trainer/trainer.py @@ -470,9 +470,8 @@ class Trainer: from mindformers.pynative.memory import SwapManager # Get number of layers from model num_layers = getattr(self.model.model.decoder, 'num_layers', 1) - n_iter = num_layers * self.state.max_steps - self.swap_manager = self.swap_config.create_manager(n_iter=n_iter) - self.swap_manager.current_step = 0 + # n_iter = num_layers (each training step is a complete swap cycle) + self.swap_manager = self.swap_config.create_manager(n_iter=num_layers) # Attach swap_manager to TransformerBlock and register hooks for each layer decoder = self.model.model.decoder @@ -499,10 +498,6 @@ class Trainer: # Execute training loop self._inner_train_loop() - # Cleanup swap manager - if self.swap_manager is not None: - self.swap_manager.finish() - # Call train end callback if self.callback_handler is not None: self.callback_handler.on_train_end(self.config, self.state) @@ -760,10 +755,6 @@ class Trainer: """ loss_sync_allreduce = AllReduce(ReduceOp.SUM) - # Update swap manager current step if enabled - if self.swap_manager is not None: - self.swap_manager.current_step = self.state.global_step - # Forward and compute loss (loss, _), grads = self.grad_fn(model, inputs) if self.use_parallel: @@ -798,6 +789,10 @@ class Trainer: # Optimizer step self.optimizer(grads) + # Finish swap cycle for this training step + if self.swap_manager is not None: + self.swap_manager.finish() + return loss, grad_norm def _get_data_parallel(self): diff --git a/mindformers/pynative/transformers/transformer_block.py b/mindformers/pynative/transformers/transformer_block.py index a12f1fb8c..987599486 100644 --- a/mindformers/pynative/transformers/transformer_block.py +++ b/mindformers/pynative/transformers/transformer_block.py @@ -164,8 +164,7 @@ class TransformerBlock(nn.Cell): prefix_kv = prefix_keys_values[index] if prefix_keys_values is not None else None if swap_manager is not None: - iter_id = swap_manager.current_step * self.num_layers + index - swap_manager.set_iteration(iter_id) + swap_manager.set_iteration(index) hidden_states, _, layer_aux_loss = layer( hidden_states, -- Gitee From 81d8452c8182702058de7587c9fc89ca243be65e Mon Sep 17 00:00:00 2001 From: wangchengzhao Date: Tue, 10 Mar 2026 14:49:16 +0800 Subject: [PATCH 4/8] Update swap logger. --- mindformers/pynative/memory/swap_manager.py | 2 +- mindformers/pynative/memory/swap_tensor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mindformers/pynative/memory/swap_manager.py b/mindformers/pynative/memory/swap_manager.py index 79759ec14..054ae9554 100644 --- a/mindformers/pynative/memory/swap_manager.py +++ b/mindformers/pynative/memory/swap_manager.py @@ -31,7 +31,7 @@ from mindspore.profiler import mstx from .swap_tensor import SwapTensor -logger = logging.getLogger(__name__) +from mindformers.tools.logger import logger process = psutil.Process() diff --git a/mindformers/pynative/memory/swap_tensor.py b/mindformers/pynative/memory/swap_tensor.py index 95fc2a471..7a8a7feba 100644 --- a/mindformers/pynative/memory/swap_tensor.py +++ b/mindformers/pynative/memory/swap_tensor.py @@ -19,7 +19,7 @@ import mindspore as ms from mindspore import Tensor from mindspore import runtime -logger = logging.getLogger(__name__) +from mindformers.tools.logger import logger process = psutil.Process() -- Gitee From 1f1d1892613b33e58e3c7cf3327f7c8e2035f36e Mon Sep 17 00:00:00 2001 From: userA Date: Tue, 10 Mar 2026 15:27:34 +0800 Subject: [PATCH 5/8] Update hook use. --- mindformers/pynative/trainer/trainer.py | 4 +-- .../transformers/transformer_block.py | 29 ++++++++++++------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/mindformers/pynative/trainer/trainer.py b/mindformers/pynative/trainer/trainer.py index 94cca8570..b43ddf01c 100644 --- a/mindformers/pynative/trainer/trainer.py +++ b/mindformers/pynative/trainer/trainer.py @@ -473,11 +473,9 @@ class Trainer: # n_iter = num_layers (each training step is a complete swap cycle) self.swap_manager = self.swap_config.create_manager(n_iter=num_layers) - # Attach swap_manager to TransformerBlock and register hooks for each layer + # Attach swap_manager to TransformerBlock decoder = self.model.model.decoder decoder.swap_manager = self.swap_manager - for layer_idx, layer in enumerate(decoder.layers): - layer.register_saved_tensors_hooks(self.swap_manager.pack_hook, self.swap_manager.unpack_hook) logger.info(f"Memory swap enabled: {num_layers} layers, prefetch_interval={self.swap_config.prefetch_interval}") diff --git a/mindformers/pynative/transformers/transformer_block.py b/mindformers/pynative/transformers/transformer_block.py index 987599486..fcc446fc8 100644 --- a/mindformers/pynative/transformers/transformer_block.py +++ b/mindformers/pynative/transformers/transformer_block.py @@ -165,18 +165,25 @@ class TransformerBlock(nn.Cell): if swap_manager is not None: swap_manager.set_iteration(index) - - hidden_states, _, layer_aux_loss = layer( - hidden_states, - attention_mask, - rotary_pos_emb=rotary_pos_emb, - prefix_keys_values=prefix_kv, - actual_seq_len=actual_seq_len, - input_ids=input_ids, - ) - - if swap_manager is not None: + with ms.saved_tensors_hooks(swap_manager.pack_hook, swap_manager.unpack_hook): + hidden_states, _, layer_aux_loss = layer( + hidden_states, + attention_mask, + rotary_pos_emb=rotary_pos_emb, + prefix_keys_values=prefix_kv, + actual_seq_len=actual_seq_len, + input_ids=input_ids, + ) swap_manager.end_iteration() + else: + hidden_states, _, layer_aux_loss = layer( + hidden_states, + attention_mask, + rotary_pos_emb=rotary_pos_emb, + prefix_keys_values=prefix_kv, + actual_seq_len=actual_seq_len, + input_ids=input_ids, + ) if layer_aux_loss is not None: aux_loss = layer_aux_loss if aux_loss is None else aux_loss + layer_aux_loss -- Gitee From 1a565658be53d0e6585997533fcbef18cf94c1fc Mon Sep 17 00:00:00 2001 From: wangchengzhao Date: Tue, 10 Mar 2026 16:15:58 +0800 Subject: [PATCH 6/8] Update minor bug. --- mindformers/pynative/memory/swap_manager.py | 3 ++- mindformers/pynative/trainer/trainer.py | 1 - mindformers/pynative/transformers/transformer_block.py | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/mindformers/pynative/memory/swap_manager.py b/mindformers/pynative/memory/swap_manager.py index 054ae9554..87097761a 100644 --- a/mindformers/pynative/memory/swap_manager.py +++ b/mindformers/pynative/memory/swap_manager.py @@ -483,6 +483,7 @@ class SwapManager: mem_info = psutil.virtual_memory() device_gb = ms.runtime.memory_allocated() / (1024 * 1024 * 1024) + reserved_gb = ms.runtime.memory_reserved() / (1024 * 1024 * 1024) cpu_gb = process.memory_info().rss / (1024 * 1024 * 1024) system_total_gb = mem_info.total / (1024 * 1024 * 1024) system_free_gb = mem_info.free / (1024 * 1024 * 1024) @@ -510,6 +511,6 @@ class SwapManager: f"H2D wait={self.h2d_wait_count} times {self.h2d_wait_time * 1000:.1f}ms" ) logger.info( - f"Memory: device={device_gb:.2f}GB, cpu={cpu_gb:.2f}GB, " + f"Memory: device={device_gb:.2f}GB, reserved={reserved_gb:.2f}GB, cpu={cpu_gb:.2f}GB, " f"system_total={system_total_gb:.2f}GB, system_free={system_free_gb:.2f}GB" ) diff --git a/mindformers/pynative/trainer/trainer.py b/mindformers/pynative/trainer/trainer.py index b43ddf01c..c3c9be396 100644 --- a/mindformers/pynative/trainer/trainer.py +++ b/mindformers/pynative/trainer/trainer.py @@ -467,7 +467,6 @@ class Trainer: # Initialize swap manager if configured if self.swap_config is not None: - from mindformers.pynative.memory import SwapManager # Get number of layers from model num_layers = getattr(self.model.model.decoder, 'num_layers', 1) # n_iter = num_layers (each training step is a complete swap cycle) diff --git a/mindformers/pynative/transformers/transformer_block.py b/mindformers/pynative/transformers/transformer_block.py index fcc446fc8..147dead60 100644 --- a/mindformers/pynative/transformers/transformer_block.py +++ b/mindformers/pynative/transformers/transformer_block.py @@ -164,8 +164,7 @@ class TransformerBlock(nn.Cell): prefix_kv = prefix_keys_values[index] if prefix_keys_values is not None else None if swap_manager is not None: - swap_manager.set_iteration(index) - with ms.saved_tensors_hooks(swap_manager.pack_hook, swap_manager.unpack_hook): + with swap_manager.iteration(index): hidden_states, _, layer_aux_loss = layer( hidden_states, attention_mask, @@ -174,7 +173,6 @@ class TransformerBlock(nn.Cell): actual_seq_len=actual_seq_len, input_ids=input_ids, ) - swap_manager.end_iteration() else: hidden_states, _, layer_aux_loss = layer( hidden_states, -- Gitee From 64149efe426e460e7395bddb3700295ba9b8cd4e Mon Sep 17 00:00:00 2001 From: wang_cheng_zhao Date: Tue, 10 Mar 2026 17:04:54 +0800 Subject: [PATCH 7/8] Fix npu memory leak. --- mindformers/pynative/memory/swap_manager.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/mindformers/pynative/memory/swap_manager.py b/mindformers/pynative/memory/swap_manager.py index 87097761a..a1c730b3f 100644 --- a/mindformers/pynative/memory/swap_manager.py +++ b/mindformers/pynative/memory/swap_manager.py @@ -398,13 +398,29 @@ class SwapManager: Prints swap statistics and clears all stored swap tensors, resetting the state for potential reuse. """ + # Wait for pending transfers before cleanup + if self.last_d2h_iter is not None and self.last_d2h_iter in self.prefetch_list: + prev_tensors = self.prefetch_list[self.last_d2h_iter] + if prev_tensors and prev_tensors[0].stat == "d2h": + runtime.current_stream().wait_stream(self.transfer_stream) + for swap_tensor in prev_tensors: + swap_tensor.wait_d2h_finished(need_wait=False) + + if self.last_prefetched_iter is not None and self.last_prefetched_iter in self.prefetch_list: + prev_tensors = self.prefetch_list[self.last_prefetched_iter] + if prev_tensors and prev_tensors[0].stat == "h2d": + runtime.current_stream().wait_stream(self.transfer_stream) + for swap_tensor in prev_tensors: + swap_tensor.wait_h2d_finished(need_wait=False) + self.print_stats() - # Explicitly release all tensor references in SwapTensor objects + # Explicitly release all tensor and event references for iter_tensors in self.prefetch_list.values(): for swap_tensor in iter_tensors: swap_tensor.tensor = None swap_tensor.tensor_cpu = None + swap_tensor.forward_event = None self.prefetch_list.clear() self.prefetched_iters.clear() -- Gitee From 1a5a78b0f6d47c008516d4cb7775ba9187e20371 Mon Sep 17 00:00:00 2001 From: wangchengzhao Date: Thu, 12 Mar 2026 14:27:17 +0800 Subject: [PATCH 8/8] Clean print to avoid npu memory leak. --- mindformers/pynative/transformers/transformer_layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindformers/pynative/transformers/transformer_layer.py b/mindformers/pynative/transformers/transformer_layer.py index 40bb092ab..bde32cfac 100644 --- a/mindformers/pynative/transformers/transformer_layer.py +++ b/mindformers/pynative/transformers/transformer_layer.py @@ -218,8 +218,8 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): def expand_hidden_states(hidden_states): hidden_states = hidden_states.unsqueeze(axis=2) return hidden_states.repeat(1, 1, self.mhc_expansion_rate, 1).astype(float32) - print(f"-------------- Transformer Layer {self.layer_number} ----------------") - print(f"@ Original hidden_states AMax: {ops.abs(hidden_states).max()} {ops.abs(hidden_states).mean()}") + # print(f"-------------- Transformer Layer {self.layer_number} ----------------") + # print(f"@ Original hidden_states AMax: {ops.abs(hidden_states).max()} {ops.abs(hidden_states).mean()}") # Before layernorm # if self.apply_manifold_constrained_hyper_connections: @@ -331,5 +331,5 @@ class TransformerLayer(nn.Cell, BaseTransformerLayer): output = self.add(residual, dropout_output) # Note: context parameter is returned for API compatibility but currently unused. # It may be deprecated in future versions. - print(f"@ transformer layer output AMax: {ops.abs(output).max()} {ops.abs(output).mean()}") + # print(f"@ transformer layer output AMax: {ops.abs(output).max()} {ops.abs(output).mean()}") return output, context, aux_loss -- Gitee