Skip to content

Commit 2085d1d

Browse files
committed
Merge branch 'main' into dev-cb-heterogeneous-tkv
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
2 parents a58087f + 11562e9 commit 2085d1d

File tree

15 files changed

+124
-239
lines changed

15 files changed

+124
-239
lines changed

.github/workflows/build_docker.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@ on:
1010
pull_request:
1111
branches:
1212
- "main"
13-
paths-ignore:
14-
- "**.md"
13+
paths:
14+
- ".github/workflows/build_docker.yml"
15+
- "docker/**"
16+
- "vllm_spyre/**/*.py"
17+
- "pyproject.toml"
18+
- "uv.lock"
1519
release:
1620
types: [published]
1721

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ jobs:
7070
files: |
7171
.github/workflows/test.yml
7272
pyproject.toml
73+
uv.lock
7374
tests/**/*.py
7475
vllm_spyre/**/*.py
7576

.github/workflows/type_check.yaml

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,52 @@
11
name: Type Check
22

33
on:
4-
# Trigger the workflow on push or pull request, but only for the main branch.
5-
# Don't use pull_request.paths filter since this workflow is required for
6-
# all pull requests on main irrespective of file type or location.
4+
# Don't use `paths` or `paths-ignore` filter since this workflow is required
5+
# for all pull requests on main irrespective of file type or location
6+
# Use `changed-src-files` step to determine if source code was changed
77
pull_request:
88
branches:
99
- main
1010
push:
1111
branches:
1212
- main
13-
paths:
14-
- '**/*.py'
15-
- '.github/workflows/type_check.yaml'
16-
- 'tools/type_check.sh'
17-
- 'pyproject.toml'
1813

1914
jobs:
2015
type-check:
2116
runs-on: ubuntu-latest
2217
strategy:
2318
matrix:
24-
python-version: ["3.9", "3.10", "3.11", "3.12"]
19+
python-version: ["3.10", "3.11", "3.12"]
2520
steps:
26-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
27-
- name: Set up Python ${{ matrix.python-version }}
28-
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
21+
- name: "Checkout"
22+
uses: actions/checkout@v4
23+
24+
- name: "Get changed source files"
25+
id: changed-src-files
26+
uses: tj-actions/changed-files@v46
27+
with: # Avoid using single or double quotes for multiline patterns
28+
files: |
29+
.github/workflows/type_check.yaml
30+
tools/type_check.sh
31+
pyproject.toml
32+
**.py
33+
34+
- name: "Set up Python ${{ matrix.python-version }}"
35+
if: steps.changed-src-files.outputs.any_changed == 'true'
36+
uses: astral-sh/setup-uv@v5
2937
with:
3038
python-version: ${{ matrix.python-version }}
31-
- name: Install dependencies
32-
run: |
33-
# TODO: use `uv`
34-
python -m pip install --upgrade pip
35-
pip install mypy==1.11.1
36-
pip install types-setuptools
37-
pip install types-PyYAML
38-
pip install types-requests
39-
pip install types-setuptools
40-
- name: Mypy
39+
enable-cache: true
40+
ignore-nothing-to-cache: true
41+
cache-dependency-glob: |
42+
pyproject.toml
43+
44+
- name: "Install dependencies"
45+
if: steps.changed-src-files.outputs.any_changed == 'true'
46+
run: uv sync --frozen --only-group lint
47+
48+
- name: "Run mypy"
49+
if: steps.changed-src-files.outputs.any_changed == 'true'
4150
run: |
4251
echo "::add-matcher::.github/workflows/matchers/mypy.json"
4352
tools/type_check.sh 1 ${{ matrix.python-version }}

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,6 @@ benchmarks/*.json
193193
# Linting
194194
actionlint
195195
shellcheck*/
196+
197+
# version file generated by setuptools-scm
198+
/vllm_spyre/_version.py

docker/Dockerfile.amd64

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ ENV COMPILATION_MODE=offline_decoder \
7272
FLEX_COMPUTE=SENTIENT \
7373
FLEX_DEVICE=PF \
7474
FLEX_OVERWRITE_NMB_FRAME=1 \
75-
FLEX_UNLINK_DEVMEM=false \
76-
FLEX_RDMA_MODE_FULL=1 \
7775
TOKENIZERS_PARALLELISM=false \
7876
TORCH_SENDNN_LOG=WARNING
7977

docs/.nav.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ nav:
88
- Kubernetes: deploying/k8s.md
99
- Examples:
1010
- Offline Inference: examples/offline_inference
11-
- Other: examples/other
11+
- Online Inference: examples/online_inference
1212
- User Guide:
1313
- Configuration: user_guide/configuration.md
1414
- Environment Variables: user_guide/env_vars.md
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,29 @@
1+
"""
2+
This example shows how to run offline inference using continuous batching
3+
on CPU.
4+
"""
5+
6+
import argparse
17
import os
28
import platform
39
import time
410

511
from vllm import LLM, SamplingParams
612

7-
# RUN with fms branch: https://github.com/foundation-model-stack/
8-
# foundation-model-stack/tree/paged_attn_mock
13+
# Continuous batching currently requires installing the branch
14+
# https://github.com/foundation-model-stack/foundation-model-stack/tree/paged_attn_mock
15+
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument("--model", type=str, default="/models/llama-194m")
18+
parser.add_argument("--max_model_len", type=int, default=2048)
19+
parser.add_argument("--max_num_seqs", type=int, default=2)
20+
parser.add_argument("--tp", type=int, default=1)
21+
args = parser.parse_args()
922

1023
max_tokens1 = 65
1124
max_tokens2 = 67
1225
max_tokens3 = 7
13-
max_num_seqs = 2 # defines max batch size
26+
max_num_seqs = args.max_num_seqs # defines the max batch size
1427

