|
2 | 2 | import logging |
3 | 3 | from textwrap import dedent |
4 | 4 |
|
| 5 | +import tiktoken |
5 | 6 | from daiv.settings.components import DATA_DIR |
6 | 7 | from langchain_core.documents import Document |
7 | 8 | from langchain_core.embeddings import Embeddings |
@@ -129,21 +130,34 @@ def _build_content_to_embed(self, document: Document, description: str) -> str: |
129 | 130 | str: Content to embed |
130 | 131 | """ |
131 | 132 | if not description: |
132 | | - return dedent(f"""\ |
| 133 | + content = dedent(f"""\ |
133 | 134 | Repository: {document.metadata.get("repo_id", "")} |
134 | 135 | File Path: {document.metadata.get("source", "")} |
135 | 136 |
|
136 | 137 | {document.page_content} |
137 | 138 | """) |
138 | 139 | else: |
139 | | - return dedent(f"""\ |
| 140 | + content = dedent(f"""\ |
140 | 141 | Repository: {document.metadata.get("repo_id", "")} |
141 | 142 | File Path: {document.metadata.get("source", "")} |
142 | 143 | Description: {description} |
143 | 144 |
|
144 | 145 | {document.page_content} |
145 | 146 | """) |
146 | 147 |
|
| 148 | + count = self._embeddings_count_tokens(content) |
| 149 | + |
| 150 | + if count > settings.EMBEDDINGS_MAX_INPUT_TOKENS: |
| 151 | + logger.warning( |
| 152 | + "Chunk is too large, truncating: %s. Chunk tokens: %d, max allowed: %d", |
| 153 | + document.metadata["source"], |
| 154 | + self._embeddings_count_tokens(content), |
| 155 | + settings.EMBEDDINGS_MAX_INPUT_TOKENS, |
| 156 | + ) |
| 157 | + return content[: settings.EMBEDDINGS_MAX_INPUT_TOKENS] |
| 158 | + |
| 159 | + return content |
| 160 | + |
147 | 161 | def delete_documents(self, namespace: CodebaseNamespace, source: str | list[str]): |
148 | 162 | """ |
149 | 163 | Deletes documents from the namespace matching the given source(s). |
@@ -196,3 +210,15 @@ def as_retriever(self, namespace: CodebaseNamespace | None = None, **kwargs) -> |
196 | 210 | if namespace is None: |
197 | 211 | return PostgresRetriever(embeddings=self.embeddings, **kwargs) |
198 | 212 | return ScopedPostgresRetriever(namespace=namespace, embeddings=self.embeddings, **kwargs) |
| 213 | + |
| 214 | + def _embeddings_count_tokens(self, text: str) -> int: |
| 215 | + """ |
| 216 | + Count the number of tokens in the text. |
| 217 | + """ |
| 218 | + provider, model_name = settings.EMBEDDINGS_MODEL_NAME.split("/", 1) |
| 219 | + |
| 220 | + if provider == "voyageai": |
| 221 | + return self.embeddings._client.count_tokens([text], model=model_name) |
| 222 | + elif provider == "openai": |
| 223 | + return len(tiktoken.encoding_for_model(model_name).encode(text)) |
| 224 | + return len(tiktoken.get_encoding("cl100k_base").encode(text)) |
0 commit comments