Skip to content

Commit 862bd3b

Browse files
Add multi-batch inference support, fix hivemind dependency, and improve installation process (#27)
* Add batch inference support and CPU compatibility - Add --batch_size CLI argument for parallel sequence processing - Add conditional CUDA stream creation for CPU-only mode - Add device-aware ExecutionEnv and Policy resource distribution - Fix MPS compatibility on macOS * fix hardcode of model loading and support batch size * Resolving dependency conflicts * docs: refine README setup and usage sections for clarity and correctness * Add batch size related updates * delete ddebug output * delete .id files * fix max token size problem * add prompt * clear the debug print --------- Co-authored-by: Danny Willow Liu <dannywillowliu@uchicago.edu>
1 parent c3f0e88 commit 862bd3b

25 files changed

+432
-132
lines changed

README.md

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -25,78 +25,90 @@ pip install bloombee
2525
```bash
2626
git clone https://github.com/ai-decentralized/BloomBee.git
2727
cd BloomBee
28-
python3 -m venv bloombee-venv
29-
source bloombee-venv/bin/activate
30-
pip install -e .
31-
32-
pip install pynvml
33-
pip install attrs
3428
```
35-
If you are using Hivemind (required for BloomBee setup), please install this as well:
29+
Create and activate an environment (either one):
30+
31+
```bash
32+
# Using venv
33+
python3 -m venv bloombee-venv && source bloombee-venv/bin/activate
34+
35+
# OR using conda (recommended)
36+
conda create -n bloombee python=3.10.16 && conda activate bloombee
3637
```
37-
git clone https://github.com/learning-at-home/hivemind
38-
cd hivemind
39-
pip install -e .
4038

39+
Then install:
40+
41+
```bash
42+
pip install -e .
4143
```
4244
## How to use BloomBee(<a href="https://colab.research.google.com/drive/1pENMOEoEV01DqBImZzuX_4jTV3fNwNga#scrollTo=oyCFDemCZsRs">Try now in Colab</a>)
43-
#### 1. Start the main server
44-
```
45-
python -m bloombee.cli.run_dht --host_maddrs /ip4/0.0.0.0/tcp/31340 --identity_path bootstrapp1.id
45+
46+
#### 1. Start the main server
47+
Start the DHT main node:
48+
```bash
49+
python -m bloombee.cli.run_dht --host_maddrs /ip4/0.0.0.0/tcp/31340 --identity_path bootstrapp1.id
50+
````
51+
52+
After running, you will see output similar to:
4653

4754
```
48-
Now you will get the BloomBee's main server location:
55+
[INFO] Running a DHT instance. To connect other peers to this one, use:
56+
--initial_peers /ip4/10.0.4.215/tcp/31340/p2p/QmZtZJwF8G2qspQxEVxXfipV4fR7EgpfnkXdbbzaEooaVf
4957
```
50-
Mon 00 01:23:45.678 [INFO] Running a DHT instance. To connect other peers to this one, use --initial_peers /ip4/YOUR_IP_ADDRESS/tcp/31340/p2p/QmefxzDL1DaJ7TcrZjLuz7Xs9sUVKpufyg7f5276ZHFjbQ
51-
```
52-
You can provide this address as --initial_peers to workers or other backbone servers.
5358
54-
If you want your swarm to be accessible outside of your local network, ensure that you have a **public IP address** or set up **port forwarding** correctly, so that your peer is reachable from the outside.
59+
Copy **your own** full address (including the `/p2p/...` part).
60+
Each DHT node generates a unique Peer ID, so do **not** copy the example above.
5561
56-
#### 2. Connect the workers to the main bloombee server
57-
Here is the BloomBee Server location:
58-
```
59-
export BBSERVER=/ip4/10.52.2.249/tcp/31340/p2p/QmefxzDL1DaJ7TcrZjLuz7Xs9sUVKpufyg7f5276ZHFjbQ
62+
You can provide this address as `--initial_peers` to connect workers or other backbone servers.
6063
61-
```
62-
To setup the workers, connect to the GPUs being used (If using remote SSH to instance):
63-
```
64-
chmod 400 ~/.ssh/<YOURKEYPAIR>.pem
65-
ssh -i ~/.ssh/<YOURKEYPAIR.pem cc@<FLOATING IP>
66-
```
67-
Next, make sure that the workers are fully set up in the BloomBee environment.
68-
```
69-
git clone https://github.com/ai-decentralized/BloomBee.git
70-
cd BloomBee
71-
python3 -m venv bloombee-venv
72-
source bloombee-venv/bin/activate
73-
pip install -e .
64+
> 💡 **Tip:**
65+
> If you want your swarm to be accessible outside of your local network,
66+
> ensure you have a **public IP address** or set up **port forwarding** correctly.
7467
75-
pip install pynvml
76-
pip install attrs
68+
---
7769
78-
git clone https://github.com/learning-at-home/hivemind
79-
cd hivemind
80-
pip install -e .
81-
```
82-
Start one worker to hold 16 blocks (16 tranformer layers)
83-
```
84-
python -m bloombee.cli.run_server huggyllama/llama-7b --initial_peers $BBSERVER --num_blocks 16 --identity_path bootstrap_1.id
85-
```
86-
Start second worker to hold another 16 blocks (16 tranformer layers)
70+
#### 2. Connect the workers to the main BloomBee server
71+
72+
Set your main server address (replace with your actual output from step 1):
73+
74+
```bash
75+
export BBSERVER=/ip4/10.0.4.215/tcp/31340/p2p/QmZtZJwF8G2qspQxEVxXfipV4fR7EgpfnkXdbbzaEooaVf
8776
```
88-
python -m bloombee.cli.run_server huggyllama/llama-7b --initial_peers $BBSERVER --num_blocks 16 --identity_path bootstrap_1.id
77+
78+
Activate the BloomBee environment on each worker
79+
(you can reuse the environment created in **From Source**).
80+
81+
Each worker should be started **in a separate terminal** (or on a separate node)
82+
after activating its environment.
83+
84+
Start the first worker to hold 16 blocks (e.g., 16 transformer layers):
85+
86+
```bash
87+
python -m bloombee.cli.run_server huggyllama/llama-7b \
88+
--initial_peers $BBSERVER --num_blocks 16 --identity_path bootstrap_1.id
8989
```
90-
In case your workers do not run do to IP connection resets, please configure the config files containing the workers' IPs.
9190

92-
If a bitsandbytes error comes up, please use this fix:
91+
Start the second worker in another activated terminal:
92+
93+
```bash
94+
python -m bloombee.cli.run_server huggyllama/llama-7b \
95+
--initial_peers $BBSERVER --num_blocks 16 --identity_path bootstrap_2.id
9396
```
94-
cd ~/BloomBee
95-
rm -rf bitsandbytes
97+
98+
If you encounter network issues (e.g., connection resets),
99+
please verify your worker IP configurations in the relevant config files.
100+
101+
**Optional:** If `bitsandbytes` causes a CUDA version error:
102+
103+
```bash
104+
cd ~
96105
git clone https://github.com/TimDettmers/bitsandbytes.git
97-
cd bitsandbytes
106+
cd bitsandbytes && python setup.py install
98107
```
99-
Make sure to set CUDA versions to the correct library paths if necessary.
108+
109+
Ensure your CUDA library path matches your environment.
110+
111+
100112

101113
#### 3. Run inference or finetune jobs
102114

benchmarks/benchmark_forward.py

100755100644
File mode changed.

benchmarks/benchmark_inference.py

Lines changed: 123 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import argparse
44
import multiprocessing as mp
55
from time import perf_counter
6+
import logging
67

78
import numpy as np
89
import torch
@@ -14,15 +15,24 @@
1415

1516
logger = get_logger()
1617

18+
# Set logging level to INFO to see all debug messages
19+
logging.basicConfig(
20+
level=logging.INFO,
21+
format='%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - %(message)s',
22+
datefmt='%Y-%m-%d %H:%M:%S'
23+
)
24+
1725

1826
def main():
1927
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
2028
parser.add_argument("--model", type=str, required=True, help="Model")
2129
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
2230
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
2331
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
24-
parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
32+
parser.add_argument("--seq_len", type=int, default=2048, help="Number of tokens to generate (generation length)")
33+
parser.add_argument("--prompt_len", type=int, default=None, help="Desired prompt/prefill length in tokens (optional)")
2534
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
35+
parser.add_argument("--batch_size", type=int, default=1, help="Client batch size (number of sequences to generate in parallel)")
2636
args = parser.parse_args()
2737

2838
if args.n_processes == "n_gpus":
@@ -45,34 +55,135 @@ def main():
4555
def benchmark_inference(process_idx, args, result_pipe):
4656
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
4757
# Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
58+
59+
# Set pad_token for LLaMA tokenizer (required for batch padding)
60+
if tokenizer.pad_token is None:
61+
tokenizer.pad_token = tokenizer.eos_token
62+
logger.info(f"[DEBUG] Set pad_token to eos_token: {tokenizer.pad_token}")
4863

4964
model = AutoDistributedModelForCausalLM.from_pretrained(
50-
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
65+
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype],
66+
use_server_to_server=True # Explicitly enable server-to-server communication
5167
)
5268
logger.info(f"Created model: {process_idx=} {model.device=}")
5369

54-
test_prompt = ""
55-
input_ids = tokenizer.encode(test_prompt, return_tensors="pt", add_special_tokens=True)
70+
# Prepare batch of prompts for benchmarking
71+
batch_size = getattr(args, 'batch_size', 1)
72+
73+
# Create different prompts for each batch to verify independent generation
74+
if batch_size == 1:
75+
prompts = [""]
76+
elif batch_size == 2:
77+
prompts = ["Once upon a time", "In a galaxy far away"]
78+
elif batch_size == 3:
79+
prompts = ["Once upon a time", "In a galaxy far away", "The quick brown fox"]
80+
else:
81+
base_prompt = (
82+
"Quantum mechanics explains the behavior of particles at very small scales. "
83+
"Neural networks learn patterns by adjusting weights through backpropagation. "
84+
"Distributed systems require robust consensus mechanisms to maintain state. "
85+
"Optimization algorithms like gradient descent are fundamental to machine learning. "
86+
"Transformer architectures rely on attention mechanisms to capture dependencies. "
87+
"Reinforcement learning optimizes actions by maximizing cumulative rewards. "
88+
"Bayesian inference updates beliefs based on observed evidence and prior knowledge. "
89+
"Convex optimization problems guarantee global minima under certain conditions. "
90+
"Signal processing extracts meaningful information from noisy measurements. "
91+
)
92+
prompts = [
93+
f"{base_prompt} Example {i + 1} discusses large-scale AI systems and scientific discovery."
94+
for i in range(batch_size)
95+
]
96+
97+
if args.prompt_len is None:
98+
encodings = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=True)
99+
input_ids = encodings["input_ids"]
100+
else:
101+
target_prompt_length = args.prompt_len
102+
bos_token_id = tokenizer.bos_token_id
103+
filler_sentence = (
104+
" Advanced research explores interdisciplinary insights, collaborative innovation, "
105+
"scientific computation, trustworthy deployment, and sustainable engineering practices."
106+
)
107+
filler_tokens = tokenizer(filler_sentence, add_special_tokens=False)["input_ids"]
108+
if not filler_tokens:
109+
filler_tokens = [tokenizer.eos_token_id or tokenizer.pad_token_id or 0]
110+
processed = []
111+
for prompt in prompts:
112+
prompt_tokens = tokenizer(prompt, add_special_tokens=False)["input_ids"]
113+
if bos_token_id is not None:
114+
full_tokens = [bos_token_id] + prompt_tokens
115+
else:
116+
full_tokens = prompt_tokens[:]
117+
if len(full_tokens) >= target_prompt_length:
118+
full_tokens = full_tokens[:target_prompt_length]
119+
else:
120+
while len(full_tokens) < target_prompt_length:
121+
need = target_prompt_length - len(full_tokens)
122+
full_tokens.extend(filler_tokens[:need])
123+
processed.append(full_tokens)
124+
input_ids = torch.tensor(processed, dtype=torch.long)
125+
126+
logger.info(f"[DEBUG] {process_idx=} Client batch_size={batch_size}, input_ids.shape={input_ids.shape}")
127+
for i, prompt in enumerate(prompts):
128+
logger.info(f"[DEBUG] {process_idx=} batch[{i}] prompt: '{prompt}' (token_ids: {input_ids[i].tolist()})")
56129
temp_result_tokens = input_ids
57130

131+
# Calculate max_length: prompt_length + number of tokens to generate
132+
prompt_length = input_ids.shape[1]
133+
if args.prompt_len is not None:
134+
target_prompt_length = args.prompt_len
135+
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
136+
if target_prompt_length < prompt_length:
137+
input_ids = input_ids[:, :target_prompt_length]
138+
elif target_prompt_length > prompt_length:
139+
extra = target_prompt_length - prompt_length
140+
pad_block = torch.full((batch_size, extra), pad_token_id, dtype=input_ids.dtype)
141+
input_ids = torch.cat([input_ids, pad_block], dim=1)
142+
prompt_length = target_prompt_length
143+
temp_result_tokens = input_ids
144+
logger.info(f"[DEBUG] {process_idx=} adjusted prompt_length to {prompt_length} tokens")
145+
146+
total_max_length = prompt_length + args.seq_len
147+
logger.info(f"[DEBUG] {process_idx=} prompt_length={prompt_length}, generating {args.seq_len} tokens, total_max_length={total_max_length}")
148+
58149
step_times = []
59150

60-
with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
151+
with model.transformer.h.inference_session(max_length=total_max_length) as sess:
152+
logger.info(f"[DEBUG] {process_idx=} Created inference session with max_length={total_max_length}")
153+
logger.info(f"[BENCHMARK_START] Process={process_idx} | BatchSize={batch_size} | SeqLen={args.seq_len}")
154+
61155
for step in range(args.seq_len):
62-
start_time = perf_counter()
156+
step_start_time = perf_counter()
157+
158+
# For the first step, pass input_ids; for subsequent steps, generate() will use session state
159+
if step == 0:
160+
logger.info(f"[DEBUG] {process_idx=} {step=} First step, passing input_ids.shape={input_ids.shape}")
161+
outputs = model.generate(input_ids, max_new_tokens=1, session=sess)
162+
else:
163+
outputs = model.generate(max_new_tokens=1, session=sess)
164+
165+
# Log generated tokens for all sequences in the batch
166+
for batch_idx in range(outputs.shape[0]):
167+
new_token_id = outputs[batch_idx][-1].item()
168+
new_token_text = tokenizer.decode([new_token_id])
169+
logger.info(f"[DEBUG] {process_idx=} {step=} batch[{batch_idx}] Generated token: '{new_token_text}' (id={new_token_id})")
63170

64-
outputs = model.generate(max_new_tokens=1, session=sess)
65-
new_token_id = outputs[0][-1].item()
66-
new_token_text = tokenizer.decode([new_token_id])
67171
temp_result_tokens = torch.cat([temp_result_tokens, outputs[:, -1:]], dim=1)
68172

69173
if step >= args.warmup_steps:
70-
step_times.append(perf_counter() - start_time)
174+
step_times.append(perf_counter() - step_start_time)
71175
speed = 1 / np.mean(step_times)
72-
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
176+
# Report speed per sequence (total tokens / time)
177+
effective_speed = speed * batch_size
178+
logger.info(f"{process_idx=} {step=} {speed=:.2f} tokens/sec/sequence, effective={effective_speed:.2f} tokens/sec")
179+
180+
# Show final generated text for each batch
181+
for batch_idx in range(temp_result_tokens.shape[0]):
182+
full_text = tokenizer.decode(temp_result_tokens[batch_idx], skip_special_tokens=True)
183+
logger.info(f"\nbatch[{batch_idx}] Full generated text:\n{full_text}\n")
73184

74185
result_pipe.send(speed)
75186

76187

77188
if __name__ == "__main__":
78-
main()
189+
main()

benchmarks/benchmark_training.py

100755100644
File mode changed.

bootstrapP100_7.id

-1.17 KB
Binary file not shown.

bootstrap_1.id

-1.17 KB
Binary file not shown.

bootstrapp1.id

-1.17 KB
Binary file not shown.

bootstrapp100.id

-1.17 KB
Binary file not shown.

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ packages = find:
3333
python_requires = >=3.8
3434
install_requires =
3535
torch>=1.12
36-
bitsandbytes==0.46.0
36+
bitsandbytes==0.41.0
3737
accelerate>=0.27.2
3838
huggingface-hub>=0.11.1,<1.0.0
3939
tokenizers>=0.13.3

setup.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def get_version():
5454
"tokenizers>=0.13.3",
5555
"transformers==4.43.1",
5656
"speedtest-cli==2.1.3",
57-
"hivemind",
57+
"hivemind @ git+https://github.com/learning-at-home/hivemind.git@213bff98a62accb91f254e2afdccbf1d69ebdea9",
5858
"tensor_parallel==1.0.23",
5959
"humanfriendly",
6060
"async-timeout>=4.0.2",
@@ -63,8 +63,10 @@ def get_version():
6363
"sentencepiece>=0.1.99",
6464
"peft==0.8.2",
6565
"safetensors>=0.3.1",
66-
"Dijkstar>=2.6.0",
67-
"numpy<2",
66+
"Dijkstar>=2.6.0",
67+
"numpy<2",
68+
"attrs",
69+
"nvidia-ml-py"
6870
],
6971
extras_require={
7072
"dev": [

0 commit comments

Comments
 (0)