Skip to content

Commit acd5ba8

Browse files
authored
basic validator implementation (#1362)
Update PR Summary: Implements a validator that can be easily plugged into the training loop and configured from the job specific config file. Changes: - Created validation section in job_config with enabled, dataset, freq, and steps fields - Created a builder function for validator in train_spec - Created a separate builder function for validation dataset in hf_dataset.py - Created validator class - Validator class initializes a build_validation_hf_loader but leaves this dataloader function unexposed to the train_spec - Validator class supports ddp, fsdp, cp, and tp (but not pp yet) - Integrated validation call into training loop - Creates an integration test to test parallelization Updated tests training the same base model weights from a seed checkpoint: | FSDP=2 | FSDP=2,TP=4 | | --- | --- | | <img width="978" alt="Screenshot 2025-07-09 at 4 33 53 PM" src="https://github.com/user-attachments/assets/a1fa9fa7-df2f-4302-aa4a-d556a5699ba9" /> | <img width="978" alt="Screenshot 2025-07-09 at 4 33 53 PM" src="https://github.com/user-attachments/assets/a1fa9fa7-df2f-4302-aa4a-d556a5699ba9" /> | | FSDP=2,CP=4 | FSDP=2,TP=2,CP=2 | | --- | --- | | <img width="972" alt="Screenshot 2025-07-09 at 4 39 35 PM" src="https://github.com/user-attachments/assets/56d62841-5841-4969-85b1-803705892465" /> | <img width="970" alt="Screenshot 2025-07-09 at 4 28 57 PM" src="https://github.com/user-attachments/assets/f7d33fa8-ca2c-48f1-931c-8d4c017a47ce" /> |
1 parent be15836 commit acd5ba8

File tree

8 files changed

+287
-3
lines changed

8 files changed

+287
-3
lines changed

tests/integration_tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,20 @@ def build_test_list():
509509
"gradient_accumulation",
510510
ngpu=2,
511511
),
512+
OverrideDefinitions(
513+
[
514+
[
515+
"--validation.enabled",
516+
"--validation.dataset c4_test",
517+
"--parallelism.data_parallel_replicate_degree=2",
518+
"--parallelism.tensor_parallel_degree=2",
519+
"--parallelism.context_parallel_degree=2",
520+
],
521+
],
522+
"Validation test with fsdp, tp, cp",
523+
"validation_fsdp_tp_cp",
524+
ngpu=8,
525+
),
512526
]
513527
return integration_tests_flavors
514528

