From 7c8a3a73dc5b575c2db3b076ee71c533a55f9782 Mon Sep 17 00:00:00 2001 From: Benedikt Kantz Date: Tue, 1 Apr 2025 10:37:48 +0200 Subject: [PATCH 1/2] Add support for literals --- .../gbnf_grammar_from_pydantic_models.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py b/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py index 5551cb2..0dd84a9 100644 --- a/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py +++ b/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py @@ -17,6 +17,7 @@ get_args, get_origin, get_type_hints, + Literal ) from docstring_parser import parse @@ -65,6 +66,7 @@ class PydanticDataType(Enum): CUSTOM_CLASS = "custom-class" CUSTOM_DICT = "custom-dict" SET = "set" + LITERAL = "literal" def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str: @@ -99,6 +101,10 @@ def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str: elif get_origin(pydantic_type) is dict: key_type, value_type = get_args(pydantic_type) return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}" + elif get_origin(pydantic_type) is Literal: + literal_types = get_args(pydantic_type) + literal_rules = [map_pydantic_type_to_gbnf(lt) for lt in literal_types] + return f"literal-{'-or-'.join(literal_rules)}" else: return "unknown" @@ -540,6 +546,13 @@ def generate_gbnf_rule_for_type( gbnf_type, rules = generate_gbnf_integer_rules( max_digit=max_digits, min_digit=min_digits ) + elif gbnf_type.startswith("literal-"): + literal_types = get_args(field_type) + literal_types_str = [ + json.dumps(lt).replace('"', '\\"') for lt in literal_types + ] + literal_types_str=[f'"{lt}"' for lt in literal_types_str] + gbnf_type = "|".join(literal_types_str) else: gbnf_type, rules = gbnf_type, [] From 96553d70e857107af9525ead3aec7f4b388bfa67 Mon Sep 17 00:00:00 2001 From: Benedikt Kantz Date: Tue, 1 Apr 2025 10:56:46 +0200 Subject: [PATCH 2/2] simplify logic, fix int types --- .../gbnf_grammar_from_pydantic_models.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py b/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py index 0dd84a9..5c38226 100644 --- a/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py +++ b/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py @@ -17,7 +17,7 @@ get_args, get_origin, get_type_hints, - Literal + Literal, ) from docstring_parser import parse @@ -101,10 +101,6 @@ def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str: elif get_origin(pydantic_type) is dict: key_type, value_type = get_args(pydantic_type) return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}" - elif get_origin(pydantic_type) is Literal: - literal_types = get_args(pydantic_type) - literal_rules = [map_pydantic_type_to_gbnf(lt) for lt in literal_types] - return f"literal-{'-or-'.join(literal_rules)}" else: return "unknown" @@ -546,13 +542,11 @@ def generate_gbnf_rule_for_type( gbnf_type, rules = generate_gbnf_integer_rules( max_digit=max_digits, min_digit=min_digits ) - elif gbnf_type.startswith("literal-"): + elif get_origin(field_type) is Literal: literal_types = get_args(field_type) - literal_types_str = [ - json.dumps(lt).replace('"', '\\"') for lt in literal_types - ] - literal_types_str=[f'"{lt}"' for lt in literal_types_str] - gbnf_type = "|".join(literal_types_str) + literal_types_str = [json.dumps(lt).replace('"', '\\"') for lt in literal_types] + literal_types_str = [f'"{lt}"' for lt in literal_types_str] + gbnf_type = f"({'|'.join(literal_types_str)})" else: gbnf_type, rules = gbnf_type, [] @@ -776,9 +770,9 @@ def generate_gbnf_grammar_from_pydantic_models( model, processed_models, created_rules ) if add_request_heartbeat and model.__name__ in request_heartbeat_models: - model_rules[ - 0 - ] += rf' "," ws "\"{request_heartbeat_field_name}\"" ":" ws boolean ' + model_rules[0] += ( + rf' "," ws "\"{request_heartbeat_field_name}\"" ":" ws boolean ' + ) # if not has_special_string: # model_rules[0] += r' ws "}"'