diff --git a/examples/llm_embedding.py b/examples/llm_embedding.py index e6dff763c..2fb07e7f1 100644 --- a/examples/llm_embedding.py +++ b/examples/llm_embedding.py @@ -88,12 +88,13 @@ def __init__(self, model: str = "embed-english-v3.0"): self.co = cohere.Client(api_key) def __call__(self, sentences: List[str]) -> Tensor: - from cohere import Embeddings + from cohere import EmbedResponse - items: Embeddings = self.co.embed(model=self.model, texts=sentences, - input_type="classification") - assert len(items) == len(sentences) - embeddings = torch.tensor(items.embeddings) + response: EmbedResponse = self.co.embed(model=self.model, + texts=sentences, + input_type="classification") + assert len(response.embeddings) == len(sentences) + embeddings = torch.tensor(response.embeddings) return embeddings @@ -138,7 +139,7 @@ def __call__(self, sentences: List[str]) -> Tensor: dataset = MultimodalTextBenchmark( root=path, name=args.dataset, - text_embedder_cfg=TextEmbedderConfig( + col_to_text_embedder_cfg=TextEmbedderConfig( text_embedder=text_encoder, batch_size=text_encoder.text_embedder_batch_size, ),