Skip to content

Commit 80791e8

Browse files
authored
Merge pull request #137 from anaregdesign/feature/similarity
feature: similarity
2 parents 2dda416 + 16c19fc commit 80791e8

File tree

4 files changed

+94
-1
lines changed

4 files changed

+94
-1
lines changed

src/openaivec/pandas_ext.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import logging
3838
from typing import Awaitable, Callable, Type, TypeVar
3939

40+
import numpy as np
4041
import pandas as pd
4142
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
4243
from pydantic import BaseModel
@@ -67,6 +68,11 @@
6768
_TIKTOKEN_ENCODING = tiktoken.encoding_for_model(_RESPONSES_MODEL_NAME)
6869

6970

71+
# internal method for accesing .ai accessor in spark udfs
72+
def _wakeup() -> None:
73+
pass
74+
75+
7076
def use(client: OpenAI) -> None:
7177
"""Register a custom OpenAI‑compatible client.
7278
@@ -460,6 +466,12 @@ def responses(
460466
)
461467
)
462468

469+
def similarity(self, col1: str, col2: str) -> pd.Series:
470+
return self._obj.apply(
471+
lambda row: np.dot(row[col1], row[col2]) / (np.linalg.norm(row[col1]) * np.linalg.norm(row[col2])),
472+
axis=1,
473+
).rename("similarity")
474+
463475

464476
@pd.api.extensions.register_series_accessor("aio")
465477
class AsyncOpenAIVecSeriesAccessor:

src/openaivec/spark.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,3 +458,21 @@ def fn(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
458458
yield part.map(lambda x: len(_TIKTOKEN_ENC.encode(x)) if isinstance(x, str) else 0)
459459

460460
return fn
461+
462+
463+
def similarity_udf() -> UserDefinedFunction:
464+
@pandas_udf(FloatType())
465+
def fn(a: pd.Series, b: pd.Series) -> pd.Series:
466+
"""Compute cosine similarity between two vectors.
467+
468+
Args:
469+
a: First vector.
470+
b: Second vector.
471+
472+
Returns:
473+
Cosine similarity between the two vectors.
474+
"""
475+
pandas_ext._wakeup()
476+
return pd.DataFrame({"a": a, "b": b}).ai.similarity("a", "b")
477+
478+
return fn

tests/test_pandas_ext.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,33 @@ def test_count_tokens(self):
203203

204204
# assert all values are elements of int
205205
self.assertTrue(all(isinstance(num_token, int) for num_token in num_tokens))
206+
207+
def test_similarity(self):
208+
sample_df = pd.DataFrame(
209+
{
210+
"vector1": [np.array([1, 0]), np.array([0, 1]), np.array([1, 1])],
211+
"vector2": [np.array([1, 0]), np.array([0, 1]), np.array([1, -1])],
212+
}
213+
)
214+
similarity_scores = sample_df.ai.similarity("vector1", "vector2")
215+
216+
# Expected cosine similarity values
217+
expected_scores = [
218+
1.0, # Cosine similarity between [1, 0] and [1, 0]
219+
1.0, # Cosine similarity between [0, 1] and [0, 1]
220+
0.0, # Cosine similarity between [1, 1] and [1, -1]
221+
]
222+
223+
# Assert similarity scores match expected values
224+
self.assertTrue(np.allclose(similarity_scores, expected_scores))
225+
226+
def test_similarity_with_invalid_vectors(self):
227+
sample_df = pd.DataFrame(
228+
{
229+
"vector1": [np.array([1, 0]), "invalid", np.array([1, 1])],
230+
"vector2": [np.array([1, 0]), np.array([0, 1]), np.array([1, -1])],
231+
}
232+
)
233+
234+
with self.assertRaises(TypeError):
235+
sample_df.ai.similarity("vector1", "vector2")

tests/test_spark.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
from pyspark.sql.session import SparkSession
77
from pyspark.sql.types import ArrayType, FloatType, IntegerType, StringType, StructField, StructType
88

9-
from openaivec.spark import EmbeddingsUDFBuilder, ResponsesUDFBuilder, _pydantic_to_spark_schema, count_tokens_udf
9+
from openaivec.spark import (
10+
EmbeddingsUDFBuilder,
11+
ResponsesUDFBuilder,
12+
_pydantic_to_spark_schema,
13+
count_tokens_udf,
14+
similarity_udf,
15+
)
1016

1117

1218
class TestUDFBuilder(TestCase):
@@ -143,3 +149,30 @@ def test_count_token(self):
143149
SELECT sentence, count_tokens(sentence) as token_count from sentences
144150
"""
145151
).show(truncate=False)
152+
153+
154+
class TestSimilarityUDF(TestCase):
155+
def setUp(self):
156+
self.spark: SparkSession = SparkSession.builder.getOrCreate()
157+
self.spark.sparkContext.setLogLevel("INFO")
158+
self.spark.udf.register("similarity", similarity_udf())
159+
160+
def test_similarity(self):
161+
df = self.spark.createDataFrame(
162+
[
163+
(1, [0.1, 0.2, 0.3]),
164+
(2, [0.4, 0.5, 0.6]),
165+
(3, [0.7, 0.8, 0.9]),
166+
],
167+
["id", "vector"],
168+
)
169+
df.createOrReplaceTempView("vectors")
170+
result_df = self.spark.sql(
171+
"""
172+
SELECT id, similarity(vector, vector) as similarity_score
173+
FROM vectors
174+
"""
175+
)
176+
result_df.show(truncate=False)
177+
df_pandas = result_df.toPandas()
178+
assert df_pandas.shape == (3, 2)

0 commit comments

Comments
 (0)