@@ -180,6 +180,7 @@ def generate(
180
180
max_seq_length = (
181
181
min (T + max_new_tokens , model .config .block_size ) if not interactive else 350
182
182
)
183
+ print (f"max_seq_length={ max_seq_length } , prompt_length={ T } " )
183
184
new_tokens = max_seq_length - T
184
185
185
186
# format model input
@@ -242,11 +243,13 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device):
242
243
243
244
244
245
def _load_model (checkpoint_path , device , precision ):
245
- checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
246
+ checkpoint = torch .load (
247
+ str (checkpoint_path ), mmap = True , weights_only = True , map_location = device
248
+ )
246
249
if "model" in checkpoint and "stories" in str (checkpoint_path ):
247
250
checkpoint = checkpoint ["model" ]
248
251
with torch .device ("meta" ):
249
- model = Transformer .from_name (checkpoint_path . parent . name )
252
+ model = Transformer .from_name (checkpoint_path )
250
253
model .load_state_dict (checkpoint , assign = True )
251
254
model = model .to (device = device , dtype = precision )
252
255
@@ -585,7 +588,7 @@ def ffn_or_attn_only(mod, fqn):
585
588
weight_dtype = getattr (torch , f"int{ _quant_args [1 ]} " )
586
589
group_size = int (_quant_args [2 ])
587
590
granularity = PerGroup (group_size ) if group_size > 0 else PerAxis (0 )
588
- is_asymmetric = bool (_quant_args [3 ])
591
+ is_asymmetric = bool (_quant_args [3 ]. lower () == "true" )
589
592
quantize_ (
590
593
model ,
591
594
Int8DynamicActivationIntxWeightConfig (
0 commit comments