Skip to content

Commit adaa6bb

Browse files
doc: specify device_map argument in the examples (#1621)
SUMMARY: This PR implements the change proposed in #1620 Specifying the argument device_map in the from_pretrained method is more coherent with the description published in the example: ``` The model is first loaded onto the cpu, as indicated through the use of None for the device_map argument in the from_pretrained method when loading the model. ``` TEST PLAN: Documentation change only --------- Signed-off-by: Soren Dreano <soren@numind.ai> Co-authored-by: Soren Dreano <soren@numind.ai>
1 parent 2c41df7 commit adaa6bb

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

examples/big_models_with_sequential_onloading/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The Llama 3.3 70b is larger than 80 GB, surpassing the size of 1 A100. However,
1818

1919
```python
2020
model_id = "meta-llama/Llama-3.3-70B-Instruct"
21-
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
21+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map=None)
2222
```
2323

2424
The model is first loaded onto the `cpu`, as indicated through the use of `None` for the `device_map` argument in the `from_pretrained` method when loading the model.
@@ -42,4 +42,4 @@ output = model.generate(**sample, max_new_tokens=100)
4242
print(tokenizer.decode(output[0]))
4343
```
4444

45-
Finally, we call `dispatch_for_generation` to evenly load the model across available devices (potentially offloading the model if required) and run sample generations on the newly quantized model.
45+
Finally, we call `dispatch_for_generation` to evenly load the model across available devices (potentially offloading the model if required) and run sample generations on the newly quantized model.

examples/big_models_with_sequential_onloading/llama3.3_70b.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
# Select model and load it.
1010
model_id = "meta-llama/Llama-3.3-70B-Instruct"
11-
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
11+
model = AutoModelForCausalLM.from_pretrained(
12+
model_id,
13+
torch_dtype="auto",
14+
device_map=None,
15+
)
1216
tokenizer = AutoTokenizer.from_pretrained(model_id)
1317

1418
# Select calibration dataset.

0 commit comments

Comments
 (0)