Skip to content

feat: Add persistent site exclusion settings (#652) #675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions sample.config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ OLLAMA = "http://127.0.0.1:11434"
LM_STUDIO = "http://localhost:1234/v1"
OPENAI = "https://api.openai.com/v1"


[LOGGING]
LOG_REST_API = "true"
LOG_PROMPTS = "false"

[TIMEOUT]
INFERENCE = 60
INFERENCE = 60

[SITE_EXCLUSIONS]
EXCLUDED_SITES = []
25 changes: 17 additions & 8 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,23 @@ def _load_config(self):
# check if all the keys are present in the config file
with open("sample.config.toml", "r") as f:
sample_config = toml.load(f)

with open("config.toml", "r+") as f:
config = toml.load(f)

# Update the config with any missing keys and their keys of keys
for key, value in sample_config.items():
config.setdefault(key, value)
if isinstance(value, dict):
for sub_key, sub_value in value.items():
config[key].setdefault(sub_key, sub_value)

f.seek(0)
toml.dump(config, f)
f.truncate()

self.config = config

def get_config(self):
return self.config

Expand All @@ -59,7 +59,7 @@ def get_google_search_api_endpoint(self):

def get_ollama_api_endpoint(self):
return self.config["API_ENDPOINTS"]["OLLAMA"]

def get_lmstudio_api_endpoint(self):
return self.config["API_ENDPOINTS"]["LM_STUDIO"]

Expand Down Expand Up @@ -107,10 +107,19 @@ def get_logging_rest_api(self):

def get_logging_prompts(self):
return self.config["LOGGING"]["LOG_PROMPTS"] == "true"

def get_timeout_inference(self):
return self.config["TIMEOUT"]["INFERENCE"]

def get_excluded_sites(self):
return self.config.get("SITE_EXCLUSIONS", {}).get("EXCLUDED_SITES", [])

def set_excluded_sites(self, sites):
if "SITE_EXCLUSIONS" not in self.config:
self.config["SITE_EXCLUSIONS"] = {}
self.config["SITE_EXCLUSIONS"]["EXCLUDED_SITES"] = sites
self.save_config()

def set_bing_api_key(self, key):
self.config["API_KEYS"]["BING"] = key
self.save_config()
Expand All @@ -134,7 +143,7 @@ def set_google_search_api_endpoint(self, endpoint):
def set_ollama_api_endpoint(self, endpoint):
self.config["API_ENDPOINTS"]["OLLAMA"] = endpoint
self.save_config()

def set_lmstudio_api_endpoint(self, endpoint):
self.config["API_ENDPOINTS"]["LM_STUDIO"] = endpoint
self.save_config()
Expand Down
45 changes: 45 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import pytest
from src.config import Config

@pytest.fixture
def config():
# Create a temporary config for testing
if os.path.exists("config.toml"):
os.rename("config.toml", "config.toml.bak")

yield Config()

# Restore original config
if os.path.exists("config.toml.bak"):
os.rename("config.toml.bak", "config.toml")
else:
os.remove("config.toml")

def test_excluded_sites_empty_by_default(config):
"""Test that excluded sites list is empty by default."""
assert config.get_excluded_sites() == []

def test_set_excluded_sites(config):
"""Test setting and getting excluded sites."""
test_sites = ["example.com", "test.org"]
config.set_excluded_sites(test_sites)
assert config.get_excluded_sites() == test_sites

def test_excluded_sites_persistence(config):
"""Test that excluded sites persist after saving."""
test_sites = ["example.com", "test.org"]
config.set_excluded_sites(test_sites)

# Create new config instance to test persistence
new_config = Config()
assert new_config.get_excluded_sites() == test_sites

def test_update_excluded_sites(config):
"""Test updating excluded sites list."""
initial_sites = ["example.com"]
config.set_excluded_sites(initial_sites)

updated_sites = ["example.com", "test.org"]
config.set_excluded_sites(updated_sites)
assert config.get_excluded_sites() == updated_sites