Skip to content

Commit 55af917

Browse files
Merge pull request #139 from StarlightSearch/add-siglip
Add-siglip
2 parents ccb3bf9 + 67c1d39 commit 55af917

File tree

18 files changed

+820
-145
lines changed

18 files changed

+820
-145
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/clip.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# Load the model.
99
model = embed_anything.EmbeddingModel.from_pretrained_hf(
1010
embed_anything.WhichModel.Clip,
11-
model_id="openai/clip-vit-base-patch16",
11+
model_id="google/siglip-base-patch16-224",
1212
)
1313
data: list[EmbedData] = embed_anything.embed_image_directory(
1414
"test_files", embedder=model
@@ -17,10 +17,8 @@
1717
# Convert the embeddings to a numpy array
1818
embeddings = np.array([data.embedding for data in data])
1919

20-
print(data[0])
21-
2220
# Embed a query
23-
query = ["Photo of a monkey?"]
21+
query = ["Photo of a monkey"]
2422
query_embedding = np.array(
2523
embed_anything.embed_query(query, embedder=model)[0].embedding
2624
)
@@ -31,7 +29,14 @@
3129
# Find the index of the most similar embedding
3230
max_index = np.argmax(similarities)
3331

32+
print("Descending order of similarity: ")
33+
indices = np.argsort(similarities)[::-1]
34+
for idx in indices:
35+
print(data[idx].text)
36+
37+
print("----------- ")
38+
3439
# Print the most similar image
35-
print(data[max_index].text)
40+
print("Most similar image: ", data[max_index].text)
3641
end = time.time()
3742
print("Time taken: ", end - start)

examples/cohere_pdf.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from embed_anything import EmbeddingModel, TextEmbedConfig, WhichModel
2+
import numpy as np
3+
from pathlib import Path
4+
from tabulate import tabulate
5+
from embed_anything import EmbedData
6+
from pdf2image import convert_from_path
7+
8+
9+
# Initialize the model once
10+
model: EmbeddingModel = EmbeddingModel.from_pretrained_cloud(
11+
WhichModel.CohereVision, model_id="embed-v4.0"
12+
)
13+
14+
15+
# Get all PDF files in the directory
16+
directory = Path("test_files")
17+
files = directory.glob("*.pdf")
18+
# files = [Path("test_files/attention.pdf")]
19+
20+
file_embed_data: list[EmbedData] = []
21+
for file in files:
22+
try:
23+
embedding: list[EmbedData] = model.embed_file(
24+
str(file), TextEmbedConfig(batch_size=8)
25+
)
26+
file_embed_data.extend(embedding)
27+
except Exception as e:
28+
print(f"Error embedding file {file}: {e}")
29+
30+
# Define the query
31+
query = "What are the Bleu score results for the attention paper?"
32+
33+
# Scoring
34+
file_embeddings = np.array([e.embedding for e in file_embed_data])
35+
query_embedding = model.embed_query([query])
36+
query_embeddings = np.array([e.embedding for e in query_embedding])
37+
print(file_embeddings.shape)
38+
print(query_embeddings.shape)
39+
40+
41+
scores = np.dot(query_embeddings, file_embeddings.T).squeeze()
42+
43+
# Get top pages
44+
top_pages = np.argsort(scores)[-5:][::-1].tolist() # Convert to list
45+
46+
print(top_pages)
47+
# Extract file names and page numbers
48+
table = [
49+
[
50+
file_embed_data[int(page)].metadata["file_path"],
51+
file_embed_data[int(page)].metadata["page_number"],
52+
]
53+
for page in top_pages
54+
]
55+
56+
# Print the results in a table
57+
print(tabulate(table, headers=["File Name", "Page Number"], tablefmt="grid"))
58+
59+
images = [file_embed_data[int(page)].metadata["image"] for page in top_pages]

processors/src/markdown_processor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ impl MarkdownProcessor {
2020
impl DocumentProcessor for MarkdownProcessor {
2121

2222
fn process_document(&self, content: &str) -> anyhow::Result<Document> {
23-
let chunks = self.splitter.chunks(content).into_iter()
23+
let chunks = self.splitter.chunks(content)
2424
.map(|x| x.to_string())
2525
.collect();
2626
Ok(Document {

python/python/embed_anything/_embed_anything.pyi

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def embed_html(
239239
file_name: The path to the HTML file to embed.
240240
embedder: The embedding model to use.
241241
origin: The origin of the HTML file.
242-
config: The configuration for the embedding model.
242+
config: The configuration for the embedding model.
243243
adapter: The adapter to use for storing the embeddings.
244244
245245
Returns:
@@ -259,7 +259,6 @@ def embed_html(
259259
```
260260
"""
261261

262-
263262
def embed_audio_file(
264263
file_path: str,
265264
audio_decoder: AudioDecoderModel,
@@ -542,19 +541,26 @@ class ImageEmbedConfig:
542541
543542
Attributes:
544543
buffer_size: The buffer size for the Image Embedding model. Default is 100.
544+
batch_size: The batch size for processing the embeddings. Default is 32. Based on the memory, you can increase or decrease the batch size.
545545
"""
546546

547-
def __init__(self, buffer_size: int | None = None):
547+
def __init__(self, buffer_size: int | None = None, batch_size: int | None = None):
548548
self.buffer_size = buffer_size
549+
self.batch_size = batch_size
549550
buffer_size: int | None
551+
batch_size: int | None
550552

551553
class EmbeddingModel:
552554
"""
553555
Represents an embedding model.
554556
"""
555557

556558
def from_pretrained_hf(
557-
model: WhichModel, model_id: str, revision: str | None = None, token: str | None = None, dtype: Dtype | None = None
559+
model: WhichModel,
560+
model_id: str,
561+
revision: str | None = None,
562+
token: str | None = None,
563+
dtype: Dtype | None = None,
558564
) -> EmbeddingModel:
559565
"""
560566
Loads an embedding model from the Hugging Face model hub.
@@ -586,9 +592,12 @@ class EmbeddingModel:
586592
Attributes:
587593
model (WhichModel): The cloud service to use. Currently supports WhichModel.OpenAI and WhichModel.Cohere.
588594
model_id (str): The ID of the model to use.
595+
589596
- For OpenAI, see available models at https://platform.openai.com/docs/guides/embeddings/embedding-models
590597
- For Cohere, see available models at https://docs.cohere.com/docs/cohere-embed
598+
- For CohereVision, see available models at https://docs.cohere.com/docs/cohere-embed
591599
api_key (str | None, optional): The API key for accessing the model. If not provided, it is taken from the environment variable:
600+
592601
- For OpenAI: OPENAI_API_KEY
593602
- For Cohere: CO_API_KEY
594603
@@ -680,6 +689,7 @@ class EmbeddingModel:
680689
Returns:
681690
A list of EmbedData objects.
682691
"""
692+
683693
def embed_files_batch(
684694
self,
685695
files: list[str],
@@ -697,6 +707,7 @@ class EmbeddingModel:
697707
Returns:
698708
A list of EmbedData objects.
699709
"""
710+
700711
def embed_audio_file(
701712
self,
702713
audio_file: str,
@@ -714,6 +725,7 @@ class EmbeddingModel:
714725
Returns:
715726
A list of EmbedData objects.
716727
"""
728+
717729
def embed_query(
718730
self,
719731
query: list[str],
@@ -747,6 +759,7 @@ class EmbeddingModel:
747759
Returns:
748760
A list of EmbedData objects.
749761
"""
762+
750763
def embed_directory(
751764
self,
752765
directory: str,
@@ -764,6 +777,7 @@ class EmbeddingModel:
764777
Returns:
765778
A list of EmbedData objects.
766779
"""
780+
767781
def embed_directory_stream(
768782
self,
769783
directory: str,
@@ -781,6 +795,7 @@ class EmbeddingModel:
781795
Returns:
782796
A list of EmbedData objects.
783797
"""
798+
784799
def embed_webpage(
785800
self,
786801
url: str,
@@ -798,6 +813,7 @@ class EmbeddingModel:
798813
Returns:
799814
A list of EmbedData objects.
800815
"""
816+
801817
class AudioDecoderModel:
802818
"""
803819
Represents an audio decoder model.
@@ -835,13 +851,15 @@ class AudioDecoderModel:
835851
class WhichModel(Enum):
836852
OpenAI = ("OpenAI",)
837853
Cohere = ("Cohere",)
854+
CohereVision = ("CohereVision",)
838855
Bert = ("Bert",)
839856
Jina = ("Jina",)
840857
Clip = ("Clip",)
841858
Colpali = ("Colpali",)
842859
ColBert = ("ColBert",)
843860
SparseBert = ("SparseBert",)
844861
ModernBert = ("ModernBert",)
862+
845863
class ONNXModel(Enum):
846864
"""
847865
Enum representing various ONNX models.
@@ -952,4 +970,4 @@ class ONNXModel(Enum):
952970

953971
SPLADEPPENV2 = "SPLADEPPENV2"
954972

955-
ModernBERTBase = "ModernBERTBase"
973+
ModernBERTBase = "ModernBERTBase"

python/src/config.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,20 @@ pub struct ImageEmbedConfig {
7373
#[pymethods]
7474
impl ImageEmbedConfig {
7575
#[new]
76-
#[pyo3(signature = (buffer_size=None))]
77-
pub fn new(buffer_size: Option<usize>) -> Self {
76+
#[pyo3(signature = (buffer_size=None, batch_size=None))]
77+
pub fn new(buffer_size: Option<usize>, batch_size: Option<usize>) -> Self {
7878
Self {
79-
inner: embed_anything::config::ImageEmbedConfig::new(buffer_size),
79+
inner: embed_anything::config::ImageEmbedConfig::new(buffer_size, batch_size),
8080
}
8181
}
8282

8383
#[getter]
8484
pub fn buffer_size(&self) -> Option<usize> {
8585
self.inner.buffer_size
8686
}
87+
88+
#[getter]
89+
pub fn batch_size(&self) -> Option<usize> {
90+
self.inner.batch_size
91+
}
8792
}

python/src/lib.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ impl EmbedData {
8282
pub enum WhichModel {
8383
OpenAI,
8484
Cohere,
85+
CohereVision,
8586
Bert,
8687
SparseBert,
8788
ColBert,
@@ -275,6 +276,18 @@ impl EmbeddingModel {
275276
inner: Arc::new(model),
276277
})
277278
}
279+
WhichModel::CohereVision => {
280+
let model_id = model_id.unwrap_or("embed-v4.0");
281+
let model = Embedder::Vision(VisionEmbedder::Cohere(
282+
embed_anything::embeddings::cloud::cohere::CohereEmbedder::new(
283+
model_id.to_string(),
284+
api_key,
285+
),
286+
));
287+
Ok(EmbeddingModel {
288+
inner: Arc::new(model),
289+
})
290+
}
278291
_ => panic!("Invalid model"),
279292
}
280293
}
@@ -668,7 +681,6 @@ pub fn embed_directory(
668681
let embedding_model = &embedder.inner;
669682

670683
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
671-
println!("Runtime created");
672684
let adapter = match adapter {
673685
Some(adapter) => {
674686
let callback = move |data: Vec<embed_anything::embeddings::embed::EmbedData>| {
@@ -725,8 +737,6 @@ pub fn embed_image_directory(
725737
let embedding_model = &embedder.inner;
726738
let config = config.map(|c| &c.inner);
727739
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
728-
println!("Runtime created");
729-
730740
let adapter = match adapter {
731741
Some(adapter) => {
732742
let callback = move |data: Vec<embed_anything::embeddings::embed::EmbedData>| {

rust/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ statistical = "1.0.0"
6969
half = "2.4.1"
7070
candle-flash-attn = { workspace = true, optional = true }
7171

72+
# Logging
73+
log = "0.4"
74+
7275
[dev-dependencies]
7376
tempdir = "0.3.7"
7477
lazy_static = "1.4.0"

rust/examples/clip.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ async fn main() {
1111

1212
let model = EmbedderBuilder::new()
1313
.model_architecture("clip")
14-
.model_id(Some("openai/clip-vit-base-patch32"))
14+
.model_id(Some("google/siglip-base-patch16-224"))
1515
.revision(None)
1616
.token(None)
1717
.from_pretrained_hf()
@@ -22,7 +22,8 @@ async fn main() {
2222
.unwrap()
2323
.unwrap();
2424

25-
let query_emb_data = embed_query(&["Photo of a monkey"], &model, None)
25+
26+
let query_emb_data = embed_query(&["Photo of a monkey?"], &model, None)
2627
.await
2728
.unwrap();
2829
let n_vectors = out.len();
@@ -68,18 +69,18 @@ async fn main() {
6869
.unwrap()
6970
.to_vec1::<f32>()
7071
.unwrap();
72+
7173
let mut indices: Vec<usize> = (0..similarities.len()).collect();
7274
indices.sort_by(|a, b| similarities[*b].partial_cmp(&similarities[*a]).unwrap());
75+
76+
println!("Descending order of similarity: ");
77+
for idx in &indices {
78+
println!("{}", image_paths[*idx]);
79+
}
7380

74-
let top_3_indices = indices[0..3].to_vec();
75-
let top_3_image_paths = top_3_indices
76-
.iter()
77-
.map(|i| image_paths[*i].clone())
78-
.collect::<Vec<String>>();
79-
80-
let similar_image = top_3_image_paths[0].clone();
81+
println!("-----------");
8182

82-
println!("{:?}", similar_image);
83+
println!("Most similar image: {}", image_paths[indices[0]]);
8384

8485
let elapsed_time = now.elapsed();
8586
println!("Elapsed Time: {}", elapsed_time.as_secs_f32());

0 commit comments

Comments
 (0)