Skip to content

Commit 2b4b5d2

Browse files
feat: add normalize option (#70)
1 parent 618076e commit 2b4b5d2

File tree

18 files changed

+268
-261
lines changed

18 files changed

+268
-261
lines changed

Cargo.lock

Lines changed: 204 additions & 184 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ authors = ["Olivier Dehaene"]
1717
homepage = "https://github.com/huggingface/text-embeddings-inference"
1818

1919
[patch.crates-io]
20-
cudarc = { git = "https://github.com/OlivierDehaene/cudarc", rev = "8be6ff46e4a2014fb563570e0d206c09aea88152" }
20+
cudarc = { git = "https://github.com/OlivierDehaene/cudarc", rev = "c19522f1e411ab453d71bdfad3383b118cd4216f" }
2121
candle = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-core" }
2222
candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-nn" }
2323
candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-transformers" }

backends/candle/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ candle-nn = { version = "0.3.0" }
1313
candle-transformers = { version = "0.3.0" }
1414
candle-flash-attn = { version = "0.3.0", optional = true }
1515
candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "62b75f1ea4e0961fad7b983ee8d723ed6fd68be5", optional = true }
16-
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "ffd246552c266640fab217f964a83960e07a66ec", optional = true }
16+
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "58684e116aae248c353f87846ddf0b2a8a7ed855", optional = true }
1717
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "5ed96012a693dff9685320765dd55a57fdaecdd6", optional = true }
1818
lazy_static = "^1.4"
1919
text-embeddings-backend-core = { path = "../core" }

backends/candle/src/models/bert.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,10 +554,7 @@ impl BertModel {
554554
}
555555
};
556556

557-
// Normalize
558-
let normalized_results = results.broadcast_div(&results.sqr()?.sum_keepdim(1)?.sqrt()?)?;
559-
560-
Ok(normalized_results)
557+
Ok(results)
561558
}
562559
}
563560

backends/candle/src/models/bert_quant.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,10 +448,7 @@ impl QuantBertModel {
448448
Pool::Mean => (outputs.sum_keepdim(0)? / (batch.max_length as f64))?,
449449
};
450450

451-
// Normalize
452-
let normalized_results = results.broadcast_div(&results.sqr()?.sum_keepdim(1)?.sqrt()?)?;
453-
454-
Ok(normalized_results)
451+
Ok(results)
455452
}
456453
}
457454

backends/candle/src/models/flash_bert.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,10 +383,7 @@ impl FlashBertModel {
383383
}
384384
};
385385

386-
// Normalize
387-
let normalized_results = results.broadcast_div(&results.sqr()?.sum_keepdim(1)?.sqrt()?)?;
388-
389-
Ok(normalized_results)
386+
Ok(results)
390387
}
391388
}
392389

backends/candle/src/models/jina.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,10 +582,7 @@ impl JinaBertModel {
582582
}
583583
};
584584

585-
// Normalize
586-
let normalized_results = results.broadcast_div(&results.sqr()?.sum_keepdim(1)?.sqrt()?)?;
587-
588-
Ok(normalized_results)
585+
Ok(results)
589586
}
590587
}
591588

backends/python/server/text_embeddings_server/models/default_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
4242

4343
output = self.model(**kwargs)
4444
embedding = output[0][:, 0]
45-
results = torch.nn.functional.normalize(embedding, p=2, dim=1)
46-
47-
cpu_results = results.view(-1).tolist()
45+
cpu_results = embedding.view(-1).tolist()
4846

4947
return [
5048
Embedding(

backends/python/server/text_embeddings_server/models/flash_bert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,7 @@ def embed(self, batch: FlashBatch) -> List[Embedding]:
243243
cu_seqlens=batch.cu_seqlens,
244244
max_s=batch.max_s,
245245
)
246-
results = torch.nn.functional.normalize(embedding, p=2, dim=1)
247-
cpu_results = results.view(-1).tolist()
246+
cpu_results = embedding.view(-1).tolist()
248247

249248
return [
250249
Embedding(

backends/python/server/text_embeddings_server/utils/tracing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def _start_span(self, handler_call_details, context, set_status_on_exception=Fal
5555

5656

5757
def setup_tracing(otlp_endpoint: str):
58-
resource = Resource.create(attributes={"service.name": f"text-embeddings-inference.server"})
58+
resource = Resource.create(
59+
attributes={"service.name": f"text-embeddings-inference.server"}
60+
)
5961
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
6062
span_processor = BatchSpanProcessor(span_exporter)
6163

0 commit comments

Comments
 (0)