Skip to content

Commit c38a57e

Browse files
committed
Draft
1 parent 222ca17 commit c38a57e

19 files changed

+1539
-54
lines changed

docs/features/advanced/backends.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
---
2+
title: Structured Generation Backends
3+
---
4+
5+
# Structured Generation Backends
6+
7+
Outlines relies on a structured generation backend to control text generation for steerable models such thah they conform to the output type provided. One of those backends is of course `outlines-core`, but you also have access to two other libraries that fulfill the same purpose: `llguidance` and `xgrammar`.
8+
9+
## Overview
10+
11+
To select the backend to use for your generation, provide a value for the `backend` argument when calling a model or a generator.
12+
13+
For instance:
14+
15+
```python
16+
from typing import Literal
17+
import outlines
18+
from transformers import AutoModelForCausalLM, AutoTokenizer
19+
20+
output_type = Literal["Paris", "London", "Rome", "Berlin"]
21+
22+
model = outlines.from_transformers(
23+
AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct"),
24+
AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
25+
)
26+
27+
result = model("What is the capital of France?", output_type, backend="llguidance")
28+
print(result) # 'Paris'
29+
30+
generator = outlines.Generaor(model, output_type)
31+
result = generator("What is the capital of France?", backend="xgrammar")
32+
print(result) # 'Paris'
33+
```
34+
35+
If you do not provide a value for the `backend` argument, the default value will be used. The default value depends on the type of output type:
36+
37+
- JSON schema: `outlines_core`
38+
- Regex: `outlines_core`
39+
- Context-free grammar: `llguidance`
40+
- Interegular FSM: `outlines_core`
41+
42+
## Features matrix
43+
44+
As mentioned previously, selecting the structured generation backend is only applicable to steerable models, so `Transformers`, `LlmaCpp` and `MLXLM`. Additionaly, some backends do not support some models within those or some output types.
45+
46+
| | outlines_core | llguidance | xgrammar |
47+
|---|---|---|---|
48+
| **Models** | | | |
49+
| Transformers ||||
50+
| LlamaCpp ||||
51+
| MLXLM ||||
52+
| **Output Types** | | | |
53+
| JSON Schema ||||
54+
| Regex ||||
55+
| Grammar ||||
56+
| FSM ||||

docs/features/models/llamacpp.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ for chunk in model.stream("Write a short story about a cat.", max_tokens=100):
109109

110110
## Structured Generation
111111

112-
The `LlamaCpp` model supports all output types available in Outlines except for context-free grammars. Simply provide an `output_type` after the prompt when calling the model.
112+
The `LlamaCpp` model supports all output types available in Outlines. Simply provide an `output_type` after the prompt when calling the model.
113113

114114
### Basic Type
115115

