Skip to content

Conversation

Electronic-Waste
Copy link
Member

@Electronic-Waste Electronic-Waste commented Sep 17, 2025

What this PR does / why we need it:

This PR adds support for LoRA/QLoRA/DoRA in LLM Trainer V2.

I tested it with scripts:

from kubeflow.trainer import *

client = TrainerClient()

# QLoRA
client.train(
    runtime=client.get_runtime(name="torchtune-llama3.2-1b"),
    initializer=Initializer(
        dataset=HuggingFaceDatasetInitializer(
            storage_uri="hf://tatsu-lab/alpaca/data"
        ),
        model=HuggingFaceModelInitializer(
            storage_uri="hf://meta-llama/Llama-3.2-1B-Instruct",
            access_token="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",
        )
    ),
    trainer=BuiltinTrainer(
        config=TorchTuneConfig(
            dataset_preprocess_config=TorchTuneInstructDataset(
                source=DataFormat.PARQUET,
            ),
            peft_config=LoraConfig(
                apply_lora_to_mlp=True,
                lora_attn_modules=["q_proj", "k_proj", "v_proj", "output_proj"],
                quantize_base=True,
            ),
            resources_per_node={
                "gpu": 1,
            }
        )
    )
)

Results:

Setting manual seed to local seed 2668943352. Local seed is seed + rank = 2668943352 + 0
Model is initialized with precision torch.bfloat16.
Memory stats after model init:
        GPU peak memory allocation: 1.21 GiB
        GPU peak memory reserved: 1.25 GiB
        GPU peak memory active: 1.21 GiB
Tokenizer is initialized from file.
Optimizer and loss are initialized.
Loss is initialized.
Writing logs to /workspace/output/logs/log_1758120939.txt
Generating train split: 52002 examples [00:00, 275879.07 examples/s]
Learning rate scheduler is initialized.
 Profiling disabled.
 Profiler config after instantiation: {'enabled': False}
1|1625|Loss: 1.6104315519332886: 100%|██████████| 1625/1625 [32:52<00:00,  1.26s/it]Starting checkpoint save...
Model checkpoint of size 2.30 GiB saved to /workspace/output/epoch_0/model-00001-of-00001.safetensors
Adapter checkpoint of size 0.08 GiB saved to /workspace/output/epoch_0/adapter_model.pt
Adapter checkpoint of size 0.08 GiB saved to /workspace/output/epoch_0/adapter_model.safetensors
Adapter checkpoint of size 0.00 GiB saved to /workspace/output/epoch_0/adapter_config.json
Saving final epoch checkpoint.
The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
Checkpoint saved in 5.90 seconds.
1|1625|Loss: 1.6104315519332886: 100%|██████████| 1625/1625 [32:57<00:00,  1.22s/it]

/cc @kubeflow/kubeflow-trainer-team @kubeflow/wg-training-leads @kramaranya @szaher @deepanker13 @franciscojavierarceo @varodrig @rudeigerc

