Skip to content

Commit f340eba

Browse files
committed
Draft
1 parent 222ca17 commit f340eba

File tree

16 files changed

+1512
-35
lines changed

16 files changed

+1512
-35
lines changed

outlines/backends/__init__.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Module to define the backends in charge of creating logits processors."""
2+
3+
import interegular
4+
5+
from outlines.backends.base import (
6+
BaseBackend,
7+
LogitsProcessorType,
8+
)
9+
from outlines.backends.llguidance import LLGuidanceBackend
10+
from outlines.backends.outlines_core import OutlinesCoreBackend
11+
from outlines.backends.xgrammar import XGrammarBackend
12+
from outlines.models import SteerableModel
13+
14+
15+
CFG_DEFAULT_BACKEND = "llguidance"
16+
FSM_DEFAULT_BACKEND = "outlines_core"
17+
JSON_SCHEMA_DEFAULT_BACKEND = "outlines_core"
18+
REGEX_DEFAULT_BACKEND = "outlines_core"
19+
20+
21+
def _get_backend(backend_name: str, model: SteerableModel) -> BaseBackend:
22+
"""Create a Backend instance.
23+
24+
Parameters
25+
----------
26+
backend_name: str
27+
The name of the backend to get.
28+
model: Model
29+
The Outlines model of the user.
30+
31+
Returns
32+
-------
33+
backend: BaseBackend
34+
The backend instance.
35+
36+
"""
37+
if backend_name == "outlines_core":
38+
return OutlinesCoreBackend(model)
39+
elif backend_name == "xgrammar":
40+
return XGrammarBackend(model)
41+
elif backend_name == "llguidance":
42+
return LLGuidanceBackend(model)
43+
else:
44+
raise ValueError(f"Backend {backend_name} not supported")
45+
46+
47+
def get_json_schema_logits_processor(
48+
backend_name: str | None,
49+
model: SteerableModel,
50+
json_schema: str,
51+
) -> LogitsProcessorType:
52+
"""Create a logits processor from a JSON schema.
53+
54+
Parameters
55+
----------
56+
backend_name: str | None
57+
The name of the backend to use.
58+
model: Model
59+
The Outlines model of the user.
60+
json_schema: str
61+
The JSON schema to create a logits processor from.
62+
63+
Returns
64+
-------
65+
LogitsProcessorType
66+
The logits processor.
67+
68+
"""
69+
backend = _get_backend(
70+
backend_name or JSON_SCHEMA_DEFAULT_BACKEND,
71+
model,
72+
)
73+
return backend.get_json_schema_logits_processor(json_schema)
74+
75+
76+
def get_regex_logits_processor(
77+
backend_name: str | None,
78+
model: SteerableModel,
79+
regex: str,
80+
) -> LogitsProcessorType:
81+
"""Create a logits processor from a regex.
82+
83+
Parameters
84+
----------
85+
backend_name: str | None
86+
The name of the backend to use.
87+
model: Model
88+
The Outlines model of the user.
89+
regex: str
90+
The regex to create a logits processor from.
91+
92+
Returns
93+
-------
94+
LogitsProcessorType
95+
The logits processor.
96+
97+
"""
98+
backend = _get_backend(
99+
backend_name or REGEX_DEFAULT_BACKEND,
100+
model,
101+
)
102+
return backend.get_regex_logits_processor(regex)
103+
104+
105+
def get_cfg_logits_processor(
106+
backend_name: str | None,
107+
model: SteerableModel,
108+
grammar: str,
109+
) -> LogitsProcessorType:
110+
"""Create a logits processor from a context-free grammar.
111+
112+
Parameters
113+
----------
114+
backend_name: str | None
115+
The name of the backend to use.
116+
model: Model
117+
The Outlines model of the user.
118+
grammar: str
119+
The context-free grammar to create a logits processor from.
120+
121+
Returns
122+
-------
123+
LogitsProcessorType
124+
The logits processor.
125+
126+
"""
127+
backend = _get_backend(
128+
backend_name or CFG_DEFAULT_BACKEND,
129+
model,
130+
)
131+
return backend.get_cfg_logits_processor(grammar)
132+
133+
134+
def get_fsm_logits_processor(
135+
backend_name: str | None,
136+
model: SteerableModel,
137+
fsm: interegular,
138+
) -> LogitsProcessorType:
139+
"""Create a logits processor from an interegular FSM.
140+
141+
Parameters
142+
----------
143+
backend_name: str | None
144+
The name of the backend to use.
145+
model: Model
146+
The Outlines model of the user.
147+
fsm: interegular.fsm.FSM
148+
The interegular FSM to create a logits processor from.
149+
150+
Returns
151+
-------
152+
LogitsProcessorType
153+
The logits processor.
154+
155+
"""
156+
backend = _get_backend(
157+
backend_name or FSM_DEFAULT_BACKEND,
158+
model,
159+
)
160+
return backend.get_fsm_logits_processor(fsm)

outlines/backends/base.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Base class for all backends."""
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Any
5+
6+
from interegular.fsm import FSM
7+
8+
9+
LogitsProcessorType = Any
10+
11+
12+
class BaseBackend(ABC):
13+
"""Base class for all backends.
14+
15+
The subclasses must implement methods that create a logits processor
16+
from a JSON schema, regex, CFG or FSM.
17+
18+
"""
19+
20+
@abstractmethod
21+
def get_json_schema_logits_processor(
22+
self, json_schema: str
23+
) -> LogitsProcessorType:
24+
"""Create a logits processor from a JSON schema.
25+
26+
Parameters
27+
----------
28+
json_schema: str
29+
The JSON schema to create a logits processor from.
30+
31+
Returns
32+
-------
33+
LogitsProcessorType
34+
The logits processor.
35+
36+
"""
37+
...
38+
39+
@abstractmethod
40+
def get_regex_logits_processor(self, regex: str) -> LogitsProcessorType:
41+
"""Create a logits processor from a regex.
42+
43+
Parameters
44+
----------
45+
regex: str
46+
The regex to create a logits processor from.
47+
48+
Returns
49+
-------
50+
LogitsProcessorType
51+
The logits processor.
52+
53+
"""
54+
...
55+
56+
@abstractmethod
57+
def get_cfg_logits_processor(self, grammar: str) -> LogitsProcessorType:
58+
"""Create a logits processor from a context-free grammar.
59+
60+
Parameters
61+
----------
62+
grammar: str
63+
The context-free grammar to create a logits processor from.
64+
65+
Returns
66+
-------
67+
LogitsProcessorType
68+
The logits processor.
69+
70+
"""
71+
...
72+
73+
@abstractmethod
74+
def get_fsm_logits_processor(self, fsm: FSM) -> LogitsProcessorType:
75+
"""Create a logits processor from an interegular FSM.
76+
77+
Parameters
78+
----------
79+
fsm: interegular.fsm.FSM
80+
The interegular FSM to create a logits processor from.
81+
82+
Returns
83+
-------
84+
LogitsProcessorType
85+
The logits processor.
86+
87+
"""
88+
...