@@ -195,6 +195,29 @@ result = model("Generate a fake social security number.", output_type)
195195
print(result) # '782-32-3789'
196196
```
197197

198+
### Context-free grammar
199+
200+
```python
201+
from outlines.types import CFG
202+
import outlines
203+
from llama_cpp import Llama
204+
205+
output_type = CFG("""
206+
root ::= answer
207+
answer ::= "yes" | "no"
208+
""")
209+
210+
model = outlines.from_llamacpp(
211+
Llama.from_pretrained(
212+
repo_id="TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
213+
filename="mistral-7b-instruct-v0.2.Q5_K_M.gguf",
214+
)
215+
)
216+
217+
result = model("Are you feeling good today?", output_type)
218+
print(result) # 'yes'
219+
```
220+
198221
## Inference Arguments
199222

200223
When calling the model, you can provide optional inference parameters on top of the prompt and the output type. These parameters will be passed on to the `__call__` method of the `llama_cpp.Llama` model. Some common inference arguments include `max_tokens`, `temperature`, `frequency_penalty` and `top_p`.

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ nav:
161161

162162
- Advanced:
163163
- Logits Processors: features/advanced/logits_processors.md
164+
- Structured Generation Backends: features/advanced/backends.md
164165

165166
- API Reference: api_reference/
166167

outlines/backends/__init__.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Module to define the backends in charge of creating logits processors."""
2+
3+
import interegular
4+
5+
from outlines.backends.base import (
6+
BaseBackend,
7+
LogitsProcessorType,
8+
)
9+
from outlines.backends.llguidance import LLGuidanceBackend
10+
from outlines.backends.outlines_core import OutlinesCoreBackend
11+
from outlines.backends.xgrammar import XGrammarBackend
12+
from outlines.models import SteerableModel
13+
14+
15+
CFG_DEFAULT_BACKEND = "llguidance"
16+
FSM_DEFAULT_BACKEND = "outlines_core"
17+
JSON_SCHEMA_DEFAULT_BACKEND = "outlines_core"
18+
REGEX_DEFAULT_BACKEND = "outlines_core"
19+
20+
21+
def _get_backend(backend_name: str, model: SteerableModel) -> BaseBackend:
22+
"""Create a Backend instance.
23+
24+
Parameters
25+
----------
26+
backend_name: str
27+
The name of the backend to get.
28+
model: Model
29+
The Outlines model of the user.
30+
31+
Returns
32+
-------
33+
backend: BaseBackend
34+
The backend instance.
35+
36+
"""
37+
if backend_name == "outlines_core":
38+
return OutlinesCoreBackend(model)
39+
elif backend_name == "xgrammar":
40+
return XGrammarBackend(model)
41+
elif backend_name == "llguidance":
42+
return LLGuidanceBackend(model)
43+
else:
44+
raise ValueError(f"Backend {backend_name} not supported")
45+
46+
47+
def get_json_schema_logits_processor(
48+
backend_name: str | None,
49+
model: SteerableModel,
50+
json_schema: str,
51+
) -> LogitsProcessorType:
52+
"""Create a logits processor from a JSON schema.
53+
54+
Parameters
55+
----------
56+
backend_name: str | None
57+
The name of the backend to use.
58+
model: Model
59+
The Outlines model of the user.
60+
json_schema: str
61+
The JSON schema to create a logits processor from.
62+
63+
Returns
64+
-------
65+
LogitsProcessorType
66+
The logits processor.
67+
68+
"""
69+
backend = _get_backend(
70+
backend_name or JSON_SCHEMA_DEFAULT_BACKEND,
71+
model,
72+
)
73+
return backend.get_json_schema_logits_processor(json_schema)
74+
75+
76+
def get_regex_logits_processor(
77+
backend_name: str | None,
78+
model: SteerableModel,
79+
regex: str,
80+
) -> LogitsProcessorType:
81+
"""Create a logits processor from a regex.
82+
83+
Parameters
84+
----------
85+
backend_name: str | None
86+
The name of the backend to use.
87+
model: Model
88+
The Outlines model of the user.
89+
regex: str
90+
The regex to create a logits processor from.
91+
92+
Returns
93+
-------
94+
LogitsProcessorType
95+
The logits processor.
96+
97+
"""
98+
backend = _get_backend(
99+
backend_name or REGEX_DEFAULT_BACKEND,
100+
model,
101+
)
102+
return backend.get_regex_logits_processor(regex)
103+
104+
105+
def get_cfg_logits_processor(
106+
backend_name: str | None,
107+
model: SteerableModel,
108+
grammar: str,
109+
) -> LogitsProcessorType:
110+
"""Create a logits processor from a context-free grammar.
111+
112+
Parameters
113+
----------
114+
backend_name: str | None
115+
The name of the backend to use.
116+
model: Model
117+
The Outlines model of the user.
118+
grammar: str
119+
The context-free grammar to create a logits processor from.
120+
121+
Returns
122+
-------
123+
LogitsProcessorType
124+
The logits processor.
125+
126+
"""
127+
backend = _get_backend(
128+
backend_name or CFG_DEFAULT_BACKEND,
129+
model,
130+
)
131+
return backend.get_cfg_logits_processor(grammar)
132+
133+
134+
def get_fsm_logits_processor(
135+
backend_name: str | None,
136+
model: SteerableModel,
137+
fsm: interegular,
138+
) -> LogitsProcessorType:
139+
"""Create a logits processor from an interegular FSM.
140+
141+
Parameters
142+
----------
143+
backend_name: str | None
144+
The name of the backend to use.
145+
model: Model
146+
The Outlines model of the user.
147+
fsm: interegular.fsm.FSM
148+
The interegular FSM to create a logits processor from.
149+
150+
Returns
151+
-------
152+
LogitsProcessorType
153+
The logits processor.
154+
155+
"""
156+
backend = _get_backend(
157+
backend_name or FSM_DEFAULT_BACKEND,
158+
model,
159+
)
160+
return backend.get_fsm_logits_processor(fsm)

outlines/backends/base.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Base class for all backends."""
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Any
5+
6+
from interegular.fsm import FSM
7+
8+
9+
LogitsProcessorType = Any
10+
11+
12+
class BaseBackend(ABC):
13+
"""Base class for all backends.
14+
15+
The subclasses must implement methods that create a logits processor
16+
from a JSON schema, regex, CFG or FSM.
17+
18+
"""
19+
20+
@abstractmethod
21+
def get_json_schema_logits_processor(
22+
self, json_schema: str
23+
) -> LogitsProcessorType:
24+
"""Create a logits processor from a JSON schema.
25+
26+
Parameters
27+
----------
28+
json_schema: str
29+
The JSON schema to create a logits processor from.
30+
31+
Returns
32+
-------
33+
LogitsProcessorType
34+
The logits processor.
35+
36+
"""
37+
...
38+
39+
@abstractmethod
40+
def get_regex_logits_processor(self, regex: str) -> LogitsProcessorType:
41+
"""Create a logits processor from a regex.
42+
43+
Parameters
44+
----------
45+
regex: str
46+
The regex to create a logits processor from.
47+
48+
Returns
49+
-------
50+
LogitsProcessorType
51+
The logits processor.
52+
53+
"""
54+
...
55+
56+
@abstractmethod
57+
def get_cfg_logits_processor(self, grammar: str) -> LogitsProcessorType:
58+
"""Create a logits processor from a context-free grammar.
59+
60+
Parameters
61+
----------
62+
grammar: str
63+
The context-free grammar to create a logits processor from.
64+
65+
Returns
66+
-------
67+
LogitsProcessorType
68+
The logits processor.
69+
70+
"""
71+
...
72+
73+
@abstractmethod
74+
def get_fsm_logits_processor(self, fsm: FSM) -> LogitsProcessorType:
75+
"""Create a logits processor from an interegular FSM.
76+
77+
Parameters
78+
----------
79+
fsm: interegular.fsm.FSM
80+
The interegular FSM to create a logits processor from.
81+
82+
Returns
83+
-------
84+
LogitsProcessorType
85+
The logits processor.
86+
87+
"""
88+
...

0 commit comments

Comments
 (0)