Skip to content

Commit 2764a77

Browse files
[float8] Fix module filter function (#1391)
In a prior PR we added the `_init_filter_fn()` to configure a module filter function at Float8 component init time, but didn't actually use it. This went unnoticed because the existing module filter (`partial(module_filter_fn, filter_fqns=self.filter_fqns)` behaves the same way except for the case where the user uses `auto_filter_small_kn`. In this PR we fix that by using the `self.filter_fn`. ## Test plan - Test auto_filter_small_kn and verify the wk/wv are filtered for Llama3 8b: `NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --training.compile --model.converters="float8" --float8.recipe_name="rowwise" --parallelism.tensor_parallel_degree=2 --float8.filter_fqns="auto_filter_small_kn" --model.print-after-conversion` - Test without auto_filter_small_kn and verify all linears are converted: `NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --training.compile --model.converters="float8" --float8.recipe_name="rowwise" --parallelism.tensor_parallel_degree=2 --float8.filter_fqns="auto_filter" --model.print-after-conversion
1 parent 05e47c3 commit 2764a77

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def convert(self, model: nn.Module):
165165
convert_to_float8_training(
166166
model,
167167
config=self.config,
168-
module_filter_fn=partial(module_filter_fn, filter_fqns=self.filter_fqns),
168+
module_filter_fn=self.filter_fn,
169169
)
170170
logger.info(
171171
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="

0 commit comments

Comments
 (0)