From c898e4e5365b54c94aca33a8b852bdd94f5e0843 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Mon, 17 Mar 2025 21:28:02 +0900 Subject: [PATCH] fix: optimize weight device swapping with no_grad context --- library/custom_offloading_utils.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 84c2b743e..4d52cfb4d 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -42,19 +42,20 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - # cuda to cpu - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: - cuda_data_view.record_stream(stream) - module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + with torch.no_grad(): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) - stream.synchronize() + stream.synchronize() - # cpu to cuda - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: - cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) - module_to_cuda.weight.data = cuda_data_view + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view stream.synchronize() torch.cuda.current_stream().synchronize() # this prevents the illegal loss value @@ -75,14 +76,14 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) - synchronize_device() + synchronize_device(device) # cpu to device for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) module_to_cuda.weight.data = cuda_data_view - synchronize_device() + synchronize_device(device) def weighs_to_device(layer: nn.Module, device: torch.device):