diff --git a/literalai/callback/langchain_callback.py b/literalai/callback/langchain_callback.py index 57db15a..ac5e9f2 100644 --- a/literalai/callback/langchain_callback.py +++ b/literalai/callback/langchain_callback.py @@ -425,6 +425,8 @@ def _on_run_update(self, run: Run) -> None: throughput = chat_start["token_count"] / duration else: throughput = None + kwargs = message.get("kwargs", {}) + usage_metadata = kwargs.get("usage_metadata", {}) message_completion = self._convert_message(message) current_step.generation = ChatGeneration( provider=provider, @@ -440,6 +442,9 @@ def _on_run_update(self, run: Run) -> None: for m in chat_start["input_messages"] ], message_completion=message_completion, + input_token_count=usage_metadata.get("input_tokens"), + output_token_count=usage_metadata.get("output_tokens"), + token_count=usage_metadata.get("total_tokens"), ) # find first message with prompt_id prompt_id = None @@ -469,6 +474,8 @@ def _on_run_update(self, run: Run) -> None: else: throughput = None completion = generation.get("text", "") + kwargs = message.get("kwargs", {}) + usage_metadata = kwargs.get("usage_metadata", {}) current_step.generation = CompletionGeneration( provider=provider, model=model, @@ -479,6 +486,9 @@ def _on_run_update(self, run: Run) -> None: tt_first_token=completion_start.get("tt_first_token"), prompt=completion_start["prompt"], completion=completion, + input_token_count=usage_metadata.get("input_tokens"), + output_token_count=usage_metadata.get("output_tokens"), + token_count=usage_metadata.get("total_tokens"), ) current_step.output = {"content": completion}