torchtitan/components/validate.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Generator
8+
9+
import torch
10+
import torch.nn as nn
11+
from torch.distributed.fsdp import FSDPModule
12+
from torchtitan.components.dataloader import BaseDataLoader
13+
from torchtitan.components.loss import LossFunction
14+
from torchtitan.components.tokenizer import Tokenizer
15+
from torchtitan.config_manager import JobConfig
16+
from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader
17+
from torchtitan.distributed import ParallelDims, utils as dist_utils
18+
from torchtitan.tools import utils
19+
from torchtitan.tools.logging import logger
20+
21+
22+
class BaseValidator:
23+
def __init__(self, job_config: JobConfig):
24+
self.job_config = job_config
25+
26+
def validate(self, model_parts: list[nn.Module]) -> dict[str, float]:
27+
raise NotImplementedError("validate method not implemented")
28+
29+
def should_validate(self, step: int) -> bool:
30+
return step % self.job_config.validation.freq == 0
31+
32+
33+
class Validator(BaseValidator):
34+
"""
35+
Simple validator focused on correctness and integration.
36+
37+
Args:
38+
job_config: Job configuration
39+
validation_dataloader: The validation dataloader
40+
loss_fn: Loss function to use for validation
41+
model: The model to validate (single model, no parallelism)
42+
"""
43+
44+
validation_dataloader: BaseDataLoader
45+
46+
def __init__(
47+
self,
48+
job_config: JobConfig,
49+
dp_world_size: int,
50+
dp_rank: int,
51+
tokenizer: Tokenizer,
52+
parallel_dims: ParallelDims,
53+
world_mesh: torch.distributed.DeviceMesh,
54+
loss_fn: LossFunction,
55+
validation_context: Generator[None, None, None],
56+
maybe_enable_amp: Generator[None, None, None],
57+
):
58+
self.job_config = job_config
59+
self.parallel_dims = parallel_dims
60+
self.world_mesh = world_mesh
61+
self.loss_fn = loss_fn
62+
self.validation_dataloader = build_hf_validation_dataloader(
63+
job_config=job_config,
64+
dp_world_size=dp_world_size,
65+
dp_rank=dp_rank,
66+
tokenizer=tokenizer,
67+
)
68+
self.validation_context = validation_context
69+
self.maybe_enable_amp = maybe_enable_amp
70+
71+
@torch.no_grad()
72+
def validate(
73+
self,
74+
model_parts: list[nn.Module],
75+
) -> dict[str, float]:
76+
# Set model to eval mode
77+
# TODO: currently does not support pipeline parallelism
78+
model = model_parts[0]
79+
model.eval()
80+
81+
accumulated_losses = []
82+
device_type = utils.device_type
83+
num_steps = 0
84+
85+
for input_dict, labels in self.validation_dataloader:
86+
if (
87+
self.job_config.validation.steps != -1
88+
and num_steps >= self.job_config.validation.steps
89+
):
90+
break
91+
92+
for k, v in input_dict.items():
93+
input_dict[k] = v.to(device_type)
94+
inputs = input_dict["input"]
95+
labels = labels.to(device_type)
96+
97+
optional_context_parallel_ctx = (
98+
dist_utils.create_context_parallel_ctx(
99+
cp_mesh=self.world_mesh["cp"],
100+
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
101+
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
102+
cp_no_restore_buffers={inputs, labels},
103+
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
104+
)
105+
if self.parallel_dims.cp_enabled
106+
else None
107+
)
108+
109+
with self.validation_context(optional_context_parallel_ctx):
110+
assert len(model_parts) == 1
111+
with self.maybe_enable_amp:
112+
predictions = model(inputs)
113+
loss = self.loss_fn(predictions, labels)
114+
115+
accumulated_losses.append(loss.detach())
116+
117+
num_steps += 1
118+
119+
# Compute average loss
120+
loss = torch.sum(torch.stack(accumulated_losses))
121+
loss /= num_steps
122+
if self.parallel_dims.dp_cp_enabled:
123+
global_avg_loss = dist_utils.dist_mean(loss, self.world_mesh["dp_cp"])
124+
else:
125+
global_avg_loss = loss
126+
127+
logger.info(
128+
f"Validation completed. Average loss: {global_avg_loss:.4f} over {num_steps} batches"
129+
)
130+
131+
# Reshard after run forward pass
132+
# This is to ensure the model weights are sharded the same way for checkpoint saving.
133+
for module in model.modules():
134+
if isinstance(module, FSDPModule):
135+
module.reshard()
136+
137+
# Set model back to train mode
138+
model.train()
139+
140+
141+
def build_validator(
142+
job_config: JobConfig,
143+
dp_world_size: int,
144+
dp_rank: int,
145+
tokenizer: Tokenizer,
146+
parallel_dims: ParallelDims,
147+
world_mesh: torch.distributed.DeviceMesh,
148+
loss_fn: LossFunction,
149+
validation_context: Generator[None, None, None],
150+
maybe_enable_amp: Generator[None, None, None],
151+
) -> BaseValidator:
152+
"""Build a simple validator focused on correctness."""
153+
return Validator(
154+
job_config=job_config,
155+
dp_world_size=dp_world_size,
156+
dp_rank=dp_rank,
157+
tokenizer=tokenizer,
158+
parallel_dims=parallel_dims,
159+
world_mesh=world_mesh,
160+
loss_fn=loss_fn,
161+
validation_context=validation_context,
162+
maybe_enable_amp=maybe_enable_amp,
163+
)

torchtitan/config_manager.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,35 @@ class Experimental:
665665
"""
666666

667667

668+
@dataclass
669+
class Validation:
670+
enabled: bool = False
671+
"""Enable validation to default run validation after each training loop"""
672+
673+
dataset: str = "c4_validation"
674+
"""Dataset to use for validation"""
675+
676+
dataset_path: str | None = None
677+
"""Path to dataset to use for validation"""
678+
679+
local_batch_size: int = 8
680+
"""Batch size for validation"""
681+
682+
seq_len: int = 2048
683+
"""Sequence length for validation"""
684+
685+
freq: int = 10
686+
"""Frequency of validation"""
687+
688+
steps: int = -1
689+
"""Number of steps to take in the validation set, -1 means consuming all the data in the validation dataset"""
690+
691+
def __post_init__(self):
692+
assert (
693+
self.steps > 0 or self.steps == -1
694+
), "validation steps must be positive or -1"
695+
696+
668697
@dataclass
669698
class JobConfig:
670699
"""
@@ -689,6 +718,7 @@ class JobConfig:
689718
memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation)
690719
fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance)
691720
experimental: Experimental = field(default_factory=Experimental)
721+
validation: Validation = field(default_factory=Validation)
692722

693723
def to_dict(self) -> dict[str, Any]:
694724
return asdict(self)

torchtitan/datasets/hf_datasets.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8+
9+
from functools import partial
810
from typing import Any, Callable
911

