From 34b9ceb846352bc5660f3f8e2d849917d0bdcb16 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 19 Nov 2024 11:23:16 -0800 Subject: [PATCH 1/3] fix --- .../llama3_2_vision/_component_builders.py | 21 +------------------ .../models/llama3_2_vision/_model_builders.py | 2 -- torchtune/modules/peft/dora.py | 2 +- torchtune/modules/peft/lora.py | 2 +- 4 files changed, 3 insertions(+), 24 deletions(-) diff --git a/torchtune/models/llama3_2_vision/_component_builders.py b/torchtune/models/llama3_2_vision/_component_builders.py index 3de323d368..4f3e6403e0 100644 --- a/torchtune/models/llama3_2_vision/_component_builders.py +++ b/torchtune/models/llama3_2_vision/_component_builders.py @@ -338,7 +338,6 @@ def lora_llama3_2_vision_encoder( fusion_lora: bool, lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, - apply_lora_to_output: bool = False, *, # clip encoder parameters patch_size: int, @@ -377,8 +376,6 @@ def lora_llama3_2_vision_encoder( ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False - apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. - Default: False patch_size (int): The size of each patch. Used to divide the tiles into patches. E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each. @@ -412,7 +409,6 @@ def lora_llama3_2_vision_encoder( lora_options = { "lora_modules": lora_attn_modules, "apply_lora_to_mlp": apply_lora_to_mlp, - "apply_lora_to_output": apply_lora_to_output, "lora_rank": lora_rank, "lora_alpha": lora_alpha, "lora_dropout": lora_dropout, @@ -679,7 +675,6 @@ def lora_llama3_2_vision_projection_head( num_hidden_inputs: int, # LoRA args apply_lora_to_mlp: bool, - apply_lora_to_output: bool, lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, @@ -701,8 +696,6 @@ def lora_llama3_2_vision_projection_head( num_hidden_inputs (int): number of hidden inputs to the projection head. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False - apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. - Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 @@ -773,19 +766,7 @@ def lora_llama3_2_vision_projection_head( # cross encoding # TODO: quantize_base is not applied to final output_proj currently. proj_in = clip_embed_dim * (num_hidden_inputs + 1) - adapter_cls = DoRALinear if use_dora else LoRALinear - output_proj = ( - adapter_cls( - proj_in, - decoder_embed_dim, - rank=lora_rank, - alpha=lora_alpha, - dropout=lora_dropout, - use_bias=True, - ) - if apply_lora_to_output - else nn.Linear(proj_in, decoder_embed_dim) - ) + output_proj = nn.Linear(proj_in, decoder_embed_dim) return Llama3VisionProjectionHead( layers=layers, output=output_proj, diff --git a/torchtune/models/llama3_2_vision/_model_builders.py b/torchtune/models/llama3_2_vision/_model_builders.py index d13ff2dcc4..91e54781af 100644 --- a/torchtune/models/llama3_2_vision/_model_builders.py +++ b/torchtune/models/llama3_2_vision/_model_builders.py @@ -172,7 +172,6 @@ def lora_llama3_2_vision_11b( fusion_lora=fusion_type == LoRATrainable.LORA, lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, - apply_lora_to_output=apply_lora_to_output, patch_size=14, num_heads=16, clip_embed_dim=1280, @@ -330,7 +329,6 @@ def lora_llama3_2_vision_90b( fusion_lora=fusion_type == LoRATrainable.LORA, lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, - apply_lora_to_output=apply_lora_to_output, patch_size=14, num_heads=16, clip_embed_dim=1280, diff --git a/torchtune/modules/peft/dora.py b/torchtune/modules/peft/dora.py index bc1e5eeb03..6f097da6d0 100644 --- a/torchtune/modules/peft/dora.py +++ b/torchtune/modules/peft/dora.py @@ -65,7 +65,7 @@ def __init__( self.use_bias = use_bias self._quantize_base = quantize_base - if not self._quantize_base and quantization_kwargs: + if not self._quantize_base and any([v for v in quantization_kwargs.values()]): raise ValueError( f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}" ) diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 138dd0c5ee..e03d854f1f 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -65,7 +65,7 @@ def __init__( self.use_bias = use_bias self._quantize_base = quantize_base - if not self._quantize_base and quantization_kwargs: + if not self._quantize_base and any([v for v in quantization_kwargs.values()]): raise ValueError( f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}" ) From 80abb7c760d28a8510d2cc1d2a0361f7f03e26fe Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 20 Nov 2024 11:34:49 -0800 Subject: [PATCH 2/3] fix --- .../configs/llama3_2_vision/11B_qlora.yaml | 2 +- .../11B_qlora_single_device.yaml | 2 +- recipes/configs/llama3_2_vision/90B_full.yaml | 2 +- recipes/configs/llama3_2_vision/90B_lora.yaml | 2 +- .../configs/llama3_2_vision/90B_qlora.yaml | 2 +- .../llama3_2_vision/_component_builders.py | 26 ++++++++++++++++--- 6 files changed, 27 insertions(+), 9 deletions(-) diff --git a/recipes/configs/llama3_2_vision/11B_qlora.yaml b/recipes/configs/llama3_2_vision/11B_qlora.yaml index d18209adfe..c934e78008 100644 --- a/recipes/configs/llama3_2_vision/11B_qlora.yaml +++ b/recipes/configs/llama3_2_vision/11B_qlora.yaml @@ -87,7 +87,7 @@ metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml index 8261d8eeac..531f27a52f 100644 --- a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml @@ -87,7 +87,7 @@ metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2_vision/90B_full.yaml b/recipes/configs/llama3_2_vision/90B_full.yaml index 09a7a22769..2ef3c271eb 100644 --- a/recipes/configs/llama3_2_vision/90B_full.yaml +++ b/recipes/configs/llama3_2_vision/90B_full.yaml @@ -78,7 +78,7 @@ metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: /tmp/Llama-3.2-90B-Vision-Instruct/logs log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2_vision/90B_lora.yaml b/recipes/configs/llama3_2_vision/90B_lora.yaml index 14388cc4ea..970c7dab81 100644 --- a/recipes/configs/llama3_2_vision/90B_lora.yaml +++ b/recipes/configs/llama3_2_vision/90B_lora.yaml @@ -87,7 +87,7 @@ metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: /tmp/Llama-3.2-90B-Vision-Instruct/logs log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2_vision/90B_qlora.yaml b/recipes/configs/llama3_2_vision/90B_qlora.yaml index 30810e90b1..888093d574 100644 --- a/recipes/configs/llama3_2_vision/90B_qlora.yaml +++ b/recipes/configs/llama3_2_vision/90B_qlora.yaml @@ -86,7 +86,7 @@ metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: /tmp/Llama-3.2-90B-Vision-Instruct/logs log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Profiler (disabled) profiler: diff --git a/torchtune/models/llama3_2_vision/_component_builders.py b/torchtune/models/llama3_2_vision/_component_builders.py index 4f3e6403e0..6db3631444 100644 --- a/torchtune/models/llama3_2_vision/_component_builders.py +++ b/torchtune/models/llama3_2_vision/_component_builders.py @@ -338,6 +338,7 @@ def lora_llama3_2_vision_encoder( fusion_lora: bool, lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, *, # clip encoder parameters patch_size: int, @@ -376,6 +377,8 @@ def lora_llama3_2_vision_encoder( ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's decoder and encoder output projection. + Default: False patch_size (int): The size of each patch. Used to divide the tiles into patches. E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each. @@ -446,7 +449,9 @@ def lora_llama3_2_vision_encoder( } if fusion_lora: projection_head = lora_llama3_2_vision_projection_head( - **projection_options, **lora_options + apply_lora_to_output=apply_lora_to_output, + **projection_options, + **lora_options, ) else: projection_head = lora_llama3_2_vision_projection_head(**projection_options) @@ -675,6 +680,7 @@ def lora_llama3_2_vision_projection_head( num_hidden_inputs: int, # LoRA args apply_lora_to_mlp: bool, + apply_lora_to_output: bool, lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, @@ -695,7 +701,7 @@ def lora_llama3_2_vision_projection_head( clip_embed_dim (int): embedding dimension for the CLIP encoder. num_hidden_inputs (int): number of hidden inputs to the projection head. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. - Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 @@ -717,7 +723,7 @@ def lora_llama3_2_vision_projection_head( lora_modules=lora_modules, embed_dim=clip_embed_dim, num_heads=num_heads, - num_kv_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, attn_dropout=0.0, lora_rank=lora_rank, @@ -766,7 +772,19 @@ def lora_llama3_2_vision_projection_head( # cross encoding # TODO: quantize_base is not applied to final output_proj currently. proj_in = clip_embed_dim * (num_hidden_inputs + 1) - output_proj = nn.Linear(proj_in, decoder_embed_dim) + adapter_cls = DoRALinear if use_dora else LoRALinear + output_proj = ( + adapter_cls( + proj_in, + decoder_embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + use_bias=True, + ) + if apply_lora_to_output + else nn.Linear(proj_in, decoder_embed_dim) + ) return Llama3VisionProjectionHead( layers=layers, output=output_proj, From c4d155f4068a5c06b3ea55cc1a805eaddbe7f71a Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 20 Nov 2024 11:47:30 -0800 Subject: [PATCH 3/3] fix --- torchtune/models/llama3_2_vision/_model_builders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtune/models/llama3_2_vision/_model_builders.py b/torchtune/models/llama3_2_vision/_model_builders.py index 91e54781af..d13ff2dcc4 100644 --- a/torchtune/models/llama3_2_vision/_model_builders.py +++ b/torchtune/models/llama3_2_vision/_model_builders.py @@ -172,6 +172,7 @@ def lora_llama3_2_vision_11b( fusion_lora=fusion_type == LoRATrainable.LORA, lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, patch_size=14, num_heads=16, clip_embed_dim=1280, @@ -329,6 +330,7 @@ def lora_llama3_2_vision_90b( fusion_lora=fusion_type == LoRATrainable.LORA, lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, patch_size=14, num_heads=16, clip_embed_dim=1280,