diff --git a/.vscode/debug/config.toml b/.vscode/debug/config.toml new file mode 100644 index 0000000..4686406 --- /dev/null +++ b/.vscode/debug/config.toml @@ -0,0 +1,47 @@ +ai_model = "openai:gpt-5-nano" +temp_auto_clean = true +#temp_file_dir = "temp" +allow_successful = false +loading_hints = false +source_code_format = "full" +dev_mode = true + +[ai_custom."llama3.2:3b"] +server_type = "ollama" +url = "localhost:11434" +max_tokens = 100000 + +[ai_custom."llama3.3:latest"] +server_type = "ollama" +url = "localhost:11435" +max_tokens = 100000 + +[solution] +filenames = ["samples/threading.c"] + +[verifier] +name = "esbmc" + +[verifier.esbmc] +path = "~/.local/bin/esbmc" +params = [ + "--interval-analysis", + "--memory-leak-check", + "--goto-unwind", + "--unlimited-goto-unwind", + "--k-induction", + "--state-hashing", + "--add-symex-value-sets", + "--k-step", + "2", + "--floatbv", + "--unlimited-k-steps", + "--context-bound", + "2", +] +output_type = "full" +timeout = 60 + +[llm_requests] +max_tries = 5 +timeout = 60 diff --git a/.vscode/launch.json b/.vscode/launch.json index 29a9500..5c1455b 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,6 +4,21 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { + "name": "Run on Open File", + "type": "debugpy", + "request": "launch", + "module": "esbmc_ai", + "justMyCode": true, + "cwd": "${workspaceFolder}", + "env": { + "ESBMCAI_CONFIG_FILE": "${workspaceFolder}/.vscode/debug/config.toml" + }, + "args": [ + "-vvv", + "debug-view-config" + ] + }, { "name": "Fix Code on Open File", "type": "debugpy", diff --git a/config.toml b/config.toml index f30949e..1224858 100644 --- a/config.toml +++ b/config.toml @@ -1,4 +1,15 @@ -ai_model = "gpt-3.5-turbo" +# ESBMC-AI Default Config File. +# The ESBMC-AI config is loaded in the following order: +# 1. CLI +# 2. env +# 3. dotenv +# 4. TOML +# 5. file_secret +# This file is provided as an example, however it should be customized to your +# needs as it probably contains values that are wrong for your system. +# To view all the possible config options and their description, run the +# subcommand "list-config". +ai_model = "openai:gpt-5-nano" temp_auto_clean = true #temp_file_dir = "temp" allow_successful = false @@ -31,28 +42,3 @@ timeout = 60 [llm_requests] max_tries = 5 timeout = 60 - -# FIX CODE - -[fix_code] -temperature = 0.7 -max_attempts = 5 -message_history = "normal" - -[fix_code.prompt_templates.base] -initial = "The ESBMC output is:\n\n```\n$esbmc_output\n```\n\nThe source code is:\n\n```c\n$source_code\n```\n Using the ESBMC output, show the fixed text." - -[[fix_code.prompt_templates.base.system]] -role = "System" -content = "From now on, act as an Automated Code Repair Tool that repairs AI C code. You will be shown AI C code, along with ESBMC output. Pay close attention to the ESBMC output, which contains a stack trace along with the type of error that occurred and its location that you need to fix. Provide the repaired C code as output, as would an Automated Code Repair Tool. Aside from the corrected source code, do not output any other text." - -[fix_code.prompt_templates."division by zero"] -initial = "The ESBMC output is:\n\n```\n$esbmc_output\n```\n\nThe source code is:\n\n```c\n$source_code\n```\n Using the ESBMC output, show the fixed text." - -[[fix_code.prompt_templates."division by zero".system]] -role = "System" -content = "Here's a C program with a vulnerability:\n```c\n$source_code\n```\nA Formal Verification tool identified a division by zero issue:\n$esbmc_output\nTask: Modify the C code to safely handle scenarios where division by zero might occur. The solution should prevent undefined behavior or crashes due to division by zero. \nGuidelines: Focus on making essential changes only. Avoid adding or modifying comments, and ensure the changes are precise and minimal.\nGuidelines: Ensure the revised code avoids undefined behavior and handles division by zero cases effectively.\nGuidelines: Implement safeguards (like comparison) to prevent division by zero instead of using literal divisions like 1.0/0.0.Output: Provide the corrected, complete C code. The solution should compile and run error-free, addressing the division by zero vulnerability.\nStart the code snippet with ```c and end with ```. Reply OK if you understand." - -[[fix_code.prompt_templates."division by zero".system]] -role = "AI" -content = "OK." diff --git a/esbmc_ai/__init__.py b/esbmc_ai/__init__.py index 2737bb4..27e3c57 100644 --- a/esbmc_ai/__init__.py +++ b/esbmc_ai/__init__.py @@ -4,21 +4,19 @@ from esbmc_ai.__about__ import __version__, __author__ -from esbmc_ai.config_field import ConfigField from esbmc_ai.config import Config from esbmc_ai.chat_command import ChatCommand -from esbmc_ai.base_component import BaseComponent +from esbmc_ai.base_component import BaseComponent, BaseComponentConfig from esbmc_ai.verifiers import BaseSourceVerifier -from esbmc_ai.ai_models import AIModel, AIModels -from esbmc_ai.log_utils import LogCategories +from esbmc_ai.ai_models import AIModel +from esbmc_ai.log_categories import LogCategories __all__ = [ - "ConfigField", "Config", "BaseComponent", + "BaseComponentConfig", "ChatCommand", "BaseSourceVerifier", "AIModel", - "AIModels", "LogCategories", ] diff --git a/esbmc_ai/__main__.py b/esbmc_ai/__main__.py index 05fdfe4..c4d1879 100755 --- a/esbmc_ai/__main__.py +++ b/esbmc_ai/__main__.py @@ -2,92 +2,26 @@ # Author: Yiannis Charalambous 2023 +import logging import sys -import readline import argparse -from typing import Any +from pydantic_settings import CliApp, CliSettingsSource from structlog import get_logger from structlog.stdlib import BoundLogger from esbmc_ai import Config, ChatCommand, __author__, __version__ from esbmc_ai.addon_loader import AddonLoader from esbmc_ai.command_result import CommandResult -from esbmc_ai.log_utils import LogCategories +from esbmc_ai.log_utils import LogCategories, get_log_level, init_logging from esbmc_ai.verifiers.base_source_verifier import BaseSourceVerifier from esbmc_ai.verifiers.esbmc import ESBMC -from esbmc_ai.component_loader import ComponentLoader +from esbmc_ai.component_manager import ComponentManager import esbmc_ai.commands -# Enables arrow key functionality for input(). Do not remove import. -_ = readline - -HELP_MESSAGE: str = ( - "Automated Program Repair platform. To view help on subcommands, run with " - 'the subcommand "help".' -) - -_default_command_name: str = "help" - - -def _init_builtin_components() -> None: - """Initializes the builtin verifiers and commands.""" - # Built-in verifiers - esbmc = ESBMC.create() - assert isinstance(esbmc, BaseSourceVerifier) - ComponentLoader().add_verifier(esbmc) - # Init built-in commands - commands: list[ChatCommand] = [] - for cmd_classname in getattr(esbmc_ai.commands, "__all__"): - cmd_type: type = getattr(esbmc_ai.commands, cmd_classname) - assert issubclass(cmd_type, ChatCommand), f"{cmd_type} is not a ChatCommand" - cmd: object = cmd_type.create() - assert isinstance(cmd, ChatCommand) - cmd.config = Config() - commands.append(cmd) - ComponentLoader().set_builtin_commands(commands) - - -def _init_args( - parser: argparse.ArgumentParser, - map_field_names: dict[str, list[str]], - ignore_fields: dict[str, list[str]], -) -> None: - """Initializes the Config's ConfigFields to be accepted by the argument - parser, this allows all ConfigFields to be loaded as arguments. The Config - will then use the map_field_names to automatically pre-load values before - loading the rest of the config. - - Args: - - parser: The parser to add the arguments into. - - map_field_names: The parser will map the config fields to use - alternative names. - - ignore_fields: Dictionary of field names to not encode automatically. - This takes precedence over map_fields. The field that matches the key - in this dictionary will not be mapped. It is a dictionary because they - can optionally be manually initialized and mapped in the Config, so it - is worth keeping track of the aliases.""" - - parser.add_argument( - "command", - type=str, - nargs="?", - default=_default_command_name, - help=( - "The command to run using the program. To see addon commands " - "available: Run with 'help' as the default command." - ), - ) - - parser.add_argument( - *ignore_fields["solution.filenames"], - type=str, - # nargs=argparse.REMAINDER, - nargs=argparse.ZERO_OR_MORE, - help="The filename(s) to pass to the verifier.", - ) +def _init_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--version", action="version", @@ -107,142 +41,130 @@ def _init_args( ), ) - # Init arg groups - arg_groups: dict[str, argparse._ArgumentGroup] = {} - for f in Config().get_config_fields(): - name_split: list[str] = f.name.split(".") - if len(name_split) == 1: - continue # No group - arg_groups[name_split[0]] = parser.add_argument_group( - title=name_split[0], - ) - for f in Config().get_config_fields(): - if f.name in ignore_fields: - continue +def _load_config( + parser: argparse.ArgumentParser, +) -> Config: + # Parse args to get verbose level before Pydantic CLI parsing + args, _ = parser.parse_known_args() + verbose_level: int = min(args.verbose, 3) if hasattr(args, "verbose") else 0 + + # Create custom CLI settings source using our argparse parser + # This allows us to use custom arguments like -v/--verbose with action='count' + cli_settings: CliSettingsSource = CliSettingsSource( + Config, cli_parse_args=True, root_parser=parser + ) + + # Use CliApp.run with custom CLI settings source + # settings_customise_sources will handle TOML/env/dotenv loading process + config: Config = CliApp.run(Config, cli_settings_source=cli_settings) - # Get either the general parser or a group parser - arg_parser: argparse._ArgumentGroup | argparse.ArgumentParser = parser - name_split: list[str] = f.name.split(".") - if len(name_split) > 1: - arg_parser = arg_groups[name_split[0]] + # Set argparse config values - These fields have exclude=True + config.verbose_level = verbose_level - # Get the name that will be shown. - mappings: list[str] = ( - map_field_names[f.name] if f.name in map_field_names else [f.name] - ) - # Add single or double dash. Change _ into -, argparse will automatically - # convert into _ anyway. - mappings = [ - f"-{m}" if len(m) == 1 else f"--{m.replace("_", "-")}" for m in mappings - ] - - action: Any - match f.default_value: - case bool(): - action = "store_true" - case _: - action = "store" - - # Set type - try: - # Type is only accepted when the action is not some specific values. - kwargs = {} - if action not in ( - "store_true", - "store_false", - "append_const", - "count", - "help", - "version", - ): - # If None then it will only accept None values so basically useless - kwargs["type"] = ( - str if f.default_value is None else type(f.default_value) - ) - - # Create the argument. - arg_parser.add_argument( - *mappings, - action=action, - required=False, - # Will not show up in the Namespace - default=argparse.SUPPRESS, - help=f.help_message, - **kwargs, - ) - except TypeError as e: - get_logger().critical(f"Failed to encode config into args: {f.name}") - raise e + return config + + +def _init_builtin_components() -> None: + """Initializes the builtin verifiers and commands.""" + component_manager = ComponentManager() + + # Built-in verifiers + esbmc = ESBMC.create() + assert isinstance(esbmc, BaseSourceVerifier) + component_manager.add_verifier(esbmc) + # Load component-specific configuration + component_manager.load_component_config(esbmc, builtin=True) + + # Init built-in commands - Loads everything in the esbmc_ai.commands module. + commands: list[ChatCommand] = [] + for cmd_classname in getattr(esbmc_ai.commands, "__all__"): + cmd_type: type = getattr(esbmc_ai.commands, cmd_classname) + assert issubclass(cmd_type, ChatCommand), f"{cmd_type} is not a ChatCommand" + cmd: object = cmd_type.create() + assert isinstance(cmd, ChatCommand) + cmd.global_config = Config() + # Load component-specific configuration + component_manager.load_component_config(cmd, builtin=True) + commands.append(cmd) + + component_manager.set_builtin_commands(commands) + + +def _init_logging() -> None: + # Add logging handlers with config options + config = Config() + logging_handlers: list[logging.Handler] = config.log.logging_handlers + + # Reinit logging + init_logging( + level=get_log_level(config.verbose_level), + file_handlers=logging_handlers, + init_basic=config.log.basic, + ) def main() -> None: """Entry point function""" parser: argparse.ArgumentParser = argparse.ArgumentParser( prog="ESBMC-AI", - description=HELP_MESSAGE, + description="""Automated Program Repair platform. To view help on subcommands, run with the subcommand "help". + +Configuration Precedence (highest to lowest): + * CLI args > Environment variables > .env file > TOML config > Defaults + * NOTE: Setting nested values through environment variables and files is currently not supported (https://github.com/esbmc/esbmc-ai/issues/229)""", epilog=f"Made by {__author__}", + formatter_class=argparse.RawDescriptionHelpFormatter, ) - # Will rename these config fields - arg_mappings: dict[str, list[str]] = { - "solution.entry_function": ["entry-function"], - "ai_model": ["m", "ai-model"], - "solution.output_dir": ["o", "output-dir"], - "log.output": ["log-output"], - "log.by_cat": ["log-by-cat"], - "log.by_name": ["log-by-name"], - } - - # Will not expose these to the arguments. - manual_mappings: dict[str, list[str]] = { - "solution.filenames": ["filenames"], - "ai_custom": ["ai_custom"], # Block - } - - _init_args( - parser=parser, - map_field_names=arg_mappings, - ignore_fields=manual_mappings, - ) + _init_args(parser=parser) - args: argparse.Namespace = parser.parse_args() + config: Config + config = _load_config(parser=parser) + # Set the config singleton + Config.set_singleton(config) + config: Config = Config() print(f"ESBMC-AI {__version__}") print(f"Made by {__author__}") print() - Config().load( - args=args, - arg_mapping_overrides=arg_mappings | manual_mappings, - compound_load_args=[v for values in manual_mappings.values() for v in values], - ) - logger: BoundLogger = get_logger().bind(category=LogCategories.SYSTEM) + _init_logging() - logger.info(f"Config File: {Config().get_value("ESBMCAI_CONFIG_FILE")}") + logger: BoundLogger = get_logger().bind(category=LogCategories.SYSTEM) + logger.debug("Global config loaded successfully") + logger.debug("Initialized logging") _init_builtin_components() + logger.debug("Builtin components loaded successfully") - if Config().get_value("dev_mode"): + if config.dev_mode: logger.warn("Development Mode Activated") # Load addons - AddonLoader(Config()) + addon_loader: AddonLoader = AddonLoader(config) + logger.debug("Addon components loaded successfully") + logger.info("Configuration loaded successfully") # Bind addons to component loader. - ComponentLoader().addon_commands.update(AddonLoader().chat_command_addons) - ComponentLoader().verifiers.update(AddonLoader().verifier_addons) - ComponentLoader().set_verifier_by_name(Config().get_value("verifier.name")) + cm = ComponentManager() + for command in addon_loader.chat_command_addons.values(): + cm.add_command(command, builtin=False) + + for verifier in addon_loader.verifier_addons.values(): + cm.add_verifier(verifier, builtin=False) + + cm.set_verifier_by_name(config.verifier.name) # Run the command - command_name = args.command - command_names: list[str] = ComponentLoader().command_names + command_name = config.command_name + command_names: list[str] = cm.command_names if command_name in command_names: logger.info(f"Running Command: {command_name}\n") - command: ChatCommand = ComponentLoader().commands[command_name] - result: CommandResult | None = command.execute(kwargs=vars(args)) + command: ChatCommand = cm.commands[command_name] + result: CommandResult | None = command.execute(kwargs=vars(config)) if result: - if Config().get_value("json"): + if config.use_json: print(vars(result)) else: print(result) diff --git a/esbmc_ai/addon_loader.py b/esbmc_ai/addon_loader.py index 6b84c67..f606395 100644 --- a/esbmc_ai/addon_loader.py +++ b/esbmc_ai/addon_loader.py @@ -2,34 +2,20 @@ """This module contains code regarding configuring and loading addon modules.""" -import inspect -from typing import Any import traceback import sys import importlib -from importlib.util import find_spec -from importlib.machinery import ModuleSpec import structlog -from typing_extensions import Optional from esbmc_ai.base_component import BaseComponent from esbmc_ai.chat_command import ChatCommand -from esbmc_ai.component_loader import ComponentLoader from esbmc_ai.verifiers.base_source_verifier import BaseSourceVerifier -from esbmc_ai.config_field import ConfigField from esbmc_ai.config import Config from esbmc_ai.singleton import SingletonMeta class AddonLoader(metaclass=SingletonMeta): - """The addon loader manages loading addon modules. This includes: - * Managing the config fields of the addons. - * Dynamically loading the fields when the addons request them. - - When an addon requests a config value from an addon that is not loaded, that - addon's config fields get loaded. This means that addons will have dependency - management (as long as there's no loops). - """ + """The addon loader manages loading addon modules and initializing them.""" addon_prefix: str = "addons" @@ -41,34 +27,20 @@ def __init__(self, config: Config | None = None) -> None: assert config self._config: Config = config - self._config.on_load_value.append(self._on_get_config_value) # Keeps track of the addons that have been loaded. self._loaded_addons: dict[str, BaseComponent] = {} - # Keeps track of the addons that had their config fields initialized. - self._initialized_addons: set[BaseComponent] = set() # Dev mode: Ensure the current directory is in sys.path in order for # relative addon modules to be imported (used for dev purposes). - if self._config.get_value("dev_mode") and "" not in sys.path: + if self._config.dev_mode and "" not in sys.path: sys.path.insert(0, "") - # Register field with Config to know which modules to load. - config.load_config_field( - ConfigField( - name="addon_modules", - default_value=[], - validate=self._validate_addon_modules, - error_message="couldn't find module: must be a list of Python modules to load", - help_message="The addon modules to load during startup. Additional " - "modules may be loaded by the specified modules as dependencies.", - ), - ) - # Load the config fields. - if self._config.get_value("addon_modules"): + if self._config.addon_modules: print("Loading Addons:") - for m in self._config.get_value("addon_modules"): + + for m in self._config.addon_modules: addons: list[BaseComponent] = self.load_addons_module(m) for addon in addons: print(f"\t* {addon.name} by {addon.authors}") @@ -113,37 +85,6 @@ def verifier_addon_names(self) -> list[str]: def loaded_addons(self) -> dict[str, BaseComponent]: return self._loaded_addons - @staticmethod - def _validate_addon_modules(mods: str) -> bool: - """Validates that a module exists.""" - for m in mods: - if not isinstance(m, str): - return False - spec: Optional[ModuleSpec] = find_spec(m) - if spec is None: - return False - return True - - def _on_get_config_value(self, name: str) -> None: - """Checks if an addon's values were loaded prior to it requesting an - a value. If they weren't load them. This allows for addon ConfigFields - to be loaded dynamically. - - Will only check for ConfigFields that have a name that is prefixed with - "addon.".""" - # Check if it references an addon and is not loaded. - split_key: list[str] = name.split(".") - if split_key[0] == AddonLoader.addon_prefix and len(split_key) > 1: - # Check if the addon is valid. - addon: BaseComponent | None = self.get_addon_by_name(split_key[1]) - if addon in self._initialized_addons: - return - elif addon: - self._load_addon_config(addon) - self._initialized_addons.add(addon) - else: - raise KeyError(f"AddonLoader: failed to load config of addon {name}") - def get_addon_by_name(self, name: str) -> BaseComponent | None: """Returns an addon by the addon name defined in __init__.""" for addon in self.loaded_addons.values(): @@ -151,44 +92,6 @@ def get_addon_by_name(self, name: str) -> BaseComponent | None: return addon return None - def _get_addon_resolved_fields(self, addon: BaseComponent) -> list[ConfigField]: - """Returns the addons's config fields. After resolving each field to - their namespace using the name of the component.""" - # If an addon prefix is defined, then add a . - addons_prefix: str = ( - AddonLoader.addon_prefix + "." if AddonLoader.addon_prefix else "" - ) - - fields_resolved: list[ConfigField] = [] - # Loop through each field of the verifier addon - fields: list[ConfigField] = addon.get_config_fields() - for field in fields: - new_field: ConfigField = self._resolve_config_field( - field, f"{addons_prefix}{addon.name}" - ) - fields_resolved.append(new_field) - return fields_resolved - - def _resolve_config_field(self, field: ConfigField, prefix: str): - """Resolve the name of each field by prefixing it with the component name. - Returns a new config field with the name resolved to the prefix - supplied. Using inspection all the other fields are copied. The returning - field is exactly the same as the original, aside from the resolved name.""" - - # Inspect the signature of the ConfigField which is a named tuple. - signature = inspect.signature(ConfigField) - params: dict[str, Any] = {} - # Iterate and capture all parameters - for param_name, param in signature.parameters.items(): - _ = param - match param_name: - case "name": - params[param_name] = f"{prefix}.{getattr(field, param_name)}" - case _: - params[param_name] = getattr(field, param_name) - - return ConfigField(**params) - def load_addons_module(self, module_name: str) -> list[BaseComponent]: """Loads an addon, needs to expose a BaseComponent. @@ -223,20 +126,11 @@ def init_base_component(self, t: type[BaseComponent]) -> BaseComponent: self._logger.debug(f"Loading addon: {addon.__class__.__name__}") # Register config with modules - addon.config = self._config + addon.global_config = self._config - return addon + # Load component-specific configuration using ComponentManager + from esbmc_ai.component_manager import ComponentManager - def _load_addon_config(self, addon: BaseComponent) -> None: - """Loads the config fields defined by an addon""" - - # Load config fields - added_field_names: set[ConfigField] = set() - for f in self._get_addon_resolved_fields(addon): - # Add config fields - if f.name in added_field_names: - raise KeyError(f"AddonLoader: field already loaded: {f.name}") - try: - self._config.load_config_field(f) - except Exception: - self._logger.error(f"failed to register config field: {f.name}") + ComponentManager().load_component_config(addon, builtin=False) + + return addon diff --git a/esbmc_ai/ai_models.py b/esbmc_ai/ai_models.py index 15c4e95..e1d88c0 100644 --- a/esbmc_ai/ai_models.py +++ b/esbmc_ai/ai_models.py @@ -1,514 +1,150 @@ # Author: Yiannis Charalambous -from abc import ABC, abstractmethod -from dataclasses import dataclass, field, replace -from datetime import time, timedelta, datetime -from pathlib import Path -from typing import Any, Iterable -from platformdirs import user_cache_dir -from pydantic.types import SecretStr +from typing import Any +from uuid import UUID +from langchain.schema import BaseMessage, LLMResult from typing_extensions import override -import tiktoken import structlog -from anthropic import Anthropic as AnthropicClient -from openai import Client as OpenAIClient +from langchain.chat_models import init_chat_model +from langchain_core.language_models import BaseChatModel +from langchain_core.rate_limiters import InMemoryRateLimiter +from langchain_core.callbacks import CallbackManager, BaseCallbackHandler -from langchain.prompts.chat import ChatPromptValue -from langchain_core.language_models import BaseChatModel, LanguageModelInput -from langchain_core.messages import get_buffer_string -from langchain_openai import ChatOpenAI -from langchain_ollama import ChatOllama -from langchain_anthropic import ChatAnthropic -from langchain.schema import ( - BaseMessage, - PromptValue, -) +from esbmc_ai.config import Config from esbmc_ai.log_utils import LogCategories -from esbmc_ai.singleton import SingletonMeta -@dataclass(frozen=True, kw_only=True) -class AIModel(ABC): - """This base class represents an abstract AI model. Each AIModel has the - required properties to invoke the underlying langchain implementation - BaseChatModel. To configure the properties, call the bind method and set - them.""" +class LoggingCallbackHandler(BaseCallbackHandler): + """Invoke callback handler is used to print debug messages to the LLM.""" - name: str - tokens: int - temperature: float = 1.0 - requests_max_tries: int = 5 - requests_timeout: float = 60 - _llm: BaseChatModel | None = field(default=None) - - def __post_init__(self): - object.__setattr__(self, "_llm", self.create_llm()) - - @abstractmethod - def create_llm(self) -> BaseChatModel: - """Initializes a langchain BaseChatModel with the provided parameters. - Used internally by low-level functions. Bind should be used for - AIModels.""" - raise NotImplementedError() - - def bind(self, **kwargs: Any) -> "AIModel": - """Returns a new model with new parameters.""" - new_ai_model: AIModel = replace(self, **kwargs) - llm: BaseChatModel = new_ai_model.create_llm() - return replace(new_ai_model, _llm=llm) - - def invoke(self, input: LanguageModelInput, **kwargs: Any) -> BaseMessage: - """Invokes the underlying BaseChatModel implementation and returns the - message.""" - if not self._llm: - raise ValueError("LLM is not initialized, call bind.") - return self._llm.invoke(input, **kwargs) - - def get_num_tokens(self, content: str) -> int: - """Gets the number of tokens for this AI model.""" - if not self._llm: - raise ValueError("LLM is not initialized, call bind.") - return self._llm.get_num_tokens(content) - - def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int: - """Gets the number of tokens for this AI model for a list of messages.""" - if not self._llm: - raise ValueError("LLM is not initialized, call bind.") - return self._llm.get_num_tokens_from_messages(messages) - - @classmethod - def convert_messages_to_tuples( - cls, messages: Iterable[BaseMessage] - ) -> list[tuple[str, str]]: - """Converts messages into a format understood by the ChatPromptTemplate, - since it won't format BaseMessage derived classes for some reason, but - will for tuples, because they get converted into Templates in function - `_convert_to_message`.""" - return [(message.type, str(message.content)) for message in messages] - - @classmethod - def safe_substitute(cls, content: str, **values: Any) -> str: - """Safe template substitution. Replaces $var with provided values, - leaves undefined $vars unchanged.""" - from string import Template - - template = Template(content) - return template.safe_substitute(**values) - - def apply_chat_template( - self, - messages: Iterable[BaseMessage], - **format_values: Any, - ) -> PromptValue: - """Applies the formatted values onto the message chat template. For example, - if the message contains the token $source, then format_values contains a - value for source then it will be substituted.""" - result_messages = [] - for msg in messages: - content = AIModel.safe_substitute(str(msg.content), **format_values) - new_msg = msg.model_copy() - new_msg.content = content - result_messages.append(new_msg) - return ChatPromptValue(messages=result_messages) - - def apply_str_template( - self, - text: str, - **format_values: Any, - ) -> str: - """Applies the formatted values onto a string template. For example, - if the message contains the token $source, then format_values contains a - value for source then it will be substituted.""" - return AIModel.safe_substitute(text, **format_values) - - -@dataclass(frozen=True, kw_only=True) -class AIModelService(AIModel): - """Represents an AI model from a service.""" - - api_key: str = "" - - @staticmethod - def _get_max_tokens(name: str, token_groups: dict[str, int]) -> int: - """Dynamically resolves the max tokens from a base model.""" - - # Split into - segments and remove each section from the end to find out - # which one matches the most. - - # Base Case - if name in token_groups: - return token_groups[name] - - # Step Case - name_split: list[str] = name.split("-") - for i in range(1, name.count("-")): - subname: str = "-".join(name_split[:-i]) - if subname in token_groups: - return token_groups[subname] - - raise ValueError(f"Could not figure out max tokens for model: {name}") - - @classmethod - @abstractmethod - def get_models_list(cls, api_key: str) -> list[str]: - """Get available models from the API service.""" - raise NotImplementedError() - - @classmethod - @abstractmethod - def create_model(cls, name: str) -> "AIModel": - """Create an AI model instance from model name.""" - raise NotImplementedError() - - @classmethod - @abstractmethod - def get_cache_filename(cls) -> str: - """Get the cache filename for this service.""" - raise NotImplementedError() - - @classmethod - @abstractmethod - def get_canonical_name(cls) -> str: - """Get the canonical name of this service for dictionary access (lowercase).""" - raise NotImplementedError() - - -@dataclass(frozen=True, kw_only=True) -class AIModelOpenAI(AIModelService): - """OpenAI model.""" - - requests_max_tries: int = 5 - requests_timeout: float = 60 - - @property - def _reason_model(self) -> bool: - if "o3-mini" in self.name: - return True - return False - - @override - def create_llm(self) -> BaseChatModel: - kwargs = {} - if self.api_key: - kwargs["api_key"] = (SecretStr(self.api_key) or None,) - - return ChatOpenAI( - model=self.name, - temperature=None if self._reason_model else self.temperature, - reasoning_effort="high" if self._reason_model else None, - max_retries=self.requests_max_tries, - timeout=self.requests_timeout, - model_kwargs={}, - **kwargs, - ) - - @override - def get_num_tokens(self, content: str) -> int: - encoding: tiktoken.Encoding = tiktoken.encoding_for_model(self.name) - return len(encoding.encode(content)) - - @override - def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int: - encoding: tiktoken.Encoding - try: - encoding: tiktoken.Encoding = tiktoken.encoding_for_model(self.name) - except KeyError: - logger: structlog.stdlib.BoundLogger = structlog.get_logger().bind( - category=LogCategories.SYSTEM, - ) - logger.error( - f"TikToken error: failed to map model {self.name} to an " - "encoder. This could possibly be because the model is new and " - "has not been added yet. Make sure you update TikToken." - ) - raise - return sum(len(encoding.encode(get_buffer_string([m]))) for m in messages) - - @classmethod - def get_max_tokens(cls, name: str) -> int: - """Dynamically resolves the max tokens from a base model.""" - # https://platform.openai.com/docs/models - tokens: dict[str, int] = { - "gpt-3.5": 16385, - "gpt-4": 8192, - "gpt-4-turbo": 128000, - "gpt-4.1": 1047576, - "gpt-4.5": 128000, - "gpt-4o": 128000, - "o1": 200000, - "o3": 200000, - "o4-mini": 200000, - } - return cls._get_max_tokens(name, tokens) - - @classmethod - def get_models_list(cls, api_key: str) -> list[str]: - """Get available models from the OpenAI API service.""" - if not api_key: - return [] - try: - return [ - str(model.id) - for model in OpenAIClient(api_key=api_key).models.list().data - ] - except ImportError: - return [] - - @classmethod - def create_model(cls, name: str) -> "AIModel": - """Create an OpenAI AI model instance from model name.""" - return cls( - name=name.strip(), - tokens=cls.get_max_tokens(name), + def __init__(self, ai_model: str) -> None: + super().__init__() + self.logger: structlog.stdlib.BoundLogger = structlog.get_logger().bind( + category=LogCategories.CHAT, + prefix_name=ai_model, ) - @classmethod - def get_cache_filename(cls) -> str: - """Get the cache filename for OpenAI service.""" - return "openai_models.txt" - - @classmethod - def get_canonical_name(cls) -> str: - """Get the canonical name of the OpenAI service.""" - return "openai" - - -@dataclass(frozen=True, kw_only=True) -class OllamaAIModel(AIModel): - """A model that is running on the Ollama service.""" - - url: str - @override - def create_llm(self) -> BaseChatModel: - return ChatOllama( - base_url=self.url, - model=self.name, - temperature=self.temperature, - client_kwargs={ - "timeout": self.requests_timeout, - }, - ) - + def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + **kwargs: Any, + ) -> Any: + """Run when LLM starts running. + + .. ATTENTION:: + This method is called for non-chat models (regular LLMs). If you're + implementing a handler for a chat model, you should use + ``on_chat_model_start`` instead. -class AIModelAnthropic(AIModelService): - """An AI model that uses the Anthropic service.""" + Args: + serialized (dict[str, Any]): The serialized LLM. + prompts (list[str]): The prompts. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[list[str]]): The tags. + metadata (Optional[dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. + """ + _ = run_id, parent_run_id, tags, metadata, kwargs + self.logger.debug("Invoke LLM") @override - def create_llm(self) -> BaseChatModel: - kwargs = {} - if self.api_key: - kwargs["api_key"] = SecretStr(self.api_key) or None - return ChatAnthropic( # pyright: ignore [reportCallIssue] - model_name=self.name, - temperature=self.temperature, - timeout=self.requests_timeout, - max_retries=self.requests_max_tries, - **kwargs, - ) - - @classmethod - def get_max_tokens(cls, name: str) -> int: - # docs.anthropic.com/en/docs/about-claude/models/overview#model-names - tokens = { - "claude-3": 200000, - "claude-sonnet-4": 200000, - "claude-opus-4": 200000, - } - - return cls._get_max_tokens(name, tokens) + def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> Any: + _ = run_id, parent_run_id, kwargs + self.logger.debug("LLM End") @override - def get_num_tokens(self, content: str) -> int: - # Delete trailing whitespace from last message because of (this might - # change in the future because of their API): - # anthropic.BadRequestError: Error code: 400 - # final assistant content cannot end with trailing whitespace. - return super().get_num_tokens(content.strip()) + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> Any: + _ = run_id, parent_run_id, kwargs + self.logger.debug("LLM Error") @override - def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int: - # Delete trailing whitespace from last message because of (this might - # change in the future because of their API): - # anthropic.BadRequestError: Error code: 400 - # final assistant content cannot end with trailing whitespace. - if messages: - messages[-1].content = str(messages[-1].content).strip() - return super().get_num_tokens_from_messages(messages) - - @classmethod - def get_models_list(cls, api_key: str) -> list[str]: - """Get available models from the Anthropic API service.""" - if not api_key: - return [] - client = AnthropicClient(api_key=api_key) - return [str(i.id) for i in client.models.list()] + [ - # Include also latest models ID not returned by the API but can be - # used to use the latest version of a model. - # docs.anthropic.com/en/docs/about-claude/models/overview#model-aliases - "claude-opus-4-0", - "claude-sonnet-4-0", - "claude-3-7-sonnet-latest", - "claude-3-5-sonnet-latest", - "claude-3-5-haiku-latest", - "claude-3-opus-latest", - ] - - @classmethod - def create_model(cls, name: str) -> "AIModel": - """Create an Anthropic AI model instance from model name.""" - return cls( - name=name.strip(), - tokens=cls.get_max_tokens(name), - ) - - @classmethod - def get_cache_filename(cls) -> str: - """Get the cache filename for Anthropic service.""" - return "anthropic_models.txt" - - @classmethod - def get_canonical_name(cls) -> str: - """Get the canonical name of the Anthropic service.""" - return "anthropic" - - -class AIModels(metaclass=SingletonMeta): - """Manages the loading of AI Models from different sources.""" - - def __init__(self) -> None: - super().__init__() - - self._logger: structlog.stdlib.BoundLogger = structlog.get_logger().bind( - category=LogCategories.SYSTEM, - ) - self._api_keys: dict[str, str] = {} - self._ai_models: dict[str, AIModel] = {} - self._cache_dir.mkdir(parents=True, exist_ok=True) - - def load_default_models( + def on_chat_model_start( self, - api_keys: dict[str, str], - refresh_duration_seconds: int = 86400, - ) -> None: - """Loads the default AI models from OpenAI and Anthropic services. + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running. + + **ATTENTION**: This method is called for chat models. If you're implementing + a handler for a non-chat model, you should use ``on_llm_start`` instead. Args: - api_keys - Dictionary with the canonical names of the models as keys - for API access. If a model does not come with a valid API key, - then it will not load the models. - refresh_duration_seconds - If the refresh duration has passed since - the last update, then the models will be loaded from the API - rather from cache.""" - - self._api_keys = api_keys - - # Load models from each service - services = [AIModelOpenAI, AIModelAnthropic] - for service in services: - api_key = api_keys.get(service.get_canonical_name(), "") - self._load_service_ai_models_list( - service=service, - api_key=api_key, - refresh_duration_seconds=refresh_duration_seconds, - ) - - @property - def model_names(self) -> list[str]: - return list(self._ai_models.keys()) - - @property - def _cache_dir(self) -> Path: - cache: Path = Path(user_cache_dir("esbmc-ai", "Yiannis Charalambous")) - return cache - - def is_valid_ai_model(self, ai_model: str | AIModel) -> bool: - """Returns true if the model exists.""" - - # Get the name of the model - name: str = ai_model.name if isinstance(ai_model, AIModel) else ai_model - - # Use the predefined list of models. - return name in self._ai_models - - @property - def ai_models(self) -> dict[str, AIModel]: - """Gets all loaded AI models""" - return self._ai_models - - def get_ai_model(self, name: str) -> AIModel: - """Checks for built-in and custom_ai models""" - if name in self._ai_models: - return self._ai_models[name] - - raise KeyError(f'The AI "{name}" was not found...') - - def add_ai_model(self, ai_model: AIModel, replace: bool = False) -> None: - """Registers a custom AI model.""" - # Check if AI already already exists. - if ai_model.name in self._ai_models and not replace: - raise KeyError(f'AI Model "{ai_model.name}" already exists...') - - self._ai_models[ai_model.name] = ai_model - - def _load_service_ai_models_list( - self, - service: type[AIModelService], - api_key: str, - refresh_duration_seconds: int, - ) -> None: - """Loads the service model names from cache or refreshes them from the internet.""" - - duration: timedelta = timedelta(seconds=refresh_duration_seconds) - service_name: str = service.get_canonical_name().title() - cache_name: str = service.get_cache_filename() - self._logger.info(f"Loading {service_name} models list") - models_list: list[str] = [] - - # Read the last updated date to determine if a new update is required - try: - last_update, models_list = self._load_cache(cache_name) - # Write new & updated cache file - if datetime.now() >= last_update + duration: - self._logger.info("\tModels list outdated, refreshing...") - models_list = service.get_models_list(api_key) - self._write_cache(cache_name, models_list) - except ValueError as e: - self._logger.error(f"Loading {service_name} models list failed:", e) - self._logger.info("\tCreating new models list.") - models_list = service.get_models_list(api_key) - self._write_cache(cache_name, models_list) - except FileNotFoundError: - self._logger.info("\tModels list not found, creating new...") - models_list = service.get_models_list(api_key) - self._write_cache(cache_name, models_list) - - # Add models that have been loaded. - for model_name in models_list: - try: - self.add_ai_model(service.create_model(model_name), replace=True) - except ValueError as e: - # Ignore models that don't count, like image only models. - self._logger.debug(f"Could not add model: {e}") - pass + serialized (dict[str, Any]): The serialized chat model. + messages (list[list[BaseMessage]]): The messages. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[list[str]]): The tags. + metadata (Optional[dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. + """ + _ = run_id, parent_run_id, tags, metadata, kwargs + self.logger.debug("Invoke Chat Model LLM") + + +class AIModel: + """Loading utils for models.""" + + @classmethod + def get_model( + cls, + *, + model: str, + provider: str | None = None, + temperature: float | None = None, + url: str | None = None, + ) -> BaseChatModel: + chat_model: BaseChatModel = init_chat_model( + model=model, + model_provider=provider, + temperature=temperature, + max_tokens=None, # Use all remaining tokens + base_url=url, + timeout=Config().llm_requests_timeout, + max_retries=Config().llm_requests_max_retries, + rate_limiter=InMemoryRateLimiter( + requests_per_second=10, + check_every_n_seconds=0.1, + max_bucket_size=100, + ), + ) - def _write_cache(self, name: str, models_list: list[str]): - with open(self._cache_dir / name, "w") as file: - file.seek(0) - file.write(datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n") - file.writelines(model_name + "\n" for model_name in models_list) - return models_list + handler: BaseCallbackHandler = LoggingCallbackHandler( + ai_model=f"{provider}:{model}" + ) + # Add the logging handler. + if chat_model.callback_manager: + chat_model.callback_manager.add_handler(handler) + else: + chat_model.callback_manager = CallbackManager([handler]) - def _load_cache(self, path: str) -> tuple[datetime, list[str]]: - cache: Path = Path(user_cache_dir("esbmc-ai", "Yiannis Charalambous")) - cache.mkdir(parents=True, exist_ok=True) - with open(self._cache_dir / path, "r") as file: - last_update: datetime = datetime.strptime( - file.readline().strip(), "%Y-%m-%d %H:%M:%S" - ) - model_names: list[str] = [] - for line in file.readlines(): - model_names.append(line.strip()) - return last_update, model_names + return chat_model diff --git a/esbmc_ai/base_component.py b/esbmc_ai/base_component.py index 20fa951..710b516 100644 --- a/esbmc_ai/base_component.py +++ b/esbmc_ai/base_component.py @@ -5,25 +5,157 @@ import inspect from abc import ABC +import os +from pathlib import Path import re -from typing import Any +from typing import Any, cast, override +from pydantic import FilePath +from pydantic.fields import FieldInfo +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, +) +from pydantic_settings.sources import InitSettingsSource import structlog - -from esbmc_ai.base_config import BaseConfig -from esbmc_ai.config_field import ConfigField +import tomllib + +from esbmc_ai.config import Config + + +class DictConfigSettingsSource(PydanticBaseSettingsSource): + """Custom settings source that loads from a pre-loaded dictionary.""" + + def __init__( + self, settings_cls: type[BaseSettings], config_dict: dict[str, Any] + ) -> None: + super().__init__(settings_cls) + self.config_dict = config_dict + + @override + def get_field_value( + self, field: FieldInfo, field_name: str + ) -> tuple[Any, str, bool]: + """Get field value from the config dictionary.""" + _ = field + if field_name in self.config_dict: + return self.config_dict[field_name], field_name, False + return None, field_name, False + + def __call__(self) -> dict[str, Any]: + """Return the config dictionary.""" + return self.config_dict + + +class BaseComponentConfig(BaseSettings): + """Pydantic BaseSettings preconfigured to be able to load config values. + + Component configs are loaded from the TOML file under the 'addons.' + section by ComponentManager. + + The component name is passed via _component_name kwarg to __init__ and is + extracted in settings_customise_sources to determine the TOML table header. + """ + + # Used to allow loading from cli and env. + model_config = SettingsConfigDict( + env_prefix="ESBMCAI_", + env_file=".env", + env_file_encoding="utf-8", + # Do not parse CLI args in component configs - only the main Config should + cli_parse_args=False, + # Ignore extra fields from .env file that don't match the component schema + extra="ignore", + ) + + def __init__(self, **values: Any) -> None: + """Used to provide static analyzers type annotations so that we don't get + errors in ComponentManager.""" + super().__init__(**values) + + @override + @classmethod + def settings_customise_sources( + cls, + settings_cls: type["BaseSettings"], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + # Get config file path from global Config + # Note: .env is already loaded by the global Config before component + # configs are loaded + config_file_path: FilePath | None = Config().config_file + + sources: list[PydanticBaseSettingsSource] = [ + init_settings, + env_settings, + dotenv_settings, + ] + + # Add TOML config source if config file is specified + if config_file_path: + config_file: Path = Path(config_file_path).expanduser() + if config_file.exists(): + # Cast to InitSettingsSource to access init_kwargs + init_source: InitSettingsSource = cast( + InitSettingsSource, init_settings + ) + + # Extract component name and builtin flag from init_settings if + # provided + component_name: str = cast( + str, init_source.init_kwargs.get("_component_name") + ) + builtin: bool = cast(bool, init_source.init_kwargs.get("_builtin")) + + # If a component name is actually given, then the config will be + # loaded from the file. If none, then nothjing is loaded from the + # config file, instead the other sources are used only. + if component_name: + # Load TOML file and extract addons. section + with open(config_file, "rb") as f: + config_data: dict = tomllib.load(f) + + # Get the component-specific config from either + # or addons. + component_config: dict = config_data + if builtin: + component_config = config_data.get(component_name, {}) + else: + component_config = config_data.get("addons", {}).get( + component_name, {} + ) + + # Add custom dict source with component config + if component_config: + sources.append( + DictConfigSettingsSource(settings_cls, component_config) + ) + + # Priority order: init > env > dotenv > TOML > file_secret + sources.append(file_secret_settings) + return tuple(sources) class BaseComponent(ABC): """The base component class that is inherited by chat commands and verifiers - and allows them to be loaded by the AddonLoader.""" + and allows them to be loaded by the AddonLoader. + + The model is mapped to the config via the prefix of the name it has. So if + the name of a BaseComponent is "MyComponent" and we have a Field "my_field" + then the value from config loaded will be "addons.MyComponent.my_field". + """ @classmethod def create(cls) -> "BaseComponent": - """Factory method to instantiate a default version of this class.""" + """Factory method to instantiate a default version of this class. Used + by AddonLoader.""" # Check if __init__ takes only self (no required args) - sig = inspect.signature(cls.__init__) - params = list(sig.parameters.values()) + sig: inspect.Signature = inspect.signature(cls.__init__) + params: list[inspect.Parameter] = list(sig.parameters.values()) # params[0] is always 'self' if len(params) > 1 and any(p.default is p.empty for p in params[1:]): raise TypeError( @@ -36,7 +168,7 @@ def create(cls) -> "BaseComponent": def __init__(self) -> None: super().__init__() - self._config: BaseConfig + self._global_config: Config self._name: str = self.__class__.__name__ pattern = re.compile(r"[a-zA-Z_]\w*") @@ -65,22 +197,35 @@ def authors(self) -> str: return self._authors @property - def config(self) -> BaseConfig: + def global_config(self) -> Config: """Gets the config for this chat command.""" - return self._config + return self._global_config + + @global_config.setter + def global_config(self, value: Config) -> None: + self._global_config = value + + @property + def config(self) -> BaseComponentConfig: + """Gets the component-specific configuration. + + This property provides access to the component's configuration. + The configuration must be set by the ComponentManager before accessing. + + Returns: + The component's configuration instance. + + Raises: + RuntimeError: If configuration has not been set. + """ + raise NotImplementedError(f"Configuration not set for component {self.name}") @config.setter - def config(self, value: BaseConfig) -> None: - self._config: BaseConfig = value - - def get_config_fields(self) -> list[ConfigField]: - """Called during initialization, this is meant to return all config - fields that are going to be loaded from the config. The name that each - field has will automatically be prefixed with {verifier name}.""" - return [] - - def get_config_value(self, key: str) -> Any: - """Loads a value from the config. If the value is defined in the namespace - of the verifier name then that value will be returned. + def config(self, value: BaseComponentConfig) -> None: + """Sets the component-specific configuration. + + Args: + value: The configuration instance for this component. """ - return self._config.get_value(key) + _ = value + raise NotImplementedError(f"Configuration not set for component {self.name}") diff --git a/esbmc_ai/base_config.py b/esbmc_ai/base_config.py deleted file mode 100644 index 58df655..0000000 --- a/esbmc_ai/base_config.py +++ /dev/null @@ -1,143 +0,0 @@ -# Author: Yiannis Charalambous 2023 - -"""ABC Config that can be used to load config files.""" - -from abc import ABC -import sys -from pathlib import Path -import tomllib as toml -from typing import ( - Any, - Callable, - Dict, - List, -) - -from esbmc_ai.config_field import ConfigField - - -class BaseConfig(ABC): - """Config loader for ESBMC-AI""" - - def __init__(self) -> None: - super().__init__() - self._fields: List[ConfigField] = [] - self._values: Dict[str, Any] = {} - self.original_config_file: dict[str, Any] - self.config_file: dict[str, Any] - self.on_load_value: list[Callable[[str], None]] = [] - - def load_config_fields(self, cfg_path: Path, fields: list[ConfigField]) -> None: - """Initializes the base config structures. Loads the config file and fields.""" - - if not (cfg_path.exists() and cfg_path.is_file()): - print(f"Error: Config not found: {cfg_path}") - sys.exit(1) - - with open(cfg_path, "r") as file: - self.original_config_file = toml.loads(file.read()) - - # Flatten dict as the _fields are defined in a flattened format for - # convenience. - self.config_file = self.flatten_dict(self.original_config_file) - - # Load all the config file field entries - for field in fields: - self.load_config_field(field) - - def load_config_field(self, field: ConfigField) -> None: - """Loads a new field from the config. Init needs to be called before - calling this to initialize the base config.""" - if field not in self._fields: - self._fields.append(field) - - # If on_read is overwritten, then the reading process is manually - # defined so fallback to that. - if field.on_read: - self._values[field.name] = field.on_read(self.original_config_file) - return - - # Proceed to default read - - # Is field entry found in config? - if field.name in self.config_file: - # Check if None and not allowed! - if ( - field.default_value is None - and not field.default_value_none - and self.config_file[field.name] is None - ): - raise ValueError( - f"The config entry {field.name} has a None value when it can't be" - ) - - # Validate field - if not field.validate(self.config_file[field.name]): - msg = f"Field: {field.name} is invalid: {self.config_file[field.name]}" - if field.get_error_message is not None: - msg += ": " + field.get_error_message(self.config_file[field.name]) - elif field.error_message: - msg += ": " + field.error_message - raise ValueError(f"Config loading error: {msg}") - - # Assign field from config file - self._values[field.name] = field.on_load(self.config_file[field.name]) - elif field.default_value is None and not field.default_value_none: - raise KeyError(f"{field.name} is missing from config file") - else: - # Use default value - self._values[field.name] = field.default_value - - def set_custom_field(self, field: ConfigField, value: Any) -> None: - """Loads a new field from a custom source. Still validates the value given - to it.""" - # Check if None and not allowed! - if ( - field.default_value is None - and not field.default_value_none - and value is None - ): - raise ValueError( - f"Failed to add field from custom source: {field.name} has a " - "None value when it can't be" - ) - - # Validate field - if not field.validate(value): - msg = f"Field: {field.name} is invalid: {value}" - if field.get_error_message is not None: - msg += ": " + field.get_error_message(value) - elif field.error_message: - msg += ": " + field.error_message - raise ValueError(f"Config loading error: {msg}") - - if field not in self._fields: - self._fields.append(field) - self._values[field.name] = field.on_load(value) - - def get_value(self, name: str) -> Any: - """Gets the value of key name""" - for cb in self.on_load_value: - cb(name) - return self._values[name] - - def set_value(self, name: str, value: Any) -> None: - """Sets a value in the config, if it does not exist, it will create one. - This uses toml notation dot notation to namespace the elements.""" - self._values[name] = value - - def contains_field(self, name: str) -> bool: - """Check if config has a field.""" - return any(name == field.name for field in self._fields) - - @classmethod - def flatten_dict(cls, d, parent_key="", sep="."): - """Recursively flattens a nested dictionary.""" - items = {} - for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k - if isinstance(v, dict): - items.update(cls.flatten_dict(v, new_key, sep=sep)) - else: - items[new_key] = v - return items diff --git a/esbmc_ai/chat_response.py b/esbmc_ai/chat_response.py deleted file mode 100644 index 641d203..0000000 --- a/esbmc_ai/chat_response.py +++ /dev/null @@ -1,54 +0,0 @@ -# Author: Yiannis Charalambous - -from enum import Enum -from typing import NamedTuple - -from langchain.schema import ( - AIMessage, - BaseMessage, - HumanMessage, - SystemMessage, -) - - -"""Contains classes and functions that relate to sending/receiving messages from LLMs.""" - - -class FinishReason(Enum): - # API response still in progress or incomplete - null = 0 - # API returned complete model output - stop = 1 - # Incomplete model output due to max_tokens parameter or token limit - length = 2 - # Omitted content due to a flag from our content filters - content_filter = 3 - - -class ChatResponse(NamedTuple): - message: BaseMessage = AIMessage(content="") - total_tokens: int = 0 - finish_reason: FinishReason = FinishReason.null - - -def dict_to_base_message(json_string: dict) -> BaseMessage: - """Converts a json representation of messages (such as in config.json), - into LangChain object messages. The three recognized roles are: - 1. System - 2. AI - 3. Human""" - role: str = json_string["role"] - content: str = json_string["content"] - if role == "System": - return SystemMessage(content=content) - elif role == "AI": - return AIMessage(content=content) - elif role == "Human": - return HumanMessage(content=content) - else: - raise Exception() - - -def list_to_base_messages(json_messages: list[dict]) -> tuple[BaseMessage, ...]: - """Converts a list of messages from JSON format to a list of BaseMessage.""" - return tuple(dict_to_base_message(msg) for msg in json_messages) diff --git a/esbmc_ai/chats/__init__.py b/esbmc_ai/chats/__init__.py index 5ac142e..179a197 100644 --- a/esbmc_ai/chats/__init__.py +++ b/esbmc_ai/chats/__init__.py @@ -3,8 +3,14 @@ """This module contains different chat interfaces. Along with `BaseChatInterface` that provides necessary boilet-plate for implementing an LLM based chat.""" -from .base_chat_interface import BaseChatInterface +from .template_key_provider import ( + TemplateKeyProvider, + GenericTemplateKeyProvider, + ESBMCTemplateKeyProvider, +) __all__ = [ - "BaseChatInterface", + "TemplateKeyProvider", + "GenericTemplateKeyProvider", + "ESBMCTemplateKeyProvider", ] diff --git a/esbmc_ai/chats/base_chat_interface.py b/esbmc_ai/chats/base_chat_interface.py deleted file mode 100644 index a5a92c4..0000000 --- a/esbmc_ai/chats/base_chat_interface.py +++ /dev/null @@ -1,165 +0,0 @@ -# Author: Yiannis Charalambous - -"""Contains code for the base class for interacting with the LLMs in a -conversation-based way.""" - -from time import sleep, time -from typing import Any, Sequence - -from langchain.schema import ( - BaseMessage, - HumanMessage, - PromptValue, -) -import structlog - -from esbmc_ai.chat_response import ChatResponse, FinishReason -from esbmc_ai.ai_models import AIModel -from esbmc_ai.log_utils import LogCategories -from esbmc_ai.chats.template_key_provider import ( - TemplateKeyProvider, - GenericTemplateKeyProvider, -) - - -class BaseChatInterface: - """Base class for interacting with an LLM. It allows for interactions with - text generation LLMs and also chat LLMs.""" - - _last_attempt: float = 0 - cooldown_total: float = 20.0 - - def __init__( - self, - system_messages: list[BaseMessage], - ai_model: AIModel, - template_key_provider: TemplateKeyProvider | None = None, - ) -> None: - super().__init__() - self._logger: structlog.stdlib.BoundLogger = structlog.get_logger( - category=LogCategories.CHAT - ) - self.ai_model: AIModel = ai_model - self._system_messages: list[BaseMessage] = system_messages - self.messages: list[BaseMessage] = [] - self._template_key_provider = ( - template_key_provider or GenericTemplateKeyProvider() - ) - - def compress_message_stack(self) -> None: - """Compress the message stack, is abstract and needs to be implemented.""" - self.messages = [] - - def push_to_message_stack( - self, - message: BaseMessage | tuple[BaseMessage, ...] | list[BaseMessage], - ) -> None: - """Pushes a message(s) to the message stack without querying the LLM.""" - assert isinstance(message, BaseMessage | Sequence) - if isinstance(message, Sequence): - for m in message: - assert isinstance(m, BaseMessage) - if isinstance(message, list) or isinstance(message, tuple): - self.messages.extend(list(message)) - else: - self.messages.append(message) - - def get_template_keys(self, **kwargs: Any) -> dict[str, Any]: - """Gets template keys for applying in template values using the configured provider.""" - return self._template_key_provider.get_template_keys(**kwargs) - - def apply_template_value(self, **kwargs: str) -> None: - """Will substitute an f-string in the message stack and system messages to - the provided value. The new substituted messages will become the new - message stack, so the substitution is permanent.""" - - system_message_prompts: PromptValue = self.ai_model.apply_chat_template( - messages=self._system_messages, - **kwargs, - ) - self._system_messages = system_message_prompts.to_messages() - - message_prompts: PromptValue = self.ai_model.apply_chat_template( - messages=self.messages, - **kwargs, - ) - self.messages = message_prompts.to_messages() - - def get_applied_messages(self, **kwargs: str) -> tuple[BaseMessage, ...]: - """Applies the f-string substituion and returns the result instead of assigning - it to the message stack.""" - message_prompts: PromptValue = self.ai_model.apply_chat_template( - messages=self.messages, - **kwargs, - ) - return tuple(message_prompts.to_messages()) - - def get_applied_system_messages(self, **kwargs: str) -> tuple[BaseMessage, ...]: - """Same as `get_applied_messages` but for system messages.""" - message_prompts: PromptValue = self.ai_model.apply_chat_template( - messages=self._system_messages, - **kwargs, - ) - return tuple(message_prompts.to_messages()) - - @staticmethod - def send_messages( - ai_model: AIModel, - messages: list[BaseMessage], - logger: structlog.stdlib.BoundLogger | None = None, - ) -> ChatResponse: - """Static method to send messages.""" - - # Check cooldown - time_passed: float = time() - BaseChatInterface._last_attempt - if time_passed < BaseChatInterface.cooldown_total: - sleep_total_seconds: float = BaseChatInterface.cooldown_total - time_passed - if logger: - logger.info(f"Sleeping for {sleep_total_seconds}...") - sleep(sleep_total_seconds) - BaseChatInterface._last_attempt = time() - - response_message: BaseMessage = ai_model.invoke(input=messages) - - # Check if token limit has been exceeded. - new_tokens: int = ai_model.get_num_tokens_from_messages( - messages=messages + [response_message], - ) - - response: ChatResponse - if new_tokens > ai_model.tokens: - response = ChatResponse( - finish_reason=FinishReason.length, - message=response_message, - total_tokens=ai_model.tokens, - ) - else: - response = ChatResponse( - finish_reason=FinishReason.stop, - message=response_message, - total_tokens=new_tokens, - ) - - return response - - def send_message(self, message: str | None = None) -> ChatResponse: - """Sends a message to the AI model. Returns solution.""" - if message: - self.push_to_message_stack(message=HumanMessage(content=message)) - - all_messages = self._system_messages.copy() - all_messages.extend(self.messages.copy()) - - if message: - self._logger.debug(f"LLM Prompt: {message}") - - response: ChatResponse = self.send_messages( - ai_model=self.ai_model, - messages=all_messages, - logger=self._logger, - ) - - self._logger.debug(f"LLM Response: {response.message.content}") - - self.push_to_message_stack(message=response.message) - return response diff --git a/esbmc_ai/chats/solution_generator.py b/esbmc_ai/chats/solution_generator.py index 6f005b5..cc0a8c4 100644 --- a/esbmc_ai/chats/solution_generator.py +++ b/esbmc_ai/chats/solution_generator.py @@ -2,31 +2,33 @@ """Contains code for automatically repairing code using ESBMC.""" -from dataclasses import dataclass, replace -from typing_extensions import override +from dataclasses import replace +from langchain_core.prompts.chat import MessageLikeRepresentation +from pydantic import BaseModel, Field, SkipValidation from langchain.schema import BaseMessage +from langchain_core.language_models import BaseChatModel -from esbmc_ai.chat_response import ChatResponse, FinishReason from esbmc_ai.solution import SourceFile - -from esbmc_ai.ai_models import AIModel from esbmc_ai.verifiers.base_source_verifier import ( SourceCodeParseError, VerifierTimedOutException, ) -from esbmc_ai.chats.base_chat_interface import BaseChatInterface -from esbmc_ai.chats.template_key_provider import ESBMCTemplateKeyProvider +from esbmc_ai.chats.template_key_provider import ( + ESBMCTemplateKeyProvider, + TemplateKeyProvider, +) from esbmc_ai.verifiers.esbmc import ESBMCOutput +from esbmc_ai.chats.template_renderer import KeyTemplateRenderer default_scenario: str = "base" -@dataclass -class FixCodeScenario: - """Type for scenarios. A single scenario contains initial and system components.""" - - initial: BaseMessage - system: tuple[BaseMessage, ...] +class FixCodeScenario(BaseModel): + initial: str = Field(default="") + # Going to be manually instantiated by FixCodeCommandConfig + system: list[SkipValidation[MessageLikeRepresentation]] = Field( + default_factory=list, + ) def apply_formatting(esbmc_output: ESBMCOutput, format: str) -> str: @@ -57,32 +59,7 @@ def apply_formatting(esbmc_output: ESBMCOutput, format: str) -> str: raise ValueError(f"Not a valid ESBMC output type: {format}") -def get_source_code_formatted( - source_code_format: str, - source_code: str, - esbmc_output: ESBMCOutput, -) -> str: - """Gets the formatted output source code, based on the source_code_format - passed.""" - match source_code_format: - case "single": - # Get source code error line from esbmc output - line: int | None = esbmc_output.get_error_line_idx() - if line: - return source_code.splitlines(True)[line] - - raise AssertionError( - f"error line not found in esbmc output:\n{esbmc_output}" - ) - case "full": - return source_code - case _: - raise ValueError( - f"Not a valid format for source code: {source_code_format}" - ) - - -class SolutionGenerator(BaseChatInterface): +class SolutionGenerator: """SolutionGenerator is a simple conversation-based automated program repair class. The class works in a cycle, by first calling update_state with the new source_code and esbmc_output, then by calling generate_solution. The @@ -92,37 +69,25 @@ class supports scenarios to customize the system message and initial prompt def __init__( self, scenarios: dict[str, FixCodeScenario], - ai_model: AIModel, - source_code_format: str = "full", + ai_model: BaseChatModel, esbmc_output_type: str = "full", ) -> None: """Initializes the solution generator.""" + super().__init__() - super().__init__( - ai_model=ai_model, - system_messages=[], # Empty as it will be updated in the update method. - template_key_provider=ESBMCTemplateKeyProvider(), - ) + self.ai_model: BaseChatModel = ai_model + self.template_key_provider: TemplateKeyProvider = ESBMCTemplateKeyProvider() + self.messages: list[BaseMessage] = [] self.scenarios: dict[str, FixCodeScenario] = scenarios self.scenario: str = "" self.esbmc_output_type: str = esbmc_output_type - self.source_code_format: str = source_code_format self.source_code: SourceFile | None = None - self.source_code_formatted: str | None = None self.esbmc_output: ESBMCOutput | None = None self.invokations: int = 0 - @override - def compress_message_stack(self) -> None: - # Resets the conversation - cannot summarize code - # If generate_solution is called after this point, it will start new - # with the currently set state. - self.messages: list[BaseMessage] = [] - self.invokations = 0 - @staticmethod def extract_code_from_solution(solution: str) -> str: """Strip the source code of any leftover text as sometimes the AI model @@ -176,48 +141,10 @@ def update_state( # big. self.esbmc_output = verifier_output - # Format source code - self.source_code_formatted = get_source_code_formatted( - source_code_format=self.source_code_format, - source_code=source_file.content, - esbmc_output=verifier_output, - ) - - def _get_system_messages( - self, override_scenario: str | None = None - ) -> tuple[BaseMessage, ...]: - if override_scenario: - system_messages = self.scenarios[override_scenario].system - else: - assert self.scenario, "Call update or set the scenario" - if self.scenario in self.scenarios: - system_messages = self.scenarios[self.scenario].system - else: - system_messages = self.scenarios[default_scenario].system - - assert isinstance(system_messages, tuple) - assert all(isinstance(msg, BaseMessage) for msg in system_messages) - return system_messages - - def _get_initial_message(self, override_scenario: str | None = None) -> BaseMessage: - if override_scenario: - return self.scenarios[override_scenario].initial - else: - assert self.scenario, "Call update or set the scenario" - if self.scenario in self.scenarios: - return self.scenarios[self.scenario].initial - else: - return self.scenarios[default_scenario].initial - - def generate_solution( - self, - override_scenario: str | None = None, - ignore_system_message: bool = False, - ) -> tuple[str, FinishReason]: + def generate_solution(self, override_scenario: str | None = None) -> str: """Prompts the LLM to repair the source code using the verifier output. If this is the first time the method is called, the system message will - be sent to the LLM, unless ignore_system_message is True. Then the - initial prompt will be sent. + be sent to the LLM. Then the initial prompt will be sent. In subsequent invokations of generate_solution, the initial prompt will be used only. @@ -231,106 +158,41 @@ def generate_solution( assert ( self.source_code_raw is not None - and self.source_code_formatted is not None and self.esbmc_output is not None and self.scenario is not None ), "Call update_state before calling generate_solution." - # Show system message - if not ignore_system_message and self.invokations <= 0: - # Get scenario system messages and push it to message stack. Don't - # push to system message stack because we want to regenerate from - # the beginning at every reset. - system_messages: tuple[BaseMessage, ...] = self._get_system_messages( - override_scenario=override_scenario - ) - if len(system_messages) > 0: - self.push_to_message_stack(system_messages) + scenario_name: str = override_scenario or self.scenario + scenario: FixCodeScenario = self.scenarios[ + scenario_name if scenario_name in self.scenarios else "base" + ] + new_templates: list[MessageLikeRepresentation] = [] + + # Apply system message if first cycle + if self.invokations == 0: + new_templates.extend(scenario.system) # Get scenario initial message and push it to message stack - self.push_to_message_stack( - self._get_initial_message(override_scenario=override_scenario) + new_templates.append(("human", scenario.initial)) + # Prepare template values + key_template_renderer: KeyTemplateRenderer = KeyTemplateRenderer( + messages=new_templates, + key_provider=self.template_key_provider, ) - - self.invokations += 1 - error_type: str | None = self.esbmc_output.get_error_type() - - # Apply template substitution to message stack - self.apply_template_value( - **self.get_template_keys( - source_code=self.source_code_formatted, + self.messages.extend( + key_template_renderer.format_messages( + source_code=self.source_code, esbmc_output=self.esbmc_output.output, error_line=str(self.esbmc_output.get_error_line()), error_type=error_type if error_type else "unknown error", ) ) - # Generate the solution - response: ChatResponse = self.send_message() - solution: str = str(response.message.content) - - solution = SolutionGenerator.extract_code_from_solution(solution) - - # Post process source code - # If source code passed to LLM is formatted then we need to recombine to - # full source code before giving to ESBMC - match self.source_code_format: - case "single": - # Get source code error line from esbmc output - line: int | None = self.esbmc_output.get_error_line_idx() - assert line, ( - "fix code command: error line could not be found to apply " - "brutal patch replacement" - ) - solution = SourceFile.apply_line_patch( - self.source_code_raw, solution, line, line - ) - - return solution, response.finish_reason - - -class ReverseOrderSolutionGenerator(SolutionGenerator): - """SolutionGenerator that shows the source code and verifier output state in - reverse order.""" - - @override - def send_message(self, message: str | None = None) -> ChatResponse: - # Reverse the messages - messages: list[BaseMessage] = self.messages.copy() - self.messages.reverse() - - response: ChatResponse = super().send_message(message) - - # Add to the reversed message the new message received by the LLM. - messages.append(self.messages[-1]) - # Restore - self.messages = messages - - return response - + self.invokations += 1 -class LatestStateSolutionGenerator(SolutionGenerator): - """SolutionGenerator that only shows the latest source code and verifier - output state.""" + # Generate the solution + response: BaseMessage = self.ai_model.invoke(self.messages) + solution = SolutionGenerator.extract_code_from_solution(str(response.content)) - @override - def generate_solution( - self, - override_scenario: str | None = None, - ignore_system_message: bool = False, - ) -> tuple[str, FinishReason]: - # Backup message stack and clear before sending base message. We want - # to keep the message stack intact because we will print it with - # print_raw_conversation. - messages: list[BaseMessage] = self.messages - self.messages: list[BaseMessage] = [] - solution, finish_reason = super().generate_solution( - override_scenario=override_scenario, - ignore_system_message=ignore_system_message, - ) - # Append last messages to the messages stack - messages.extend(self.messages) - # Restore - self.messages = messages - return solution, finish_reason + return solution diff --git a/esbmc_ai/chats/template_key_provider.py b/esbmc_ai/chats/template_key_provider.py index 47bfdf1..73abde6 100644 --- a/esbmc_ai/chats/template_key_provider.py +++ b/esbmc_ai/chats/template_key_provider.py @@ -29,7 +29,7 @@ def get_template_keys( **kwargs: Any, ) -> dict[str, Any]: """Get canonical template keys for ESBMC code repair workflows.""" - keys = { + keys: dict["str", Any] = { "source_code": source_code, "esbmc_output": esbmc_output, "error_line": error_line, diff --git a/esbmc_ai/chats/template_renderer.py b/esbmc_ai/chats/template_renderer.py new file mode 100644 index 0000000..c4bfb0f --- /dev/null +++ b/esbmc_ai/chats/template_renderer.py @@ -0,0 +1,52 @@ +# Author: Yiannis Charalambous + +"""Template renderer that integrates TemplateKeyProvider with ChatPromptTemplate.""" + +from typing import Any, Sequence, override +from langchain.schema import BaseMessage +from langchain_core.prompt_values import ChatPromptValue +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.prompts.chat import MessageLikeRepresentation +from langchain_core.prompts.string import PromptTemplateFormat +from pydantic import PrivateAttr +from esbmc_ai.chats.template_key_provider import TemplateKeyProvider + + +class KeyTemplateRenderer(ChatPromptTemplate): + """Derives ChatPromptTemplate and automatically provides template keys via + TemplateKeyProvider. This is to force standarization across types and keys.""" + + _key_provider: TemplateKeyProvider = PrivateAttr() + + def __init__( + self, + messages: Sequence[MessageLikeRepresentation], + key_provider: TemplateKeyProvider, + *, + template_format: PromptTemplateFormat = "f-string", + **kwargs: Any, + ) -> None: + super().__init__( + messages=messages, + template_format=template_format, + **kwargs, + ) + self._key_provider = key_provider + + @override + def format_prompt(self, **kwargs: Any) -> ChatPromptValue: + auto_keys = self._key_provider.get_template_keys(**kwargs) + kwargs = {**auto_keys, **kwargs} + return super().format_prompt(**kwargs) + + @override + def format_messages(self, **kwargs: Any) -> list[BaseMessage]: + auto_keys = self._key_provider.get_template_keys(**kwargs) + kwargs = {**auto_keys, **kwargs} + return super().format_messages(**kwargs) + + @override + def format(self, **kwargs: Any) -> str: + auto_keys = self._key_provider.get_template_keys(**kwargs) + kwargs = {**auto_keys, **kwargs} + return super().format(**kwargs) diff --git a/esbmc_ai/commands/__init__.py b/esbmc_ai/commands/__init__.py index 4591ad3..8a9540a 100644 --- a/esbmc_ai/commands/__init__.py +++ b/esbmc_ai/commands/__init__.py @@ -3,16 +3,14 @@ from .exit_command import ExitCommand from .help_command import HelpCommand -from .list_models_command import ListModelsCommand from .help_config import HelpConfigCommand from .fix_code_command import FixCodeCommand -from .reload_models_command import ReloadAIModelsCommand +from .debug_config import DebugConfigViewCommand __all__ = [ "ExitCommand", "HelpCommand", "HelpConfigCommand", - "ListModelsCommand", "FixCodeCommand", - "ReloadAIModelsCommand", + "DebugConfigViewCommand", ] diff --git a/esbmc_ai/commands/debug_config.py b/esbmc_ai/commands/debug_config.py new file mode 100644 index 0000000..6ebf87b --- /dev/null +++ b/esbmc_ai/commands/debug_config.py @@ -0,0 +1,100 @@ +# Author: Yiannis Charalambous + +from typing import Any, override +from esbmc_ai.chat_command import ChatCommand +from esbmc_ai.command_result import CommandResult +from esbmc_ai.component_manager import ComponentManager + + +class DebugConfigViewCommand(ChatCommand): + """Displays all loaded config values.""" + + def __init__(self) -> None: + super().__init__( + command_name="debug-view-config", + authors="", + help_message="Used for debug to view the current config state.", + ) + + @staticmethod + def _format_value(value: Any, indent: str = " ") -> None: + """Format and print a config value with appropriate formatting.""" + if isinstance(value, dict): + if value: + for k, v in value.items(): + print(f"{indent}{k}: {v}") + else: + print(f"{indent}(empty)") + elif isinstance(value, (list, tuple)): + if value: + for item in value: + print(f"{indent}- {item}") + else: + print(f"{indent}(empty)") + else: + print(f"{indent}{value}") + + def _print_config_section( + self, config_dict: dict[str, Any], indent: str = "" + ) -> None: + """Print a config dictionary with formatted field names.""" + for field_name, value in config_dict.items(): + display_name = field_name.replace("_", " ").title() + print(f"\n{indent}{display_name}:") + self._format_value(value, indent + " ") + + @override + def execute(self, **kwargs: Any | None) -> CommandResult | None: + _ = kwargs + + print("\n" + "=" * 80) + print("GLOBAL CONFIGURATION") + print("=" * 80) + + self._print_config_section({"Config File": self.global_config.config_file}) + self._print_config_section(self.global_config.model_dump()) + + print("\n" + "=" * 80) + print("BUILTIN COMPONENT CONFIGURATIONS") + print("=" * 80) + + cm = ComponentManager() + builtin_components = cm.builtin_components + if not builtin_components: + print("\n(No builtin components)") + else: + for name, component in builtin_components.items(): + print(f"\n{name}:") + try: + if component.config: + self._print_config_section( + component.config.model_dump(), indent=" " + ) + else: + print(" (no config)") + except NotImplementedError: + print(" (no config)") + + print("\n" + "=" * 80) + print("ADDON COMPONENT CONFIGURATIONS") + print("=" * 80) + + addon_components = cm.addon_components + if not addon_components: + print("\n(No addon components)") + else: + for name, component in addon_components.items(): + print(f"\n{name}:") + try: + if component.config: + self._print_config_section( + component.config.model_dump(), indent=" " + ) + else: + print(" (no config)") + except NotImplementedError: + print(" (no config)") + + print("\n" + "=" * 80 + "\n") + + return None diff --git a/esbmc_ai/commands/fix_code_command.py b/esbmc_ai/commands/fix_code_command.py index 1bbe802..46e85ad 100644 --- a/esbmc_ai/commands/fix_code_command.py +++ b/esbmc_ai/commands/fix_code_command.py @@ -1,31 +1,27 @@ # Author: Yiannis Charalambous +from enum import Enum import os from pathlib import Path import sys -from typing import Any, Optional -from langchain.schema import HumanMessage +from typing import Any +from pydantic import Field, field_validator from typing_extensions import override -from esbmc_ai.config_field import ConfigField -from esbmc_ai.component_loader import ComponentLoader +from esbmc_ai.base_component import BaseComponentConfig +from esbmc_ai.component_manager import ComponentManager from esbmc_ai.solution import Solution, SourceFile from esbmc_ai.ai_models import AIModel -from esbmc_ai.chat_response import FinishReason from esbmc_ai.chats.solution_generator import ( SolutionGenerator, - LatestStateSolutionGenerator, - ReverseOrderSolutionGenerator, FixCodeScenario, - default_scenario, ) +from esbmc_ai.chats.template_key_provider import ESBMCTemplateKeyProvider from esbmc_ai.verifiers.base_source_verifier import VerifierTimedOutException from esbmc_ai.command_result import CommandResult from esbmc_ai.verifier_output import VerifierOutput from esbmc_ai.chat_command import ChatCommand from esbmc_ai.log_utils import get_log_level, print_horizontal_line -from esbmc_ai.chat_response import list_to_base_messages -from esbmc_ai.msg_bus import Signal from esbmc_ai.loading_widget import BaseLoadingWidget, LoadingWidget import esbmc_ai.prompt_utils as prompt_utils from esbmc_ai.verifiers.esbmc import ESBMC @@ -38,12 +34,12 @@ def __init__( self, successful: bool, attempts: int, - repaired_source: Optional[str] = None, + repaired_source: str | None = None, ) -> None: super().__init__() self._successful: bool = successful self.attempts: int = attempts - self.repaired_source: Optional[str] = repaired_source + self.repaired_source: str | None = repaired_source @property @override @@ -55,146 +51,113 @@ def __str__(self) -> str: if self._successful and self.repaired_source is not None: return self.repaired_source - return "ESBMC-AI Notice: Failed all attempts..." + return "Failed all attempts..." + + +class FixCodeCommandConfig(BaseComponentConfig): + class VerifierOutputType(str, Enum): + full = "full" + ce = "ce" + vp = "vp" + + verifier_output_type: str = Field( + default=VerifierOutputType.full, + description="The type of output from ESBMC in the fix code command.", + ) + + temperature: float = Field( + default=0, + description="The temperature of the LLM for the fix code command.", + ) + + max_attempts: int = Field( + default=5, + description="Fix code command max attempts.", + ) + + prompt_templates: dict[str, FixCodeScenario] = Field( + default={ + "base": FixCodeScenario( + initial="The ESBMC output is:\n\n```\n{esbmc_output}\n```\n\nThe source code is:\n\n```c\n{source_code}\n```\n Using the ESBMC output, show the fixed text.", + system=[ + { + "role": "system", + "content": "From now on, act as an Automated Code Repair Tool that repairs AI C code. You will be shown AI C code, along with ESBMC output. Pay close attention to the ESBMC output, which contains a stack trace along with the type of error that occurred and its location that you need to fix. Provide the repaired C code as output, as would an Automated Code Repair Tool. Aside from the corrected source code, do not output any other text.", + } + ], + ), + "division by zero": FixCodeScenario( + initial="The ESBMC output is:\n\n```\n{esbmc_output}\n```\n\nThe source code is:\n\n```c\n{source_code}\n```\n Using the ESBMC output, show the fixed text.", + system=[ + { + "role": "system", + "content": "Here's a C program with a vulnerability:\n```c\n{source_code}\n```\nA Formal Verification tool identified a division by zero issue:\n{esbmc_output}\nTask: Modify the C code to safely handle scenarios where division by zero might occur. The solution should prevent undefined behavior or crashes due to division by zero. \nGuidelines: Focus on making essential changes only. Avoid adding or modifying comments, and ensure the changes are precise and minimal.\nGuidelines: Ensure the revised code avoids undefined behavior and handles division by zero cases effectively.\nGuidelines: Implement safeguards (like comparison) to prevent division by zero instead of using literal divisions like 1.0/0.0.Output: Provide the corrected, complete C code. The solution should compile and run error-free, addressing the division by zero vulnerability.\nStart the code snippet with ```c and end with ```. Reply OK if you understand.", + }, + {"role": "ai", "content": "OK."}, + ], + ), + }, + description="Scenario prompt templates for different types of bugs for the fix code command.", + ) + + @field_validator("verifier_output_type", mode="after") + @classmethod + def validate_verifier_output_type(cls, v: str) -> str: + if v not in ["full", "vp", "ce"]: + raise ValueError("verifier_output_type must be 'full', 'vp', or 'ce'") + return v class FixCodeCommand(ChatCommand): """Command for automatically fixing code using a verifier.""" - on_solution_signal: Signal = Signal() - def __init__(self) -> None: super().__init__( command_name="fix-code", help_message="Generates a solution for this code, and reevaluates it with ESBMC.", ) + # Set default config instance + self._config: FixCodeCommandConfig = FixCodeCommandConfig() + self.anim: BaseLoadingWidget - def print_raw_conversation(self, solution_generator: SolutionGenerator) -> None: - """Debug prints the raw conversation""" - print_horizontal_line(get_log_level()) - self.logger.info("ESBMC-AI Notice: Printing raw conversation...") - all_messages = solution_generator._system_messages + solution_generator.messages - messages: list[str] = [f"{msg.type}: {msg.content}" for msg in all_messages] - self.logger.info("\n" + "\n\n".join(messages)) - self.logger.info("ESBMC-AI Notice: End of raw conversation") - print_horizontal_line(get_log_level()) + @classmethod + def _get_config_class(cls) -> type[BaseComponentConfig]: + """Return the config class for this component.""" + return FixCodeCommandConfig + @property @override - def get_config_fields(self) -> list[ConfigField]: - return [ - ConfigField( - name="fix_code.verifier_output_type", - default_value="full", - validate=lambda v: v in ["full", "vp", "ce"], - help_message="The type of output from ESBMC in the fix code command.", - ), - ConfigField( - name="fix_code.temperature", - default_value=1.0, - validate=lambda v: isinstance(v, float) and 0 <= v <= 2, - error_message="Temperature needs to be a value between 0 and 2.0", - help_message="The temperature of the LLM for the fix code command.", - ), - ConfigField( - name="fix_code.max_attempts", - default_value=5, - validate=lambda v: isinstance(v, int), - help_message="Fix code command max attempts.", - ), - ConfigField( - name="fix_code.message_history", - default_value="normal", - validate=lambda v: v in ["normal", "latest_only", "reverse"], - error_message='fix_code.message_history can only be "normal", ' - + '"latest_only", "reverse"', - help_message="The type of history to be shown in the fix code command.", - ), - ConfigField( - name="fix_code.raw_conversation", - default_value=False, - help_message="Print the raw conversation at different parts of execution.", - ), - ConfigField( - name="fix_code.source_code_format", - default_value="full", - validate=lambda v: isinstance(v, str) and v in ["full", "single"], - error_message="source_code_format can only be 'full' or 'single'", - help_message="The source code format in the fix code prompt.", - ), - # Here we have a list of prompt templates that are for each scenario. - # The base scenario prompt template is required. - ConfigField( - name="fix_code.prompt_templates", - default_value=None, - validate=lambda v: default_scenario in v - and all( - prompt_utils.validate_prompt_template(prompt_template) - for prompt_template in v.values() - ), - on_read=lambda config_file: ( - { - scenario: FixCodeScenario( - initial=HumanMessage(content=conv["initial"]), - system=list_to_base_messages(conv["system"]), - ) - for scenario, conv in config_file["fix_code"][ - "prompt_templates" - ].items() - } - if "prompt_templates" in config_file["fix_code"] - else {} - ), - help_message="Scenario prompt templates for differnet types of bugs " - "for the fix code command.", - ), - ] + def config(self) -> BaseComponentConfig: + return self._config + + @config.setter + def config(self, value: BaseComponentConfig) -> None: + assert isinstance(value, FixCodeCommandConfig) + self._config = value @override def execute(self, **kwargs: Any) -> FixCodeCommandResult: - ComponentLoader().load_base_component_config(self) - # Handle kwargs source_file: SourceFile = SourceFile.load( - self.get_config_value("solution.filenames")[0], + self.global_config.solution.filenames[0], Path(os.getcwd()), ) original_source_file: SourceFile = SourceFile( source_file.file_path, source_file.base_path, source_file.content ) self.anim = ( - LoadingWidget() - if self.get_config_value("loading_hints") - else BaseLoadingWidget() + LoadingWidget() if self.global_config.loading_hints else BaseLoadingWidget() ) - generate_patches: bool = self.get_config_value("generate_patches") - message_history: str = self.get_config_value("fix_code.message_history") - ai_model: AIModel = self.get_config_value("ai_model") - temperature: float = self.get_config_value("fix_code.temperature") - max_tries: int = self.get_config_value("fix_code.max_attempts") - timeout: int = self.get_config_value("llm_requests.timeout") - source_code_format: str = self.get_config_value("fix_code.source_code_format") - esbmc_output_format: str = self.get_config_value( - "fix_code.verifier_output_type" - ) - scenarios: dict[str, FixCodeScenario] = self.get_config_value( - "fix_code.prompt_templates" - ) - max_attempts: int = self.get_config_value("fix_code.max_attempts") - raw_conversation: bool = self.get_config_value("fix_code.raw_conversation") - entry_function: str = self.get_config_value("solution.entry_function") - output_dir: Path = self.get_config_value("solution.output_dir") # End of handle kwargs solution: Solution = Solution([]) solution.add_source_file(source_file) - self._logger.info(f"Temperature: {temperature}") - self._logger.info(f"Verifying function: {entry_function}") + self._logger.info(f"FixCodeConfig: {self._config}") - verifier: Any = ComponentLoader().get_verifier("esbmc") + verifier: Any = ComponentManager().get_verifier("esbmc") assert isinstance(verifier, ESBMC) - self._logger.info(f"Running verifier: {verifier.verifier_name}") verifier_result: VerifierOutput = verifier.verify_source( solution=solution, **kwargs ) @@ -203,50 +166,24 @@ def execute(self, **kwargs: Any) -> FixCodeCommandResult: if verifier_result.successful(): self.logger.info("File verified successfully") returned_source: str - if generate_patches: + if self.global_config.generate_patches: returned_source = source_file.get_diff(source_file) else: returned_source = source_file.content return FixCodeCommandResult(True, 0, returned_source) - match message_history: - case "normal": - solution_generator = SolutionGenerator( - ai_model=ai_model.bind( - temperature=temperature, - requests_max_tries=max_tries, - requests_timeout=timeout, - ), - scenarios=scenarios, - source_code_format=source_code_format, - esbmc_output_type=esbmc_output_format, - ) - case "latest_only": - solution_generator = LatestStateSolutionGenerator( - ai_model=ai_model.bind( - temperature=temperature, - requests_max_tries=max_tries, - requests_timeout=timeout, - ), - scenarios=scenarios, - source_code_format=source_code_format, - esbmc_output_type=esbmc_output_format, - ) - case "reverse": - solution_generator = ReverseOrderSolutionGenerator( - ai_model=ai_model.bind( - temperature=temperature, - requests_max_tries=max_tries, - requests_timeout=timeout, - ), - scenarios=scenarios, - source_code_format=source_code_format, - esbmc_output_type=esbmc_output_format, - ) - case _: - raise NotImplementedError( - f"error: {message_history} has not been implemented in the Fix Code Command" - ) + # Create the AI model with the specified parameters + ai_model = AIModel.get_model( + model=self.global_config.ai_model.id, + temperature=self._config.temperature, + url=self.global_config.ai_model.base_url, + ) + + solution_generator: SolutionGenerator = SolutionGenerator( + ai_model=ai_model, + scenarios=self._config.prompt_templates, + esbmc_output_type=self._config.verifier_output_type, + ) try: solution_generator.update_state( @@ -259,62 +196,42 @@ def execute(self, **kwargs: Any) -> FixCodeCommandResult: print() - for attempt in range(1, max_attempts + 1): - result: Optional[FixCodeCommandResult] = self._attempt_repair( + for attempt in range(1, self._config.max_attempts + 1): + result: FixCodeCommandResult | None = self._attempt_repair( attempt=attempt, solution_generator=solution_generator, verifier=verifier, - max_attempts=max_attempts, - output_dir=output_dir, solution=solution, - raw_conversation=raw_conversation, ) if result: - if raw_conversation: - self.print_raw_conversation(solution_generator) - - if generate_patches: + if self.global_config.generate_patches: result.repaired_source = source_file.get_diff(original_source_file) return result - if raw_conversation: - self.print_raw_conversation(solution_generator) - - return FixCodeCommandResult(False, max_attempts, None) + return FixCodeCommandResult(False, self._config.max_attempts, None) def _attempt_repair( self, attempt: int, - max_attempts: int, solution_generator: SolutionGenerator, solution: Solution, verifier: ESBMC, - output_dir: Optional[Path], - raw_conversation: bool, - ) -> Optional[FixCodeCommandResult]: + ) -> FixCodeCommandResult | None: source_file: SourceFile = solution.files[0] - # Get a response. Use while loop to account for if the message stack - # gets full, then need to compress and retry. - while True: - # Generate AI solution - with self.anim("Generating Solution... Please Wait"): - llm_solution, finish_reason = solution_generator.generate_solution() + # Generate AI solution + with self.anim("Generating Solution... Please Wait"): + llm_solution = solution_generator.generate_solution() - if finish_reason == FinishReason.length: - solution_generator.compress_message_stack() - else: - # Update the source file state - source_file.content = llm_solution - break + # Update the source file state + source_file.content = llm_solution # Print verbose lvl 2 - self._logger.debug("\nESBMC-AI Notice: Source Code Generation:") print_horizontal_line(get_log_level(3)) + self._logger.debug("\nSource Code Generation:") self._logger.debug(source_file.content) print_horizontal_line(get_log_level(3)) - self._logger.debug("") solution = solution.save_temp() @@ -325,27 +242,16 @@ def _attempt_repair( source_file.verifier_output = verifier_result - # Print verbose lvl 2 - self._logger.debug("\nESBMC-AI Notice: ESBMC Output:") - print_horizontal_line(get_log_level(3)) - self._logger.debug(source_file.verifier_output.output) - print_horizontal_line(get_log_level(3)) - # Solution found if verifier_result.return_code == 0: - self.on_solution_signal.emit(source_file.content) - - if raw_conversation: - self.print_raw_conversation(solution_generator) self.logger.info("Successfully verified code") # Check if an output directory is specified and save to it - if output_dir: - assert ( - output_dir.is_dir() - ), "FixCodeCommand: Output directory needs to be valid" - source_file.save_file(output_dir / source_file.file_path.name) + if self.global_config.solution.output_dir: + source_file.save_file( + self.global_config.solution.output_dir / source_file.file_path.name + ) return FixCodeCommandResult(True, attempt, source_file.content) try: @@ -355,13 +261,15 @@ def _attempt_repair( source_file.verifier_output, ) except VerifierTimedOutException: - if raw_conversation: - self.print_raw_conversation(solution_generator) self.logger.error("ESBMC has timed out...") sys.exit(1) # Failure case - if attempt != max_attempts: - self.logger.info(f"Failure {attempt}/{max_attempts}: Retrying...") + if attempt != self._config.max_attempts: + self.logger.info( + f"Failure {attempt}/{self._config.max_attempts}: Retrying..." + ) else: - self.logger.info(f"Failure {attempt}/{max_attempts}: Exiting...") + self.logger.info( + f"Failure {attempt}/{self._config.max_attempts}: Exiting..." + ) diff --git a/esbmc_ai/commands/help_command.py b/esbmc_ai/commands/help_command.py index 4bdcb01..aad001b 100644 --- a/esbmc_ai/commands/help_command.py +++ b/esbmc_ai/commands/help_command.py @@ -6,7 +6,7 @@ from typing_extensions import override from esbmc_ai.chat_command import ChatCommand -from esbmc_ai.component_loader import ComponentLoader +from esbmc_ai.component_manager import ComponentManager class HelpCommand(ChatCommand): @@ -24,16 +24,16 @@ def __init__(self) -> None: @override def execute(self, **_: Any) -> Any: print("Commands:") - for command in ComponentLoader().builtin_commands.values(): + for command in ComponentManager().builtin_commands.values(): print(f"* {command.command_name}: {command.help_message}") if command.authors: print(f"\tAuthors: {command.authors}") - if ComponentLoader().addon_commands: + if ComponentManager().addon_commands: print("\nAddon Commands:") else: self.logger.info("No addon commands to show...") - for command in ComponentLoader().addon_commands.values(): + for command in ComponentManager().addon_commands.values(): print(f"* {command.command_name}: {command.help_message}") if command.authors: print(f"\tAuthors: {command.authors}") diff --git a/esbmc_ai/commands/help_config.py b/esbmc_ai/commands/help_config.py index f215c5b..1ebc2c5 100644 --- a/esbmc_ai/commands/help_config.py +++ b/esbmc_ai/commands/help_config.py @@ -1,12 +1,15 @@ # Author: Yiannis Charalambous - from typing import Any from typing_extensions import override +from pydantic.fields import FieldInfo +from pydantic import BaseModel -from esbmc_ai.addon_loader import Config, AddonLoader -from esbmc_ai.config_field import ConfigField +from esbmc_ai.config import Config from esbmc_ai.chat_command import ChatCommand +from esbmc_ai.command_result import CommandResult +from esbmc_ai.component_manager import ComponentManager +from esbmc_ai.base_component import BaseComponent class HelpConfigCommand(ChatCommand): @@ -19,38 +22,149 @@ def __init__(self) -> None: ) @staticmethod - def _print_config_field(field: ConfigField) -> None: - value_type: type = type(field.default_value) - # Default value: for strings enforce a limit - default_value: Any = field.default_value - if value_type is str and len(field.default_value) > 30: - default_value = field.default_value[:30] + "..." - - print( - f"\t* {field.name}: " - f'{value_type.__name__} = "{default_value}" - {field.help_message}' - ) + def _print_config_field( + field_name: str, field_info: FieldInfo, level: int = 0 + ) -> None: + """Print information about a single config field.""" + # Create indentation based on level + indent = "\t" * (level + 1) + + # Get default value and format it + default_value = field_info.default + if isinstance(default_value, str) and len(default_value) > 30: + default_value = default_value[:30] + "..." + + print(f"{indent}* {field_name}:") + + if field_info.description: + print(f"{indent} Description: {field_info.description}") + + # Check if this field's annotation is a BaseModel type + field_annotation = field_info.annotation + if field_annotation and hasattr(field_annotation, "__origin__"): + # Handle generic types like list, dict, etc. + field_annotation = field_annotation.__origin__ + + # Check if the field type is a BaseModel subclass + if ( + field_annotation + and isinstance(field_annotation, type) + and issubclass(field_annotation, BaseModel) + ): + + print(f"{indent} Nested fields:") + # Recursively print fields of the BaseModel + for ( + nested_field_name, + nested_field_info, + ) in field_annotation.model_fields.items(): + HelpConfigCommand._print_config_field( + nested_field_name, nested_field_info, level + 1 + ) + else: + # Print default value for non-BaseModel fields + if field_info.default is not None: + print(f"{indent} Default: {default_value}") + + if hasattr(field_info, "alias") and field_info.alias: + print(f"{indent} Alias: {field_info.alias}") + + print() + + @staticmethod + def _print_component_config_fields( + component_name: str, component: BaseComponent, level: int = 0 + ) -> bool: + """Print config fields for a component if it has any. + + Returns True if the component has config fields, False otherwise. + """ + try: + # Try to access the component's config + config = component.config + if config is None: + return False + + # Get the config class and its fields + config_class = type(config) + config_fields = config_class.model_fields + + if not config_fields: + return False + + indent = "\t" * level + print(f"{indent}{component_name}:") + + for field_name, field_info in config_fields.items(): + if not getattr(field_info, "exclude", False): + HelpConfigCommand._print_config_field(field_name, field_info, level) + + return True + + except (NotImplementedError, RuntimeError, AttributeError): + # Component doesn't have config + return False + + @staticmethod + def _print_components_section( + title: str, components: dict[str, BaseComponent] + ) -> None: + """Print a section for a group of components (builtin or addon).""" + # First, collect components that have config + components_with_config: list[tuple[str, BaseComponent]] = [] + + for name, component in components.items(): + try: + # Try to access config to check if it exists + if component.config is not None: + components_with_config.append((name, component)) + except (NotImplementedError, RuntimeError, AttributeError): + # Component doesn't have config, skip it + continue + + if not components_with_config: + # No components with config in this section + return + + print(title) + print() + + for name, component in components_with_config: + HelpConfigCommand._print_component_config_fields(name, component) @override - def execute(self, **kwargs: Any | None) -> Any: + def execute(self, **kwargs: Any | None) -> CommandResult | None: _ = kwargs - addon_fields: list[ConfigField] = [] - print("ESBMC-AI Config Fields:") - for field in Config()._fields: - split_field_name: list[str] = field.name.split(".") - if ( - len(split_field_name) > 1 - and split_field_name[0] == AddonLoader.addon_prefix - ): - addon_fields.append(field) - else: - self._print_config_field(field) - - if addon_fields: - print("\nESBMC-AI Addon Fields:") - for field in addon_fields: - self._print_config_field(field) + print() + + # Get all fields from the Config model + config_fields = Config.model_fields + + for field_name, field_info in config_fields.items(): + # Skip excluded fields (like command_name and config_file) + if not getattr(field_info, "exclude", False): + self._print_config_field(field_name, field_info) + + from esbmc_ai.log_utils import print_horizontal_line + + print_horizontal_line() + + # Get component manager to access builtin and addon components + component_manager = ComponentManager() + + # Print builtin component config fields + self._print_components_section( + "Builtin Component Config Fields:", + dict(component_manager.builtin_components), + ) + + print_horizontal_line() + + # Print addon component config fields + self._print_components_section( + "Addon Component Config Fields:", dict(component_manager.addon_components) + ) return None diff --git a/esbmc_ai/commands/list_models_command.py b/esbmc_ai/commands/list_models_command.py deleted file mode 100644 index 80f4a94..0000000 --- a/esbmc_ai/commands/list_models_command.py +++ /dev/null @@ -1,35 +0,0 @@ -# Author: Yiannis Charalambous - -from typing import Any, DefaultDict, override -from esbmc_ai.ai_models import AIModels -from esbmc_ai.chat_command import ChatCommand -from esbmc_ai.command_result import CommandResult - - -class ListModelsCommand(ChatCommand): - """Command to list all models that are available.""" - - def __init__(self) -> None: - super().__init__( - command_name="list-models", - help_message="Lists all available AI models.", - authors="", - ) - - @override - def execute(self, **kwargs: Any | None) -> CommandResult | None: - _ = kwargs - - # Sort models into categories based on type: OpenAI, Anthropic, Ollama... - model_types: dict[str, list[str]] = DefaultDict(list) - for n, m in sorted( - AIModels().ai_models.items(), key=lambda v: type(v[1]).__name__ - ): - model_types[type(m).__name__].append(n) - - # Show in ordered list - for model_type, models in model_types.items(): - for model_name in sorted(models): - print(f"* {model_type}: {model_name}") - - return None diff --git a/esbmc_ai/commands/reload_models_command.py b/esbmc_ai/commands/reload_models_command.py deleted file mode 100644 index 119c9f3..0000000 --- a/esbmc_ai/commands/reload_models_command.py +++ /dev/null @@ -1,21 +0,0 @@ -# Author: Yiannis Charalambous - -from typing import Any -from esbmc_ai import ChatCommand -from esbmc_ai.command_result import CommandResult -from esbmc_ai import AIModels - - -class ReloadAIModelsCommand(ChatCommand): - def __init__(self) -> None: - super().__init__( - command_name="reload-models", - authors="", - help_message="Refreshes the list of models by pulling the model definitions from online.", - ) - - def execute(self, **kwargs: Any | None) -> CommandResult | None: - _ = kwargs - - self.logger.info("Reloading AI Models...") - AIModels().load_default_models(self.get_config_value("api_keys")) diff --git a/esbmc_ai/component_loader.py b/esbmc_ai/component_loader.py deleted file mode 100644 index cdd5dab..0000000 --- a/esbmc_ai/component_loader.py +++ /dev/null @@ -1,122 +0,0 @@ -# Author: Yiannis Charalambous - -"""Module contains class for keeping track and managing built-in base -components.""" - - -from typing import Set -import structlog -import re - -from esbmc_ai.chat_command import ChatCommand -from esbmc_ai.base_component import BaseComponent -from esbmc_ai.config_field import ConfigField -from esbmc_ai.singleton import SingletonMeta -from esbmc_ai.verifiers.base_source_verifier import BaseSourceVerifier -from esbmc_ai.log_utils import LogCategories - -_loaded_fields: Set[str] = set() - - -class ComponentLoader(metaclass=SingletonMeta): - """Class for keeping track of and initializing local components. - - Local components are classes derived from BaseComponent that use base - component features (maybe for readability). Built-in commands, built-in - verifiers for example. - - Manages all the verifiers that are used. Can get the appropriate one based - on the config.""" - - def __init__(self) -> None: - self._logger: structlog.stdlib.BoundLogger = structlog.get_logger( - self.__class__.__name__ - ).bind(category=LogCategories.SYSTEM) - self.verifiers: dict[str, BaseSourceVerifier] = {} - self._verifier: BaseSourceVerifier | None = None - - self._builtin_commands: dict[str, ChatCommand] = {} - self._addon_commands: dict[str, ChatCommand] = {} - - def load_base_component_config(self, component: BaseComponent) -> None: - """Loads the config fields of a built-in base component, should be called at - execute.""" - from esbmc_ai.config import Config - - # Handle loading config (since this is a built-in module) - fields: list[ConfigField] = component.get_config_fields() - for f in fields: - if f.name in _loaded_fields: - continue - _loaded_fields.add(f.name) - Config().load_config_field(f) - - @property - def verfifier(self) -> BaseSourceVerifier: - """Returns the verifier that is selected.""" - assert self._verifier, "Verifier is not set..." - return self._verifier - - @verfifier.setter - def verifier(self, value: BaseSourceVerifier) -> None: - assert ( - value not in self.verifiers - ), f"Unregistered verifier set: {value.verifier_name}" - self._verifier = value - - def add_verifier(self, verifier: BaseSourceVerifier) -> None: - """Adds a verifier.""" - from esbmc_ai.config import Config - - self.verifiers[verifier.name] = verifier - verifier.config = Config() - - def set_verifier_by_name(self, value: str) -> None: - self.verifier = self.verifiers[value] - self._logger.info(f"Main Verifier: {value}") - - def get_verifier(self, value: str) -> BaseSourceVerifier | None: - return self.verifiers.get(value) - - @property - def commands(self) -> dict[str, ChatCommand]: - """Returns all commands.""" - return self._builtin_commands | self._addon_commands - - @property - def command_names(self) -> list[str]: - """Returns a list of built-in commands. This is a reference to the - internal list.""" - return list(self.commands.keys()) - - @property - def builtin_commands(self) -> dict[str, ChatCommand]: - return self._builtin_commands - - @property - def addon_commands(self) -> dict[str, ChatCommand]: - return self._addon_commands - - def add_command(self, command: ChatCommand, builtin: bool = False) -> None: - if builtin: - self._builtin_commands[command.name] = command - else: - self._addon_commands[command.name] = command - - def set_builtin_commands(self, builtin_commands: list[ChatCommand]) -> None: - """Sets the builtin commands.""" - self._builtin_commands = {cmd.command_name: cmd for cmd in builtin_commands} - - @staticmethod - def parse_command(user_prompt_string: str) -> tuple[str, list[str]]: - """Parses a command and returns it based on the command rules outlined in - the wiki: https://github.com/Yiannis128/esbmc-ai/wiki/User-Chat-Mode""" - regex_pattern: str = ( - r'\s+(?=(?:[^\\"]*(?:\\.[^\\"]*)*)$)|(?:(? None: + self._logger: structlog.stdlib.BoundLogger = structlog.get_logger( + self.__class__.__name__ + ).bind(category=LogCategories.SYSTEM) + + # Internal storage + self._builtin_commands: dict[str, ChatCommand] = {} + self._addon_commands: dict[str, ChatCommand] = {} + self._builtin_verifiers: dict[str, BaseSourceVerifier] = {} + self._addon_verifiers: dict[str, BaseSourceVerifier] = {} + self._verifier: BaseSourceVerifier | None = None + + # Cached combined dictionaries for efficiency + self._all_commands_cache: dict[str, ChatCommand] | None = None + self._all_verifiers_cache: dict[str, BaseSourceVerifier] | None = None + self._all_components_cache: dict[str, BaseComponent] | None = None + self._builtin_components_cache: dict[str, BaseComponent] | None = None + self._addon_components_cache: dict[str, BaseComponent] | None = None + + def _invalidate_caches(self) -> None: + """Invalidate all cached combined dictionaries.""" + self._all_commands_cache = None + self._all_verifiers_cache = None + self._all_components_cache = None + self._builtin_components_cache = None + self._addon_components_cache = None + + # ==================== + # Commands API + # ==================== + + @property + def builtin_commands(self) -> MappingProxyType[str, ChatCommand]: + """Returns a read-only view of all builtin commands.""" + return MappingProxyType(self._builtin_commands) + + @property + def addon_commands(self) -> MappingProxyType[str, ChatCommand]: + """Returns a read-only view of all addon commands.""" + return MappingProxyType(self._addon_commands) + + @property + def commands(self) -> MappingProxyType[str, ChatCommand]: + """Returns a read-only view of all commands (builtin + addon).""" + if self._all_commands_cache is None: + self._all_commands_cache = self._builtin_commands | self._addon_commands + return MappingProxyType(self._all_commands_cache) + + @property + def command_names(self) -> list[str]: + """Returns a list of all command names.""" + return list(self.commands.keys()) + + def add_command(self, command: ChatCommand, builtin: bool = False) -> None: + """Add a command to the manager.""" + if builtin: + self._builtin_commands[command.name] = command + else: + self._addon_commands[command.name] = command + self._invalidate_caches() + + def remove_command(self, name: str) -> bool: + """Remove a command by name. Returns True if removed, False if not found.""" + removed = False + if name in self._builtin_commands: + del self._builtin_commands[name] + removed = True + if name in self._addon_commands: + del self._addon_commands[name] + removed = True + if removed: + self._invalidate_caches() + return removed + + def get_command(self, name: str) -> ChatCommand | None: + """Get a command by name.""" + return self.commands.get(name) + + def set_builtin_commands(self, builtin_commands: list[ChatCommand]) -> None: + """Sets the builtin commands, replacing any existing builtin commands.""" + self._builtin_commands = {cmd.command_name: cmd for cmd in builtin_commands} + self._invalidate_caches() + + # ==================== + # Verifiers API + # ==================== + + @property + def builtin_verifiers(self) -> MappingProxyType[str, BaseSourceVerifier]: + """Returns a read-only view of all builtin verifiers.""" + return MappingProxyType(self._builtin_verifiers) + + @property + def addon_verifiers(self) -> MappingProxyType[str, BaseSourceVerifier]: + """Returns a read-only view of all addon verifiers.""" + return MappingProxyType(self._addon_verifiers) + + @property + def verifiers(self) -> MappingProxyType[str, BaseSourceVerifier]: + """Returns a read-only view of all verifiers (builtin + addon).""" + if self._all_verifiers_cache is None: + self._all_verifiers_cache = self._builtin_verifiers | self._addon_verifiers + return MappingProxyType(self._all_verifiers_cache) + + @property + def verifier(self) -> BaseSourceVerifier: + """Returns the currently selected verifier.""" + assert self._verifier, "Verifier is not set..." + return self._verifier + + @verifier.setter + def verifier(self, value: BaseSourceVerifier) -> None: + assert value.name in self.verifiers, f"Unregistered verifier set: {value.name}" + self._verifier = value + + def add_verifier(self, verifier: BaseSourceVerifier, builtin: bool = True) -> None: + """Adds a verifier.""" + from esbmc_ai.config import Config + + if builtin: + self._builtin_verifiers[verifier.name] = verifier + else: + self._addon_verifiers[verifier.name] = verifier + verifier.global_config = Config() + self._invalidate_caches() + + def remove_verifier(self, name: str) -> bool: + """Remove a verifier by name. Returns True if removed, False if not found.""" + removed = False + if name in self._builtin_verifiers: + del self._builtin_verifiers[name] + removed = True + if name in self._addon_verifiers: + del self._addon_verifiers[name] + removed = True + if removed: + self._invalidate_caches() + return removed + + def set_verifier_by_name(self, value: str) -> None: + """Set the active verifier by name.""" + self.verifier = self.verifiers[value] + self._logger.info(f"Main Verifier: {value}") + + def get_verifier(self, value: str) -> BaseSourceVerifier | None: + """Get a verifier by name.""" + return self.verifiers.get(value) + + # ==================== + # Components API + # ==================== + + @property + def builtin_components(self) -> MappingProxyType[str, BaseComponent]: + """Returns a read-only view of all builtin components (commands + verifiers).""" + if self._builtin_components_cache is None: + self._builtin_components_cache = cast( + dict[str, BaseComponent], self._builtin_commands + ) | cast(dict[str, BaseComponent], self._builtin_verifiers) + return MappingProxyType(self._builtin_components_cache) + + @property + def addon_components(self) -> MappingProxyType[str, BaseComponent]: + """Returns a read-only view of all addon components (commands + verifiers).""" + if self._addon_components_cache is None: + self._addon_components_cache = cast( + dict[str, BaseComponent], self._addon_commands + ) | cast(dict[str, BaseComponent], self._addon_verifiers) + + assert self._addon_components_cache is not None + return MappingProxyType(self._addon_components_cache) + + @property + def components(self) -> MappingProxyType[str, BaseComponent]: + """Returns a read-only view of all components (builtin + addon commands + verifiers).""" + if self._all_components_cache is None: + self._all_components_cache = ( + cast(dict[str, BaseComponent], self._builtin_commands) + | cast(dict[str, BaseComponent], self._addon_commands) + | cast(dict[str, BaseComponent], self._builtin_verifiers) + | cast(dict[str, BaseComponent], self._addon_verifiers) + ) + return MappingProxyType(self._all_components_cache) + + def get_component(self, name: str) -> BaseComponent | None: + """Get a component by name (command or verifier).""" + return self.components.get(name) + + def load_component_config(self, component: BaseComponent, builtin: bool) -> None: + """Load component-specific configuration. + + Component configs are loaded automatically via BaseComponentConfig.settings_customise_sources(), + which loads from TOML, env vars, and .env files. + """ + try: + # Check if component has a config instance set + if component.config is None: + raise NotImplementedError() + + # Get the config class from the existing instance + config_class: type[BaseComponentConfig] = type(component.config) + + # Instantiate the config - settings_customise_sources will handle + # actual loading. Pass component_name via _component_name parameter + # and builtin via _builtin so BaseComponentConfig can use it for + # TOML table header. These are not defined fields but captured via + # extra="ignore" and used in settings_customise_sources. + loaded_config = config_class( # type: ignore[call-arg] + _component_name=component.name, _builtin=builtin + ) + + # Replace the component's config with the loaded one + component.config = loaded_config + self._logger.debug(f"Loaded component config for {component.name}") + + except NotImplementedError: + self._logger.debug(f"No config for component: {component.name}") + + except Exception as e: + self._logger.error( + f"Failed to load config for component {component.name}: {e}" + ) diff --git a/esbmc_ai/config.py b/esbmc_ai/config.py index a3513b5..7d81c23 100644 --- a/esbmc_ai/config.py +++ b/esbmc_ai/config.py @@ -1,669 +1,541 @@ # Author: Yiannis Charalambous 2023 -import argparse +from collections import defaultdict +from importlib.machinery import ModuleSpec +from importlib.util import find_spec import logging import os -import sys -from platform import system as system_name from pathlib import Path -from typing import ( - Any, - override, -) -import argparse -from dotenv import load_dotenv, find_dotenv -import structlog +from pydantic_settings import ( + BaseSettings, + CliPositionalArg, + NoDecode, + PydanticBaseSettingsSource, + SettingsConfigDict, + TomlConfigSettingsSource, +) +from typing import Annotated +from pydantic import ( + AliasChoices, + BaseModel, + BeforeValidator, + DirectoryPath, + Field, + FilePath, + field_validator, +) -from esbmc_ai.chats.base_chat_interface import BaseChatInterface from esbmc_ai.singleton import SingletonMeta, makecls -from esbmc_ai.config_field import ConfigField -from esbmc_ai.base_config import BaseConfig -from esbmc_ai.log_utils import ( +from esbmc_ai.log_handlers import ( CategoryFileHandler, - LogCategories, NameFileHandler, - get_log_level, - init_logging, - set_horizontal_lines, - set_horizontal_line_width, -) -from esbmc_ai.ai_models import ( - AIModel, - AIModelAnthropic, - AIModelOpenAI, - AIModels, - OllamaAIModel, ) -class Config(BaseConfig, metaclass=makecls(SingletonMeta)): - """Config loader for ESBMC-AI""" +def _alias_choice(value: str) -> AliasChoices: + """Adds aliases to each option that requires a different alias that works + with all the config setting sources we are using.""" + return AliasChoices( + value, # exact field name for TOML and direct matching + value.replace("_", "-"), # dashed alias for CLI or other uses + f"ESBMCAI_{value.replace('-', '_').upper()}", # prefixed env var alias + ) - def __init__(self) -> None: - - super().__init__() - - self._args: argparse.Namespace - self._arg_mappings: dict[str, list[str]] = {} - self._compound_load_args: list[str] = [] - self._logger: structlog.stdlib.BoundLogger - - # Huggingface warning supress - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # Even though this gets initialized when we call self.load_config_fields - # it should be fine because load_config_field wont double add. - self._config_fields = [ - ConfigField( - name="dev_mode", - default_value=False, - help_message="Adds to the python system path the current " - "directory so addons can be developed.", - ), - ConfigField( - name="json", - default_value=False, - help_message="Print the result of the chat command as a JSON output", - ), - ConfigField( - name="show_horizontal_lines", - default_value=True, - on_load=set_horizontal_lines, - help_message="True to print horizontal lines to segment the output. " - "Makes it easier to read.", - ), - ConfigField( - name="horizontal_line_width", - default_value=None, - default_value_none=True, - on_load=set_horizontal_line_width, - help_message="Sets the width of the horizontal lines to draw. " - "Don't set a value to use the terminal width. Needs to have " - "show_horizontal_lines set to true.", - ), - ConfigField( - name="ai_custom", - default_value=[], - on_read=self.load_custom_ai, - error_message="Invalid custom AI specification", - ), - ConfigField( - name="llm_requests.model_refresh_seconds", - # Default is to refresh once a day - default_value=86400, - validate=lambda v: isinstance(v, int), - help_message="How often to refresh the models list provided by OpenAI. " - "Make sure not to spam them as they can IP block. Default is once a day.", - error_message="Invalid value, needs to be an int in seconds", - ), - ConfigField( - name="llm_requests.cooldown_seconds", - default_value=0.0, - validate=lambda v: 0 <= v, - help_message="Cooldown applied in seconds between LLM requests.", - ), - ConfigField( - name="ai_model", - default_value=None, - validate=lambda v: isinstance(v, str), - help_message="Which AI model to use.", - ), - ConfigField( - name="temp_auto_clean", - default_value=True, - validate=lambda v: isinstance(v, bool), - help_message="Should the temporary files created be cleared automatically?", - ), - ConfigField( - name="temp_file_dir", - default_value=None, - validate=lambda v: isinstance(v, str) and Path(v).is_file(), - on_load=Path, - default_value_none=True, - help_message="Sets the directory to store temporary ESBMC-AI files. " - "Don't supply a value to use the system default.", - ), - ConfigField( - name="allow_successful", - default_value=False, - validate=lambda v: isinstance(v, bool), - help_message="Run the ESBMC-AI command even if the verifier has not " - "found any problems.", - ), - ConfigField( - name="loading_hints", - default_value=False, - validate=lambda v: isinstance(v, bool), - help_message="Show loading hints when running. Turn off if output " - "is going to be logged to a file.", - ), - ConfigField( - name="generate_patches", - default_value=False, - help_message="Should the repaired result be returned as a patch " - "instead of a new file. Generate patch files and place them in " - "the same folder as the source files.", - ), - ConfigField( - name="log.output", - default_value=None, - default_value_none=True, - validate=lambda v: Path(v).exists() or Path(v).parent.exists(), - on_load=lambda v: Path(v), - help_message="Save the output logs to a location. Do not add " - ".log suffix, it will be added automatically.", - ), - ConfigField( - name="log.append", - default_value=False, - help_message="Will append to the logs rather than replace them.", - ), - ConfigField( - name="log.by_cat", - default_value=False, - help_message="Will split the logs by category and write them to" - " different files. They will have the same base log.output path" - " but will have an extension to differentiate them.", - ), - ConfigField( - name="log.by_name", - default_value=False, - help_message="Will split the logs by name and write them to" - " different files. They will have the same base log.output path" - " but will have an extension to differentiate them.", - ), - ConfigField( - name="log.basic", - default_value=False, - help_message="Enable basic logging mode, will contain no " - "formatting and also will render --log-by-name (log.by_name) " - "and --log-by-cat (log.by_cat) useless. Used for debugging " - "noisy libs.", - ), - # This is the parameters that the user passes as args which are the - # file names of the source code to target. It can also be a directory. - ConfigField( - name="solution.filenames", - default_value=[], - validate=lambda v: isinstance(v, list) - and ( - len(v) == 0 - or all(isinstance(f, str) and Path(f).exists() for f in v) - ), - on_load=self._filenames_load, - get_error_message=self._filenames_error_msg, - ), - ConfigField( - name="solution.include_dirs", - default_value=[], - validate=lambda v: isinstance(v, list) - and all( - isinstance(f, str) and Path(f).exists() and Path(f).is_dir() - for f in v - ), - help_message="Include directories for C files.", - on_load=lambda v: [Path(path) for path in v], - ), - # If argument is passed, then the config value is ignored. - ConfigField( - name="solution.entry_function", - default_value="main", - error_message="The entry function name needs to be a string", - help_message="The name of the entry function to repair, defaults to main.", - ), - ConfigField( - name="solution.output_dir", - default_value=None, - default_value_none=True, - validate=lambda v: Path(v).exists() and Path(v).is_dir(), - on_load=lambda v: Path(v).expanduser(), - error_message="Dir does not exist", - help_message="Set the output directory to save successfully repaired " - "files in. Leave empty to not use. Specifying the same directory will " - "overwrite the original file.", - ), - # The value is checked in AddonLoader. - ConfigField( - name="verifier.name", - default_value="esbmc", - validate=lambda v: isinstance(v, str), - error_message="Invalid verifier name specified.", - help_message="The verifier to use. Default is ESBMC.", - ), - ConfigField( - name="verifier.enable_cache", - default_value=True, - help_message="Cache the results of verification in order to save time. " - "This is not supported by all verifiers.", - ), - ConfigField( - name="verifier.esbmc.path", - default_value=None, - validate=lambda v: isinstance(v, str) - and Path(v).expanduser().is_file(), - on_load=lambda v: Path(os.path.expanduser(os.path.expandvars(v))), - help_message="Path to the ESBMC binary.", - ), - ConfigField( - name="verifier.esbmc.params", - default_value=[ - "--interval-analysis", - "--goto-unwind", - "--unlimited-goto-unwind", - "--k-induction", - "--state-hashing", - "--add-symex-value-sets", - "--k-step", - "2", - "--floatbv", - "--unlimited-k-steps", - "--compact-trace", - "--context-bound", - "2", - ], - validate=lambda v: isinstance(v, list | str), - on_load=lambda v: [ - str(arg) for arg in (v.split(" ") if isinstance(v, str) else v) - ], - help_message="Parameters for ESBMC. Can accept as a list or a string.", - ), - ConfigField( - name="verifier.esbmc.timeout", - default_value=None, - default_value_none=True, - validate=lambda v: v is None or isinstance(v, int), - help_message="The timeout set for ESBMC.", - ), - ConfigField( - name="llm_requests.max_tries", - default_value=5, - validate=lambda v: isinstance(v, int), - help_message="How many times to query the AI service before giving up.", - ), - ConfigField( - name="llm_requests.timeout", - default_value=60, - validate=lambda v: isinstance(v, int), - help_message="The timeout for querying the AI service.", - ), - ] - def load( - self, - args: Any, - arg_mapping_overrides: dict[str, list[str]], - compound_load_args: list[str], - ) -> None: - """Begins the loading procedure to load the configuration. - - 1. Environment variables - 2. Arguments (the arguments are mapped 1:1 to the config file, if there - is a difference in mapping it is specified in the arg_mapping_overrides). - Some config fields should be loaded when the config file has been loaded - too, so defer_loading_args should have a list of argument names to skip - and defer the loading behaviour to the config file loading. This process - is not automated, so the respective ConfigField should specify how to - load from the args. - - Args: - - args: The values from the argument parser. They will take precedence - over the config file. - - arg_mapping_overrides: A dictionary that maps configuration field - IDs to custom argument names. Use this to override the default - mapping, allowing arguments to have names different from their - corresponding configuration fields. - - compound_load_fields: A list of argument names that, - when supplied, should be loaded from both command-line arguments - and a config file. - """ - - self._args = args - # Create an argument mapping of all the config files, the arg_mapping is - # applied over that to translate the mappings from the arguments to the - # mappings of the config fields. - self._arg_mappings = { - f.name: [f.name] for f in self.get_config_fields() - } | arg_mapping_overrides - self._compound_load_args = compound_load_args - - # Init logging - init_logging(level=get_log_level(args.verbose)) - self._logger = structlog.get_logger().bind(category=LogCategories.CONFIG) - - # Load config fields from environment - self._load_envs( - ConfigField.from_env( - name="ESBMCAI_CONFIG_FILE", - default_value=None, - on_load=lambda v: Path(os.path.expanduser(os.path.expandvars(str(v)))), - ), - ConfigField.from_env( - name="OPENAI_API_KEY", - default_value=None, - default_value_none=True, - ), - ConfigField.from_env( - name="ANTHROPIC_API_KEY", - default_value=None, - default_value_none=True, - ), - ) +class AIModelConfig(BaseModel): + id: str = Field( + default="openai:gpt-5-nano", + validation_alias=_alias_choice("ai_model"), + description="Which AI model to use. Prefix with openai, anthropic, or " + "ollama then separate with : and enter the model name to use.", + ) + + base_url: str | None = Field( + default=None, + exclude=True, + description="Gets initialized by the config if this model is an Ollama model.", + ) + + @property + def provider(self) -> str: + """The provider part of a model string.""" + return self.id.split(":", maxsplit=1)[0] - fields: list[ConfigField] = self.get_config_fields() - self._load_args(args, fields) - # Base init needs to be called last - self.load_config_fields(self.get_value("ESBMCAI_CONFIG_FILE"), fields) + @property + def name(self) -> str: + """The name part of a model string.""" + return self.id.split(":", maxsplit=1)[1] + + +class AICustomModelConfig(BaseModel): + server_type: str + url: str + max_tokens: int + + +class LogConfig(BaseModel): + output: FilePath | None = Field( + default=None, + description="Save the output logs to a location. Do not add " + ".log suffix, it will be added automatically.", + ) + + append: bool = Field( + default=False, + description="Will append to the logs rather than replace them.", + ) + + by_cat: bool = Field( + default=False, + description="Will split the logs by category and write them to" + " different files. They will have the same base log.output path" + " but will have an extension to differentiate them.", + ) + + by_name: bool = Field( + default=False, + description="Will split the logs by name and write them to" + " different files. They will have the same base log.output path" + " but will have an extension to differentiate them.", + ) + + basic: bool = Field( + default=False, + description="Enable basic logging mode, will contain no " + "formatting and also will render --log-by-name (log.by_name) " + "and --log-by-cat (log.by_cat) useless. Used for debugging " + "noisy libs.", + ) - # =============== Post Init - Set to good values to fields ============ - # Add logging handlers with config options + @property + def logging_handlers(self) -> list[logging.Handler]: logging_handlers: list[logging.Handler] = [] - if self.get_value("log.output"): - log_path: Path = self.get_value("log.output") + if self.output: # Log categories - if self.get_value("log.by_cat"): + if self.by_cat: logging_handlers.append( CategoryFileHandler( - log_path, - append=self.get_value("log.append"), + self.output, + append=self.append, skip_uncategorized=True, ) ) # Log by name - if self.get_value("log.by_name"): + if self.by_name: logging_handlers.append( NameFileHandler( - log_path, - append=self.get_value("log.append"), + self.output, + append=self.append, ) ) # Normal logging file_log_handler: logging.Handler = logging.FileHandler( - str(log_path) + ".log", - mode="a" if self.get_value("log.append") else "w", + str(self.output) + ".log", + mode="a" if self.append else "w", ) logging_handlers.append(file_log_handler) + return logging_handlers - # Reinit logging - init_logging( - level=get_log_level(args.verbose), - file_handlers=logging_handlers, - init_basic=self.get_value("log.basic"), - ) - self.set_custom_field( - ConfigField( - name="api_keys", - default_value={}, - ), - value={ - AIModelOpenAI.get_canonical_name(): self.get_value("OPENAI_API_KEY"), - AIModelAnthropic.get_canonical_name(): self.get_value( - "ANTHROPIC_API_KEY" - ), - }, - ) - # Load AI models and set ai_model - AIModels().load_default_models( - self.get_value("api_keys"), - self.get_value("llm_requests.model_refresh_seconds"), - ) - self.set_value("ai_model", AIModels().get_ai_model(self.get_value("ai_model"))) - # BaseChatInterface cooldown - BaseChatInterface.cooldown_total = self.get_value( - "llm_requests.cooldown_seconds" - ) - self._logger.debug(f"LLM Cooldown Total: {BaseChatInterface.cooldown_total}") - - def _load_envs(self, *fields: ConfigField) -> None: - """Loads the environment variables. - Environment variables are loaded in the following order: - - 1. Environment variables already loaded. Any variable not present will be looked for in - .env files in the following locations. - 2. .env file in the current directory, moving upwards in the directory tree. - 3. esbmc-ai.env file in the current directory, moving upwards in the directory tree. - 4. esbmc-ai.env file in $HOME/.config/ for Linux/macOS and %userprofile% for Windows. - - Note: ESBMCAI_CONFIG_FILE undergoes tilde user expansion and also environment - variable expansion. - """ - - for field in fields: - assert ( - field.on_read is None - ), f"ConfigField on_read for envs is not supported: {field.name}" - - keys: dict[str, ConfigField] = {field.name: field for field in fields} - - def get_env_vars() -> None: - """Gets all the system environment variables that are currently in the env - and loads them. Will only load keys that have not already been loaded.""" - for field_name, field in keys.items(): - value: str | None = os.getenv(field_name) - - # Assign field from config file - self._values[field_name] = field.on_load(value) - - # Load from system env - get_env_vars() - - # Find .env in current working directory and load it. - dotenv_file_path: str = find_dotenv(usecwd=True) - if dotenv_file_path != "": - load_dotenv(dotenv_path=dotenv_file_path, override=False, verbose=True) - else: - # Find esbmc-ai.env in current working directory and load it. - dotenv_file_path: str = find_dotenv(filename="esbmc-ai.env", usecwd=True) - if dotenv_file_path != "": - load_dotenv(dotenv_path=dotenv_file_path, override=False, verbose=True) - - get_env_vars() - - # Look for .env in home folder. - home_path: Path = Path.home() - match system_name(): - case "Linux" | "Darwin": - home_path /= ".config/esbmc-ai.env" - case "Windows": - home_path /= "esbmc-ai.env" - case _: - raise ValueError(f"Unknown OS type: {system_name()}") - - load_dotenv(home_path, override=False, verbose=True) - - get_env_vars() - - # Check all field values are set, if they aren't then error. - for field_name, field in keys.items(): - # If field name is not loaded in the end... - if field_name not in self._values: - if field.default_value is None and not field.default_value_none: - print(f"Error: No ${field_name} in environment.") - sys.exit(1) - self._values[field_name] = field.default_value - - # Validate field - value: Any = self._values[field_name] - if not field.validate(value): - msg = f"Field: {field.name} is invalid: {value}" - if field.get_error_message is not None: - msg += ": " + field.get_error_message(value) - elif field.error_message: - msg += ": " + field.error_message - raise ValueError(f"Env loading error: {msg}") - - self._fields.extend(fields) +class SolutionConfig(BaseModel): + filenames: CliPositionalArg[list[Path]] = Field( + default_factory=list, + description="The filename(s) to pass to the verifier.", + ) - @property - def _arg_reverse_mappings(self) -> dict[str, str]: - """Returns the reverse mapping of arg mappings. args --> field - - This also replaces all the - with _ in the names since this is done by - argparse, for example: --ai-model will be accessible through ai_model.""" - reverse_mappings: dict[str, str] = {} - for field_name, mappings in self._arg_mappings.items(): - for mapping in mappings: - reverse_mappings[mapping.replace("-", "_")] = field_name - - return reverse_mappings - - def _load_args(self, args: argparse.Namespace, fields: list[ConfigField]) -> None: - """Will load the fields set in the program arguments.""" - - # Track the names of the fields set, fields that are already set are - # skipped: --ai-models and -m are treated as one this way. - fields_set: set[str] = set() - reverse_mappings: dict[str, str] = self._arg_reverse_mappings - fields_mapped: dict[str, ConfigField] = {f.name: f for f in fields} - for mapped_name, value in vars(args).items(): - # Check if a field is set in args - if mapped_name in reverse_mappings: - # Get the field name - field_name: str = reverse_mappings[mapped_name] - # Skip if added - if field_name in fields_set: - continue - - self._logger.debug(f"Loading from arg: {field_name}") - - fields_set.add(field_name) - # Load with value. - self.set_custom_field(fields_mapped[field_name], value) - - @override - def load_config_fields(self, cfg_path: Path, fields: list[ConfigField]) -> None: - """Override to only load fields that have not been loaded by the args.""" - - # Track the names of the fields set, fields that are already set are - # skipped: --ai-models and -m are treated as one this way. - fields_set: set[str] = set() - reverse_mappings: dict[str, str] = self._arg_reverse_mappings - - for mapped_name in vars(self._args).keys(): - # Check if a field is set in args (exclude if in compound load list) - if ( - mapped_name in reverse_mappings - and mapped_name not in self._compound_load_args - ): - # Get the field name - fields_set.add(reverse_mappings[mapped_name]) - - load_fields: list[ConfigField] = [f for f in fields if f.name not in fields_set] - return super().load_config_fields(cfg_path, load_fields) - - def _filenames_load(self, file_names: list[str]) -> list[Path]: - """Loads the filenames from the command line first then from the config.""" - - results: list[Path] = [] - - if len(self._args.filenames): - results.extend(Path(f) for f in self._args.filenames) - - for file in file_names: - results.append(Path(file)) - - return results - - def _validate_custom_ai(self, ai_config_list: dict) -> bool: - for name, ai_config in ai_config_list.items(): - # Check the field is a dict not a list - if not isinstance(ai_config, dict): - raise ValueError( - f"The value of each entry in ai_custom needs to be a dict: {ai_config}" - ) + @field_validator("filenames", mode="after") + @classmethod + def on_set_filenames(cls, value: list[str]) -> list[Path]: + """Validates that filenames are either all file paths or a single directory path.""" + if not value: + return [] - # Max tokens - if "max_tokens" not in ai_config: - raise KeyError( - f'max_tokens field not found in "ai_custom" entry "{name}".' - ) - elif not isinstance(ai_config["max_tokens"], int): - raise TypeError( - f'custom_ai_max_tokens in ai_custom entry "{name}" needs to ' - "be an int and greater than 0." - ) - elif ai_config["max_tokens"] <= 0: + paths: list[Path] = [] + files: list[str] = [] + dirs: list[str] = [] + + for filename in value: + path = Path(filename).expanduser() + + # Check if path exists + if not path.exists(): + raise ValueError(f"File or directory not found: '{filename}'") + + # Categorize as file or directory + if path.is_file(): + files.append(filename) + paths.append(path) + elif path.is_dir(): + dirs.append(filename) + paths.append(path) + else: raise ValueError( - f'custom_ai_max_tokens in ai_custom entry "{name}" needs to ' - "be an int and greater than 0." - ) - - # URL - if "url" not in ai_config: - raise KeyError(f'url field not found in "ai_custom" entry "{name}".') - - # Server type - if "server_type" not in ai_config: - raise KeyError( - f"server_type for custom AI '{name}' is invalid, it needs to be a valid string" + f"Path exists but is neither a file nor directory: '{filename}'" ) - return True - - def load_custom_ai(self, config_file: dict) -> list[AIModel]: - """Loads custom AI defined in the config and ascociates it with the AIModels - module.""" + # Validate: either all files OR single directory + if dirs and files: + raise ValueError( + f"Cannot mix files and directories. Either provide file paths or a single directory.\n" + f" Files provided: {', '.join(files)}\n" + f" Directories provided: {', '.join(dirs)}" + ) - if "ai_custom" not in config_file: - return [] + if len(dirs) > 1: + raise ValueError( + f"Only one directory can be specified. Got {len(dirs)} directories: {', '.join(dirs)}" + ) - ai_config_list: dict = config_file["ai_custom"] - self._validate_custom_ai(ai_config_list) + return paths + + include_dirs: list[DirectoryPath] = Field( + default_factory=list, + description="Include directories for C files.", + ) + + entry_function: str = Field( + default="main", + description="The name of the entry function to repair, defaults to main.", + ) + + output_dir: DirectoryPath | None = Field( + default=None, + description="Set the output directory to save successfully repaired " + "files in. Leave empty to not use. Specifying the same directory will " + "overwrite the original file.", + ) + + @field_validator("output_dir", mode="before") + @classmethod + def on_set_output_dir(cls, value: DirectoryPath | None) -> DirectoryPath | None: + if value is None: + return None + return Path(value).expanduser() + + +class ESBMCConfig(BaseModel): + path: FilePath | None = Field( + default=None, + description="Path to the ESBMC binary.", + ) + + @field_validator("path", mode="before") + @classmethod + def on_set_path(cls, value: FilePath | None) -> Path | None: + if value is None: + return None + return Path(value).expanduser() + + params: list[str] = Field( + default=[ + "--interval-analysis", + "--goto-unwind", + "--unlimited-goto-unwind", + "--k-induction", + "--state-hashing", + "--add-symex-value-sets", + "--k-step", + "2", + "--floatbv", + "--unlimited-k-steps", + "--compact-trace", + "--context-bound", + "2", + ], + description="Parameters for ESBMC. Can accept as a list or a string.", + ) + + timeout: int | None = Field( + default=None, + description="The timeout set for ESBMC.", + ) + + +class VerifierConfig(BaseModel): + # The value is checked in AddonLoader. + name: str = Field( + default="esbmc", + description="The verifier to use. Default is ESBMC.", + ) + + enable_cache: bool = Field( + default=True, + description="Cache the results of verification in order to save time. " + "This is not supported by all verifiers.", + ) + + esbmc: ESBMCConfig = Field( + default_factory=ESBMCConfig, + description="ESBMC-specific configuration.", + ) + + +def _parse_ai_model(value: str | dict | AIModelConfig) -> AIModelConfig: + """Validator function to convert string/dict to AIModelConfig.""" + # If it's already an AIModelConfig, return as-is + if isinstance(value, AIModelConfig): + return value + # If it's a string, create AIModelConfig with validation_alias key + if isinstance(value, str): + return AIModelConfig(**{"ai_model": value}) + # If it's a dict, create AIModelConfig from it + return AIModelConfig(**value) + + +class Config(BaseSettings, metaclass=makecls(SingletonMeta)): + """Config loader for ESBMC-AI""" - custom_ai: list[AIModel] = [] + config_file: FilePath | None = Field( + default=None, + exclude=True, + description="Path to configuration file (TOML format). Can be set via " + "ESBMCAI_CONFIG_FILE environment variable.", + ) + + command_name: CliPositionalArg[str] = Field( + default="help", + exclude=True, + description="The (sub-)command to run.", + ) + + addon_modules: list[str] = Field( + default_factory=list, + description="The addon modules to load during startup. Additional " + "modules may be loaded by the specified modules as dependencies.", + ) + + @field_validator("addon_modules", mode="after") + @classmethod + def on_set_addon_modules(cls, mods: list[str]) -> list[str]: + """Validates that a module exists.""" + for m in mods: + if not isinstance(m, str): + raise ValueError("Needs to be a string") + spec: ModuleSpec | None = find_spec(m) + if spec is None: + raise ValueError("Could not find specification for module") + return mods + + verbose_level: int = Field( + default=0, + ge=0, + le=3, + exclude=True, # Exclude from Pydantic CLI parsing + description="Show up to 3 levels of verbose output. Level 1: extra information." + " Level 2: show failed generations, show ESBMC output. Level 3: " + "print hidden pushes to the message stack.", + ) + + dev_mode: bool = Field( + default=False, + validation_alias=_alias_choice("dev_mode"), + description="Adds to the python system path the current " + "directory so addons can be developed.", + ) + + use_json: bool = Field( + default=False, + alias="json", + description="Print the result of the chat command as a JSON output", + ) + + show_horizontal_lines: bool = Field( + default=True, + validation_alias=_alias_choice("show_horizontal_lines"), + description="True to print horizontal lines to segment the output. " + "Makes it easier to read.", + ) + + horizontal_line_width: int | None = Field( + default=None, + validation_alias="horizontal_line_width", + description="Sets the width of the horizontal lines to draw. " + "Don't set a value to use the terminal width. Needs to have " + "show_horizontal_lines set to true.", + ) + + ai_custom: dict[str, AICustomModelConfig] = Field( + default_factory=defaultdict, + description=( + "Dictionary of Ollama AI models configurations. " + "Each key is a model name (arbitrary string), and each value is a " + "structure defining the model with the following fields:\n" + " - server_type: string, must be 'ollama'\n" + " - url: string specifying the Ollama service URL\n" + " - max_tokens: integer specifying the token limit" + ), + ) + + @field_validator("ai_custom", mode="after") + @classmethod + def _validate_custom_ai( + cls, ai_config_list: dict[str, AICustomModelConfig] + ) -> dict[str, AICustomModelConfig]: for name, ai_config in ai_config_list.items(): - # Load the max tokens - max_tokens: int = ai_config["max_tokens"] - - # Load the URL - url: str = ai_config["url"] - - # Get provider type - server_type = ai_config["server_type"] - - # Create correct type of LLM - llm: AIModel - match server_type: - case "ollama": - llm = OllamaAIModel( - name=name, - tokens=max_tokens, - url=url, + # If already an AICustomModelConfig instance, validate its fields + if isinstance(ai_config, AICustomModelConfig): + if ai_config.max_tokens <= 0: + raise ValueError( + f'max_tokens in ai_custom entry "{name}" needs to be greater than 0.' ) - case _: - raise NotImplementedError( - f"The custom AI server type is not implemented: {server_type}" + else: + # If it's a dict, validate the dict structure + if not isinstance(ai_config, dict): + raise ValueError( + f"The value of each entry in ai_custom needs to be a dict: {ai_config}" ) - # Add the custom AI. - custom_ai.append(llm) - AIModels().add_ai_model(llm) - - return custom_ai + # Max tokens + if "max_tokens" not in ai_config: + raise KeyError( + f'max_tokens field not found in "ai_custom" entry "{name}".' + ) + elif not isinstance(ai_config["max_tokens"], int): + raise TypeError( + f'max_tokens in ai_custom entry "{name}" needs to be an int and greater than 0.' + ) + elif ai_config["max_tokens"] <= 0: + raise ValueError( + f'max_tokens in ai_custom entry "{name}" needs to be greater than 0.' + ) - @staticmethod - def _filenames_error_msg(file_names: list) -> str: - """Gets the error message for an invalid list of file_names specified in - the config.""" + # URL + if "url" not in ai_config: + raise KeyError( + f'url field not found in "ai_custom" entry "{name}".' + ) - wrong: list[str] = [] - for file_name in file_names: - if not isinstance(file_name, str) or not ( - Path(file_name).is_file() or Path(file_name).is_dir() - ): - wrong.append(file_name) + # Server type + if "server_type" not in ai_config: + raise KeyError( + f"server_type for custom AI '{name}' is invalid, it needs to be a valid string" + ) - # Don't return the error because if there's too many files, the message - # get's truncated. - print("The following files or directories cannot be found:") - for f in wrong: - print("\t*", f) + return ai_config_list - return "Error while loading files..." + ai_model: Annotated[AIModelConfig, NoDecode, BeforeValidator(_parse_ai_model)] = ( + Field( + default_factory=AIModelConfig, + description="AI Model configuration group. Can be a string like 'openai:gpt-4' or a config object.", + ) + ) + + temp_auto_clean: bool = Field( + default=True, + validation_alias=_alias_choice("temp_auto_clean"), + description="Should the temporary files created be cleared automatically?", + ) + + temp_file_dir: DirectoryPath = Field( + default=Path("/tmp"), + validation_alias=_alias_choice("temp_file_dir"), + description="Sets the directory to store temporary ESBMC-AI files. " + "Don't supply a value to use the system default.", + ) + + loading_hints: bool = Field( + default=False, + validation_alias=_alias_choice("loading_hints"), + description="Show loading hints when running. Turn off if output " + "is going to be logged to a file.", + ) + + generate_patches: bool = Field( + default=False, + validation_alias=_alias_choice("generate_patches"), + description="Should the repaired result be returned as a patch " + "instead of a new file. Generate patch files and place them in " + "the same folder as the source files.", + ) + + log: LogConfig = Field( + default_factory=LogConfig, + description="Logging configuration group.", + ) + + solution: SolutionConfig = Field( + default_factory=SolutionConfig, + description="Solution config group.", + ) + + verifier: VerifierConfig = Field( + default_factory=VerifierConfig, + description="Verifier config group.", + ) + + llm_requests_max_retries: int = Field( + default=5, + ge=1, + validation_alias="llm_requests.max_retries", + description="How many times to query the AI service before giving up.", + ) + + llm_requests_timeout: int = Field( + default=60, + ge=1, + validation_alias="llm_requests.timeout", + description="The timeout for querying the AI service.", + ) + + model_config = SettingsConfigDict( + env_prefix="ESBMCAI_", + env_file=".env", + env_file_encoding="utf-8", + # Enables CLI parse support out-of-the-box for CliApp.run integration + cli_parse_args=True, + # Allow extra fields for compatibility with pydantic_settings, since + # .env contains fields for services like OpenAI and Anthropic that isn't + # mapped to the config directly + extra="ignore", + ) + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type["BaseSettings"], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + # Manually load .env file to get ESBMCAI_CONFIG_FILE before creating sources + from dotenv import load_dotenv + load_dotenv(".env", override=False) + + # Get config file path from environment variable + config_file_path = os.getenv("ESBMCAI_CONFIG_FILE") + + sources: list[PydanticBaseSettingsSource] = [ + init_settings, + env_settings, + dotenv_settings, + ] - def get_config_fields(self) -> list[ConfigField]: - """Returns the config fields. Excluding env fields.""" - return self._config_fields + # Add TOML config source if config file is specified + if config_file_path: + config_file = Path(config_file_path).expanduser() + if config_file.exists(): + sources.append(TomlConfigSettingsSource(settings_cls, config_file)) + + # Priority order: init/CLI > env > dotenv > TOML > file_secret + sources.append(file_secret_settings) + return tuple(sources) + + def __init__(self, **kwargs) -> None: + # Accept keyword arguments to support pydantic_settings CliApp.run() + # which passes settings_sources and other configuration parameters. + # The singleton metaclass intercepts the constructor call and needs + # to pass these arguments through to BaseSettings.__init__(). + super().__init__(**kwargs) + + # Populate ai_model.base_url from ai_custom if model exists there + if self.ai_model.id in self.ai_custom: + self.ai_model.base_url = self.ai_custom[self.ai_model.id].url + + @classmethod + def set_singleton(cls, config: "Config") -> None: + """Sets the singleton.""" + getattr(cls, "_instances")[cls] = config diff --git a/esbmc_ai/config_field.py b/esbmc_ai/config_field.py deleted file mode 100644 index 56e5439..0000000 --- a/esbmc_ai/config_field.py +++ /dev/null @@ -1,73 +0,0 @@ -# Author: Yiannis Charalambous - -"""This module can be used by other modules to declare config entries.""" - -from typing import ( - Any, - Callable, - NamedTuple, -) - - -class ConfigField(NamedTuple): - """Represents a loadable entry in the config.""" - - name: str - """The name of the config field and also namespace""" - default_value: Any - """If a default value is supplied, then it can be omitted from the config. - In order to have a "None" default value, default_value_none must be set.""" - default_value_none: bool = False - """If true, then the default value will be None, so during - validation, if no value is supplied, then None will be the - the default value, instead of failing due to None being the - default value which under normal circumstances means that the - field is not optional.""" - validate: Callable[[Any], bool] = lambda _: True - """Lambda function to validate if field has a valid value. - Default is identity function which is return true.""" - on_load: Callable[[Any], Any] = lambda v: v - """Transform the value once loaded, this allows the value to be saved - as a more complex type than that which is represented in the config - file. - - Is ignored if on_read is defined.""" - on_read: Callable[[dict[str, Any]], Any] | None = None - """If defined, will be called and allows to custom load complex types that - may not match 1-1 in the config. The config file passed as a parameter here - is the original, unflattened version. The value returned should be the value - assigned to this field. - - This is a more versatile version of on_load. So if this is used, the on_load - will be ignored.""" - help_message: str | None = None - error_message: str | None = None - """Optional string to provide a generic error message.""" - get_error_message: Callable[[Any], str] | None = None - """Optionsl function to get more verbose output than error_message.""" - - @staticmethod - def from_env( - name: str, - default_value: Any, - default_value_none: bool = False, - validate: Callable[[Any], bool] = lambda _: True, - on_load: Callable[[Any], Any] = lambda v: v, - help_message: str | None = None, - error_message: str | None = None, - get_error_message: ( - Callable[[Any], str] | None - ) = lambda v: f"Error: No ${v} in environment.", - ) -> "ConfigField": - """Defines an env var loaded from the environment. Will prefix name with - 'env.'""" - return ConfigField( - name=name, - default_value=default_value, - default_value_none=default_value_none, - validate=validate, - on_load=on_load, - help_message=help_message, - error_message=error_message, - get_error_message=get_error_message, - ) diff --git a/esbmc_ai/log_categories.py b/esbmc_ai/log_categories.py new file mode 100644 index 0000000..9993444 --- /dev/null +++ b/esbmc_ai/log_categories.py @@ -0,0 +1,13 @@ +# Author: Yiannis Charalambous + +from enum import Enum + + +class LogCategories(Enum): + NONE = "NONE" + ALL = "ALL" + SYSTEM = "ESBMC_AI" + VERIFIER = "VERIFIER" + COMMAND = "COMMAND" + CONFIG = "CONFIG" + CHAT = "CHAT" diff --git a/esbmc_ai/log_handlers.py b/esbmc_ai/log_handlers.py new file mode 100644 index 0000000..0b42f97 --- /dev/null +++ b/esbmc_ai/log_handlers.py @@ -0,0 +1,120 @@ +# Author: Yiannis Charalambous + +from enum import Enum +from pathlib import Path +from typing import override +import logging +import re + +from esbmc_ai.log_categories import LogCategories + +_ansi_escape = re.compile(r"\x1b\[[0-9;]*m") + + +def _strip_ansi_escape_processor(record: logging.LogRecord) -> bool | logging.LogRecord: + """ + Remove ANSI escape sequences from all string values in the LogRecord. + """ + for attr, value in list(record.__dict__.items()): + if isinstance(value, str): + setattr(record, attr, _ansi_escape.sub("", value)) + + args = record.args + if isinstance(args, tuple): + record.args = tuple( + _ansi_escape.sub("", v) if isinstance(v, str) else v for v in args + ) + elif isinstance(args, dict): + record.args = { + k: (_ansi_escape.sub("", v) if isinstance(v, str) else v) + for k, v in args.items() + } + + return record + + +class CategoryFileHandler(logging.Handler): + """Logger that will save by category.""" + + def __init__( + self, + base_path: Path, + append: bool = False, + skip_uncategorized: bool = False, + ) -> None: + super().__init__() + self.base_path = base_path + self.append = append + self.skip_uncategorized = skip_uncategorized + self.handlers: dict[str, logging.FileHandler] = {} + self.addFilter(_strip_ansi_escape_processor) + + def emit(self, record: logging.LogRecord) -> None: + # Grab the category (because of wrap_for_formatter) + # Try attribute first (for stdlib logging) + category: str | Enum | None = getattr(record, "category", None) + + # If not present, try to get it from record.msg (structlog dict) + if category is None and isinstance(record.msg, dict): + category = record.msg.get("category", None) + + # Convert to string + if isinstance(category, Enum): + category = category.value + + assert isinstance(category, str) + + # Skip uncategorized if desired + if ( + not category or category == LogCategories.NONE.value + ) and self.skip_uncategorized: + return + + # None category is a catch all + if not category: + category = LogCategories.NONE.value + + # Write ALL category + if category == LogCategories.ALL.value: + for h in self.handlers.values(): + h.emit(record) + return + + # Lazily build the per‐category FileHandler + if category not in self.handlers: + fn: Path = Path(f"{self.base_path}-{category}.log") + fh: logging.FileHandler = logging.FileHandler( + fn, mode="a" if self.append else "w" + ) + fh.setFormatter(self.formatter) + self.handlers[category] = fh + + # Delegate to the per‐category handler + self.handlers[category].emit(record) + + +class NameFileHandler(logging.Handler): + """Logging file handler that will write by logger name.""" + + def __init__( + self, base_path: Path, append: bool = False, skip_unnamed: bool = False + ) -> None: + super().__init__() + self.base_path: Path = base_path + self.append: bool = append + self.handlers: dict[str, logging.FileHandler] = {} + self.skip_unnamed: bool = skip_unnamed + self.addFilter(_strip_ansi_escape_processor) + + @override + def emit(self, record: logging.LogRecord) -> None: + logger_name: str = record.name + # Write to file (and also to stdout, if desired) + if logger_name not in self.handlers: + handler = logging.FileHandler( + f"{self.base_path}-{logger_name}.log", + mode="a" if self.append else "w", + ) + handler.setFormatter(self.formatter) + self.handlers[logger_name] = handler + self.handlers[logger_name].emit(record) diff --git a/esbmc_ai/log_utils.py b/esbmc_ai/log_utils.py index 2b56083..00603f5 100644 --- a/esbmc_ai/log_utils.py +++ b/esbmc_ai/log_utils.py @@ -3,32 +3,16 @@ """Horizontal line logging integrated with Structlog.""" from enum import Enum -from pathlib import Path -import re -from typing import Optional, override from os import get_terminal_size import logging import structlog from structlog.typing import EventDict -_enable_horizontal_lines: bool = True -_horizontal_line_width: Optional[int] = None +from esbmc_ai.log_categories import LogCategories + _verbose_level: int = logging.INFO _logging_format: str = "%(name)s %(message)s" - - -class LogCategories(Enum): - NONE = "NONE" - ALL = "ALL" - SYSTEM = "ESBMC_AI" - VERIFIER = "VERIFIER" - COMMAND = "COMMAND" - CONFIG = "CONFIG" - CHAT = "CHAT" - - _largest_cat_len: int = min(10, max(len(cat.value) for cat in LogCategories)) -_ansi_escape = re.compile(r"\x1b\[[0-9;]*m") def get_log_level(verbosity: int | None = None) -> int: @@ -165,29 +149,22 @@ def init_logging( ) -def set_horizontal_lines(value: bool) -> None: - global _enable_horizontal_lines - _enable_horizontal_lines = value - - -def set_horizontal_line_width(value: Optional[int]) -> None: - global _horizontal_line_width - _horizontal_line_width = value - - def print_horizontal_line( level: str | int = "info", *, char: str = "=", category: Enum | str = LogCategories.ALL, - width: Optional[int] = None, + width: int | None = None, logger: structlog.stdlib.BoundLogger | None = None, ) -> None: """ Print a horizontal line if logging is enabled for the specified level. Both an int of the level or the verbose name could be surprised. """ - if not _enable_horizontal_lines: + # Import Config locally to avoid circular import + from esbmc_ai.config import Config + + if not Config().show_horizontal_lines: return # Convert level name to numeric value (e.g., "info" -> logging.INFO) @@ -198,15 +175,19 @@ def print_horizontal_line( ) # Determine line width + line_width: int if width is not None: line_width = width - elif _horizontal_line_width is not None: - line_width = _horizontal_line_width + else: - try: - line_width = get_terminal_size().columns - except OSError: - line_width = 80 - _largest_cat_len + config_hlw: int | None = Config().horizontal_line_width + if config_hlw is not None: + line_width = config_hlw + else: + try: + line_width = get_terminal_size().columns + except OSError: + line_width = 80 - _largest_cat_len if logger is None: logger = structlog.get_logger() @@ -286,112 +267,3 @@ def _filter_keys_processor( event_dict.pop("_record", None) event_dict.pop("category", None) return event_dict - - -def _strip_ansi_escape_processor(record: logging.LogRecord) -> bool | logging.LogRecord: - """ - Remove ANSI escape sequences from all string values in the LogRecord. - """ - for attr, value in list(record.__dict__.items()): - if isinstance(value, str): - setattr(record, attr, _ansi_escape.sub("", value)) - - args = record.args - if isinstance(args, tuple): - record.args = tuple( - _ansi_escape.sub("", v) if isinstance(v, str) else v for v in args - ) - elif isinstance(args, dict): - record.args = { - k: (_ansi_escape.sub("", v) if isinstance(v, str) else v) - for k, v in args.items() - } - - return record - - -class CategoryFileHandler(logging.Handler): - """Logger that will save by category.""" - - def __init__( - self, - base_path: Path, - append: bool = False, - skip_uncategorized: bool = False, - ) -> None: - super().__init__() - self.base_path = base_path - self.append = append - self.skip_uncategorized = skip_uncategorized - self.handlers: dict[str, logging.FileHandler] = {} - self.addFilter(_strip_ansi_escape_processor) - - def emit(self, record: logging.LogRecord) -> None: - # Grab the category (because of wrap_for_formatter) - # Try attribute first (for stdlib logging) - category: str | Enum | None = getattr(record, "category", None) - - # If not present, try to get it from record.msg (structlog dict) - if category is None and isinstance(record.msg, dict): - category = record.msg.get("category", None) - - # Convert to string - if isinstance(category, Enum): - category = category.value - - assert isinstance(category, str) - - # Skip uncategorized if desired - if ( - not category or category == LogCategories.NONE.value - ) and self.skip_uncategorized: - return - - # None category is a catch all - if not category: - category = LogCategories.NONE.value - - # Write ALL category - if category == LogCategories.ALL.value: - for h in self.handlers.values(): - h.emit(record) - return - - # Lazily build the per‐category FileHandler - if category not in self.handlers: - fn: Path = Path(f"{self.base_path}-{category}.log") - fh: logging.FileHandler = logging.FileHandler( - fn, mode="a" if self.append else "w" - ) - fh.setFormatter(self.formatter) - self.handlers[category] = fh - - # Delegate to the per‐category handler - self.handlers[category].emit(record) - - -class NameFileHandler(logging.Handler): - """Logging file handler that will write by logger name.""" - - def __init__( - self, base_path: Path, append: bool = False, skip_unnamed: bool = False - ) -> None: - super().__init__() - self.base_path: Path = base_path - self.append: bool = append - self.handlers: dict[str, logging.FileHandler] = {} - self.skip_unnamed: bool = skip_unnamed - self.addFilter(_strip_ansi_escape_processor) - - @override - def emit(self, record: logging.LogRecord) -> None: - logger_name: str = record.name - # Write to file (and also to stdout, if desired) - if logger_name not in self.handlers: - handler = logging.FileHandler( - f"{self.base_path}-{logger_name}.log", - mode="a" if self.append else "w", - ) - handler.setFormatter(self.formatter) - self.handlers[logger_name] = handler - self.handlers[logger_name].emit(record) diff --git a/esbmc_ai/msg_bus.py b/esbmc_ai/msg_bus.py deleted file mode 100644 index 12a01a0..0000000 --- a/esbmc_ai/msg_bus.py +++ /dev/null @@ -1,17 +0,0 @@ -# Author: Yiannis Charalambous - -from typing import Callable - - -class Signal(object): - subscribers: list[Callable] = [] - - def add_listener(self, fn: Callable) -> None: - self.subscribers.append(fn) - - def remove_listener(self, fn: Callable) -> None: - self.subscribers.remove(fn) - - def emit(self, *args, **params) -> None: - for sub in self.subscribers: - sub(*args, **params) diff --git a/esbmc_ai/solution.py b/esbmc_ai/solution.py index c50d6ad..108ee06 100644 --- a/esbmc_ai/solution.py +++ b/esbmc_ai/solution.py @@ -1,6 +1,6 @@ # Author: Yiannis Charalambous -"""Keeps track of all the source files that ESBMC-AI is targeting. """ +"""Keeps track of all the source files that ESBMC-AI is targeting.""" from dataclasses import dataclass from os import getcwd, walk @@ -10,9 +10,9 @@ from shutil import copytree from typing import override +from langchain_core.language_models import BaseChatModel import lizard -from esbmc_ai.ai_models import AIModel from esbmc_ai.log_utils import get_log_level, print_horizontal_line from esbmc_ai.verifier_output import VerifierOutput @@ -102,7 +102,7 @@ def file_extension(self) -> str: def get_num_tokens( self, - ai_model: AIModel, + ai_model: BaseChatModel, lower_idx: int | None = None, upper_idx: int | None = None, ) -> int: diff --git a/esbmc_ai/verifiers/esbmc.py b/esbmc_ai/verifiers/esbmc.py index 39234e7..d0516b1 100644 --- a/esbmc_ai/verifiers/esbmc.py +++ b/esbmc_ai/verifiers/esbmc.py @@ -300,7 +300,9 @@ def __init__(self) -> None: @property def esbmc_path(self) -> Path: """Returns the ESBMC path from config.""" - return self.get_config_value("verifier.esbmc.path").absolute() + if not self.global_config.verifier.esbmc.path: + raise ValueError("No esbmc path set.") + return self.global_config.verifier.esbmc.path.absolute() @override def verify_source( @@ -315,7 +317,7 @@ def verify_source( _ = kwargs esbmc_params: list[str] = ( - params if params else self.get_config_value("verifier.esbmc.params") + params if params else self.global_config.verifier.esbmc.params ) if "--multi-property" in esbmc_params: @@ -343,7 +345,7 @@ def verify_source( raise SolutionIntegrityError(solution.files) # Check if cached version exists. - enable_cache: bool = self.get_config_value("verifier.enable_cache") + enable_cache: bool = self.global_config.verifier.enable_cache cache_properties: Any = [solution, entry_function, timeout, params, kwargs] if enable_cache: cached_result: Any = self._load_cached(cache_properties) diff --git a/pyproject.toml b/pyproject.toml index 5402ea3..2fdf4e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "structlog", "platformdirs", "python-dotenv", + "pydantic", "regex", "torch", # Needed by transformers "transformers", # Needed by langchain-core to calculate get_token_ids diff --git a/tests/regtest/_regtest_outputs/test_base_chat_interface.test_push_message_stack.out b/tests/regtest/_regtest_outputs/test_base_chat_interface.test_push_message_stack.out deleted file mode 100644 index dad5f03..0000000 --- a/tests/regtest/_regtest_outputs/test_base_chat_interface.test_push_message_stack.out +++ /dev/null @@ -1,5 +0,0 @@ -system: System message -ai: OK -ai: Test 1 -human: Test 2 -system: Test 3 diff --git a/tests/regtest/_regtest_outputs/test_base_chat_interface.test_send_message.out b/tests/regtest/_regtest_outputs/test_base_chat_interface.test_send_message.out deleted file mode 100644 index e804908..0000000 --- a/tests/regtest/_regtest_outputs/test_base_chat_interface.test_send_message.out +++ /dev/null @@ -1,11 +0,0 @@ -FinishReason.stop 26 OK 1 -FinishReason.stop 36 OK 2 -FinishReason.stop 46 OK 3 -System message: 0 System message -System message: 1 OK -Message: 0 Test 1 -Message: 1 OK 1 -Message: 2 Test 2 -Message: 3 OK 2 -Message: 4 Test 3 -Message: 5 OK 3 diff --git a/tests/regtest/test_base_chat_interface.py b/tests/regtest/test_base_chat_interface.py deleted file mode 100644 index 8096711..0000000 --- a/tests/regtest/test_base_chat_interface.py +++ /dev/null @@ -1,70 +0,0 @@ -# Author: Yiannis Charalambous - -import pytest - -from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage - -from esbmc_ai.ai_models import AIModel -from esbmc_ai.chat_response import ChatResponse -from esbmc_ai.chats.base_chat_interface import BaseChatInterface -from tests.test_ai_models import MockAIModel -from pprint import pprint - - -@pytest.fixture -def setup() -> BaseChatInterface: - responses: list[str] = ["OK 1", "OK 2", "OK 3"] - - ai_model: AIModel = MockAIModel(name="test", tokens=1024, responses=responses) - assert isinstance(ai_model, MockAIModel) - - system_messages = [ - SystemMessage(content="System message"), - AIMessage(content="OK"), - ] - - chat: BaseChatInterface = BaseChatInterface( - system_messages=system_messages, - ai_model=ai_model, - ) - chat.cooldown_total = 0 - - return chat - - -def test_push_message_stack(regtest, setup) -> None: - chat = setup - - messages: list[BaseMessage] = [ - AIMessage(content="Test 1"), - HumanMessage(content="Test 2"), - SystemMessage(content="Test 3"), - ] - - chat.push_to_message_stack(messages[0]) - chat.push_to_message_stack(messages[1]) - chat.push_to_message_stack(messages[2]) - - with regtest: - for msg in chat._system_messages: - print(f"{msg.type}: {msg.content}") - - for msg in chat.messages: - print(f"{msg.type}: {msg.content}") - - -def test_send_message(regtest, setup) -> None: - chat: BaseChatInterface = setup - - with regtest: - r: ChatResponse = chat.send_message("Test 1") - print(r.finish_reason, r.total_tokens, r.message.content) - r = chat.send_message("Test 2") - print(r.finish_reason, r.total_tokens, r.message.content) - r = chat.send_message("Test 3") - print(r.finish_reason, r.total_tokens, r.message.content) - # Show system messages - for idx, msg in enumerate(chat._system_messages): - print("System message:", idx, msg.content) - for idx, msg in enumerate(chat.messages): - print("Message:", idx, msg.content) diff --git a/tests/test_ai_models.py b/tests/test_ai_models.py index 61435f4..e8a25e2 100644 --- a/tests/test_ai_models.py +++ b/tests/test_ai_models.py @@ -1,271 +1,32 @@ # Author: Yiannis Charalambous from dataclasses import dataclass, field -from typing import override -from langchain.prompts.chat import ChatPromptValue -from langchain.schema import ( - AIMessage, - BaseMessage, - HumanMessage, - PromptValue, - SystemMessage, -) +from langchain.schema import BaseMessage from langchain_core.language_models import BaseChatModel, FakeListChatModel -from pytest import raises -import pytest - -from esbmc_ai.ai_models import ( - AIModel, - AIModelOpenAI, - AIModels, -) @dataclass(frozen=True, kw_only=True) -class MockAIModel(AIModel): - """Used to test AIModels, it implements some mock versions of abstract - methods.""" +class MockAIModel: + """Mock AI model for testing purposes.""" + name: str = "mock-model" + tokens: int = 1024 responses: list[str] = field(default_factory=list[str]) - @override def create_llm(self) -> BaseChatModel: + """Create a fake LLM for testing.""" return FakeListChatModel(responses=self.responses) - @override - def get_num_tokens(self, content: str) -> int: - _ = content - return len(content) - - @override - def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int: - _ = messages - return sum(len(str(msg.content)) for msg in messages) - - -@pytest.fixture(autouse=True) -def clear_ai_models(): - """Clear the AIModels singleton state before each test.""" - # Clear the singleton's internal state - AIModels()._ai_models.clear() - yield - # Optionally clear after test as well - AIModels()._ai_models.clear() - - -def test_is_not_valid_ai_model() -> None: - custom_model: MockAIModel = MockAIModel( - name="custom_ai", - tokens=999, - ) - assert not AIModels().is_valid_ai_model(custom_model) - assert not AIModels().is_valid_ai_model("doesn't exist") - - -def test_add_custom_ai_model() -> None: - custom_model: MockAIModel = MockAIModel( - name="custom_ai", - tokens=999, - ) - - AIModels().add_ai_model(custom_model) - - assert AIModels().is_valid_ai_model(custom_model.name) - - # Test add again. - - if AIModels().is_valid_ai_model(custom_model.name): - with raises(Exception): - AIModels().add_ai_model(custom_model) - - assert AIModels().is_valid_ai_model(custom_model.name) - - -def test_get_ai_model_by_name() -> None: - # Try with first class AI - # assert get_ai_model_by_name("falcon-7b") - - # Try with custom AI. - # Add custom AI model if not added by previous tests. - custom_model: MockAIModel = MockAIModel( - name="custom_ai", - tokens=999, - ) - if not AIModels().is_valid_ai_model(custom_model.name): - AIModels().add_ai_model(custom_model) - assert AIModels().get_ai_model("custom_ai") is not None - - # Try with non existent. - with raises(Exception): - AIModels().get_ai_model("not-exists") - - -def test_apply_chat_template() -> None: - messages: list = [ - SystemMessage(content="M1"), - HumanMessage(content="M2"), - AIMessage(content="M3"), - ] - # Test the identity method. - custom_model_1: MockAIModel = MockAIModel( - name="custom", - tokens=999, +def test_mock_ai_model_creation() -> None: + """Test that MockAIModel can be created.""" + mock_model = MockAIModel( + name="test-model", tokens=2048, responses=["response1", "response2"] ) + assert mock_model.name == "test-model" + assert mock_model.tokens == 2048 + assert len(mock_model.responses) == 2 - prompt: PromptValue = custom_model_1.apply_chat_template(messages=messages) - - assert prompt == ChatPromptValue(messages=messages) - - -def test_safe_substitute() -> None: - """Tests that safe template substitution works correctly.""" - - # Test basic substitution - result = MockAIModel.safe_substitute("Hello $name, you have $count messages", name="Alice", count=5) - assert result == "Hello Alice, you have 5 messages" - - # Test missing variables are left unchanged - result = MockAIModel.safe_substitute("Hello $name, you have $count messages", name="Alice") - assert result == "Hello Alice, you have $count messages" - - # Test no variables - result = MockAIModel.safe_substitute("Hello world") - assert result == "Hello world" - - # Test mixed defined and undefined - result = MockAIModel.safe_substitute("Process $input with $method using $tool", input="data.txt", tool="hammer") - assert result == "Process data.txt with $method using hammer" - - -def test_safe_substitute_special_characters() -> None: - """Test substitution with special characters and escaping.""" - - # Test literal dollar signs (double dollar becomes single) - result = MockAIModel.safe_substitute("Price is $$50", price=100) - assert result == "Price is $50" - - # Test variables with underscores and numbers - result = MockAIModel.safe_substitute("File: $file_name_2", file_name_2="test.txt") - assert result == "File: test.txt" - - # Test variables at string boundaries - result = MockAIModel.safe_substitute("$start middle $end", start="BEGIN", end="FINISH") - assert result == "BEGIN middle FINISH" - - -def test_safe_substitute_data_types() -> None: - """Test substitution with different data types.""" - - # Test with None values - result = MockAIModel.safe_substitute("Value: $value", value=None) - assert result == "Value: None" - - # Test with boolean values - result = MockAIModel.safe_substitute("Enabled: $enabled, Debug: $debug", enabled=True, debug=False) - assert result == "Enabled: True, Debug: False" - - # Test with numeric values - result = MockAIModel.safe_substitute("Count: $count, Rate: $rate", count=0, rate=3.14) - assert result == "Count: 0, Rate: 3.14" - - # Test with list and dict (stringify behavior) - result = MockAIModel.safe_substitute("List: $items, Dict: $config", items=[1, 2, 3], config={"key": "value"}) - assert result == "List: [1, 2, 3], Dict: {'key': 'value'}" - - -def test_safe_substitute_whitespace_formatting() -> None: - """Test substitution with whitespace and formatting edge cases.""" - - # Test multiline strings - multiline = """Line 1: $var1 -Line 2: $var2 -Line 3: $undefined""" - result = MockAIModel.safe_substitute(multiline, var1="value1", var2="value2") - expected = """Line 1: value1 -Line 2: value2 -Line 3: $undefined""" - assert result == expected - - # Test with tabs and spaces - result = MockAIModel.safe_substitute("\t$var\n $other ", var="tabbed", other="spaced") - assert result == "\ttabbed\n spaced " - - -def test_safe_substitute_consecutive_variables() -> None: - """Test substitution with consecutive variables.""" - - # Test consecutive variables without separators - result = MockAIModel.safe_substitute("$prefix$suffix$extension", prefix="file", suffix="name", extension=".txt") - assert result == "filename.txt" - - # Test multiple undefined consecutive variables - result = MockAIModel.safe_substitute("$a$b$c$d", a="A", c="C") - assert result == "A$bC$d" - - -def test_safe_substitute_boundary_conditions() -> None: - """Test boundary conditions for substitution.""" - - # Test empty string - result = MockAIModel.safe_substitute("", var="value") - assert result == "" - - # Test string with only variable - result = MockAIModel.safe_substitute("$only", only="result") - assert result == "result" - - # Test undefined variable only - result = MockAIModel.safe_substitute("$undefined") - assert result == "$undefined" - - -def test_safe_substitute_invalid_variable_patterns() -> None: - """Test handling of invalid or malformed variable patterns.""" - - # Test variables starting with numbers (invalid Python identifiers) - result = MockAIModel.safe_substitute("Value: $1var $2test", var="valid") - assert result == "Value: $1var $2test" # Should remain unchanged as invalid identifiers - - # Test variables with hyphens - Template treats "var" as the variable name and stops at hyphen - result = MockAIModel.safe_substitute("Config: $var-name $test-value", var="valid") - assert result == "Config: valid-name $test-value" # "var" gets substituted, hyphen treated as separator - - # Test lone dollar sign - result = MockAIModel.safe_substitute("Price: $ and $valid", valid="100") - assert result == "Price: $ and 100" - - -def test_safe_substitute_recursive_values() -> None: - """Test substitution where values contain variable-like patterns.""" - - # Test when substitution values contain dollar signs - result = MockAIModel.safe_substitute("Template: $template_content", template_content="Use $variable for substitution") - assert result == "Template: Use $variable for substitution" - - # Test when values look like variables but shouldn't be re-substituted - result = MockAIModel.safe_substitute("$msg", msg="Hello $world") - assert result == "Hello $world" # Should not recursively substitute $world - - -def test_safe_substitute_unicode() -> None: - """Test substitution with Unicode characters.""" - - # Test Unicode in content and values - result = MockAIModel.safe_substitute("Greeting: $greeting", greeting="Hello 世界!") - assert result == "Greeting: Hello 世界!" - - # Test Unicode variable names - Python's Template class doesn't support Unicode identifiers - # So Unicode variable names remain unchanged - result = MockAIModel.safe_substitute("文档: $文档", **{"文档": "document.txt"}) - assert result == "文档: $文档" - - -def test__get_openai_model_max_tokens() -> None: - assert AIModelOpenAI.get_max_tokens("gpt-4o") == 128000 - assert AIModelOpenAI.get_max_tokens("gpt-4-turbo") == 128000 - assert AIModelOpenAI.get_max_tokens("gpt-3.5-turbo") == 16385 - assert AIModelOpenAI.get_max_tokens("gpt-3.5-turbo-aaaaaa") == 16385 - - with raises(ValueError): - AIModelOpenAI.get_max_tokens("aaaaa") + llm = mock_model.create_llm() + assert llm is not None + assert isinstance(llm, FakeListChatModel) diff --git a/tests/test_base_chat_interface.py b/tests/test_base_chat_interface.py deleted file mode 100644 index e1c40bc..0000000 --- a/tests/test_base_chat_interface.py +++ /dev/null @@ -1,133 +0,0 @@ -# Author: Yiannis Charalambous - -import pytest - -from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage -from esbmc_ai.chats.base_chat_interface import BaseChatInterface -from esbmc_ai.chat_response import ChatResponse -from tests.test_ai_models import MockAIModel - - -@pytest.fixture(scope="module") -def setup(): - responses: list[str] = ["OK 1", "OK 2", "OK 3"] - ai_model: MockAIModel = MockAIModel(name="test", tokens=1024, responses=responses) - - system_messages: list[BaseMessage] = [ - SystemMessage(content="First system message"), - AIMessage(content="OK"), - ] - - return ai_model, system_messages, responses - - -def test_push_message_stack(setup) -> None: - ai_model, system_messages, _ = setup - - chat: BaseChatInterface = BaseChatInterface( - system_messages=system_messages, - ai_model=ai_model, - ) - - for msg, chat_msg in zip(system_messages, chat._system_messages): - assert msg.type == chat_msg.type - assert msg.content == chat_msg.content - - messages: list[BaseMessage] = [ - AIMessage(content="Test 1"), - HumanMessage(content="Test 2"), - SystemMessage(content="Test 3"), - SystemMessage(content="Test 4"), - SystemMessage(content="Test 5"), - SystemMessage(content="Test 6"), - ] - - chat.push_to_message_stack(message=messages[0]) - chat.push_to_message_stack(message=messages[1]) - chat.push_to_message_stack(message=messages[2]) - - assert chat.messages[0] == messages[0] - assert chat.messages[1] == messages[1] - assert chat.messages[2] == messages[2] - - chat.push_to_message_stack(message=messages[3:]) - - assert chat.messages[3] == messages[3] - assert chat.messages[4] == messages[4] - assert chat.messages[5] == messages[5] - - -def test_send_message(setup) -> None: - ai_model, system_messages, responses = setup - - chat: BaseChatInterface = BaseChatInterface( - system_messages=system_messages, - ai_model=ai_model, - ) - - chat_responses: list[ChatResponse] = [ - chat.send_message("Test 1"), - chat.send_message("Test 2"), - chat.send_message("Test 3"), - ] - - assert chat_responses[0].message.content == responses[0] - assert chat_responses[1].message.content == responses[1] - assert chat_responses[2].message.content == responses[2] - - -def test_apply_template() -> None: - ai_model: MockAIModel = MockAIModel(name="test", tokens=1024) - - system_messages: list[BaseMessage] = [ - SystemMessage(content="This is a $source_code message"), - SystemMessage(content="Replace with $esbmc_output message"), - SystemMessage(content="$source_code$esbmc_output"), - ] - - responses: list[str] = [ - "This is a replaced message", - "Replace with $esbmc_output message", - "replaced$esbmc_output", - "This is a replaced message", - "Replace with also replaced message", - "replacedalso replaced", - ] - - chat: BaseChatInterface = BaseChatInterface( - system_messages=system_messages, - ai_model=ai_model, - ) - - chat.apply_template_value(source_code="replaced") - - assert chat._system_messages[0].content == responses[0] - assert chat._system_messages[1].content == responses[1] - assert chat._system_messages[2].content == responses[2] - - chat.apply_template_value(esbmc_output="also replaced") - - assert chat._system_messages[0].content == responses[3] - assert chat._system_messages[1].content == responses[4] - assert chat._system_messages[2].content == responses[5] - - -def test_apply_template_escape_characters() -> None: - ai_model: MockAIModel = MockAIModel(name="test", tokens=1024) - - system_messages: list[BaseMessage] = [ - SystemMessage(content="This is a $$source_code message"), - SystemMessage(content="Escaped $$aaa but not $source_code"), - ] - - chat: BaseChatInterface = BaseChatInterface( - system_messages=system_messages, - ai_model=ai_model, - ) - - chat.apply_template_value(source_code="replaced", aaa="should_not_appear") - - # $$source_code should become $source_code (escaped) - assert chat._system_messages[0].content == "This is a $source_code message" - # $$aaa should become $aaa (escaped), but $source_code should be replaced - assert chat._system_messages[1].content == "Escaped $aaa but not replaced" diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py deleted file mode 100644 index d1549c3..0000000 --- a/tests/test_command_parser.py +++ /dev/null @@ -1,26 +0,0 @@ -# Author: Yiannis Charalambous - -from esbmc_ai.component_loader import ComponentLoader - - -def test_parse() -> None: - sentence = 'Your sentence goes "here and \\"here\\" as well."' - result = ComponentLoader.parse_command(sentence) - assert result == ( - "Your", - [ - "sentence", - "goes", - '"here and \\"here\\" as well."', - ], - ) - - -def test_parse_command() -> None: - result = ComponentLoader.parse_command("/fix-code") - assert result == ("/fix-code", []) - - -def test_parse_command_args() -> None: - result = ComponentLoader.parse_command("/optimize-code main") - assert result == ("/optimize-code", ["main"]) diff --git a/tests/test_config.py b/tests/test_config.py index f726be9..537ec43 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,86 +1,47 @@ # Author: Yiannis Charalambous 2023 from pytest import raises +from pydantic import ValidationError -from esbmc_ai.config import Config -from esbmc_ai.ai_models import AIModels +from esbmc_ai.config import Config, AICustomModelConfig def test_load_custom_ai() -> None: - custom_ai_config: dict = { - "ai_custom": { - "example_ai": { - "max_tokens": 4096, - "url": "www.example.com", - "server_type": "ollama", - } - } - } + """Test loading custom AI models through config.""" + # Test valid custom AI config + custom_config = AICustomModelConfig( + max_tokens=4096, + url="www.example.com", + server_type="ollama", + ) - Config().load_custom_ai(custom_ai_config) - - assert AIModels().is_valid_ai_model("example_ai") + assert custom_config.max_tokens == 4096 + assert custom_config.url == "www.example.com" + assert custom_config.server_type == "ollama" def test_load_custom_ai_fail() -> None: - # Wrong max_tokens type - ai_conf: dict = { - "example_ai_2": { - "max_tokens": "1024", - "url": "www.example.com", - "server_type": "ollama", - } - } - - with raises(TypeError): - Config()._validate_custom_ai(ai_conf) - - # Wrong max_tokens value - ai_conf: dict = { - "example_ai_2": { - "max_tokens": 0, - "url": "www.example.com", - "server_type": "ollama", - } - } - - with raises(ValueError): - Config()._validate_custom_ai(ai_conf) + """Test that invalid custom AI configurations are rejected.""" + # Note: Pydantic v2 coerces strings to int when possible, so "1024" -> 1024 + # This is expected behavior and not an error # Missing max_tokens - ai_conf: dict = { - "example_ai_2": { - "url": "www.example.com", - "server_type": "ollama", - } - } - - with raises(KeyError): - Config()._validate_custom_ai(ai_conf) + with raises(ValidationError): + AICustomModelConfig( # type: ignore + url="www.example.com", + server_type="ollama", + ) # Missing url - ai_conf: dict = { - "example_ai_2": { - "max_tokens": 1000, - "server_type": "ollama", - } - } - - with raises(KeyError): - Config()._validate_custom_ai(ai_conf) + with raises(ValidationError): + AICustomModelConfig( # type: ignore + max_tokens=1000, + server_type="ollama", + ) # Missing server type - ai_conf: dict = { - "example_ai_2": { - "max_tokens": 100, - "url": "www.example.com", - } - } - - with raises(KeyError): - Config()._validate_custom_ai(ai_conf) - - # Test load empty - ai_conf: dict = {} - - Config()._validate_custom_ai(ai_conf) + with raises(ValidationError): + AICustomModelConfig( # type: ignore + max_tokens=100, + url="www.example.com", + ) diff --git a/tests/test_template_renderer.py b/tests/test_template_renderer.py new file mode 100644 index 0000000..1bde7fe --- /dev/null +++ b/tests/test_template_renderer.py @@ -0,0 +1,56 @@ +# Author: Yiannis Charalambous + +from esbmc_ai.chats.template_renderer import KeyTemplateRenderer +from esbmc_ai.chats.template_key_provider import ESBMCTemplateKeyProvider + + +def test_template_substitution(): + """Test that template variables are correctly substituted.""" + template_str = "The ESBMC output is:\n\n```\n{esbmc_output}\n```\n\nThe source code is:\n\n```c\n{source_code}\n```" + messages = [("human", template_str)] + + renderer = KeyTemplateRenderer( + messages=messages, + key_provider=ESBMCTemplateKeyProvider(), + ) + + formatted = renderer.format_messages( + source_code="int main() { return 0; }", + esbmc_output="VERIFICATION SUCCESSFUL", + error_line="0", + error_type="none", + ) + + assert len(formatted) == 1 + assert "int main() { return 0; }" in formatted[0].content + assert "VERIFICATION SUCCESSFUL" in formatted[0].content + assert "{source_code}" not in formatted[0].content + assert "{esbmc_output}" not in formatted[0].content + + +def test_template_substitution_with_multiline_code(): + """Test template substitution with multiline source code.""" + template_str = "Source:\n{source_code}\nError: {error_type}" + messages = [("human", template_str)] + + renderer = KeyTemplateRenderer( + messages=messages, + key_provider=ESBMCTemplateKeyProvider(), + ) + + source = """int main() { + int x = 5; + return x / 0; +}""" + + formatted = renderer.format_messages( + source_code=source, + esbmc_output="Division by zero", + error_line="3", + error_type="division by zero", + ) + + assert len(formatted) == 1 + assert "int x = 5;" in formatted[0].content + assert "return x / 0;" in formatted[0].content + assert "division by zero" in formatted[0].content