Skip to content

Commit eec22c9

Browse files
committed
doc mapper: convert number handling to deserialization
change number deserialization in docmapper from json to generic deserialization. This improves codes reuse between different code paths, e.g. serialization and validation.
1 parent a91d2e7 commit eec22c9

File tree

6 files changed

+252
-123
lines changed

6 files changed

+252
-123
lines changed

quickwit/Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

quickwit/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ sea-query-binder = { version = "0.5", features = [
212212
# ^1.0.184 due to serde-rs/serde#2538
213213
serde = { version = "1.0.184", features = ["derive", "rc"] }
214214
serde_json = "1.0"
215-
serde_json_borrow = "0.5"
215+
serde_json_borrow = "0.7"
216216
serde_qs = { version = "0.12", features = ["warp"] }
217217
serde_with = "3.9.0"
218218
serde_yaml = "0.9"
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
// Copyright (C) 2024 Quickwit, Inc.
2+
//
3+
// Quickwit is offered under the AGPL v3.0 and as commercial software.
4+
// For commercial licensing, contact us at hello@quickwit.io.
5+
//
6+
// AGPL:
7+
// This program is free software: you can redistribute it and/or modify
8+
// it under the terms of the GNU Affero General Public License as
9+
// published by the Free Software Foundation, either version 3 of the
10+
// License, or (at your option) any later version.
11+
//
12+
// This program is distributed in the hope that it will be useful,
13+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
// GNU Affero General Public License for more details.
16+
//
17+
// You should have received a copy of the GNU Affero General Public License
18+
// along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
20+
use std::fmt::{self, Display};
21+
22+
use serde::de::{self, Deserializer, IntoDeserializer, Visitor};
23+
use serde::Deserialize;
24+
use serde_json::Value;
25+
26+
/// Deserialize a number.
27+
///
28+
/// If the value is a string, it can be optionally coerced to a number.
29+
fn deserialize_num_with_coerce<'de, T, D>(deserializer: D, coerce: bool) -> Result<T, String>
30+
where
31+
T: std::str::FromStr + Deserialize<'de>,
32+
T::Err: fmt::Display,
33+
D: Deserializer<'de>,
34+
{
35+
struct CoerceVisitor<T> {
36+
coerce: bool,
37+
marker: std::marker::PhantomData<T>,
38+
}
39+
40+
impl<'de, T> Visitor<'de> for CoerceVisitor<T>
41+
where
42+
T: std::str::FromStr + Deserialize<'de>,
43+
T::Err: fmt::Display,
44+
{
45+
type Value = T;
46+
47+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
48+
if self.coerce {
49+
formatter
50+
.write_str("any number of i64, u64, or f64 or a string that can be coerced")
51+
} else {
52+
formatter.write_str("any number of i64, u64, or f64")
53+
}
54+
}
55+
56+
fn visit_str<E>(self, v: &str) -> Result<T, E>
57+
where E: de::Error {
58+
if self.coerce {
59+
v.parse::<T>().map_err(|_e| {
60+
de::Error::custom(format!(
61+
"failed to coerce JSON string `\"{}\"` to {}",
62+
v,
63+
std::any::type_name::<T>(),
64+
))
65+
})
66+
} else {
67+
Err(de::Error::custom(format!(
68+
"expected JSON number, got string `\"{}\"`. enable coercion to {} with the \
69+
`coerce` parameter in the field mapping",
70+
v,
71+
std::any::type_name::<T>()
72+
)))
73+
}
74+
}
75+
76+
fn visit_i64<E>(self, v: i64) -> Result<T, E>
77+
where E: de::Error {
78+
T::deserialize(v.into_deserializer()).map_err(|_: E| {
79+
de::Error::custom(format!(
80+
"expected {}, got inconvertible JSON number `{}`",
81+
std::any::type_name::<T>(),
82+
v
83+
))
84+
})
85+
}
86+
87+
fn visit_u64<E>(self, v: u64) -> Result<T, E>
88+
where E: de::Error {
89+
T::deserialize(v.into_deserializer()).map_err(|_: E| {
90+
de::Error::custom(format!(
91+
"expected {}, got inconvertible JSON number `{}`",
92+
std::any::type_name::<T>(),
93+
v
94+
))
95+
})
96+
}
97+
98+
fn visit_f64<E>(self, v: f64) -> Result<T, E>
99+
where E: de::Error {
100+
T::deserialize(v.into_deserializer()).map_err(|_: E| {
101+
de::Error::custom(format!(
102+
"expected {}, got inconvertible JSON number `{}`",
103+
std::any::type_name::<T>(),
104+
v
105+
))
106+
})
107+
}
108+
109+
fn visit_map<M>(self, mut map: M) -> Result<T, M::Error>
110+
where M: de::MapAccess<'de> {
111+
let json_value: Value =
112+
Deserialize::deserialize(de::value::MapAccessDeserializer::new(&mut map))?;
113+
Err(de::Error::custom(error_message(json_value, self.coerce)))
114+
}
115+
116+
fn visit_seq<S>(self, mut seq: S) -> Result<T, S::Error>
117+
where S: de::SeqAccess<'de> {
118+
let json_value: Value =
119+
Deserialize::deserialize(de::value::SeqAccessDeserializer::new(&mut seq))?;
120+
Err(de::Error::custom(error_message(json_value, self.coerce)))
121+
}
122+
123+
fn visit_none<E>(self) -> Result<Self::Value, E>
124+
where E: de::Error {
125+
Err(de::Error::custom(error_message("null", self.coerce)))
126+
}
127+
128+
fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
129+
where E: de::Error {
130+
Err(de::Error::custom(error_message(v, self.coerce)))
131+
}
132+
}
133+
134+
deserializer
135+
.deserialize_any(CoerceVisitor {
136+
coerce,
137+
marker: std::marker::PhantomData,
138+
})
139+
.map_err(|err| err.to_string())
140+
}
141+
142+
fn error_message<T: Display>(got: T, coerce: bool) -> String {
143+
if coerce {
144+
format!("expected JSON number or string, got `{}`", got)
145+
} else {
146+
format!("expected JSON, got `{}`", got)
147+
}
148+
}
149+
150+
pub fn deserialize_i64<'de, D>(deserializer: D, coerce: bool) -> Result<i64, String>
151+
where D: Deserializer<'de> {
152+
deserialize_num_with_coerce(deserializer, coerce)
153+
}
154+
155+
pub fn deserialize_u64<'de, D>(deserializer: D, coerce: bool) -> Result<u64, String>
156+
where D: Deserializer<'de> {
157+
deserialize_num_with_coerce(deserializer, coerce)
158+
}
159+
160+
pub fn deserialize_f64<'de, D>(deserializer: D, coerce: bool) -> Result<f64, String>
161+
where D: Deserializer<'de> {
162+
deserialize_num_with_coerce(deserializer, coerce)
163+
}
164+
165+
#[cfg(test)]
166+
mod tests {
167+
use serde_json::json;
168+
169+
use super::*;
170+
171+
#[test]
172+
fn test_deserialize_i64_with_coercion() {
173+
let json_data = json!("-123");
174+
let result: i64 = deserialize_i64(json_data.into_deserializer(), true).unwrap();
175+
assert_eq!(result, -123);
176+
177+
let json_data = json!("456");
178+
let result: i64 = deserialize_i64(json_data.into_deserializer(), true).unwrap();
179+
assert_eq!(result, 456);
180+
}
181+
182+
#[test]
183+
fn test_deserialize_u64_with_coercion() {
184+
let json_data = json!("789");
185+
let result: u64 = deserialize_u64(json_data.into_deserializer(), true).unwrap();
186+
assert_eq!(result, 789);
187+
188+
let json_data = json!(123);
189+
let result: u64 = deserialize_u64(json_data.into_deserializer(), false).unwrap();
190+
assert_eq!(result, 123);
191+
}
192+
193+
#[test]
194+
fn test_deserialize_f64_with_coercion() {
195+
let json_data = json!("78.9");
196+
let result: f64 = deserialize_f64(json_data.into_deserializer(), true).unwrap();
197+
assert_eq!(result, 78.9);
198+
199+
let json_data = json!(45.6);
200+
let result: f64 = deserialize_f64(json_data.into_deserializer(), false).unwrap();
201+
assert_eq!(result, 45.6);
202+
}
203+
204+
#[test]
205+
fn test_deserialize_invalid_string_coercion() {
206+
let json_data = json!("abc");
207+
let result: Result<i64, _> = deserialize_i64(json_data.into_deserializer(), true);
208+
assert!(result.is_err());
209+
210+
let err_msg = result.unwrap_err().to_string();
211+
assert_eq!(err_msg, "failed to coerce JSON string `\"abc\"` to i64");
212+
}
213+
214+
#[test]
215+
fn test_deserialize_json_object() {
216+
let json_data = json!({ "key": "value" });
217+
let result: Result<i64, _> = deserialize_i64(json_data.into_deserializer(), true);
218+
assert!(result.is_err());
219+
220+
let err_msg = result.unwrap_err().to_string();
221+
assert_eq!(
222+
err_msg,
223+
"expected JSON number or string, got `{\"key\":\"value\"}`"
224+
);
225+
}
226+
}

quickwit/quickwit-doc-mapper/src/doc_mapper/doc_mapper_impl.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1846,7 +1846,7 @@ mod tests {
18461846
}"#,
18471847
"concat",
18481848
r#"{"some_int": 25}"#,
1849-
vec![25_u64.into()],
1849+
vec![25_i64.into()],
18501850
);
18511851
}
18521852

0 commit comments

Comments
 (0)