Skip to content

Commit 756d074

Browse files
authored
Fix torchao generate script for cpu device (#2267)
* up * up
1 parent 0aa8dbd commit 756d074

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

torchao/_models/llama/generate.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def generate(
180180
max_seq_length = (
181181
min(T + max_new_tokens, model.config.block_size) if not interactive else 350
182182
)
183+
print(f"max_seq_length={max_seq_length}, prompt_length={T}")
183184
new_tokens = max_seq_length - T
184185

185186
# format model input
@@ -242,11 +243,13 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device):
242243

243244

244245
def _load_model(checkpoint_path, device, precision):
245-
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
246+
checkpoint = torch.load(
247+
str(checkpoint_path), mmap=True, weights_only=True, map_location=device
248+
)
246249
if "model" in checkpoint and "stories" in str(checkpoint_path):
247250
checkpoint = checkpoint["model"]
248251
with torch.device("meta"):
249-
model = Transformer.from_name(checkpoint_path.parent.name)
252+
model = Transformer.from_name(checkpoint_path)
250253
model.load_state_dict(checkpoint, assign=True)
251254
model = model.to(device=device, dtype=precision)
252255

@@ -585,7 +588,7 @@ def ffn_or_attn_only(mod, fqn):
585588
weight_dtype = getattr(torch, f"int{_quant_args[1]}")
586589
group_size = int(_quant_args[2])
587590
granularity = PerGroup(group_size) if group_size > 0 else PerAxis(0)
588-
is_asymmetric = bool(_quant_args[3])
591+
is_asymmetric = bool(_quant_args[3].lower() == "true")
589592
quantize_(
590593
model,
591594
Int8DynamicActivationIntxWeightConfig(

0 commit comments

Comments
 (0)