Skip to content

style: apply ALL ruff rules #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ target-version = "py310"
docstring-code-format = true

[tool.ruff.lint]
select = ["A", "ASYNC", "B", "BLE", "C4", "C90", "D", "DTZ", "E", "EM", "ERA", "F", "FBT", "FLY", "FURB", "G", "I", "ICN", "INP", "INT", "ISC", "LOG", "N", "NPY", "PERF", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "Q", "RET", "RSE", "RUF", "S", "SIM", "SLF", "SLOT", "T10", "T20", "TCH", "TID", "TRY", "UP", "W", "YTT"]
ignore = ["D203", "D213", "E501", "RET504", "RUF002", "RUF022", "S101", "S307", "TC004"]
select = ["ALL"]
ignore = ["CPY", "FIX", "ARG001", "COM812", "D203", "D213", "E501", "PD008", "PD009", "RET504", "S101", "TD003"]
unfixable = ["ERA001", "F401", "F841", "T201", "T203"]

[tool.ruff.lint.flake8-tidy-imports]
Expand Down
2 changes: 1 addition & 1 deletion src/raglite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
vector_search,
)

__all__ = [
__all__ = [ # noqa: RUF022
# Config
"RAGLiteConfig",
# Insert
Expand Down
23 changes: 14 additions & 9 deletions src/raglite/_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
insert_variant: str | None = None,
search_variant: str | None = None,
config: RAGLiteConfig | None = None,
):
) -> None:
super().__init__(
dataset,
num_results=num_results,
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
num_results: int = 10,
insert_variant: str | None = None,
search_variant: str | None = None,
):
) -> None:
super().__init__(
dataset,
num_results=num_results,
Expand All @@ -156,7 +156,7 @@ def __init__(
self.embedder_dim = 3072
self.persist_path = self.cwd / self.insert_id

def insert_documents(self, max_workers: int | None = None) -> None:
def insert_documents(self, max_workers: int | None = None) -> None: # noqa: ARG002
# Adapted from https://docs.llamaindex.ai/en/stable/examples/vector_stores/FaissIndexDemo/.
import faiss
from llama_index.core import Document, StorageContext, VectorStoreIndex
Expand All @@ -178,14 +178,15 @@ def insert_documents(self, max_workers: int | None = None) -> None:
index.storage_context.persist(persist_dir=self.persist_path)

@cached_property
def index(self) -> Any:
def index(self) -> Any: # noqa: ANN401
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore

vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.persist_path.as_posix())
storage_context = StorageContext.from_defaults(
vector_store=vector_store, persist_dir=self.persist_path.as_posix()
vector_store=vector_store,
persist_dir=self.persist_path.as_posix(),
)
embed_model = OpenAIEmbedding(model=self.embedder, dimensions=self.embedder_dim)
index = load_index_from_storage(storage_context, embed_model=embed_model)
Expand Down Expand Up @@ -215,7 +216,7 @@ def __init__(
num_results: int = 10,
insert_variant: str | None = None,
search_variant: str | None = None,
):
) -> None:
super().__init__(
dataset,
num_results=num_results,
Expand All @@ -227,7 +228,7 @@ def __init__(
)

@cached_property
def client(self) -> Any:
def client(self) -> Any: # noqa: ANN401
import openai

return openai.OpenAI()
Expand Down Expand Up @@ -269,7 +270,9 @@ def insert_documents(self, max_workers: int | None = None) -> None:
files.append(temp_file.open("rb"))
if len(files) == max_files_per_batch or (i == self.dataset.docs_count() - 1):
self.client.vector_stores.file_batches.upload_and_poll(
vector_store_id=vector_store.id, files=files, max_concurrency=max_workers
vector_store_id=vector_store.id,
files=files,
max_concurrency=max_workers,
)
for f in files:
f.close()
Expand All @@ -283,7 +286,9 @@ def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[Sc
if not self.vector_store_id:
return []
response = self.client.vector_stores.search(
vector_store_id=self.vector_store_id, query=query, max_num_results=2 * num_results
vector_store_id=self.vector_store_id,
query=query,
max_num_results=2 * num_results,
)
scored_docs = [
ScoredDoc(
Expand Down
6 changes: 4 additions & 2 deletions src/raglite/_chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def start_chat() -> None:
TextInput(id="llm", label="LLM", initial=config.llm),
TextInput(id="embedder", label="Embedder", initial=config.embedder),
Switch(id="vector_search_query_adapter", label="Query adapter", initial=True),
]
],
).send()
await update_config(settings)

Expand Down Expand Up @@ -95,7 +95,9 @@ async def handle_message(user_message: cl.Message) -> None:
messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call]
messages.append({"role": "user", "content": user_prompt})
async for token in async_rag(
messages, on_retrieval=lambda x: chunk_spans.extend(x), config=config
messages,
on_retrieval=lambda x: chunk_spans.extend(x),
config=config,
):
await assistant_message.stream_token(token)
# Append RAG sources, if any.
Expand Down
55 changes: 38 additions & 17 deletions src/raglite/_chatml_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def _convert_chunks_to_completion(
{
"text": text,
"index": 0,
"logprobs": logprobs, # TODO: Improve accumulation of logprobs
"logprobs": logprobs, # TODO(lsorber): Improve accumulation of logprobs
"finish_reason": finish_reason, # type: ignore[typeddict-item]
}
},
],
}
# Add usage section if present in the chunks
Expand Down Expand Up @@ -131,7 +131,8 @@ def _stream_tool_calls(
prompt += f"functions.{tool_name}:\n"
try:
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
json.dumps(tool["function"]["parameters"]),
verbose=llama.verbose,
)
except Exception as e:
warnings.warn(
Expand All @@ -140,7 +141,8 @@ def _stream_tool_calls(
stacklevel=2,
)
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF, verbose=llama.verbose
llama_grammar.JSON_GBNF,
verbose=llama.verbose,
)
completion_or_chunks = llama.create_completion(
prompt=prompt,
Expand Down Expand Up @@ -182,7 +184,8 @@ def _stream_tool_calls(
"stop": [*completion_kwargs["stop"], ":", "</function_calls>"],
"max_tokens": None,
"grammar": llama_grammar.LlamaGrammar.from_string(
follow_up_gbnf_tool_grammar, verbose=llama.verbose
follow_up_gbnf_tool_grammar,
verbose=llama.verbose,
),
},
),
Expand Down Expand Up @@ -253,7 +256,7 @@ def chatml_function_calling_with_streaming(
grammar: Optional[llama.LlamaGrammar] = None, # type: ignore[name-defined]
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs: Any,
**kwargs: Any, # noqa: ANN401
) -> Union[
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
Expand Down Expand Up @@ -381,7 +384,10 @@ def chatml_function_calling_with_streaming(
or len(tools) == 0
):
prompt = template_renderer.render(
messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
messages=messages,
tools=[],
tool_calls=None,
add_generation_prompt=True,
)
return llama_chat_format._convert_completion_to_chat( # noqa: SLF001
llama.create_completion(
Expand All @@ -404,7 +410,10 @@ def chatml_function_calling_with_streaming(
assert tools
function_names = " | ".join([f'''"functions.{t["function"]["name"]}:"''' for t in tools])
prompt = template_renderer.render(
messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True
messages=messages,
tools=tools,
tool_calls=True,
add_generation_prompt=True,
)
initial_gbnf_tool_grammar = (
(
Expand All @@ -429,7 +438,8 @@ def chatml_function_calling_with_streaming(
"stream": False,
"max_tokens": None,
"grammar": llama_grammar.LlamaGrammar.from_string(
initial_gbnf_tool_grammar, verbose=llama.verbose
initial_gbnf_tool_grammar,
verbose=llama.verbose,
),
},
),
Expand All @@ -449,7 +459,10 @@ def chatml_function_calling_with_streaming(
# Case 2 step 2A: Respond with a message
if tool_name is None:
prompt = template_renderer.render(
messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
messages=messages,
tools=[],
tool_calls=None,
add_generation_prompt=True,
)
prompt += think
return llama_chat_format._convert_completion_to_chat( # noqa: SLF001
Expand All @@ -469,7 +482,12 @@ def chatml_function_calling_with_streaming(
prompt += "<function_calls>\n"
if stream:
return _stream_tool_calls(
llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar
llama,
prompt,
tools,
tool_name,
completion_kwargs,
follow_up_gbnf_tool_grammar,
)
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
completions: List[llama_types.CreateCompletionResponse] = []
Expand All @@ -479,7 +497,8 @@ def chatml_function_calling_with_streaming(
prompt += f"functions.{tool_name}:\n"
try:
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
json.dumps(tool["function"]["parameters"]),
verbose=llama.verbose,
)
except Exception as e:
warnings.warn(
Expand All @@ -488,7 +507,8 @@ def chatml_function_calling_with_streaming(
stacklevel=2,
)
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF, verbose=llama.verbose
llama_grammar.JSON_GBNF,
verbose=llama.verbose,
)
completion_or_chunks = llama.create_completion(
prompt=prompt,
Expand All @@ -515,7 +535,8 @@ def chatml_function_calling_with_streaming(
"stop": [*completion_kwargs["stop"], ":", "</function_calls>"], # type: ignore[misc]
"max_tokens": None,
"grammar": llama_grammar.LlamaGrammar.from_string(
follow_up_gbnf_tool_grammar, verbose=llama.verbose
follow_up_gbnf_tool_grammar,
verbose=llama.verbose,
),
},
),
Expand All @@ -533,7 +554,7 @@ def chatml_function_calling_with_streaming(
"finish_reason": "tool_calls",
"index": 0,
"logprobs": _convert_text_completion_logprobs_to_chat(
completion["choices"][0]["logprobs"]
completion["choices"][0]["logprobs"],
),
"message": {
"role": "assistant",
Expand All @@ -548,11 +569,11 @@ def chatml_function_calling_with_streaming(
},
}
for i, (tool_name, completion) in enumerate(
zip(completions_tool_name, completions, strict=True)
zip(completions_tool_name, completions, strict=True),
)
],
},
}
},
],
"usage": {
"completion_tokens": sum(
Expand Down
25 changes: 18 additions & 7 deletions src/raglite/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class RAGLiteCLIConfig(BaseSettings):
"""RAGLite CLI config."""

model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
env_prefix="RAGLITE_", env_file=".env", extra="allow"
env_prefix="RAGLITE_",
env_file=".env",
extra="allow",
)

mcp_server_name: str = "RAGLite"
Expand Down Expand Up @@ -67,7 +69,7 @@ def install_mcp_server(
claude_config_path = get_claude_config_path()
if not claude_config_path:
typer.echo(
"Please download the Claude desktop app from https://claude.ai/download before installing an MCP server."
"Please download the Claude desktop app from https://claude.ai/download before installing an MCP server.",
)
return
claude_config_filepath = claude_config_path / "claude_desktop_config.json"
Expand All @@ -88,7 +90,7 @@ def install_mcp_server(
"--python",
"3.11",
"--with",
"numpy<2.0.0", # TODO: Remove this constraint when uv no longer needs it to solve the environment.
"numpy<2.0.0", # TODO(lsorber): Remove this constraint when uv no longer needs it to solve the environment.
"raglite",
"mcp",
"run",
Expand All @@ -112,7 +114,9 @@ def run_mcp_server(
from raglite._mcp import create_mcp_server

config = RAGLiteConfig(
db_url=ctx.obj["db_url"], llm=ctx.obj["llm"], embedder=ctx.obj["embedder"]
db_url=ctx.obj["db_url"],
llm=ctx.obj["llm"],
embedder=ctx.obj["embedder"],
)
mcp = create_mcp_server(server_name, config=config)
mcp.run()
Expand All @@ -122,7 +126,10 @@ def run_mcp_server(
def bench(
ctx: typer.Context,
dataset_name: str = typer.Option(
"nano-beir/hotpotqa", "--dataset", "-d", help="Dataset to use from https://ir-datasets.com/"
"nano-beir/hotpotqa",
"--dataset",
"-d",
help="Dataset to use from https://ir-datasets.com/",
),
measure: str = typer.Option(
"AP@10",
Expand Down Expand Up @@ -157,7 +164,9 @@ def bench(
)
dataset = ir_datasets.load(dataset_name)
evaluator = RAGLiteEvaluator(
dataset, insert_variant=f"single-vector-{chunk_max_size // 4}t", config=config
dataset,
insert_variant=f"single-vector-{chunk_max_size // 4}t",
config=config,
)
index.append("RAGLite (single-vector)")
results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score()))
Expand All @@ -170,7 +179,9 @@ def bench(
)
dataset = ir_datasets.load(dataset_name)
evaluator = RAGLiteEvaluator(
dataset, insert_variant=f"multi-vector-{chunk_max_size // 4}t", config=config
dataset,
insert_variant=f"multi-vector-{chunk_max_size // 4}t",
config=config,
)
index.append("RAGLite (multi-vector)")
results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score()))
Expand Down
11 changes: 7 additions & 4 deletions src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@


# Lazily load the default search method to avoid circular imports.
# TODO: Replace with search_and_rerank_chunk_spans after benchmarking.
# TODO(lsorber): Replace with search_and_rerank_chunk_spans after benchmarking.
def _vector_search(
query: str, *, num_results: int = 8, config: "RAGLiteConfig | None" = None
query: str,
*,
num_results: int = 8,
config: "RAGLiteConfig | None" = None,
) -> tuple[list[ChunkId], list[float]]:
from raglite._search import vector_search

Expand All @@ -45,7 +48,7 @@ class RAGLiteConfig:
"llama-cpp-python/unsloth/Qwen3-8B-GGUF/*Q4_K_M.gguf@8192"
if llama_supports_gpu_offload()
else "llama-cpp-python/unsloth/Qwen3-4B-GGUF/*Q4_K_M.gguf@8192"
)
),
)
llm_max_tries: int = 4
# Embedder config used for indexing.
Expand All @@ -54,7 +57,7 @@ class RAGLiteConfig:
"llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf@512"
if llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 4 # noqa: PLR2004
else "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@512"
)
),
)
embedder_normalize: bool = True
# Chunk config used to partition documents into chunks.
Expand Down
Loading