From 3b69572b133626dbbb5b654ce6f5aa2e59d63c94 Mon Sep 17 00:00:00 2001 From: ionmincu Date: Fri, 4 Apr 2025 14:58:52 +0300 Subject: [PATCH] feat(tracer): support traceable attribute --- sdk/langchain/pyproject.toml | 2 +- .../_cli/_runtime/_runtime.py | 4 +- .../uipath_langchain/_utils/__init__.py | 3 +- .../tracers/AsyncUiPathTracer.py | 163 ++++++---- .../uipath_langchain/tracers/__init__.py | 3 +- .../uipath_langchain/tracers/_events.py | 33 ++ .../tracers/_instrument_traceable.py | 285 ++++++++++++++++++ 7 files changed, 432 insertions(+), 61 deletions(-) create mode 100644 sdk/langchain/uipath_langchain/tracers/_events.py create mode 100644 sdk/langchain/uipath_langchain/tracers/_instrument_traceable.py diff --git a/sdk/langchain/pyproject.toml b/sdk/langchain/pyproject.toml index fd65947a..f1d3d93c 100644 --- a/sdk/langchain/pyproject.toml +++ b/sdk/langchain/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-langchain" -version = "0.0.83" +version = "0.0.84" description = "UiPath Langchain" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.9" diff --git a/sdk/langchain/uipath_langchain/_cli/_runtime/_runtime.py b/sdk/langchain/uipath_langchain/_cli/_runtime/_runtime.py index 333d211f..048a7832 100644 --- a/sdk/langchain/uipath_langchain/_cli/_runtime/_runtime.py +++ b/sdk/langchain/uipath_langchain/_cli/_runtime/_runtime.py @@ -14,6 +14,7 @@ UiPathRuntimeResult, ) +from ..._utils import _instrument_traceable from ...tracers import AsyncUiPathTracer from ._context import LangGraphRuntimeContext from ._exception import LangGraphRuntimeError @@ -43,6 +44,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]: Raises: LangGraphRuntimeError: If execution fails """ + _instrument_traceable() await self.validate() @@ -71,7 +73,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]: callbacks: List[BaseCallbackHandler] = [] if self.context.job_id and self.context.tracing_enabled: - tracer = AsyncUiPathTracer() + tracer = AsyncUiPathTracer(context=self.context.trace_context) await tracer.init_trace( self.context.entrypoint, self.context.job_id ) diff --git a/sdk/langchain/uipath_langchain/_utils/__init__.py b/sdk/langchain/uipath_langchain/_utils/__init__.py index 1902d730..31605f17 100644 --- a/sdk/langchain/uipath_langchain/_utils/__init__.py +++ b/sdk/langchain/uipath_langchain/_utils/__init__.py @@ -1,3 +1,4 @@ +from ..tracers._instrument_traceable import _instrument_traceable from ._request_mixin import UiPathRequestMixin -__all__ = ["UiPathRequestMixin"] +__all__ = ["UiPathRequestMixin", "_instrument_traceable"] diff --git a/sdk/langchain/uipath_langchain/tracers/AsyncUiPathTracer.py b/sdk/langchain/uipath_langchain/tracers/AsyncUiPathTracer.py index b90e637a..4173192c 100644 --- a/sdk/langchain/uipath_langchain/tracers/AsyncUiPathTracer.py +++ b/sdk/langchain/uipath_langchain/tracers/AsyncUiPathTracer.py @@ -6,13 +6,15 @@ import uuid import warnings from os import environ as env -from typing import Any, Optional +from typing import Any, Dict, Optional import httpx from langchain_core.tracers.base import AsyncBaseTracer from langchain_core.tracers.schemas import Run from pydantic import PydanticDeprecationWarning +from uipath_sdk._cli._runtime._contracts import UiPathTraceContext +from ._events import CustomTraceEvents, FunctionCallEventData from ._utils import _setup_tracer_httpx_logging, _simple_serialize_defaults logger = logging.getLogger(__name__) @@ -27,78 +29,98 @@ class Status: class AsyncUiPathTracer(AsyncBaseTracer): - def __init__(self, client=None, **kwargs): + def __init__( + self, + context: Optional[UiPathTraceContext] = None, + client: Optional[httpx.AsyncClient] = None, + **kwargs, + ): super().__init__(**kwargs) self.client = client or httpx.AsyncClient() self.retries = 3 self.log_queue: queue.Queue[dict[str, Any]] = queue.Queue() + self.context = context or UiPathTraceContext() + llm_ops_pattern = self._get_base_url() + "{orgId}/llmops_" - self.orgId = env.get( - "UIPATH_ORGANIZATION_ID", "00000000-0000-0000-0000-000000000000" - ) - self.tenantId = env.get( - "UIPATH_TENANT_ID", "00000000-0000-0000-0000-000000000000" - ) - self.url = llm_ops_pattern.format(orgId=self.orgId).rstrip("/") - self.auth_token = env.get("UNATTENDED_USER_ACCESS_TOKEN") or env.get( + self.url = llm_ops_pattern.format(orgId=self.context.org_id).rstrip("/") + + auth_token = env.get("UNATTENDED_USER_ACCESS_TOKEN") or env.get( "UIPATH_ACCESS_TOKEN" ) - self.jobKey = env.get("UIPATH_JOB_KEY") - self.folderKey = env.get("UIPATH_FOLDER_KEY") - self.processKey = env.get("UIPATH_PROCESS_UUID") - self.parent_span_id = env.get("UIPATH_PARENT_SPAN_ID") - - self.referenceId = self.jobKey or str(uuid.uuid4()) - - self.headers = { - "Authorization": f"Bearer {self.auth_token}", - } + self.headers = {"Authorization": f"Bearer {auth_token}"} self.running = True self.worker_task = asyncio.create_task(self._worker()) + self.function_call_run_map: Dict[str, Run] = {} + + async def on_custom_event( + self, + name: str, + data: Any, + *, + run_id: uuid.UUID, + tags=None, + metadata=None, + **kwargs: Any, + ) -> None: + if name == CustomTraceEvents.UIPATH_TRACE_FUNCTION_CALL: + # only handle the function call event + + if not isinstance(data, FunctionCallEventData): + logger.warning( + f"Received unexpected data type for function call event: {type(data)}" + ) + return - def _get_base_url(self) -> str: - uipath_url = ( - env.get("UIPATH_URL") or "https://cloud.uipath.com/dummyOrg/dummyTennant/" - ) - uipath_url = uipath_url.rstrip("/") + if data.event_type == "call": + run = self.run_map[str(run_id)] + child_run = run.create_child( + name=data.function_name, run_type=data.run_type, tags=data.tags + ) - # split by "//" to get ['', 'https:', 'alpha.uipath.com/ada/byoa'] - parts = uipath_url.split("//") + if data.metadata is not None: + run.add_metadata(data.metadata) - # after splitting by //, the base URL will be at index 1 along with the rest, - # hence split it again using "/" to get ['https:', 'alpha.uipath.com', 'ada', 'byoa'] - base_url_parts = parts[1].split("/") + call_uuid = data.call_uuid + self.function_call_run_map[call_uuid] = child_run - # combine scheme and netloc to get the base URL - base_url = parts[0] + "//" + base_url_parts[0] + "/" + self._send_span(run) - return base_url + if data.event_type == "completion": + call_uuid = data.call_uuid + previous_run = self.function_call_run_map.pop(call_uuid, None) + + if previous_run: + previous_run.end( + outputs=self._safe_dict_dump(data.output), error=data.error + ) + self._send_span(previous_run) async def init_trace(self, run_name, trace_id=None) -> None: - trace_id_env = env.get("UIPATH_TRACE_ID") + if self.context.trace_id: + # trace id already set no need to do anything + return - if trace_id_env: - self.trace_parent = trace_id_env - else: - await self.start_trace(run_name, trace_id) + # no trace id, start a new trace + await self.start_trace(run_name, trace_id) async def start_trace(self, run_name, trace_id=None) -> None: - self.trace_parent = trace_id or str(uuid.uuid4()) - run_name = run_name or f"Job Run: {self.trace_parent}" + self.context.trace_id = str(uuid.uuid4()) + + run_name = run_name or f"Job Run: {self.context.trace_id}" trace_data = { - "id": self.trace_parent, + "id": self.context.trace_id, "name": re.sub( "[!@#$<>\.]", "", run_name ), # if we use these characters the Agents UI throws some error (but llmops backend seems fine) - "referenceId": self.referenceId, + "referenceId": self.context.reference_id, "attributes": "{}", - "organizationId": self.orgId, - "tenantId": self.tenantId, + "organizationId": self.context.org_id, + "tenantId": self.context.tenant_id, } for attempt in range(self.retries): @@ -176,9 +198,9 @@ async def _worker(self): async def _persist_run(self, run: Run) -> None: # Determine if this is a start or end trace based on whether end_time is set - await self._send_span(run) + self._send_span(run) - async def _send_span(self, run: Run) -> None: + def _send_span(self, run: Run) -> None: """Send span data for a run to the API""" run_id = str(run.id) @@ -193,27 +215,27 @@ async def _send_span(self, run: Run) -> None: parent_id = ( str(run.parent_run_id) if run.parent_run_id is not None - else self.parent_span_id + else self.context.parent_span_id ) - attributes = self._safe_json_dump(self._run_to_dict(run)) + attributes = self._safe_jsons_dump(self._run_to_dict(run)) status = self._determine_status(run.error) span_data = { "id": run_id, "parentId": parent_id, - "traceId": self.trace_parent, + "traceId": self.context.trace_id, "name": run.name, "startTime": start_time, "endTime": end_time, - "referenceId": self.referenceId, + "referenceId": self.context.reference_id, "attributes": attributes, - "organizationId": self.orgId, - "tenantId": self.tenantId, + "organizationId": self.context.org_id, + "tenantId": self.context.tenant_id, "spanType": "LangGraphRun", "status": status, - "jobKey": self.jobKey, - "folderKey": self.folderKey, - "processKey": self.processKey, + "jobKey": self.context.job_id, + "folderKey": self.context.folder_key, + "processKey": self.context.folder_key, } self.log_queue.put(span_data) @@ -237,14 +259,23 @@ def _determine_status(self, error: Optional[str]): return Status.SUCCESS - def _safe_json_dump(self, obj) -> str: + def _safe_jsons_dump(self, obj) -> str: try: json_str = json.dumps(obj, default=_simple_serialize_defaults) return json_str except Exception as e: - logger.warning(e) + logger.warning(f"Error serializing object to JSON: {e}") return "{ }" + def _safe_dict_dump(self, obj) -> Dict[str, Any]: + try: + serialized = json.loads(json.dumps(obj, default=_simple_serialize_defaults)) + return serialized + except Exception as e: + # Last resort - string representation + logger.warning(f"Error serializing object to JSON: {e}") + return {"raw": str(obj)} + def _run_to_dict(self, run: Run): with warnings.catch_warnings(): warnings.simplefilter("ignore", category=PydanticDeprecationWarning) @@ -254,3 +285,21 @@ def _run_to_dict(self, run: Run): "inputs": run.inputs.copy() if run.inputs is not None else None, "outputs": run.outputs.copy() if run.outputs is not None else None, } + + def _get_base_url(self) -> str: + uipath_url = ( + env.get("UIPATH_URL") or "https://cloud.uipath.com/dummyOrg/dummyTennant/" + ) + uipath_url = uipath_url.rstrip("/") + + # split by "//" to get ['', 'https:', 'alpha.uipath.com/ada/byoa'] + parts = uipath_url.split("//") + + # after splitting by //, the base URL will be at index 1 along with the rest, + # hence split it again using "/" to get ['https:', 'alpha.uipath.com', 'ada', 'byoa'] + base_url_parts = parts[1].split("/") + + # combine scheme and netloc to get the base URL + base_url = parts[0] + "//" + base_url_parts[0] + "/" + + return base_url diff --git a/sdk/langchain/uipath_langchain/tracers/__init__.py b/sdk/langchain/uipath_langchain/tracers/__init__.py index 5cec8284..360d26bd 100644 --- a/sdk/langchain/uipath_langchain/tracers/__init__.py +++ b/sdk/langchain/uipath_langchain/tracers/__init__.py @@ -1,4 +1,5 @@ +from ._instrument_traceable import _instrument_traceable from .AsyncUiPathTracer import AsyncUiPathTracer from .UiPathTracer import UiPathTracer -__all__ = ["AsyncUiPathTracer", "UiPathTracer"] +__all__ = ["AsyncUiPathTracer", "UiPathTracer", "_instrument_traceable"] diff --git a/sdk/langchain/uipath_langchain/tracers/_events.py b/sdk/langchain/uipath_langchain/tracers/_events.py new file mode 100644 index 00000000..9cb1e8a7 --- /dev/null +++ b/sdk/langchain/uipath_langchain/tracers/_events.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, List, Literal, Optional + +RUN_TYPE_T = Literal[ + "tool", "chain", "llm", "retriever", "embedding", "prompt", "parser" +] + + +class CustomTraceEvents: + UIPATH_TRACE_FUNCTION_CALL = "__uipath_trace_function_call" + + +class FunctionCallEventData: + def __init__( + self, + function_name: str, + event_type: str, + inputs: Dict[str, Any], + call_uuid: str, + output: Any, + error: str, + run_type: Optional[RUN_TYPE_T] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + self.function_name = function_name + self.event_type = event_type + self.inputs = inputs + self.call_uuid = call_uuid + self.output = output + self.error = error + self.run_type = run_type or "chain" + self.tags = tags + self.metadata = metadata diff --git a/sdk/langchain/uipath_langchain/tracers/_instrument_traceable.py b/sdk/langchain/uipath_langchain/tracers/_instrument_traceable.py new file mode 100644 index 00000000..87b3979e --- /dev/null +++ b/sdk/langchain/uipath_langchain/tracers/_instrument_traceable.py @@ -0,0 +1,285 @@ +import functools +import importlib +import inspect +import logging +import sys +import uuid +from typing import Any, Dict, List, Literal, Optional + +from langchain_core.callbacks import dispatch_custom_event + +from ._events import CustomTraceEvents, FunctionCallEventData + +# Original module and traceable function references +original_langsmith: Any = None +original_traceable: Any = None + +logger = logging.getLogger(__name__) + + +def dispatch_trace_event( + func_name, + inputs: Dict[str, Any], + event_type="call", + call_uuid=None, + result=None, + exception=None, + run_type: Optional[ + Literal["tool", "chain", "llm", "retriever", "embedding", "prompt", "parser"] + ] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, +): + """Dispatch trace event to our server.""" + + event_data = FunctionCallEventData( + function_name=func_name, + event_type=event_type, + inputs=inputs, + call_uuid=call_uuid, + output=result, + error=str(exception), + run_type=run_type, + tags=tags, + metadata=metadata, + ) + dispatch_custom_event(CustomTraceEvents.UIPATH_TRACE_FUNCTION_CALL, event_data) + + +def format_args_for_trace( + signature: inspect.Signature, *args: Any, **kwargs: Any +) -> Dict[str, Any]: + try: + """Return a dictionary of inputs from the function signature.""" + # Create a parameter mapping by partially binding the arguments + parameter_binding = signature.bind_partial(*args, **kwargs) + + # Fill in default values for any unspecified parameters + parameter_binding.apply_defaults() + + # Extract the input parameters, skipping special Python parameters + result = {} + for name, value in parameter_binding.arguments.items(): + # Skip class and instance references + if name in ("self", "cls"): + continue + + # Handle **kwargs parameters specially + param_info = signature.parameters.get(name) + if param_info and param_info.kind == inspect.Parameter.VAR_KEYWORD: + # Flatten nested kwargs directly into the result + if isinstance(value, dict): + result.update(value) + else: + # Regular parameter + result[name] = value + + return result + except Exception as e: + logger.warning( + f"Error formatting arguments for trace: {e}. Using args and kwargs directly." + ) + return {"args": args, "kwargs": kwargs} + + +# Create patched version of traceable +def patched_traceable(*decorator_args, **decorator_kwargs): + # Handle the case when @traceable is used directly as decorator without arguments + if ( + len(decorator_args) == 1 + and callable(decorator_args[0]) + and not decorator_kwargs + ): + func = decorator_args[0] + return _create_appropriate_wrapper(func, original_traceable(func), {}) + + # Handle the case when @traceable(args) is used with parameters + original_decorated = original_traceable(*decorator_args, **decorator_kwargs) + + def uipath_trace_decorator(func): + # Apply the original decorator with its arguments + wrapped_func = original_decorated(func) + return _create_appropriate_wrapper(func, wrapped_func, decorator_kwargs) + + return uipath_trace_decorator + + +def _create_appropriate_wrapper( + original_func: Any, wrapped_func: Any, decorator_kwargs: Dict[str, Any] +): + """Create the appropriate wrapper based on function type.""" + + # Get the function name and tags from decorator arguments + func_name = decorator_kwargs.get("name", original_func.__name__) + tags = decorator_kwargs.get("tags", None) + metadata = decorator_kwargs.get("metadata", None) + run_type = decorator_kwargs.get("run_type", None) + + # Async generator function + if inspect.isasyncgenfunction(wrapped_func): + + @functools.wraps(wrapped_func) + async def async_gen_wrapper(*args, **kwargs): + try: + call_uuid = str(uuid.uuid4()) + + inputs = format_args_for_trace( + inspect.signature(original_func), *args, **kwargs + ) + + dispatch_trace_event( + func_name, + inputs, + "call", + call_uuid, + run_type=run_type, + tags=tags, + metadata=metadata, + ) + async_gen = wrapped_func(*args, **kwargs) + + results = [] + + async for item in async_gen: + results.append(item) + yield item + + dispatch_trace_event( + func_name, inputs, "completion", call_uuid, results + ) + except Exception as e: + dispatch_trace_event( + func_name, inputs, "completion", call_uuid, exception=e + ) + raise + + return async_gen_wrapper + + # Sync generator function + elif inspect.isgeneratorfunction(wrapped_func): + + @functools.wraps(wrapped_func) + def gen_wrapper(*args, **kwargs): + try: + call_uuid = str(uuid.uuid4()) + + inputs = format_args_for_trace( + inspect.signature(original_func), *args, **kwargs + ) + + results = [] + + dispatch_trace_event( + func_name, + inputs, + "call", + call_uuid, + run_type=run_type, + tags=tags, + metadata=metadata, + ) + gen = wrapped_func(*args, **kwargs) + for item in gen: + results.append(item) + yield item + dispatch_trace_event( + func_name, inputs, "completion", call_uuid, results + ) + except Exception as e: + dispatch_trace_event( + func_name, inputs, "completion", call_uuid, exception=e + ) + raise + + return gen_wrapper + + # Async function + elif inspect.iscoroutinefunction(wrapped_func): + + @functools.wraps(wrapped_func) + async def async_wrapper(*args, **kwargs): + try: + call_uuid = str(uuid.uuid4()) + + inputs = format_args_for_trace( + inspect.signature(original_func), *args, **kwargs + ) + + dispatch_trace_event( + func_name, + inputs, + "call", + call_uuid, + run_type=run_type, + tags=tags, + metadata=metadata, + ) + result = await wrapped_func(*args, **kwargs) + dispatch_trace_event(func_name, inputs, "completion", call_uuid, result) + return result + except Exception as e: + dispatch_trace_event( + func_name, inputs, "completion", call_uuid, exception=e + ) + raise + + return async_wrapper + + # Regular sync function (default case) + else: + + @functools.wraps(wrapped_func) + def sync_wrapper(*args, **kwargs): + try: + call_uuid = str(uuid.uuid4()) + + inputs = format_args_for_trace( + inspect.signature(original_func), *args, **kwargs + ) + + dispatch_trace_event( + func_name, + inputs, + "call", + call_uuid, + run_type=run_type, + tags=tags, + metadata=metadata, + ) + result = wrapped_func(*args, **kwargs) + dispatch_trace_event(func_name, inputs, "completion", call_uuid, result) + return result + except Exception as e: + dispatch_trace_event( + func_name, inputs, "completion", call_uuid, exception=e + ) + raise + + return sync_wrapper + + +# Apply the patch +def _instrument_traceable(): + """Apply the patch to langsmith module at import time.""" + global original_langsmith, original_traceable + + # Import the original module if not already done + if original_langsmith is None: + # Temporarily remove our custom module from sys.modules + if "langsmith" in sys.modules: + original_langsmith = sys.modules["langsmith"] + del sys.modules["langsmith"] + + # Import the original module + original_langsmith = importlib.import_module("langsmith") + + # Store the original traceable + original_traceable = original_langsmith.traceable + + # Replace the traceable function with our patched version + original_langsmith.traceable = patched_traceable + + # Put our modified module back + sys.modules["langsmith"] = original_langsmith + + return original_langsmith