33import argparse
44import multiprocessing as mp
55from time import perf_counter
6+ import logging
67
78import numpy as np
89import torch
1415
1516logger = get_logger ()
1617
18+ # Set logging level to INFO to see all debug messages
19+ logging .basicConfig (
20+ level = logging .INFO ,
21+ format = '%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - %(message)s' ,
22+ datefmt = '%Y-%m-%d %H:%M:%S'
23+ )
24+
1725
1826def main ():
1927 parser = argparse .ArgumentParser (formatter_class = argparse .ArgumentDefaultsHelpFormatter )
2028 parser .add_argument ("--model" , type = str , required = True , help = "Model" )
2129 parser .add_argument ("--initial_peers" , type = str , nargs = "+" , default = PUBLIC_INITIAL_PEERS , help = "Initial peers" )
2230 parser .add_argument ("--torch_dtype" , type = str , default = "float32" , help = "Torch dtype" )
2331 parser .add_argument ("--n_processes" , type = str , default = 1 , help = "Number of concurrent processes" )
24- parser .add_argument ("--seq_len" , type = int , default = 2048 , help = "Sequence length" )
32+ parser .add_argument ("--seq_len" , type = int , default = 2048 , help = "Number of tokens to generate (generation length)" )
33+ parser .add_argument ("--prompt_len" , type = int , default = None , help = "Desired prompt/prefill length in tokens (optional)" )
2534 parser .add_argument ("--warmup_steps" , type = int , default = 1 , help = "Number of warmup steps" )
35+ parser .add_argument ("--batch_size" , type = int , default = 1 , help = "Client batch size (number of sequences to generate in parallel)" )
2636 args = parser .parse_args ()
2737
2838 if args .n_processes == "n_gpus" :
@@ -45,34 +55,135 @@ def main():
4555def benchmark_inference (process_idx , args , result_pipe ):
4656 tokenizer = AutoTokenizer .from_pretrained (args .model , use_fast = False )
4757 # Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
58+
59+ # Set pad_token for LLaMA tokenizer (required for batch padding)
60+ if tokenizer .pad_token is None :
61+ tokenizer .pad_token = tokenizer .eos_token
62+ logger .info (f"[DEBUG] Set pad_token to eos_token: { tokenizer .pad_token } " )
4863
4964 model = AutoDistributedModelForCausalLM .from_pretrained (
50- args .model , initial_peers = args .initial_peers , torch_dtype = DTYPE_MAP [args .torch_dtype ]
65+ args .model , initial_peers = args .initial_peers , torch_dtype = DTYPE_MAP [args .torch_dtype ],
66+ use_server_to_server = True # Explicitly enable server-to-server communication
5167 )
5268 logger .info (f"Created model: { process_idx = } { model .device = } " )
5369
54- test_prompt = ""
55- input_ids = tokenizer .encode (test_prompt , return_tensors = "pt" , add_special_tokens = True )
70+ # Prepare batch of prompts for benchmarking
71+ batch_size = getattr (args , 'batch_size' , 1 )
72+
73+ # Create different prompts for each batch to verify independent generation
74+ if batch_size == 1 :
75+ prompts = ["" ]
76+ elif batch_size == 2 :
77+ prompts = ["Once upon a time" , "In a galaxy far away" ]
78+ elif batch_size == 3 :
79+ prompts = ["Once upon a time" , "In a galaxy far away" , "The quick brown fox" ]
80+ else :
81+ base_prompt = (
82+ "Quantum mechanics explains the behavior of particles at very small scales. "
83+ "Neural networks learn patterns by adjusting weights through backpropagation. "
84+ "Distributed systems require robust consensus mechanisms to maintain state. "
85+ "Optimization algorithms like gradient descent are fundamental to machine learning. "
86+ "Transformer architectures rely on attention mechanisms to capture dependencies. "
87+ "Reinforcement learning optimizes actions by maximizing cumulative rewards. "
88+ "Bayesian inference updates beliefs based on observed evidence and prior knowledge. "
89+ "Convex optimization problems guarantee global minima under certain conditions. "
90+ "Signal processing extracts meaningful information from noisy measurements. "
91+ )
92+ prompts = [
93+ f"{ base_prompt } Example { i + 1 } discusses large-scale AI systems and scientific discovery."
94+ for i in range (batch_size )
95+ ]
96+
97+ if args .prompt_len is None :
98+ encodings = tokenizer (prompts , return_tensors = "pt" , padding = True , add_special_tokens = True )
99+ input_ids = encodings ["input_ids" ]
100+ else :
101+ target_prompt_length = args .prompt_len
102+ bos_token_id = tokenizer .bos_token_id
103+ filler_sentence = (
104+ " Advanced research explores interdisciplinary insights, collaborative innovation, "
105+ "scientific computation, trustworthy deployment, and sustainable engineering practices."
106+ )
107+ filler_tokens = tokenizer (filler_sentence , add_special_tokens = False )["input_ids" ]
108+ if not filler_tokens :
109+ filler_tokens = [tokenizer .eos_token_id or tokenizer .pad_token_id or 0 ]
110+ processed = []
111+ for prompt in prompts :
112+ prompt_tokens = tokenizer (prompt , add_special_tokens = False )["input_ids" ]
113+ if bos_token_id is not None :
114+ full_tokens = [bos_token_id ] + prompt_tokens
115+ else :
116+ full_tokens = prompt_tokens [:]
117+ if len (full_tokens ) >= target_prompt_length :
118+ full_tokens = full_tokens [:target_prompt_length ]
119+ else :
120+ while len (full_tokens ) < target_prompt_length :
121+ need = target_prompt_length - len (full_tokens )
122+ full_tokens .extend (filler_tokens [:need ])
123+ processed .append (full_tokens )
124+ input_ids = torch .tensor (processed , dtype = torch .long )
125+
126+ logger .info (f"[DEBUG] { process_idx = } Client batch_size={ batch_size } , input_ids.shape={ input_ids .shape } " )
127+ for i , prompt in enumerate (prompts ):
128+ logger .info (f"[DEBUG] { process_idx = } batch[{ i } ] prompt: '{ prompt } ' (token_ids: { input_ids [i ].tolist ()} )" )
56129 temp_result_tokens = input_ids
57130
131+ # Calculate max_length: prompt_length + number of tokens to generate
132+ prompt_length = input_ids .shape [1 ]
133+ if args .prompt_len is not None :
134+ target_prompt_length = args .prompt_len
135+ pad_token_id = tokenizer .pad_token_id if tokenizer .pad_token_id is not None else tokenizer .eos_token_id
136+ if target_prompt_length < prompt_length :
137+ input_ids = input_ids [:, :target_prompt_length ]
138+ elif target_prompt_length > prompt_length :
139+ extra = target_prompt_length - prompt_length
140+ pad_block = torch .full ((batch_size , extra ), pad_token_id , dtype = input_ids .dtype )
141+ input_ids = torch .cat ([input_ids , pad_block ], dim = 1 )
142+ prompt_length = target_prompt_length
143+ temp_result_tokens = input_ids
144+ logger .info (f"[DEBUG] { process_idx = } adjusted prompt_length to { prompt_length } tokens" )
145+
146+ total_max_length = prompt_length + args .seq_len
147+ logger .info (f"[DEBUG] { process_idx = } prompt_length={ prompt_length } , generating { args .seq_len } tokens, total_max_length={ total_max_length } " )
148+
58149 step_times = []
59150
60- with model .transformer .h .inference_session (max_length = args .seq_len ) as sess :
151+ with model .transformer .h .inference_session (max_length = total_max_length ) as sess :
152+ logger .info (f"[DEBUG] { process_idx = } Created inference session with max_length={ total_max_length } " )
153+ logger .info (f"[BENCHMARK_START] Process={ process_idx } | BatchSize={ batch_size } | SeqLen={ args .seq_len } " )
154+
61155 for step in range (args .seq_len ):
62- start_time = perf_counter ()
156+ step_start_time = perf_counter ()
157+
158+ # For the first step, pass input_ids; for subsequent steps, generate() will use session state
159+ if step == 0 :
160+ logger .info (f"[DEBUG] { process_idx = } { step = } First step, passing input_ids.shape={ input_ids .shape } " )
161+ outputs = model .generate (input_ids , max_new_tokens = 1 , session = sess )
162+ else :
163+ outputs = model .generate (max_new_tokens = 1 , session = sess )
164+
165+ # Log generated tokens for all sequences in the batch
166+ for batch_idx in range (outputs .shape [0 ]):
167+ new_token_id = outputs [batch_idx ][- 1 ].item ()
168+ new_token_text = tokenizer .decode ([new_token_id ])
169+ logger .info (f"[DEBUG] { process_idx = } { step = } batch[{ batch_idx } ] Generated token: '{ new_token_text } ' (id={ new_token_id } )" )
63170
64- outputs = model .generate (max_new_tokens = 1 , session = sess )
65- new_token_id = outputs [0 ][- 1 ].item ()
66- new_token_text = tokenizer .decode ([new_token_id ])
67171 temp_result_tokens = torch .cat ([temp_result_tokens , outputs [:, - 1 :]], dim = 1 )
68172
69173 if step >= args .warmup_steps :
70- step_times .append (perf_counter () - start_time )
174+ step_times .append (perf_counter () - step_start_time )
71175 speed = 1 / np .mean (step_times )
72- logger .info (f"{ process_idx = } { step = } { speed = :.2f} " )
176+ # Report speed per sequence (total tokens / time)
177+ effective_speed = speed * batch_size
178+ logger .info (f"{ process_idx = } { step = } { speed = :.2f} tokens/sec/sequence, effective={ effective_speed :.2f} tokens/sec" )
179+
180+ # Show final generated text for each batch
181+ for batch_idx in range (temp_result_tokens .shape [0 ]):
182+ full_text = tokenizer .decode (temp_result_tokens [batch_idx ], skip_special_tokens = True )
183+ logger .info (f"\n batch[{ batch_idx } ] Full generated text:\n { full_text } \n " )
73184
74185 result_pipe .send (speed )
75186
76187
77188if __name__ == "__main__" :
78- main ()
189+ main ()
0 commit comments