Skip to content

Commit 3d9ecfd

Browse files
committed
Update tests.
1 parent ce6d4e5 commit 3d9ecfd

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import aiohttp
2323
import requests
24+
import traceback
2425
from langchain_core.callbacks import (
2526
AsyncCallbackManagerForLLMRun,
2627
CallbackManagerForLLMRun,
@@ -175,6 +176,7 @@ def _completion_with_retry(**kwargs: Any) -> Any:
175176
except TokenExpiredError as e:
176177
raise e
177178
except Exception as err:
179+
traceback.print_exc()
178180
logger.debug(
179181
f"Requests payload: {data}. Requests arguments: "
180182
f"url={self.endpoint},timeout={request_timeout},stream={stream}. "
@@ -221,6 +223,7 @@ async def _completion_with_retry(**kwargs: Any) -> Any:
221223
except TokenExpiredError as e:
222224
raise e
223225
except Exception as err:
226+
traceback.print_exc()
224227
logger.debug(
225228
f"Requests payload: `{data}`. "
226229
f"Stream mode={stream}. "
@@ -272,6 +275,7 @@ def _parse_stream(self, lines: Iterator[bytes]) -> Iterator[str]:
272275
An iterator that yields parsed lines as strings.
273276
"""
274277
for line in lines:
278+
print("***" + str(line))
275279
_line = self._parse_stream_line(line)
276280
if _line is not None:
277281
yield _line
@@ -307,13 +311,16 @@ def _parse_stream_line(self, line: bytes) -> Optional[str]:
307311
The processed line as a string if valid, otherwise `None`.
308312
"""
309313
line = line.strip()
310-
if line:
311-
_line = line.decode("utf-8")
312-
if "[DONE]" in _line:
313-
return None
314+
if not line:
315+
return None
316+
_line = line.decode("utf-8")
317+
318+
if _line.lower().startswith("data:"):
319+
_line = _line[5:].lstrip()
314320

315-
if _line.lower().startswith("data:"):
316-
return _line[5:].lstrip()
321+
if _line.startswith("[DONE]"):
322+
return None
323+
return _line
317324
return None
318325

319326
async def _aiter_sse(
@@ -589,11 +596,11 @@ def _stream(
589596
response = self.completion_with_retry(
590597
data=body, run_manager=run_manager, stream=True, **requests_kwargs
591598
)
592-
593599
for line in self._parse_stream(response.iter_lines()):
594600
chunk = self._handle_sse_line(line)
595601
if run_manager:
596602
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
603+
597604
yield chunk
598605

599606
async def _astream(
@@ -751,7 +758,7 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
751758
752759
"""
753760

754-
api: Literal["/generate", "/v1/completions"] = "/generate"
761+
api: Literal["/generate", "/v1/completions"] = "/v1/completions"
755762
"""Api spec."""
756763

757764
frequency_penalty: float = 0.0

0 commit comments

Comments
 (0)