diff --git a/pyproject.toml b/pyproject.toml index 1bef8c1a..0b1b7f1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ classifiers = [ requires-python = ">=3.9,<4.0" [project.optional-dependencies] -burr = ["burr[start]==0.19.1"] +burr = ["burr[start]==0.22.1"] docs = ["sphinx==6.0", "furo==2024.5.6"] [build-system] diff --git a/scrapegraphai/integrations/burr_bridge.py b/scrapegraphai/integrations/burr_bridge.py index 0cac9f4d..019427ef 100644 --- a/scrapegraphai/integrations/burr_bridge.py +++ b/scrapegraphai/integrations/burr_bridge.py @@ -4,6 +4,8 @@ """ import re +import uuid +from hashlib import md5 from typing import Any, Dict, List, Tuple import inspect @@ -13,7 +15,7 @@ raise ImportError("burr package is not installed. Please install it with 'pip install scrapegraphai[burr]'") from burr import tracking -from burr.core import Application, ApplicationBuilder, State, Action, default +from burr.core import Application, ApplicationBuilder, State, Action, default, ApplicationContext from burr.lifecycle import PostRunStepHook, PreRunStepHook @@ -55,7 +57,7 @@ def writes(self) -> list[str]: def update(self, result: dict, state: State) -> State: return state.update(**result) - + def get_source(self) -> str: return inspect.getsource(self.node.__class__) @@ -100,13 +102,12 @@ class BurrBridge: def __init__(self, base_graph, burr_config): self.base_graph = base_graph self.burr_config = burr_config - self.project_name = burr_config.get("project_name", "default-project") - self.tracker = tracking.LocalTrackingClient(project=self.project_name) + self.project_name = burr_config.get("project_name", "scrapegraph: {}") self.app_instance_id = burr_config.get("app_instance_id", "default-instance") self.burr_inputs = burr_config.get("inputs", {}) self.burr_app = None - def _initialize_burr_app(self, initial_state: Dict[str, Any] = {}) -> Application: + def _initialize_burr_app(self, initial_state: Dict[str, Any] = None) -> Application: """ Initialize a Burr application from the base graph. @@ -116,24 +117,41 @@ def _initialize_burr_app(self, initial_state: Dict[str, Any] = {}) -> Applicatio Returns: Application: The Burr application instance. """ + if initial_state is None: + initial_state = {} actions = self._create_actions() transitions = self._create_transitions() hooks = [PrintLnHook()] burr_state = State(initial_state) - - app = ( + application_context = ApplicationContext.get() + builder = ( ApplicationBuilder() .with_actions(**actions) .with_transitions(*transitions) .with_entrypoint(self.base_graph.entry_point) .with_state(**burr_state) - .with_identifiers(app_id=self.app_instance_id) - .with_tracker(self.tracker) + .with_identifiers(app_id=str(uuid.uuid4())) # TODO -- grab this from state .with_hooks(*hooks) - .build() ) - return app + if application_context is not None: + builder = ( + builder + # if we're using a tracker, we want to copy it/pass in + .with_tracker( + application_context.tracker.copy() if application_context.tracker is not None else None + ) # remember to do `copy()` here! + .with_spawning_parent( + application_context.app_id, + application_context.sequence_id, + application_context.partition_key, + ) + ) + else: + # This is the case in which nothing is spawning it + # in this case, we want to create a new tracker from scratch + builder = builder.with_tracker(tracking.LocalTrackingClient(project=self.project_name)) + return builder.build() def _create_actions(self) -> Dict[str, Any]: """