|
21 | 21 |
|
22 | 22 | import aiohttp
|
23 | 23 | import requests
|
| 24 | +import traceback |
24 | 25 | from langchain_core.callbacks import (
|
25 | 26 | AsyncCallbackManagerForLLMRun,
|
26 | 27 | CallbackManagerForLLMRun,
|
@@ -175,6 +176,7 @@ def _completion_with_retry(**kwargs: Any) -> Any:
|
175 | 176 | except TokenExpiredError as e:
|
176 | 177 | raise e
|
177 | 178 | except Exception as err:
|
| 179 | + traceback.print_exc() |
178 | 180 | logger.debug(
|
179 | 181 | f"Requests payload: {data}. Requests arguments: "
|
180 | 182 | f"url={self.endpoint},timeout={request_timeout},stream={stream}. "
|
@@ -221,6 +223,7 @@ async def _completion_with_retry(**kwargs: Any) -> Any:
|
221 | 223 | except TokenExpiredError as e:
|
222 | 224 | raise e
|
223 | 225 | except Exception as err:
|
| 226 | + traceback.print_exc() |
224 | 227 | logger.debug(
|
225 | 228 | f"Requests payload: `{data}`. "
|
226 | 229 | f"Stream mode={stream}. "
|
@@ -272,6 +275,7 @@ def _parse_stream(self, lines: Iterator[bytes]) -> Iterator[str]:
|
272 | 275 | An iterator that yields parsed lines as strings.
|
273 | 276 | """
|
274 | 277 | for line in lines:
|
| 278 | + print("***" + str(line)) |
275 | 279 | _line = self._parse_stream_line(line)
|
276 | 280 | if _line is not None:
|
277 | 281 | yield _line
|
@@ -307,13 +311,16 @@ def _parse_stream_line(self, line: bytes) -> Optional[str]:
|
307 | 311 | The processed line as a string if valid, otherwise `None`.
|
308 | 312 | """
|
309 | 313 | line = line.strip()
|
310 |
| - if line: |
311 |
| - _line = line.decode("utf-8") |
312 |
| - if "[DONE]" in _line: |
313 |
| - return None |
| 314 | + if not line: |
| 315 | + return None |
| 316 | + _line = line.decode("utf-8") |
| 317 | + |
| 318 | + if _line.lower().startswith("data:"): |
| 319 | + _line = _line[5:].lstrip() |
314 | 320 |
|
315 |
| - if _line.lower().startswith("data:"): |
316 |
| - return _line[5:].lstrip() |
| 321 | + if _line.startswith("[DONE]"): |
| 322 | + return None |
| 323 | + return _line |
317 | 324 | return None
|
318 | 325 |
|
319 | 326 | async def _aiter_sse(
|
@@ -589,11 +596,11 @@ def _stream(
|
589 | 596 | response = self.completion_with_retry(
|
590 | 597 | data=body, run_manager=run_manager, stream=True, **requests_kwargs
|
591 | 598 | )
|
592 |
| - |
593 | 599 | for line in self._parse_stream(response.iter_lines()):
|
594 | 600 | chunk = self._handle_sse_line(line)
|
595 | 601 | if run_manager:
|
596 | 602 | run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
| 603 | + |
597 | 604 | yield chunk
|
598 | 605 |
|
599 | 606 | async def _astream(
|
@@ -751,7 +758,7 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
|
751 | 758 |
|
752 | 759 | """
|
753 | 760 |
|
754 |
| - api: Literal["/generate", "/v1/completions"] = "/generate" |
| 761 | + api: Literal["/generate", "/v1/completions"] = "/v1/completions" |
755 | 762 | """Api spec."""
|
756 | 763 |
|
757 | 764 | frequency_penalty: float = 0.0
|
|
0 commit comments