Skip to content

Commit 32ab369

Browse files
committed
fixes
1 parent b004994 commit 32ab369

File tree

9 files changed

+118
-208
lines changed

9 files changed

+118
-208
lines changed

golem-embed/embed-cohere/src/client.rs

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,32 +59,22 @@ impl EmbeddingsApi {
5959

6060
fn parse_response<T: DeserializeOwned + Debug>(response: Response) -> Result<T, Error> {
6161
let status = response.status();
62-
match status.is_success() {
63-
true => {
64-
let response_text = response.text().map_err(|err| from_reqwest_error("Failed to read response body", err))?;
65-
match serde_json::from_str::<T>(&response_text) {
66-
Ok(response_data) => {
67-
trace!("Response from Cohere API: {response_data:?}");
68-
Ok(response_data)
69-
}
70-
Err(error) => {
71-
trace!("Error parsing response: {error:?}");
72-
Err(Error {
73-
code: error_code_from_status(status),
74-
message: format!("Failed to decode response body: {}", response_text),
75-
provider_error_json: Some(error.to_string()),
76-
})
77-
}
78-
}
79-
},
80-
false => {
81-
let error_text = response.text().ok();
62+
let response_text = response
63+
.text()
64+
.map_err(|err| from_reqwest_error("Failed to read response body", err))?;
65+
match serde_json::from_str::<T>(&response_text) {
66+
Ok(response_data) => {
67+
trace!("Response from Hugging Face API: {response_data:?}");
68+
Ok(response_data)
69+
}
70+
Err(error) => {
71+
trace!("Error parsing response: {error:?}");
8272
Err(Error {
8373
code: error_code_from_status(status),
84-
message: "Failed to parse response".to_string(),
85-
provider_error_json: error_text,
74+
message: format!("Failed to decode response body: {}", response_text),
75+
provider_error_json: Some(error.to_string()),
8676
})
87-
},
77+
}
8878
}
8979
}
9080

golem-embed/embed-hugging-face/src/client.rs

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -38,74 +38,51 @@ impl EmbeddingsApi {
3838
trace!("Sending request to Hugging Face API: {request:?}");
3939
let response = self
4040
.client
41-
.request(Method::POST, format!(
42-
"{BASE_URL}/models/{model}/pipeline/feature-extraction"
43-
))
41+
.request(
42+
Method::POST,
43+
format!("{BASE_URL}/models/{model}/pipeline/feature-extraction"),
44+
)
4445
.bearer_auth(&self.huggingface_api_key)
4546
.json(&request)
4647
.send()
4748
.map_err(|err| from_reqwest_error("Request failed", err))?;
4849
parse_response::<EmbeddingResponse>(response)
4950
}
5051

51-
52-
pub fn rerank(&self, request: RerankRequest, model: &str) -> Result<RerankResponse, Error> {
53-
trace!("Sending rerank request to Hugging Face API: {request:?}");
54-
let response = self
55-
.client
56-
.request(Method::POST, format!(
57-
"{BASE_URL}/models/{model}/pipeline/text-classification"
58-
))
59-
.bearer_auth(&self.huggingface_api_key)
60-
.json(&request)
61-
.send()
62-
.map_err(|err| from_reqwest_error("Request failed", err))?;
63-
parse_response::<RerankResponse>(response)
64-
}
6552
}
6653

6754
fn parse_response<T: DeserializeOwned + Debug>(response: Response) -> Result<T, Error> {
6855
let status = response.status();
69-
match status.is_success() {
70-
true => {
71-
let response_text = response.text().map_err(|err| from_reqwest_error("Failed to read response body", err))?;
72-
match serde_json::from_str::<T>(&response_text) {
73-
Ok(response_data) => {
74-
trace!("Response from Hugging Face API: {response_data:?}");
75-
Ok(response_data)
76-
}
77-
Err(error) => {
78-
trace!("Error parsing response: {error:?}");
79-
Err(Error {
80-
code: error_code_from_status(status),
81-
message: format!("Failed to decode response body: {}", response_text),
82-
provider_error_json: Some(error.to_string()),
83-
})
84-
}
85-
}
86-
},
87-
false => {
88-
let error_text = response.text().ok();
56+
let response_text = response
57+
.text()
58+
.map_err(|err| from_reqwest_error("Failed to read response body", err))?;
59+
match serde_json::from_str::<T>(&response_text) {
60+
Ok(response_data) => {
61+
trace!("Response from Hugging Face API: {response_data:?}");
62+
Ok(response_data)
63+
}
64+
Err(error) => {
65+
trace!("Error parsing response: {error:?}");
8966
Err(Error {
9067
code: error_code_from_status(status),
91-
message: "Failed to parse response".to_string(),
92-
provider_error_json: error_text,
68+
message: format!("Failed to decode response body: {}", response_text),
69+
provider_error_json: Some(error.to_string()),
9370
})
94-
},
71+
}
9572
}
9673
}
9774

9875
#[derive(Debug, Clone, Serialize, Deserialize)]
9976
pub struct EmbeddingRequest {
100-
pub input: Vec<String>,
77+
pub inputs: Vec<String>,
10178
#[serde(skip_serializing_if = "Option::is_none")]
10279
pub normalize: Option<bool>,
10380
/// The name of the prompt that should be used by for encoding.
10481
/// If not set, no prompt will be applied. Must be a key in the
10582
/// `sentence-transformers` configuration `prompts` dictionary.
106-
/// For example if `prompt_name` is query and the `prompts` is {query”: “query: , …},
107-
/// then the sentence What is the capital of France? will be encoded as
108-
/// query: What is the capital of France? because the prompt text will
83+
/// For example if `prompt_name` is "query" and the `prompts` is {"query": "query: ", …},
84+
/// then the sentence "What is the capital of France?" will be encoded as
85+
/// "query: What is the capital of France?" because the prompt text will
10986
/// be prepended before any text to encode.
11087
#[serde(skip_serializing_if = "Option::is_none")]
11188
pub prompt_name: Option<String>,

golem-embed/embed-hugging-face/src/conversions.rs

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub fn create_embedding_request(inputs: Vec<ContentPart>, config: Config) -> Res
2222
.unwrap_or_else(|| "sentence-transformers/all-MiniLM-L6-v2".to_string());
2323

2424
let request = EmbeddingRequest {
25-
input: input_texts,
25+
inputs: input_texts,
2626
normalize: Some(true),
2727
prompt_name: None,
2828
truncate: config.truncation,
@@ -52,46 +52,6 @@ pub fn process_embedding_response(
5252
})
5353
}
5454

55-
pub fn create_rerank_request(
56-
query: String,
57-
documents: Vec<String>,
58-
config: Config,
59-
) -> Result<(RerankRequest, String), Error> {
60-
let model = config
61-
.model
62-
.unwrap_or_else(|| "cross-encoder/ms-marco-MiniLM-L-2-v2".to_string());
63-
64-
let request = RerankRequest {
65-
query,
66-
documents,
67-
top_k: config.dimensions,
68-
return_documents: Some(true),
69-
};
70-
71-
Ok((request, model))
72-
}
73-
74-
pub fn process_rerank_response(
75-
response: RerankResponse,
76-
model: String,
77-
) -> Result<GolemRerankResponse, Error> {
78-
let mut results = Vec::new();
79-
for result in response.results {
80-
results.push(golem_embed::golem::embed::embed::RerankResult {
81-
index: result.index,
82-
relevance_score: result.relevance_score,
83-
document: result.document,
84-
});
85-
}
86-
87-
Ok(GolemRerankResponse {
88-
results,
89-
usage: None,
90-
model,
91-
provider_metadata_json: None,
92-
})
93-
}
94-
9555
#[cfg(test)]
9656
mod tests {
9757
use golem_embed::golem::embed::embed::{ImageUrl, OutputDtype, OutputFormat, TaskType};
@@ -113,7 +73,7 @@ mod tests {
11373
};
11474
let result = create_embedding_request(inputs, config);
11575
let (request, model) = result.unwrap();
116-
assert_eq!(request.input, vec!["Hello, world!"]);
76+
assert_eq!(request.inputs, vec!["Hello, world!"]);
11777
assert_eq!(model, "sentence-transformers/all-MiniLM-L6-v2");
11878
assert_eq!(request.normalize, Some(true));
11979
assert_eq!(request.truncate, Some(false));
@@ -152,27 +112,4 @@ mod tests {
152112
let request = create_embedding_request(inputs, config);
153113
assert!(request.is_err());
154114
}
155-
156-
#[test]
157-
fn test_create_rerank_request() {
158-
let query = "What is the capital of France?".to_string();
159-
let documents = vec!["Paris is the capital of France.".to_string()];
160-
let config = Config {
161-
model: Some("cross-encoder/ms-marco-MiniLM-L-2-v2".to_string()),
162-
dimensions: Some(10),
163-
user: Some("test_user".to_string()),
164-
task_type: Some(TaskType::RetrievalQuery),
165-
truncation: Some(false),
166-
output_format: Some(OutputFormat::FloatArray),
167-
output_dtype: Some(OutputDtype::FloatArray),
168-
provider_options: vec![],
169-
};
170-
let result = create_rerank_request(query.clone(), documents.clone(), config);
171-
let (request, model) = result.unwrap();
172-
assert_eq!(request.query, query);
173-
assert_eq!(request.documents, documents);
174-
assert_eq!(request.top_k, Some(10));
175-
assert_eq!(request.return_documents, Some(true));
176-
assert_eq!(model, "cross-encoder/ms-marco-MiniLM-L-2-v2");
177-
}
178115
}

golem-embed/embed-hugging-face/src/lib.rs

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,7 @@ impl HuggingFaceComponent {
3232
}
3333
}
3434

35-
fn rerank(
36-
client: EmbeddingsApi,
37-
query: String,
38-
documents: Vec<String>,
39-
config: Config,
40-
) -> Result<RerankResponse, Error> {
41-
let (request, model) = create_rerank_request(query, documents, config)?;
42-
match client.rerank(request, &model) {
43-
Ok(response) => process_rerank_response(response, model),
44-
Err(err) => Err(err),
45-
}
46-
}
35+
4736
}
4837

