Skip to content

Commit c4b6365

Browse files
kylesayrsdsikka
andauthored
[Tests] Start oneshot tests on CPU (#1555)
## Purpose ## * Speed up tests by reducing device movement ## Background ## As of #1263, the model is dispatched to different device maps depending on which pipelines are used. If the model starts on anything but the CPU, then these dispatches and undispatches create device movement. Starting on the CPU will ensure no device movement occurs when offloaded dispatches happen. Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent 3262d85 commit c4b6365

File tree

3 files changed

+3
-10
lines changed

3 files changed

+3
-10
lines changed

tests/llmcompressor/recipe/test_recipe_parsing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def setup_model_and_config(tmp_path):
1818
"""
1919
model = AutoModelForCausalLM.from_pretrained(
2020
"Xenova/llama2.c-stories110M",
21-
device_map="auto",
2221
torch_dtype="auto",
2322
)
2423

tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def setUp(self):
2121
def test_oneshot_sparsification_then_finetune(self):
2222
recipe_str = "tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml"
2323
model = AutoModelForCausalLM.from_pretrained(
24-
"nm-testing/llama2.c-stories15M", device_map="auto", torch_dtype="auto"
24+
"nm-testing/llama2.c-stories15M", torch_dtype="auto"
2525
)
2626
dataset = "open_platypus"
2727
concatenate_data = False
@@ -47,12 +47,11 @@ def test_oneshot_sparsification_then_finetune(self):
4747
# Explictly decompress the model for training using quantization_config
4848
model = AutoModelForCausalLM.from_pretrained(
4949
self.output / "oneshot_out",
50-
device_map="auto",
5150
torch_dtype="auto",
5251
quantization_config=self.quantization_config,
5352
)
5453
distill_teacher = AutoModelForCausalLM.from_pretrained(
55-
"nm-testing/llama2.c-stories15M", device_map="auto", torch_dtype="auto"
54+
"nm-testing/llama2.c-stories15M", torch_dtype="auto"
5655
)
5756
dataset = "open_platypus"
5857
concatenate_data = False
@@ -88,7 +87,6 @@ def test_oneshot_sparsification_then_finetune(self):
8887
# Explictly decompress the model for training using quantization_config
8988
model = AutoModelForCausalLM.from_pretrained(
9089
output_dir,
91-
device_map="auto",
9290
torch_dtype="auto",
9391
quantization_config=self.quantization_config,
9492
)
@@ -112,7 +110,7 @@ def test_oneshot_quantization_then_finetune(self):
112110
)
113111

114112
model = AutoModelForCausalLM.from_pretrained(
115-
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto", torch_dtype="auto"
113+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype="auto"
116114
)
117115
dataset = "open_platypus"
118116
concatenate_data = False
@@ -136,7 +134,6 @@ def test_oneshot_quantization_then_finetune(self):
136134
quantization_config = CompressedTensorsConfig(run_compressed=False)
137135
model = AutoModelForCausalLM.from_pretrained(
138136
output_dir,
139-
device_map="auto",
140137
torch_dtype="auto",
141138
quantization_config=quantization_config,
142139
)
@@ -159,7 +156,6 @@ def test_oneshot_quantization_then_finetune(self):
159156
# test reloading checkpoint and final model
160157
model = AutoModelForCausalLM.from_pretrained(
161158
output_dir,
162-
device_map="auto",
163159
torch_dtype="auto",
164160
quantization_config=quantization_config,
165161
)

tests/llmcompressor/transformers/obcq/test_consecutive_runs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def _test_consecutive_runs(
4444

4545
first_model = AutoModelForCausalLM.from_pretrained(
4646
self.output_first,
47-
device_map="auto",
4847
torch_dtype="auto",
4948
quantization_config=self.quantization_config,
5049
)
@@ -74,7 +73,6 @@ def _test_consecutive_runs(
7473
second_model = AutoModelForCausalLM.from_pretrained(
7574
self.output_second,
7675
quantization_config=self.quantization_config,
77-
device_map="auto",
7876
torch_dtype="auto",
7977
)
8078

0 commit comments

Comments
 (0)