diff --git a/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py b/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py index 5551cb2..5c38226 100644 --- a/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py +++ b/src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py @@ -17,6 +17,7 @@ get_args, get_origin, get_type_hints, + Literal, ) from docstring_parser import parse @@ -65,6 +66,7 @@ class PydanticDataType(Enum): CUSTOM_CLASS = "custom-class" CUSTOM_DICT = "custom-dict" SET = "set" + LITERAL = "literal" def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str: @@ -540,6 +542,11 @@ def generate_gbnf_rule_for_type( gbnf_type, rules = generate_gbnf_integer_rules( max_digit=max_digits, min_digit=min_digits ) + elif get_origin(field_type) is Literal: + literal_types = get_args(field_type) + literal_types_str = [json.dumps(lt).replace('"', '\\"') for lt in literal_types] + literal_types_str = [f'"{lt}"' for lt in literal_types_str] + gbnf_type = f"({'|'.join(literal_types_str)})" else: gbnf_type, rules = gbnf_type, [] @@ -763,9 +770,9 @@ def generate_gbnf_grammar_from_pydantic_models( model, processed_models, created_rules ) if add_request_heartbeat and model.__name__ in request_heartbeat_models: - model_rules[ - 0 - ] += rf' "," ws "\"{request_heartbeat_field_name}\"" ":" ws boolean ' + model_rules[0] += ( + rf' "," ws "\"{request_heartbeat_field_name}\"" ":" ws boolean ' + ) # if not has_special_string: # model_rules[0] += r' ws "}"'