outlines/backends/llguidance.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Backend class for LLGuidance."""
2+
3+
from typing import TYPE_CHECKING
4+
5+
from outlines.backends.base import BaseBackend
6+
from outlines.models import LlamaCpp, MLXLM, SteerableModel, Transformers
7+
from outlines.processors.llguidance import LLGuidanceLogitsProcessor
8+
9+
if TYPE_CHECKING:
10+
from llguidance import LLGTokenizer
11+
12+
13+
class LLGuidanceBackend(BaseBackend):
14+
"""Backend for LLGuidance."""
15+
16+
def __init__(self, model: SteerableModel):
17+
"""
18+
Parameters
19+
----------
20+
model
21+
The Outlines model of the user.
22+
23+
"""
24+
import llguidance as llg
25+
26+
self.llg = llg
27+
self.tensor_library_name = model.tensor_library_name
28+
self.llg_tokenizer = self._create_llg_tokenizer(model)
29+
30+
def _create_llg_tokenizer(self, model: SteerableModel) -> "LLGTokenizer":
31+
"""Create an llg tokenizer from the Outlines model's tokenizer.
32+
33+
Parameters
34+
----------
35+
model: Model
36+
The Outlines model.
37+
38+
Returns
39+
-------
40+
LLGTokenizer
41+
The llg tokenizer.
42+
43+
"""
44+
if isinstance(model, Transformers):
45+
import llguidance.hf
46+
47+
return llguidance.hf.from_tokenizer(model.hf_tokenizer)
48+
49+
elif isinstance(model, LlamaCpp):
50+
import llama_cpp
51+
import llguidance.llamacpp
52+
53+
vocab = llama_cpp.llama_model_get_vocab(model.model.model)
54+
return llguidance.llamacpp.lltokenizer_from_vocab(vocab)
55+
56+
elif isinstance(model, MLXLM):
57+
import llguidance.hf
58+
59+
return llguidance.hf.from_tokenizer(
60+
model.mlx_tokenizer._tokenizer
61+
)
62+
63+
else: # pragma: no cover
64+
raise ValueError(
65+
f"Unsupported model type: {type(model)}. "
66+
"Llguidance only supports LlamaCpp, MLXLM "
67+
"and Transformers models."
68+
)
69+
70+
def get_json_schema_logits_processor(
71+
self, json_schema: str
72+
) -> LLGuidanceLogitsProcessor:
73+
"""Create a logits processor from a JSON schema.
74+
75+
Parameters
76+
----------
77+
json_schema: str
78+
The JSON schema to create a logits processor from.
79+
80+
Returns
81+
-------
82+
LogitsProcessor
83+
The logits processor to use to constrain the generation.
84+
85+
"""
86+
grammar_spec = self.llg.grammar_from("json_schema", json_schema)
87+
return LLGuidanceLogitsProcessor(
88+
grammar_spec, self.llg_tokenizer, self.tensor_library_name
89+
)
90+
91+
def get_regex_logits_processor(
92+
self, regex: str
93+
) -> LLGuidanceLogitsProcessor:
94+
"""Create a logits processor from a regex.
95+
96+
Parameters
97+
----------
98+
regex: str
99+
The regex to create a logits processor from.
100+
101+
Returns
102+
-------
103+
LogitsProcessor
104+
The logits processor to use to constrain the generation.
105+
106+
"""
107+
grammar_spec = self.llg.grammar_from("regex", regex)
108+
return LLGuidanceLogitsProcessor(
109+
grammar_spec, self.llg_tokenizer, self.tensor_library_name
110+
)
111+
112+
def get_cfg_logits_processor(
113+
self, grammar: str
114+
) -> LLGuidanceLogitsProcessor:
115+
"""Create a logits processor from a context-free grammar.
116+
117+
Parameters
118+
----------
119+
grammar: str
120+
The context-free grammar to create a logits processor from.
121+
122+
Returns
123+
-------
124+
LogitsProcessor
125+
The logits processor to use to constrain the generation.
126+
127+
"""
128+
grammar_spec = self.llg.grammar_from("lark", grammar)
129+
return LLGuidanceLogitsProcessor(
130+
grammar_spec, self.llg_tokenizer, self.tensor_library_name
131+
)
132+
133+
def get_fsm_logits_processor(self, fsm):
134+
raise NotImplementedError(
135+
"LLGuidanceBackend does not support FSM logits processors. "
136+
"Use the outlines_core backend instead."
137+
)

0 commit comments

Comments
 (0)