Which issue(s) this PR fixes (optional, in Fixes #<issue number>, #<issue number>, ... format, will close the issue(s) when PR gets merged):
Fixes #2505

Checklist:

  • Docs included if any changes are user facing

@google-oss-prow
Copy link

@Electronic-Waste: GitHub didn't allow me to request PR reviews from the following users: kubeflow/kubeflow-trainer-team.

Note that only kubeflow members and repo collaborators can review this PR, and authors cannot review their own PRs.

In response to this:

What this PR does / why we need it:

This PR adds support for LoRA/QLoRA/DoRA in LLM Trainer V2.

I tested it with scripts:

from kubeflow.trainer import *

client = TrainerClient()

# QLoRA
client.train(
   runtime=client.get_runtime(name="torchtune-llama3.2-1b"),
   initializer=Initializer(
       dataset=HuggingFaceDatasetInitializer(
           storage_uri="hf://tatsu-lab/alpaca/data"
       ),
       model=HuggingFaceModelInitializer(
           storage_uri="hf://meta-llama/Llama-3.2-1B-Instruct",
           access_token="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",
       )
   ),
   trainer=BuiltinTrainer(
       config=TorchTuneConfig(
           dataset_preprocess_config=TorchTuneInstructDataset(
               source=DataFormat.PARQUET,
           ),
           peft_config=LoraConfig(
               apply_lora_to_mlp=True,
               lora_attn_modules=["q_proj", "k_proj", "v_proj", "output_proj"],
               quantize_base=True,
           ),
           resources_per_node={
               "gpu": 1,
           }
       )
   )
)

Results:

Running LoRAFinetuneRecipeSingleDevice with resolved config:

batch_size: 4
checkpointer:
 _component_: torchtune.training.FullModelHFCheckpointer
 checkpoint_dir: /workspace/model
 checkpoint_files:
 - model.safetensors
 model_type: LLAMA3_2
 output_dir: /workspace/output
 recipe_checkpoint: null
clip_grad_norm: null
compile: false
dataset:
 _component_: torchtune.datasets.instruct_dataset
 data_dir: /workspace/dataset/data
 packed: false
 source: parquet
device: cuda
dtype: bf16
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
 _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
lr_scheduler:
 _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
 num_warmup_steps: 100
max_steps_per_epoch: null
metric_logger:
 _component_: torchtune.training.metric_logging.DiskLogger
 log_dir: /workspace/output/logs
model:
 _component_: torchtune.models.llama3_2.qlora_llama3_2_1b
 apply_lora_to_mlp: true
 lora_alpha: 128
 lora_attn_modules:
 - q_proj
 - k_proj
 - v_proj
 - output_proj
 lora_dropout: 0.0
 lora_rank: 64
optimizer:
 _component_: torch.optim.AdamW
 fused: true
 lr: 0.0003
 weight_decay: 0.01
output_dir: /workspace/output
profiler:
 _component_: torchtune.training.setup_torch_profiler
 active_steps: 2
 cpu: true
 cuda: true
 enabled: false
 num_cycles: 1
 output_dir: /workspace/output/profiling_outputs
 profile_memory: false
 record_shapes: true
 wait_steps: 5
 warmup_steps: 3
 with_flops: false
 with_stack: false
resume_from_checkpoint: false
save_adapter_weights_only: false
seed: null
shuffle: true
tokenizer:
 _component_: torchtune.models.llama3.llama3_tokenizer
 max_seq_len: null
 path: /workspace/model/original/tokenizer.model

Setting manual seed to local seed 2668943352. Local seed is seed + rank = 2668943352 + 0
Model is initialized with precision torch.bfloat16.
Memory stats after model init:
       GPU peak memory allocation: 1.21 GiB
       GPU peak memory reserved: 1.25 GiB
       GPU peak memory active: 1.21 GiB
Tokenizer is initialized from file.
Optimizer and loss are initialized.
Loss is initialized.
Writing logs to /workspace/output/logs/log_1758120939.txt
Generating train split: 52002 examples [00:00, 275879.07 examples/s]
Learning rate scheduler is initialized.
Profiling disabled.
Profiler config after instantiation: {'enabled': False}
1|1625|Loss: 1.6104315519332886: 100%|██████████| 1625/1625 [32:52<00:00,  1.26s/it]Starting checkpoint save...
Model checkpoint of size 2.30 GiB saved to /workspace/output/epoch_0/model-00001-of-00001.safetensors
Adapter checkpoint of size 0.08 GiB saved to /workspace/output/epoch_0/adapter_model.pt
Adapter checkpoint of size 0.08 GiB saved to /workspace/output/epoch_0/adapter_model.safetensors
Adapter checkpoint of size 0.00 GiB saved to /workspace/output/epoch_0/adapter_config.json
Saving final epoch checkpoint.
The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
Checkpoint saved in 5.90 seconds.
1|1625|Loss: 1.6104315519332886: 100%|██████████| 1625/1625 [32:57<00:00,  1.22s/it]

/cc @kubeflow/kubeflow-trainer-team @kubeflow/wg-training-leads @kramaranya @szaher @deepanker13 @franciscojavierarceo @varodrig @rudeigerc

Which issue(s) this PR fixes (optional, in Fixes #<issue number>, #<issue number>, ... format, will close the issue(s) when PR gets merged):
Fixes #2505

Checklist:

  • Docs included if any changes are user facing

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

@coveralls
Copy link

coveralls commented Sep 17, 2025

Pull Request Test Coverage Report for Build 18493960817

Details

  • 107 of 152 (70.39%) changed or added relevant lines in 2 files are covered.
  • 2 unchanged lines in 1 file lost coverage.
  • Overall coverage decreased (-0.07%) to 54.419%

Changes Missing Coverage Covered Lines Changed/Added Lines %
pkg/runtime/framework/plugins/torch/torch.go 32 52 61.54%
pkg/runtime/framework/plugins/torch/torchtune.go 75 100 75.0%
Files with Coverage Reduction New Missed Lines %
pkg/runtime/framework/plugins/torch/torch.go 2 84.18%
Totals Coverage Status
Change from base Build 18479440514: -0.07%
Covered Lines: 1250
Relevant Lines: 2297

💛 - Coveralls

@andreyvelich
Copy link
Member

/milestone v2.1

@google-oss-prow google-oss-prow bot added this to the v2.1 milestone Sep 24, 2025
Copy link
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Electronic-Waste!
Please rebase your PR.

@Electronic-Waste
Copy link
Member Author

@andreyvelich Thanks for your detailed review. I've addressed all of them. If you have any other questions, please don't hesitate to tell me:)

Copy link
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@Electronic-Waste Electronic-Waste left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreyvelich Thanks for your detailed review. I've refactored getNumProcPerNode as you requested:)

@Electronic-Waste
Copy link
Member Author

@andreyvelich I've refactored the code. It looks tidier now. PTAL if you have time

Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
Signed-off-by: Electronic-Waste <2690692950@qq.com>
@Electronic-Waste
Copy link
Member Author

/retest

1 similar comment
@Electronic-Waste
Copy link
Member Author

/retest

Copy link
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Electronic-Waste!
/lgtm
/approve
/hold in case @astefanutti or @tenzen-y want to give additional comments.

@google-oss-prow
Copy link

[APPROVALNOTIFIER] This PR is APPROVED

This pull-request has been approved by: andreyvelich

The full list of commands accepted by this bot can be found here.

The pull request process is described here

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@Electronic-Waste
Copy link
Member Author

/retest

@astefanutti
Copy link
Contributor

/lgtm

Thanks @Electronic-Waste!

@andreyvelich
Copy link
Member

/hold cancel

@google-oss-prow google-oss-prow bot merged commit 484ca7e into kubeflow:master Oct 15, 2025
58 of 68 checks passed
@Electronic-Waste Electronic-Waste deleted the feat/lora-support branch October 15, 2025 14:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

KEP-2401: Support LoRA/QLoRA/DoRA fine-tuning in LLM Trainer V2

4 participants