Skip to content

Commit 8eae031

Browse files
authored
feat(ai): ModelCosts version 2 (#4825)
Introduces new field `models` to ModelCosts struct which will be used to calculate cost of LLM calls by type of tokens used. Current solution was not flexible enough, and it only supported 2 type of tokens: input and output tokens. Cost calculation has gotten more complex in the last year since this was first implemented, and we need more flexible structure which will also support future changes in cost calculation. This PR only introduces new struct, but the cost calculation will be changed in a separate PR to make this PR smaller and easier to review. Part of [TET-645: Automate cost/pricing updating](https://linear.app/getsentry/issue/TET-645/automate-costpricing-updating)
1 parent 6d87298 commit 8eae031

File tree

5 files changed

+282
-35
lines changed

5 files changed

+282
-35
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
**Internal**:
1212

1313
- Always combine replay payloads and remove feature flag guarding it. ([#4812](https://github.com/getsentry/relay/pull/4812))
14+
- Added version 2 of LLM cost specification. ([#4825](https://github.com/getsentry/relay/pull/4825))
1415

1516
## 25.6.0
1617

relay-dynamic-config/src/global.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ pub struct GlobalConfig {
4747
pub metric_extraction: ErrorBoundary<MetricExtractionGroups>,
4848

4949
/// Configuration for AI span measurements.
50-
#[serde(skip_serializing_if = "is_missing")]
50+
#[serde(skip_serializing_if = "is_model_costs_empty")]
5151
pub ai_model_costs: ErrorBoundary<ModelCosts>,
5252

5353
/// Configuration to derive the `span.op` from other span fields.
@@ -378,11 +378,8 @@ fn is_ok_and_empty(value: &ErrorBoundary<MetricExtractionGroups>) -> bool {
378378
)
379379
}
380380

381-
fn is_missing(value: &ErrorBoundary<ModelCosts>) -> bool {
382-
matches!(
383-
value,
384-
&ErrorBoundary::Ok(ModelCosts{ version, ref costs }) if version == 0 && costs.is_empty()
385-
)
381+
fn is_model_costs_empty(value: &ErrorBoundary<ModelCosts>) -> bool {
382+
matches!(value, ErrorBoundary::Ok(model_costs) if model_costs.is_empty())
386383
}
387384

388385
#[cfg(test)]

relay-event-normalization/src/event.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,7 @@ fn normalize_app_start_measurements(measurements: &mut Measurements) {
14981498
mod tests {
14991499

15001500
use std::collections::BTreeMap;
1501+
use std::collections::HashMap;
15011502

15021503
use insta::assert_debug_snapshot;
15031504
use itertools::Itertools;
@@ -2261,6 +2262,7 @@ mod tests {
22612262
cost_per_1k_tokens: 20.0,
22622263
},
22632264
],
2265+
models: HashMap::new(),
22642266
}),
22652267
..NormalizationConfig::default()
22662268
},
@@ -2348,6 +2350,7 @@ mod tests {
23482350
cost_per_1k_tokens: 20.0,
23492351
},
23502352
],
2353+
models: HashMap::new(),
23512354
}),
23522355
..NormalizationConfig::default()
23532356
},

relay-event-normalization/src/normalize/mod.rs

Lines changed: 253 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::HashMap;
12
use std::hash::Hash;
23

34
use relay_base_schema::metrics::MetricUnit;
@@ -217,35 +218,94 @@ pub struct PerformanceScoreConfig {
217218
}
218219

219220
/// A mapping of AI model types (like GPT-4) to their respective costs.
221+
///
222+
/// This struct supports multiple versions with different cost structures:
223+
/// - Version 1: Array-based costs with glob pattern matching for model IDs (uses `costs` field)
224+
/// - Version 2: Dictionary-based costs with exact model ID keys and granular token pricing (uses `models` field)
225+
///
226+
/// Example V1 JSON:
227+
/// ```json
228+
/// {
229+
/// "version": 1,
230+
/// "costs": [
231+
/// {
232+
/// "modelId": "gpt-4*",
233+
/// "forCompletion": false,
234+
/// "costPer1kTokens": 0.03
235+
/// }
236+
/// ]
237+
/// }
238+
/// ```
239+
///
240+
/// Example V2 JSON:
241+
/// ```json
242+
/// {
243+
/// "version": 2,
244+
/// "models": {
245+
/// "gpt-4": {
246+
/// "inputPerToken": 0.03,
247+
/// "outputPerToken": 0.06,
248+
/// "outputReasoningPerToken": 0.12,
249+
/// "inputCachedPerToken": 0.015
250+
/// }
251+
/// }
252+
/// }
253+
/// ```
220254
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
221255
#[serde(rename_all = "camelCase")]
222256
pub struct ModelCosts {
223257
/// The version of the model cost struct
224258
pub version: u16,
225259

226-
/// The mappings of model ID => cost
227-
#[serde(skip_serializing_if = "Vec::is_empty")]
260+
/// The mappings of model ID => cost (used in version 1)
261+
#[serde(default, skip_serializing_if = "Vec::is_empty")]
228262
pub costs: Vec<ModelCost>,
263+
264+
/// The mappings of model ID => cost as a dictionary (version 2)
265+
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
266+
pub models: HashMap<String, ModelCostV2>,
229267
}
230268

231269
impl ModelCosts {
232-
const MAX_SUPPORTED_VERSION: u16 = 1;
270+
const MAX_SUPPORTED_VERSION: u16 = 2;
271+
272+
/// `true` if the model costs are empty and the version is supported.
273+
pub fn is_empty(&self) -> bool {
274+
(self.costs.is_empty() && self.models.is_empty()) || !self.is_enabled()
275+
}
233276

234277
/// `false` if measurement and metrics extraction should be skipped.
235278
pub fn is_enabled(&self) -> bool {
236279
self.version > 0 && self.version <= ModelCosts::MAX_SUPPORTED_VERSION
237280
}
238281

239-
/// Gets the cost per 1000 tokens, if defined for the given model.
240-
pub fn cost_per_1k_tokens(&self, model_id: &str, for_completion: bool) -> Option<f64> {
241-
self.costs
242-
.iter()
243-
.find(|cost| cost.matches(model_id, for_completion))
244-
.map(|c| c.cost_per_1k_tokens)
282+
/// Gets the cost per token, if defined for the given model.
283+
pub fn cost_per_token(&self, model_id: &str) -> Option<ModelCostV2> {
284+
match self.version {
285+
1 => {
286+
let input_cost = self.costs.iter().find(|cost| cost.matches(model_id, false));
287+
let output_cost = self.costs.iter().find(|cost| cost.matches(model_id, true));
288+
289+
// V1 costs were defined per 1k tokens, so we need to convert to per token.
290+
if input_cost.is_some() || output_cost.is_some() {
291+
Some(ModelCostV2 {
292+
input_per_token: input_cost.map_or(0.0, |c| c.cost_per_1k_tokens / 1000.0),
293+
output_per_token: output_cost
294+
.map_or(0.0, |c| c.cost_per_1k_tokens / 1000.0),
295+
output_reasoning_per_token: 0.0, // in v1 this info is not available
296+
input_cached_per_token: 0.0, // in v1 this info is not available
297+
})
298+
} else {
299+
None
300+
}
301+
}
302+
2 => self.models.get(model_id).copied(),
303+
_ => None,
304+
}
245305
}
246306
}
247307

248-
/// A single mapping of (AI model ID, input/output, cost)
308+
/// A mapping of AI model types (like GPT-4) to their respective costs.
249309
#[derive(Clone, Debug, Serialize, Deserialize)]
250310
#[serde(rename_all = "camelCase")]
251311
pub struct ModelCost {
@@ -261,6 +321,21 @@ impl ModelCost {
261321
}
262322
}
263323

324+
/// Version 2 of a mapping of AI model types (like GPT-4) to their respective costs.
325+
/// Version 1 had some limitations, so we're moving to a more flexible format.
326+
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
327+
#[serde(rename_all = "camelCase")]
328+
pub struct ModelCostV2 {
329+
/// The cost per input token
330+
pub input_per_token: f64,
331+
/// The cost per output token
332+
pub output_per_token: f64,
333+
/// The cost per output reasoning token
334+
pub output_reasoning_per_token: f64,
335+
/// The cost per input cached token
336+
pub input_cached_per_token: f64,
337+
}
338+
264339
#[cfg(test)]
265340
mod tests {
266341
use chrono::{TimeZone, Utc};
@@ -288,27 +363,191 @@ mod tests {
288363

289364
use super::*;
290365

366+
/// Test that integer versions are handled correctly in the struct format
291367
#[test]
292-
fn test_model_cost_config() {
293-
let original = r#"{"version":1,"costs":[{"modelId":"babbage-002.ft-*","forCompletion":false,"costPer1kTokens":0.0016}]}"#;
368+
fn test_model_cost_version_sent_as_number() {
369+
// Test integer version 1
370+
let original = r#"{"version":1,"costs":[{"modelId":"babbage-002.ft","forCompletion":false,"costPer1kTokens":0.0016}]}"#;
294371
let deserialized: ModelCosts = serde_json::from_str(original).unwrap();
295-
assert_debug_snapshot!(deserialized, @r###"
372+
assert_debug_snapshot!(
373+
deserialized,
374+
@r#"
375+
ModelCosts {
376+
version: 1,
377+
costs: [
378+
ModelCost {
379+
model_id: LazyGlob("babbage-002.ft"),
380+
for_completion: false,
381+
cost_per_1k_tokens: 0.0016,
382+
},
383+
],
384+
models: {},
385+
}
386+
"#,
387+
);
388+
389+
// Test integer version 2
390+
let original_v2 = r#"{"version":2,"models":{"gpt-4":{"inputPerToken":0.03,"outputPerToken":0.06,"outputReasoningPerToken":0.12,"inputCachedPerToken":0.015}}}"#;
391+
let deserialized_v2: ModelCosts = serde_json::from_str(original_v2).unwrap();
392+
assert_debug_snapshot!(
393+
deserialized_v2,
394+
@r###"
395+
ModelCosts {
396+
version: 2,
397+
costs: [],
398+
models: {
399+
"gpt-4": ModelCostV2 {
400+
input_per_token: 0.03,
401+
output_per_token: 0.06,
402+
output_reasoning_per_token: 0.12,
403+
input_cached_per_token: 0.015,
404+
},
405+
},
406+
}
407+
"###,
408+
);
409+
410+
// Test unknown integer version
411+
let original_unknown = r#"{"version":99,"costs":[]}"#;
412+
let deserialized_unknown: ModelCosts = serde_json::from_str(original_unknown).unwrap();
413+
assert_eq!(deserialized_unknown.version, 99);
414+
assert!(!deserialized_unknown.is_enabled());
415+
}
416+
417+
#[test]
418+
fn test_model_cost_config_v1() {
419+
let original = r#"{"version":1,"costs":[{"modelId":"babbage-002.ft","forCompletion":false,"costPer1kTokens":0.0016}]}"#;
420+
let deserialized: ModelCosts = serde_json::from_str(original).unwrap();
421+
assert_debug_snapshot!(deserialized, @r#"
296422
ModelCosts {
297423
version: 1,
298424
costs: [
299425
ModelCost {
300-
model_id: LazyGlob("babbage-002.ft-*"),
426+
model_id: LazyGlob("babbage-002.ft"),
301427
for_completion: false,
302428
cost_per_1k_tokens: 0.0016,
303429
},
304430
],
431+
models: {},
432+
}
433+
"#);
434+
435+
let serialized = serde_json::to_string(&deserialized).unwrap();
436+
assert_eq!(&serialized, original);
437+
}
438+
439+
#[test]
440+
fn test_model_cost_config_v2() {
441+
let original = r#"{"version":2,"models":{"gpt-4":{"inputPerToken":0.03,"outputPerToken":0.06,"outputReasoningPerToken":0.12,"inputCachedPerToken":0.015}}}"#;
442+
let deserialized: ModelCosts = serde_json::from_str(original).unwrap();
443+
assert_debug_snapshot!(deserialized, @r###"
444+
ModelCosts {
445+
version: 2,
446+
costs: [],
447+
models: {
448+
"gpt-4": ModelCostV2 {
449+
input_per_token: 0.03,
450+
output_per_token: 0.06,
451+
output_reasoning_per_token: 0.12,
452+
input_cached_per_token: 0.015,
453+
},
454+
},
305455
}
306456
"###);
307457

308458
let serialized = serde_json::to_string(&deserialized).unwrap();
309459
assert_eq!(&serialized, original);
310460
}
311461

462+
#[test]
463+
fn test_model_cost_functionality_v1_only_input_tokens() {
464+
// Test V1 functionality
465+
let v1_config = ModelCosts {
466+
version: 1,
467+
costs: vec![ModelCost {
468+
model_id: LazyGlob::new("gpt-4*"),
469+
for_completion: false,
470+
cost_per_1k_tokens: 0.03,
471+
}],
472+
models: HashMap::new(),
473+
};
474+
assert!(v1_config.is_enabled());
475+
let costs = v1_config.cost_per_token("gpt-4-turbo").unwrap();
476+
assert_eq!(costs.input_per_token * 1000.0, 0.03); // multiplying by 1000 to avoid floating point errors
477+
assert_eq!(costs.output_per_token, 0.0); // output tokens are not defined
478+
}
479+
480+
#[test]
481+
fn test_model_cost_functionality_v1() {
482+
let v1_config = ModelCosts {
483+
version: 1,
484+
costs: vec![
485+
ModelCost {
486+
model_id: LazyGlob::new("gpt-4*"),
487+
for_completion: false,
488+
cost_per_1k_tokens: 0.03,
489+
},
490+
ModelCost {
491+
model_id: LazyGlob::new("gpt-4*"),
492+
for_completion: true,
493+
cost_per_1k_tokens: 0.06,
494+
},
495+
],
496+
models: HashMap::new(),
497+
};
498+
assert!(v1_config.is_enabled());
499+
let costs = v1_config.cost_per_token("gpt-4").unwrap();
500+
assert_eq!(costs.input_per_token * 1000.0, 0.03); // multiplying by 1000 to avoid floating point errors
501+
assert_eq!(costs.output_per_token * 1000.0, 0.06); // multiplying by 1000 to avoid floating point errors
502+
}
503+
504+
#[test]
505+
fn test_model_cost_functionality_v2() {
506+
// Test V2 functionality
507+
let mut models_map = HashMap::new();
508+
models_map.insert(
509+
"gpt-4".to_owned(),
510+
ModelCostV2 {
511+
input_per_token: 0.03,
512+
output_per_token: 0.06,
513+
output_reasoning_per_token: 0.12,
514+
input_cached_per_token: 0.015,
515+
},
516+
);
517+
let v2_config = ModelCosts {
518+
version: 2,
519+
costs: vec![],
520+
models: models_map,
521+
};
522+
assert!(v2_config.is_enabled());
523+
let cost = v2_config.cost_per_token("gpt-4").unwrap();
524+
assert_eq!(
525+
cost,
526+
ModelCostV2 {
527+
input_per_token: 0.03,
528+
output_per_token: 0.06,
529+
output_reasoning_per_token: 0.12,
530+
input_cached_per_token: 0.015,
531+
}
532+
);
533+
}
534+
535+
#[test]
536+
fn test_model_cost_unknown_version() {
537+
// Test that unknown versions are handled properly
538+
let unknown_version_json = r#"{"version":3,"models":{"some-model":{"inputPerToken":0.01,"outputPerToken":0.02,"outputReasoningPerToken":0.03,"inputCachedPerToken":0.005}}}"#;
539+
let deserialized: ModelCosts = serde_json::from_str(unknown_version_json).unwrap();
540+
assert_eq!(deserialized.version, 3);
541+
assert!(!deserialized.is_enabled());
542+
assert_eq!(deserialized.cost_per_token("some-model"), None);
543+
544+
// Test version 0 (invalid)
545+
let version_zero_json = r#"{"version":0,"models":{}}"#;
546+
let deserialized: ModelCosts = serde_json::from_str(version_zero_json).unwrap();
547+
assert_eq!(deserialized.version, 0);
548+
assert!(!deserialized.is_enabled());
549+
}
550+
312551
#[test]
313552
fn test_merge_builtin_measurement_keys() {
314553
let foo = BuiltinMeasurementKey::new("foo", MetricUnit::Duration(DurationUnit::Hour));

0 commit comments

Comments
 (0)