4938
impl Guest for HuggingFaceComponent {
@@ -60,10 +49,10 @@ impl Guest for HuggingFaceComponent {
6049
documents: Vec<String>,
6150
config: Config,
6251
) -> Result<RerankResponse, Error> {
63-
LOGGING_STATE.with_borrow_mut(|state| state.init());
64-
with_config_key(Self::ENV_VAR_NAME, Err, |huggingface_api_key| {
65-
let client = EmbeddingsApi::new(huggingface_api_key);
66-
Self::rerank(client, query, documents, config)
52+
Err(Error {
53+
code: ErrorCode::Unsupported,
54+
message: "Hugging Face inference does not support rerank".to_string(),
55+
provider_error_json: None,
6756
})
6857
}
6958
}

golem-embed/embed-openai/src/client.rs

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,32 +49,22 @@ impl EmbeddingsApi {
4949

5050
fn parse_response<T: DeserializeOwned + Debug>(response: Response) -> Result<T, Error> {
5151
let status = response.status();
52-
match status.is_success() {
53-
true => {
54-
let response_text = response.text().map_err(|err| from_reqwest_error("Failed to read response body", err))?;
55-
match serde_json::from_str::<T>(&response_text) {
56-
Ok(response_data) => {
57-
trace!("Response from OpenAI API: {response_data:?}");
58-
Ok(response_data)
59-
}
60-
Err(error) => {
61-
trace!("Error parsing response: {error:?}");
62-
Err(Error {
63-
code: error_code_from_status(status),
64-
message: format!("Failed to decode response body: {}", response_text),
65-
provider_error_json: Some(error.to_string()),
66-
})
67-
}
68-
}
69-
},
70-
false => {
71-
let error_text = response.text().ok();
52+
let response_text = response
53+
.text()
54+
.map_err(|err| from_reqwest_error("Failed to read response body", err))?;
55+
match serde_json::from_str::<T>(&response_text) {
56+
Ok(response_data) => {
57+
trace!("Response from Hugging Face API: {response_data:?}");
58+
Ok(response_data)
59+
}
60+
Err(error) => {
61+
trace!("Error parsing response: {error:?}");
7262
Err(Error {
7363
code: error_code_from_status(status),
74-
message: "Failed to parse response".to_string(),
75-
provider_error_json: error_text,
64+
message: format!("Failed to decode response body: {}", response_text),
65+
provider_error_json: Some(error.to_string()),
7666
})
77-
},
67+
}
7868
}
7969
}
8070

0 commit comments

Comments
 (0)