Skip to content

Commit f162016

Browse files
committed
Adds basic TQDM progress bar to KG creation pipeline
1 parent 77db5e2 commit f162016

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

poetry.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pinecone-client = {version = "^4.1.0", optional = true}
3737
types-mock = "^5.1.0.20240425"
3838
eval-type-backport = "^0.2.0"
3939
pypdf = "^4.3.1"
40+
tqdm = "^4.66.5"
4041

4142
[tool.poetry.group.dev.dependencies]
4243
pylint = "^3.1.0"

src/neo4j_genai/experimental/pipeline/pipeline.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any, AsyncGenerator, Awaitable, Callable, Optional
2323

2424
from pydantic import BaseModel, Field
25+
from tqdm.asyncio import tqdm
2526

2627
from neo4j_genai.experimental.pipeline.component import Component, DataModel
2728
from neo4j_genai.experimental.pipeline.exceptions import (
@@ -410,6 +411,7 @@ def on_task_complete(self, node: TaskPipelineNode, result: RunResult) -> None:
410411
if result.result:
411412
res_to_save = result.result.model_dump()
412413
self.add_result_for_component(node.name, res_to_save, is_final=node.is_leaf())
414+
self.pbar.update(1)
413415

414416
def add_result_for_component(
415417
self, name: str, result: dict[str, Any] | None, is_final: bool = False
@@ -471,12 +473,14 @@ def validate_inputs_config(self, data: dict[str, Any]) -> None:
471473
task.validate_inputs_config(data)
472474

473475
async def run(self, data: dict[str, Any]) -> dict[str, Any]:
474-
logging.info("Starting pipeline")
476+
logger.debug("Starting pipeline")
477+
self.pbar = tqdm(total=len(self._nodes), desc="Creating knowledge graph")
475478
start_time = default_timer()
476479
self.validate_inputs_config(data)
477480
self.reinitialize()
478481
orchestrator = Orchestrator(self)
479482
await orchestrator.run(data)
480483
end_time = default_timer()
481-
logging.info(f"Pipeline finished in {end_time - start_time}s")
484+
self.pbar.close()
485+
logger.debug(f"Pipeline finished in {end_time - start_time}s")
482486
return self._final_results.all()

0 commit comments

Comments
 (0)