diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index 95049036d..bf7c29e4f 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -151,10 +151,18 @@ cur_iter=-1, ) +# cpu_offloading = dict( +# enable=True, +# num_layers=3, +# ) +# selective_checkpoint = True +# selective_checkpoint_offload = False + use_fp32_norm = False model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, diff --git a/internlm/core/parallel/comm/cpu_offload.py b/internlm/core/parallel/comm/cpu_offload.py new file mode 100644 index 000000000..89e5912b3 --- /dev/null +++ b/internlm/core/parallel/comm/cpu_offload.py @@ -0,0 +1,505 @@ +# Adapted from https://github.com/NVIDIA/TransformerEngine/blob/v1.12/transformer_engine/pytorch/cpu_offload.py +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Functionality for CPU offloading of tensors saved for backward pass.""" +from __future__ import annotations + +from contextlib import nullcontext +from typing import Any, Dict, Optional + +import torch + +__all__ = ["get_cpu_offload_context"] + +CPUOffloadEnabled = False + + +def is_cpu_offload_enabled() -> bool: + """Check if CPU offloading is currently enabled.""" + return CPUOffloadEnabled + + +class CpuOffloadSavedTensorHook: + """Contex-manager that executes a pair of pack/unpack hooks for saved tensors. + In this context, the ``on_save_for_backward`` method will be called every time + a tensor is saved for backward (this includes intermediary results saved using + :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but + also those recorded by a PyTorch-defined operation). + The ``on_get_saved_tensors`` method will be called when the backward function + of this op attempts to retrieve the saved tensor from context (this includes + :func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the + as input the return value of the ``on_save_for_backward``, and is meant to return + an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of + size, device and element values. + Example: + >>> import torch + >>> from typing import Any + >>> + >>> class DummyHook(CpuOffloadSavedTensorHook): + ... + ... def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + ... logging.info("On save", tensor) + ... return (tensor,) + ... + ... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + ... logging.info("On get", saved_state) + ... tensor, = saved_state + ... return tensor + ... + >>> a = torch.ones(5, requires_grad=True) + >>> b = torch.ones(5, requires_grad=True) * 2 + >>> with DummyHook(): + ... y = a * b + ... + On save tensor([1., 1., 1., 1., 1.], requires_grad=True) + On save tensor([2., 2., 2., 2., 2.], grad_fn=) + >>> y.sum().backward() + On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),) + On get (tensor([2., 2., 2., 2., 2.], grad_fn=),) + """ + + def __init__(self) -> None: + self.inside_context = False + + def __enter__(self): + global CPUOffloadEnabled + CPUOffloadEnabled = True + + self.inside_context = True + torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor) + + def __exit__(self, *args: Any): + global CPUOffloadEnabled + CPUOffloadEnabled = False + + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + """On save for backward.""" + raise NotImplementedError( + "`on_save_for_backward: Callable[[torch.Tensor], Any]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks" + ) + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + """On get saved tensor.""" + raise NotImplementedError( + "`on_get_saved_tensors: Callable[[Any], torch.Tensor]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks" + ) + + +class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): + """Context-manager that offloads/recovers tensors through an offload hander. + The hook just offloads/recovers the tensor object to the handler through `tensor_push` + and `tensor_pop` interface. How the offload-handler manages the offloading, recovering + or prefetching timing is transparent to this hook. + """ + + def __init__( + self, + offload_handler: OffloadHandler, + handler_extra_kwargs: Optional[Dict[str, Any]] = None, + debug: bool = False, + ) -> None: + if handler_extra_kwargs is None: + handler_extra_kwargs = {} + self.debug: bool = debug + self.offload_handler: OffloadHandler = offload_handler + self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs + super().__init__() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + return retrieve_identifier + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) + return tensor + + +class OffloadHandler: + """A base class for CPU offload-handler.""" + + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError( + "`tensor_push is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_push." + ) + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError( + "`tensor_pop is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_pop." + ) + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring + cpu_offload_handler.on_group_commit_forward() + ctx.cpu_offload_handler = cpu_offload_handler + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None + + +group_prefetch_offload_commit = GroupCommitFunction.apply + + +class SynchronizedGroupOffloadHandler(OffloadHandler): + """Offload Handler that offloads/reloads in a synchronized way. + The device-to-host and host-to-device copying happen in the same stream + as the computation kernels, thus the copying will block computation. + """ + + def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False) -> None: + super().__init__() + + self.num_offload_group = num_offload_group + self.tensor_need_offloading_checker = tensor_need_offloading_checker + self.debug = debug + + self.groupid_reset() + + def groupid_reset(self): + """Groupid reset.""" + # Data structures to label saved tensors and book-keep their cpu copies. + # Currently, on push, create a new cpu tensor and copies; on pop, copies + # the tensor back to gpu and deletes the cpu tensor. + # These will increment whenever `group_commit()` is invoked + self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 + self.tensor_tag_to_state = {} + + def on_group_commit_forward(self): + """On group commit forward.""" + # finishing up with updating current group and tensor count + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset + + def on_group_commit_backward(self): + """On group commit backward.""" + self.current_group -= 1 + assert self.current_group >= 0 + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=pin_memory, + ) + + cpu_backup.copy_(src_tensor, non_blocking=pin_memory) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def tensor_push(self, tensor: torch.Tensor, **kwargs): + """Tensor push.""" + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): + state = SynchronizedGroupOffloadHandler.offload(tensor) + self.tensor_tag_to_state[tensor_tag] = state + else: + # will be offloaded together after group commit + self.tensor_tag_to_state[tensor_tag] = tensor + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + state = self.tensor_tag_to_state.pop(tensor_tag) + if isinstance(state, tuple): + tensor = SynchronizedGroupOffloadHandler.reload(state) + else: + tensor = state + return tensor + + +class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): + """Compared to synchronize, this uses more memory because of the buffer but + achieves better performance due to the overlapping. D2h and h2d copying are + completely hidden behind computation if computation time of a layer is longer + than host-device communication time. Bulk offloading with delay and bulk reloading + with prefetch are implemented.""" + + def __init__( + self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_model_group, + tensor_need_offloading_checker=(lambda t: True), + debug=False, + ) -> None: + super().__init__( + num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + debug=debug, + ) + # Number of layers in the model + self.num_layers = num_model_group + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant + + if torch.distributed.get_rank() == 0: + print( + f"Offloading {self.num_offload_group} layers' activations with " + f"layer_window_map:{self.layer_window_map}", + flush=True, + ) + + # allocate streams and events for synchronization + self.d2h_stream = torch.cuda.Stream() + self.h2d_stream = torch.cuda.Stream() + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + torch_stray_tensor = False + + # torch2.4 + # torch_stray_tensor = isinstance( + # tensor, + # ( + # torch._subclasses.fake_tensor.FakeTensor, + # torch._subclasses.functional_tensor.FunctionalTensor, + # ), + # ) + + if not torch_stray_tensor: + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + + self.tensor_tag_to_state[tensor_tag] = tensor + + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): + self.tensor_tag_to_buf[tensor_tag] = tensor + else: + tensor_tag = (-1, self.torch_tensor_count) + self.torch_tensor_count += 1 + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + tensor = self.tensor_tag_to_state.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) + # the tensor should have been copied back in on_group_commit_backward() + # which invokes bulk_reload_group. + assert not isinstance(tensor, tuple) + return tensor + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + with torch.cuda.stream(self.d2h_stream): + for tensor_tag, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + tensor_on_device = state + + # if offload, return the reference to cpu copy + if self.tensor_need_offloading_checker(tensor_on_device): + state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) + self.tensor_tag_to_state[tensor_tag] = state + + def synchronize_on_group_commit_forward(self, current_group): + """Synchronize on group commit forward.""" + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + self.bulk_offload_group(current_group) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + # e.g. layer_window_map={0: 10, 1: 21, 2: 31} + if self.layer_window_map[self.offloaded_group_count] == current_group: + + # Stream synchronization both ways + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, _ in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 + + def on_group_commit_forward(self): + """This function will cause host device synchronization""" + # handle synchronization events + self.synchronize_on_group_commit_forward(self.current_group) + + super().on_group_commit_forward() + + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + assert group_to_reload < self.num_offload_group + + with torch.cuda.stream(self.h2d_stream): + # move back tensors + for tensor_label, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload: + if isinstance(state, tuple): + recovered_tensor = SynchronizedGroupOffloadHandler.reload(state) + self.tensor_tag_to_state[tensor_label] = recovered_tensor + + def on_group_commit_backward(self): + # first decrement the current group. + # after last commit in forward, the group will +1; in backward it -1. + # Finally it should be decremented to 0. + self.current_group -= 1 + assert self.current_group >= 0 + + # Layer window data structure helps us to reload at right times + # e.g. layer_window_map={0: 10, 1: 21, 2: 31} + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: + + # Stream synchronization both ways + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.h2d_stream) + + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) + + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 + + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + torch.cuda.current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 + + +def get_cpu_offload_context( + enabled: bool = False, + num_layers: int = 1, + model_layers: int = 1, + offload_activations: bool = False, + offload_weights: bool = False, +): + """ + This function returns the CPU Offload context and the synchronizer function that needs to be + used after every transformer layer. Returns `nullcontext()` if offloading is not enabled. + Usage: + .. code-block:: python + cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True) + with cpu_offload_context: + te_layer.forward(inp_tensor) + cpu_offload_synchronizer() + Parameters + ---------- + enabled: bool, default = `False` + When set to True, CPU Offloading functionality is enabled. + num_layers: int, default = 1 + Determines the number of transformer layers + you want to offload activations/weights for. + model_layers: int, default = 1 + Number of layers in the model that will be used under this context. + offload_activations: bool, default = `False` + When set to `True`, offloads the tensors with attribute 'activation_offloading' for the layer. + offload_weights: bool, default = `False` + When set to `True`, offloads the weights with attribute 'weight_offloading' for the layer. + """ + + def tensor_need_offloading_checker_base(tensor): # pylint: disable=W0613 + return True + + def tensor_need_offloading_checker_activations(tensor): + return hasattr(tensor, "activation_offloading") + + # This includes the Gradient Accumulation Buffer + def tensor_need_offloading_checker_weights(tensor): + return hasattr(tensor, "weight_offloading") + + def tensor_need_offloading_checker_all(tensor): + return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading") + + if offload_activations and offload_weights: + tensor_need_offloading_checker = tensor_need_offloading_checker_all + elif offload_activations: + tensor_need_offloading_checker = tensor_need_offloading_checker_activations + elif offload_weights: + tensor_need_offloading_checker = tensor_need_offloading_checker_weights + else: + tensor_need_offloading_checker = tensor_need_offloading_checker_base + + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( + num_offload_group=num_layers, + num_model_group=model_layers, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + + def group_prefetch_offload_commit_async(tensor): + return group_prefetch_offload_commit(tensor, cpu_offload_handler) + + if enabled: + return ( + CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), + group_prefetch_offload_commit_async, + ) + return nullcontext(), group_prefetch_offload_commit_async diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e1cb2f0d2..a1cf2022d 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -83,6 +83,11 @@ def args_sanity_check(): num_layers = gpc.config.model.num_layers gpc.config.isp_num_layers = num_layers + if "cpu_offloading" not in gpc.config: + gpc.config._add_item("cpu_offloading", dict(enable=False, num_layers=0)) + if gpc.config.cpu_offloading.enable is False: + assert gpc.config.cpu_offloading.num_layers == 0, "num_layers should be 0 when cpu_offloading is disabled." + if "use_apex_adam" not in gpc.config: gpc.config._add_item("use_apex_adam", False) diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 0453b9dcb..e15b9979a 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -1,6 +1,7 @@ # Copyright (c) InternLM. All rights reserved. import math import os +from contextlib import nullcontext from functools import reduce from typing import Optional @@ -11,6 +12,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.parallel.comm.cpu_offload import get_cpu_offload_context from internlm.core.parallel.shard import partition_uniform from internlm.initialize.initialize_tensor import ( normal_, @@ -387,6 +389,16 @@ def __init__( checkpoint_layer_num = int(num_layers * checkpoint) self.embed_grad_scale = embed_grad_scale self.parallel_output = parallel_output + self.enable_cpu_offloading = gpc.config.cpu_offloading.enable + + if self.enable_cpu_offloading: + (self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context( + gpc.config.cpu_offloading.enable, + gpc.config.cpu_offloading.num_layers, + gpc.config.model.num_layers, + ) + else: + self.offload_context, self.group_prefetch_offload_commit_async = nullcontext(), None if first: self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) @@ -409,7 +421,7 @@ def __init__( max_position_embeddings=max_position_embeddings, dtype=dtype, layer_norm_epsilon=layer_norm_epsilon, - checkpoint=lid < checkpoint_layer_num, + checkpoint=gpc.config.cpu_offloading.num_layers <= lid < checkpoint_layer_num, layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation use_dynamic_ntk_rope=use_dynamic_ntk_rope, residual_in_fp32=residual_in_fp32, @@ -467,7 +479,15 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs): ) for _, block in enumerate(self.layers): - hidden_states = block(hidden_states, residual=None, **kwargs) + with self.offload_context: + hidden_states = block(hidden_states, residual=None, **kwargs) + + if ( + torch.is_grad_enabled() + and self.enable_cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index f40d35f32..71ed5ef7e 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- - import math +from contextlib import nullcontext from typing import Optional import torch @@ -9,6 +9,7 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.parallel.comm.cpu_offload import get_cpu_offload_context from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.base_model import BaseModel from internlm.model.modules.embedding import Embedding1D @@ -319,6 +320,16 @@ def __init__( super().__init__() checkpoint_layer_num = int(num_layers * checkpoint) + self.enable_cpu_offloading = gpc.config.cpu_offloading.enable + + if self.enable_cpu_offloading: + (self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context( + gpc.config.cpu_offloading.enable, + gpc.config.cpu_offloading.num_layers, + gpc.config.model.num_layers, + ) + else: + self.offload_context, self.group_prefetch_offload_commit_async = nullcontext(), None if first: self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) @@ -337,7 +348,7 @@ def __init__( max_position_embeddings=max_position_embeddings, dtype=dtype, layer_norm_epsilon=layer_norm_epsilon, - checkpoint=lid < checkpoint_layer_num, + checkpoint=gpc.config.cpu_offloading.num_layers <= lid < checkpoint_layer_num, layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation use_dynamic_ntk_rope=use_dynamic_ntk_rope, residual_in_fp32=residual_in_fp32, @@ -386,8 +397,16 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs): moe_losses = [] for _, block in enumerate(self.blocks): - hidden_states, mos_loss = block(hidden_states, **kwargs) - moe_losses.append(mos_loss) + with self.offload_context: + hidden_states, mos_loss = block(hidden_states, **kwargs) + moe_losses.append(mos_loss) + + if ( + torch.is_grad_enabled() + and self.enable_cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) diff --git a/internlm/model/ops/_flash_attn.py b/internlm/model/ops/_flash_attn.py index 87aac2eb8..1d1416d94 100644 --- a/internlm/model/ops/_flash_attn.py +++ b/internlm/model/ops/_flash_attn.py @@ -50,8 +50,9 @@ def forward( k, v = kv[:, 0], kv[:, 1] _ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) + _is_ckpt_layer = gpc.config.cpu_offloading.num_layers <= layer_idx < _ckpt_block_num - if gpc.is_forward is False and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + if gpc.is_forward is False and gpc.config.selective_checkpoint and _is_ckpt_layer: out, out_padded, softmax_lse, S_dmask, rng_state = get_offload_manager().get_fa_output_with_layer(layer_idx) else: ( @@ -82,7 +83,7 @@ def forward( ) # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled - if gpc.is_forward and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + if gpc.is_forward and gpc.config.selective_checkpoint and _is_ckpt_layer: get_offload_manager().insert_fa_output_with_layer( layer_idx=layer_idx, output=(out, out_padded, softmax_lse, S_dmask, rng_state) ) @@ -159,8 +160,9 @@ def forward( k, v = kv[:, 0], kv[:, 1] _ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) + _is_ckpt_layer = gpc.config.cpu_offloading.num_layers <= layer_idx < _ckpt_block_num - if gpc.is_forward is False and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + if gpc.is_forward is False and gpc.config.selective_checkpoint and _is_ckpt_layer: out, out_padded, softmax_lse, S_dmask, rng_state = get_offload_manager().get_fa_output_with_layer(layer_idx) else: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( @@ -178,7 +180,7 @@ def forward( ) # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled - if gpc.is_forward and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + if gpc.is_forward and gpc.config.selective_checkpoint and _is_ckpt_layer: get_offload_manager().insert_fa_output_with_layer( layer_idx=layer_idx, output=(out, out_padded, softmax_lse, S_dmask, rng_state) ) diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 967398e17..128e25739 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -471,16 +471,16 @@ def test_training_with_isp(): global CONFIG_FILE_PATH, BASELINE_LOSS_LIST CONFIG_FILE_PATH = "./configs/7B_isp_sft.py" BASELINE_LOSS_LIST = [ - 12.225811004638672, - 12.103824615478516, - 12.223844528198242, - 11.87704849243164, - 11.651590347290039, - 11.629219055175781, - 10.242591857910156, - 9.768388748168945, - 9.330610275268555, - 5.505439758300781, + 12.159960746765137, + 12.22106647491455, + 12.106496810913086, + 11.951896667480469, + 11.644429206848145, + 11.459924697875977, + 10.127229690551758, + 9.795705795288086, + 9.255647659301758, + 5.301709175109863, ] # model training