1012
import torch
@@ -20,9 +22,9 @@
2022
from torchtitan.tools.logging import logger
2123

2224

23-
def _load_c4_dataset(dataset_path: str):
25+
def _load_c4_dataset(dataset_path: str, split: str):
2426
"""Load C4 dataset with default configuration."""
25-
return load_dataset(dataset_path, name="en", split="train", streaming=True)
27+
return load_dataset(dataset_path, name="en", split=split, streaming=True)
2628

2729

2830
def _process_c4_text(sample: dict[str, Any]) -> str:
@@ -41,14 +43,19 @@ class DatasetConfig:
4143
DATASETS = {
4244
"c4": DatasetConfig(
4345
path="allenai/c4",
44-
loader=_load_c4_dataset,
46+
loader=partial(_load_c4_dataset, split="train"),
4547
text_processor=_process_c4_text,
4648
),
4749
"c4_test": DatasetConfig(
4850
path="tests/assets/c4_test",
4951
loader=lambda path: load_dataset(path, split="train"),
5052
text_processor=_process_c4_text,
5153
),
54+
"c4_validation": DatasetConfig(
55+
path="allenai/c4",
56+
loader=partial(_load_c4_dataset, split="validation"),
57+
text_processor=_process_c4_text,
58+
),
5259
}
5360

5461

@@ -193,3 +200,33 @@ def build_hf_dataloader(
193200
dp_world_size=dp_world_size,
194201
batch_size=batch_size,
195202
)
203+
204+
205+
def build_hf_validation_dataloader(
206+
dp_world_size: int,
207+
dp_rank: int,
208+
tokenizer: Tokenizer,
209+
job_config: JobConfig,
210+
) -> ParallelAwareDataloader:
211+
"""Build a validation data loader for HuggingFace datasets."""
212+
dataset_name = job_config.validation.dataset
213+
dataset_path = job_config.validation.dataset_path
214+
batch_size = job_config.validation.local_batch_size
215+
seq_len = job_config.validation.seq_len
216+
217+
hf_ds = HuggingFaceDataset(
218+
dataset_name=dataset_name,
219+
dataset_path=dataset_path,
220+
tokenizer=tokenizer,
221+
seq_len=seq_len,
222+
dp_rank=dp_rank,
223+
dp_world_size=dp_world_size,
224+
infinite=False,
225+
)
226+
227+
return ParallelAwareDataloader(
228+
dataset=hf_ds,
229+
dp_rank=dp_rank,
230+
dp_world_size=dp_world_size,
231+
batch_size=batch_size,
232+
)

torchtitan/models/llama3/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torchtitan.components.loss import build_cross_entropy_loss
1010
from torchtitan.components.lr_scheduler import build_lr_schedulers
1111
from torchtitan.components.optimizer import build_optimizers
12+
from torchtitan.components.validate import build_validator
1213
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1314
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
1415
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
@@ -81,5 +82,6 @@
8182
build_dataloader_fn=build_hf_dataloader,
8283
build_tokenizer_fn=build_tiktoken_tokenizer,
8384
build_loss_fn=build_cross_entropy_loss,
85+
build_validator_fn=build_validator,
8486
)
8587
)

torchtitan/models/llama3/train_configs/debug_model.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,9 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas
7171
enable_fsdp_float8_all_gather = false
7272
precompute_float8_dynamic_scale_for_fsdp = false
7373
filter_fqns = ["output"]
74+
75+
[validation]
76+
enabled = false
77+
dataset = "c4_validation"
78+
freq = 5
79+
steps = 10

torchtitan/protocols/train_spec.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torchtitan.components.metrics import MetricsProcessor
2424
from torchtitan.components.optimizer import OptimizersContainer
2525
from torchtitan.components.tokenizer import Tokenizer
26+
from torchtitan.components.validate import BaseValidator
2627
from torchtitan.config_manager import JobConfig
2728
from torchtitan.distributed import ParallelDims
2829

@@ -80,6 +81,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
8081
[OptimizersContainer, JobConfig], LRSchedulersContainer
8182
]
8283
LossFunctionBuilder: TypeAlias = Callable[..., LossFunction]
84+
ValidatorBuilder: TypeAlias = Callable[..., BaseValidator]
8385

8486

8587
@dataclass
@@ -94,6 +96,7 @@ class TrainSpec:
9496
build_dataloader_fn: DataLoaderBuilder
9597
build_tokenizer_fn: TokenizerBuilder | None
9698
build_loss_fn: LossFunctionBuilder
99+
build_validator_fn: ValidatorBuilder | None = None
97100
build_metrics_processor_fn: MetricsProcessorBuilder | None = None
98101

99102

0 commit comments

Comments
 (0)