1528
if platform.machine() == "arm64":
1629
print("Detected arm64 running environment. "
@@ -19,58 +32,40 @@
1932
"locally on arm64.")
2033
os.environ["HF_HUB_OFFLINE"] = "1"
2134

22-
# defining here to be able to run/debug directly from VSC (not via terminal)
23-
os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager'
35+
if "VLLM_SPYRE_DYNAMO_BACKEND" not in os.environ:
36+
os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager'
2437
os.environ['VLLM_SPYRE_USE_CB'] = '1'
2538
os.environ['VLLM_SPYRE_HETEROGEN_TKV'] = '0'
2639
os.environ['VLLM_USE_V1'] = '1'
2740

28-
# Sample prompts.
2941
template = (
3042
"Below is an instruction that describes a task. Write a response that "
3143
"appropriately completes the request. Be polite in your response to the "
3244
"user.\n\n### Instruction:\n{}\n\n### Response:")
3345

34-
prompt1 = template.format(
35-
"Provide a list of instructions for preparing chicken soup for a family "
36-
"of four.")
37-
38-
prompt2 = template.format("Provide instructions for preparing chicken soup.")
39-
40-
prompt3 = template.format(
41-
"Provide a list of instructions for preparing chicken soup for a family.")
42-
43-
prompts = [
44-
prompt1,
45-
prompt2,
46-
prompt3,
46+
instructions = [
47+
"Provide a list of instructions for preparing chicken soup for a family" + \
48+
" of four.",
49+
"Provide instructions for preparing chicken soup.",
50+
"Provide a list of instructions for preparing chicken soup for a family.",
4751
]
4852

49-
# Create a sampling params object.
50-
sampling_params1 = SamplingParams(max_tokens=max_tokens1,
51-
temperature=0.0,
52-
ignore_eos=True)
53-
54-
sampling_params2 = SamplingParams(max_tokens=max_tokens2,
55-
temperature=0.0,
56-
ignore_eos=True)
53+
prompts = [template.format(instr) for instr in instructions]
5754

58-
sampling_params3 = SamplingParams(max_tokens=max_tokens3,
59-
temperature=0.0,
60-
ignore_eos=True)
55+
max_tokens_list = [max_tokens1, max_tokens2, max_tokens3]
6156

6257
sampling_params = [
63-
sampling_params1,
64-
sampling_params2,
65-
sampling_params3,
58+
SamplingParams(max_tokens=mt, temperature=0.0, ignore_eos=True)
59+
for mt in max_tokens_list
6660
]
6761

6862
# Create an LLM.
69-
llm = LLM(model="/models/llama-194m",
70-
tokenizer="/models/llama-194m",
71-
max_model_len=2048,
63+
llm = LLM(model=args.model,
64+
tokenizer=args.model,
65+
max_model_len=args.max_model_len,
7266
block_size=2048,
73-
max_num_seqs=max_num_seqs)
67+
max_num_seqs=max_num_seqs,
68+
tensor_parallel_size=args.tp)
7469

7570
# Generate texts from the prompts. The output is a list of RequestOutput objects
7671
# that contain the prompt, generated text, and other information.
@@ -81,9 +76,11 @@
8176
(len(outputs[0].outputs[0].token_ids), time.time() - t0))
8277
print("===============")
8378
for output in outputs:
84-
prompt = output.prompt
85-
generated_text = output.outputs[0].text
86-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
79+
print(output.outputs[0])
8780
print("===============")
8881
for output in outputs:
89-
print(output.outputs[0])
82+
prompt = output.prompt
83+
generated_text = output.outputs[0].text
84+
print(f"\nPrompt:\n {prompt!r}")
85+
print(f"\nGenerated text:\n {generated_text!r}\n")
86+
print("-----------------------------------")

examples/offline_inference/offline_inference_multi_spyre.py renamed to examples/offline_inference/multi_spyre_inference.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""
2+
This example shows how to use Spyre with vLLM for running offline inference
3+
with multiple cards.
4+
"""
5+
16
import gc
27
import os
38
import platform
@@ -18,13 +23,12 @@
1823
os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens)
1924
os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = '1'
2025

21-
# stuff for multi-spyre
26+
# Multi-spyre related variables
2227
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
2328
os.environ["DISTRIBUTED_STRATEGY_IGNORE_MODULES"] = "WordEmbedding"
2429
os.environ["MASTER_ADDR"] = "localhost"
2530
os.environ["MASTER_PORT"] = "12355"
2631

27-
# Sample prompts.
2832
template = (
2933
"Below is an instruction that describes a task. Write a response that "
3034
"appropriately completes the request. Be polite in your response to the "

examples/offline_inference/offline_inference_spyre.py renamed to examples/offline_inference/spyre_inference.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
This example shows how to use Spyre with vLLM for running offline inference.
3+
"""
4+
15
import os
26
import platform
37
import time
@@ -17,19 +21,17 @@
1721
os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens)
1822
os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = '1'
1923

20-
# Sample prompts.
2124
template = (
2225
"Below is an instruction that describes a task. Write a response that "
2326
"appropriately completes the request. Be polite in your response to the "
2427
"user.\n\n### Instruction:\n{}\n\n### Response:")
25-
prompt1 = template.format(
26-
"Provide a list of instructions for preparing chicken soup for a family "
27-
"of four.")
2828
prompts = [
29-
prompt1,
29+
template.format(
30+
"Provide a list of instructions for preparing chicken soup for a" + \
31+
" family of four.",
32+
)
3033
]
3134

32-
# Create a sampling params object.
3335
sampling_params = SamplingParams(max_tokens=max_tokens,
3436
temperature=0.0,
3537
ignore_eos=True)

0 commit comments

Comments
 (0)