Skip to content

Commit 6124fbd

Browse files
committed
add embeddings with openai
1 parent 45f02cd commit 6124fbd

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

scrapegraphai/nodes/generate_answer_node_k_level.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,24 @@ def execute(self, state: dict) -> dict:
9797

9898
client = state["vectorial_db"]
9999

100-
answer_db = client.query(
101-
collection_name="vectorial_collection",
102-
query_text=state["question"]
100+
if state.get("embeddings"):
101+
import openai
102+
openai_client = openai.Client()
103+
104+
answer_db = client.search(
105+
collection_name="collection",
106+
query_vector=openai_client.embeddings.create(
107+
input=["What is the best to use for vector search scaling?"],
108+
model=state.get("embeddings").get("model"),
109+
)
110+
.data[0]
111+
.embedding,
103112
)
113+
else:
114+
answer_db = client.query(
115+
collection_name="vectorial_collection",
116+
query_text=state["question"]
117+
)
104118

105119
## TODO: from the id get the data
106120
results_db = [elem for elem in state[answer_db]]

scrapegraphai/nodes/rag_node.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import List, Optional
55
from .base_node import BaseNode
66
from qdrant_client import QdrantClient
7+
from qdrant_client.models import PointStruct, VectorParams, Distance
78

89
class RAGNode(BaseNode):
910
"""
@@ -52,6 +53,41 @@ def execute(self, state: dict) -> dict:
5253
docs = [elem.get("summary") for elem in state.get("descriptions", {})]
5354
ids = [elem.get("id") for elem in state.get("descriptions", {})]
5455

56+
if state.get("embeddings"):
57+
import openai
58+
openai_client = openai.Client()
59+
60+
files = state.get("documents")
61+
62+
array_of_embeddings = []
63+
i=0
64+
65+
for file in files:
66+
embeddings = openai_client.embeddings.create(input=file,
67+
model=state.get("embeddings").get("model"))
68+
i+=1
69+
points = PointStruct(
70+
id=i,
71+
vector=embeddings,
72+
payload={"text": file},
73+
)
74+
75+
array_of_embeddings.append(points)
76+
77+
collection_name = "collection"
78+
79+
client.create_collection(
80+
collection_name,
81+
vectors_config=VectorParams(
82+
size=1536,
83+
distance=Distance.COSINE,
84+
),
85+
)
86+
client.upsert(collection_name, points)
87+
88+
state["vectorial_db"] = client
89+
return state
90+
5591
client.add(
5692
collection_name="vectorial_collection",
5793
documents=docs,

0 commit comments

Comments
 (0)