Skip to content

Commit df53cc4

Browse files
Merge branch 'main' into main
2 parents fc906f7 + b54a964 commit df53cc4

File tree

8 files changed

+171
-88
lines changed

8 files changed

+171
-88
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ First time here? Go to our [setup guide](https://outlines-dev.github.io/outlines
3737
- [x] 💾 Caching of generations
3838
- [x] 🗂️ Batch inference
3939
- [x] 🎲 Sample with the greedy, multinomial and beam search algorithms (and more to come!)
40-
- [x] 🚀 [Serve with vLLM](https://outlines-dev.github.io/outlines/reference/vllm), with official Docker image, [`outlinesdev/outlines`](https://hub.docker.com/r/outlinesdev/outlines)!
40+
- [x] 🚀 [Serve with vLLM](https://outlines-dev.github.io/outlines/reference/serve/vllm), with official Docker image, [`outlinesdev/outlines`](https://hub.docker.com/r/outlinesdev/outlines)!
4141

4242

4343
Outlines 〰 has new releases and features coming every week. Make sure to ⭐ star and 👀 watch this repository, follow [@dottxtai][dottxt-twitter] to stay up to date!

benchmarks/bench_processors.py

Lines changed: 81 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
import mlx.core as mx
21
import numpy as np
32
import torch
43

5-
from outlines.processors import OutlinesLogitsProcessor
4+
import outlines.models as models
5+
from outlines.processors import OutlinesLogitsProcessor, RegexLogitsProcessor
6+
7+
try:
8+
import mlx.core as mx
9+
except ImportError:
10+
pass
611

712

813
def is_mlx_lm_allowed():
@@ -13,40 +18,91 @@ def is_mlx_lm_allowed():
1318
return mx.metal.is_available()
1419

1520

21+
def get_mock_processor_inputs(array_library, num_tokens=30000):
22+
"""
23+
logits: (4, 30,000 ) dtype=float
24+
input_ids shape: (4, 2048) dtype=int
25+
"""
26+
if array_library == "torch":
27+
logits = torch.rand((4, num_tokens), dtype=torch.float)
28+
input_ids = torch.randint(
29+
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int
30+
)
31+
elif array_library == "torch_cuda":
32+
logits = torch.rand((4, num_tokens), dtype=torch.float, device="cuda")
33+
input_ids = torch.randint(
34+
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int, device="cuda"
35+
)
36+
elif array_library == "numpy":
37+
logits = np.random.rand(4, num_tokens).astype(np.float32)
38+
input_ids = np.random.randint(low=0, high=num_tokens, size=(4, 2048))
39+
elif array_library == "mlx":
40+
logits = mx.random.uniform(
41+
low=-1e9, high=1e9, shape=(4, num_tokens), dtype=mx.float32
42+
)
43+
input_ids = mx.random.randint(
44+
low=0, high=num_tokens, shape=(4, 2048), dtype=mx.int32
45+
)
46+
else:
47+
raise ValueError
48+
49+
return logits, input_ids
50+
51+
1652
class HalvingLogitsProcessor(OutlinesLogitsProcessor):
1753
"""Simply halve the passed logits"""
1854

1955
def process_logits(self, input_ids, logits):
2056
return logits / 2
2157

2258

23-
class LogitsProcessorBenchmark:
59+
class LogitsProcessorPassthroughBenchmark:
60+
"""
61+
Benchmark the time it takes to convert between array frameworks
62+
This should be on the order of microseconds
63+
"""
64+
2465
params = ["torch", "numpy"]
25-
if mx.metal.is_available():
66+
if is_mlx_lm_allowed():
2667
params += ["mlx"]
68+
if torch.cuda.is_available():
69+
params += ["torch_cuda"]
2770

2871
def setup(self, array_library):
2972
self.logits_processor = HalvingLogitsProcessor()
3073

31-
# logits: (4, 30,000 ) dtype=float
32-
# input_ids shape: (4, 2048) dtype=int
33-
if array_library == "torch":
34-
self.logits = torch.rand((4, 30000), dtype=torch.float)
35-
self.input_ids = torch.randint(
36-
low=0, high=30000, size=(4, 2048), dtype=torch.int
37-
)
38-
elif array_library == "numpy":
39-
self.logits = np.random.rand(4, 30000).astype(np.float32)
40-
self.input_ids = np.random.randint(low=0, high=30000, size=(4, 2048))
41-
elif array_library == "mlx":
42-
self.logits = mx.random.uniform(
43-
low=-1e9, high=1e9, shape=(4, 30000), dtype=mx.float32
44-
)
45-
self.input_ids = mx.random.randint(
46-
low=0, high=30000, shape=(4, 2048), dtype=mx.int32
47-
)
48-
else:
49-
raise ValueError
50-
51-
def time_logits_processor(self, array_library):
74+
self.logits, self.input_ids = get_mock_processor_inputs(array_library)
75+
76+
def time_passthrough(self, *params):
77+
self.logits_processor(self.input_ids, self.logits)
78+
79+
80+
class LogitsProcessorStructuredBenchmark:
81+
"""
82+
Benchmark structured generation mask application for single decoder pass
83+
"""
84+
85+
array_libraries = ["torch", "numpy"]
86+
if is_mlx_lm_allowed():
87+
array_libraries += ["mlx"]
88+
# PR TODO
89+
if torch.cuda.is_available():
90+
array_libraries += ["torch_cuda"]
91+
92+
# accept very many or very few tokens, respectively
93+
patterns = [r"[^Z]*", "Z*"]
94+
95+
params = [array_libraries, patterns]
96+
param_names = ["array_library, pattern"]
97+
98+
def setup(self, array_library, pattern):
99+
tokenizer = models.transformers("facebook/opt-125m", device="cpu").tokenizer
100+
101+
self.logits_processor = RegexLogitsProcessor(pattern, tokenizer)
102+
103+
self.logits, self.input_ids = get_mock_processor_inputs(
104+
array_library, len(tokenizer.vocabulary)
105+
)
106+
107+
def time_structured_generation(self, array_library, pattern):
52108
self.logits_processor(self.input_ids, self.logits)

docs/community/contribute.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,21 @@ You can run the benchmark test suite locally with the following command:
6666
asv run --config benchmarks/asv.conf.json
6767
```
6868

69-
Run a specific test:
69+
Caveats:
70+
- If you're on a device with CUDA, you must add the argument `--launch-method spawn`
71+
- Uncommitted code will not be benchmarked, you must first commit your changes.
72+
73+
#### Run a specific test:
7074
```
7175
asv run --config benchmarks/asv.conf.json -b bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm
7276
```
7377

74-
Profile a specific test:
78+
#### Profile a specific test:
7579
```
7680
asv run --config benchmarks/asv.conf.json --profile -b bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm
7781
```
7882

79-
Compare to `origin/main`
83+
#### Compare to `origin/main`
8084
```
8185
get fetch origin
8286
asv continuous origin/main HEAD --config benchmarks/asv.conf.json

docs/reference/prompting.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ pretty print a dictionary from within an Outlines prompt function
260260
def my_prompt(response_model):
261261
"""{{ response_model | schema }}"""
262262

263-
my_prompt(MyResponse)
263+
prompt = my_prompt(MyResponse)
264+
print(prompt)
264265
# {
265266
# "field1": "an int",
266267
# "field2": "<field2>"

outlines/fsm/guide.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import TYPE_CHECKING, List, Optional, Protocol, Tuple, Union
33

44
import interegular
5+
import torch
56
from lark import Lark
67

78
from outlines import grammars
@@ -146,6 +147,13 @@ def __init__(self, regex_string: str, tokenizer):
146147
self.eos_token_id = tokenizer.eos_token_id
147148
self.final_states = fsm_finals | {-1}
148149

150+
# cache returned masks token masks
151+
# this increases performance of the mask substantially
152+
self.states_to_token_mask = {
153+
state: torch.tensor(list(next_tokens_to_end_states.keys()))
154+
for state, next_tokens_to_end_states in self.states_to_token_maps.items()
155+
}
156+
149157
def get_next_instruction(self, state: int) -> Instruction:
150158
"""Return the next instruction for guided generation.
151159
@@ -169,11 +177,11 @@ def get_next_instruction(self, state: int) -> Instruction:
169177
A `Generate` instance that contains the model and the allowed token ids.
170178
171179
"""
172-
next_tokens_to_end_states = self.states_to_token_maps.get(state)
173-
if next_tokens_to_end_states is None:
174-
return Write([self.eos_token_id])
180+
next_tokens_mask = self.states_to_token_mask.get(state)
181+
if next_tokens_mask is None:
182+
return Write(torch.tensor([self.eos_token_id]))
175183

176-
return Generate(list(next_tokens_to_end_states.keys()))
184+
return Generate(next_tokens_mask)
177185

178186
def get_next_state(self, state: int, token_id: int) -> int:
179187
"""Update the state of the guide.

outlines/generate/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def __call__(
231231

232232
# We reshape the output to (batch_size, sample_size)
233233
output: List[List[FormattedOutput]] = list()
234-
for i in range(batch_size):
234+
for i in range(0, batch_size * num_samples, num_samples):
235235
output.append(formatted[i : i + num_samples])
236236

237237
# We remove leading dimensions for the output
@@ -372,7 +372,7 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
372372
previously_generated_sequences = generated_sequences
373373
# We reshape the output to (batch_size, sample_size)
374374
output: List[List[str]] = list()
375-
for i in range(batch_size):
375+
for i in range(0, batch_size * num_samples, num_samples):
376376
output.append(next_tokens[i : i + num_samples])
377377

378378
# We remove leading dimensions for the output

0 commit comments

Comments
 (0)