Skip to content

Commit 523d197

Browse files
authored
fix: percentage fix (#206)
1 parent 5c3ca24 commit 523d197

File tree

9 files changed

+2602
-23
lines changed

9 files changed

+2602
-23
lines changed

spark/jobs/metrics/percentages.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def calculate_percentages(
6262

6363
model_quality_reference = metrics_service.calculate_model_quality()
6464

65-
def compute_mq_percentage(metrics_cur, metric_ref):
65+
def _compute_mq_percentage(metrics_cur, metric_ref):
6666
metrics_cur_np = np.array(metrics_cur)
6767

6868
# bootstrap Parameters
@@ -80,6 +80,7 @@ def compute_mq_percentage(metrics_cur, metric_ref):
8080
# calculate 95% confidence interval
8181
lower_bound = np.percentile(bootstrap_means, 2.5)
8282
upper_bound = np.percentile(bootstrap_means, 97.5)
83+
8384
return 1 if not (lower_bound <= metric_ref <= upper_bound) else 0
8485

8586
perc_model_quality = {"value": 0, "details": []}
@@ -95,15 +96,15 @@ def compute_mq_percentage(metrics_cur, metric_ref):
9596
perc_model_quality["value"] = -1
9697
break
9798
else:
98-
is_flag = compute_mq_percentage(metrics_cur, metric_ref)
99+
is_flag = _compute_mq_percentage(metrics_cur, metric_ref)
99100
flagged_metrics += is_flag
100101
if is_flag:
101102
perc_model_quality["details"].append(
102103
{"feature_name": key_m, "score": -1}
103104
)
104-
perc_model_quality["value"] = 1 - (
105-
flagged_metrics / len(model_quality_reference)
106-
)
105+
perc_model_quality["value"] = 1 - (
106+
flagged_metrics / len(model_quality_current["grouped_metrics"])
107+
)
107108

108109
elif model.model_type == ModelType.MULTI_CLASS:
109110
flagged_metrics = 0
@@ -119,23 +120,22 @@ def compute_mq_percentage(metrics_cur, metric_ref):
119120
# not enough values to do the test, return -1
120121
cumulative_sum -= 10000
121122
else:
122-
is_flag = compute_mq_percentage(metrics_cur, metric_ref)
123+
is_flag = _compute_mq_percentage(metrics_cur, metric_ref)
123124
flagged_metrics += is_flag
124-
perc_model_quality["details"].append(
125-
{
126-
"feature_name": cm["class_name"] + "_" + key_m,
127-
"score": -1,
128-
}
129-
)
130-
cumulative_sum += 1 - (
131-
flagged_metrics / len(model_quality_reference)
132-
)
133-
perc_model_quality["value"] = (
134-
cumulative_sum
135-
/ (
136-
len(model_quality_reference["classes"])
137-
* len(model_quality_reference["class_metrics"][0])
125+
if is_flag:
126+
perc_model_quality["details"].append(
127+
{
128+
"feature_name": cm["class_name"] + "_" + key_m,
129+
"score": -1,
130+
}
131+
)
132+
cumulative_sum += 1 - (
133+
flagged_metrics
134+
/ len(model_quality_reference["class_metrics"][0]["metrics"])
138135
)
136+
flagged_metrics = 0
137+
perc_model_quality["value"] = (
138+
cumulative_sum / len(model_quality_reference["classes"])
139139
if cumulative_sum > 0
140140
else -1
141141
)

spark/jobs/utils/reference_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def calculate_model_quality(self) -> ModelQualityRegression:
2020
model=self.reference.model,
2121
dataframe=self.reference.reference,
2222
dataframe_count=self.reference.reference_count,
23-
).dict()
23+
).model_dump()
2424

2525
metrics["residuals"] = ModelQualityRegressionCalculator.residual_metrics(
2626
model=self.reference.model,

spark/tests/percentages_test.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,34 @@ def dataset_perfect_classes(spark_fixture, test_data_dir):
6464
)
6565

6666

67+
@pytest.fixture()
68+
def dataset_talk(spark_fixture, test_data_dir):
69+
yield (
70+
spark_fixture.read.csv(
71+
f"{test_data_dir}/reference/multiclass/reference_sentiment_analysis_talk.csv",
72+
header=True,
73+
),
74+
spark_fixture.read.csv(
75+
f"{test_data_dir}/current/multiclass/current_sentiment_analysis_talk.csv",
76+
header=True,
77+
),
78+
)
79+
80+
81+
@pytest.fixture()
82+
def dataset_demo(spark_fixture, test_data_dir):
83+
yield (
84+
spark_fixture.read.csv(
85+
f"{test_data_dir}/reference/multiclass/3_classes_reference.csv",
86+
header=True,
87+
),
88+
spark_fixture.read.csv(
89+
f"{test_data_dir}/current/multiclass/3_classes_current1.csv",
90+
header=True,
91+
),
92+
)
93+
94+
6795
def test_calculation_dataset_perfect_classes(spark_fixture, dataset_perfect_classes):
6896
output = OutputType(
6997
prediction=ColumnDefinition(
@@ -363,3 +391,174 @@ def test_percentages_abalone(spark_fixture, test_dataset_abalone):
363391
ignore_order=True,
364392
significant_digits=6,
365393
)
394+
395+
396+
def test_percentages_dataset_talk(spark_fixture, dataset_talk):
397+
output = OutputType(
398+
prediction=ColumnDefinition(
399+
name="content", type=SupportedTypes.int, field_type=FieldTypes.categorical
400+
),
401+
prediction_proba=None,
402+
output=[
403+
ColumnDefinition(
404+
name="content",
405+
type=SupportedTypes.int,
406+
field_type=FieldTypes.categorical,
407+
)
408+
],
409+
)
410+
target = ColumnDefinition(
411+
name="label", type=SupportedTypes.int, field_type=FieldTypes.categorical
412+
)
413+
timestamp = ColumnDefinition(
414+
name="rbit_prediction_ts",
415+
type=SupportedTypes.datetime,
416+
field_type=FieldTypes.datetime,
417+
)
418+
granularity = Granularity.HOUR
419+
features = [
420+
ColumnDefinition(
421+
name="total_tokens",
422+
type=SupportedTypes.int,
423+
field_type=FieldTypes.numerical,
424+
),
425+
ColumnDefinition(
426+
name="prompt_tokens",
427+
type=SupportedTypes.int,
428+
field_type=FieldTypes.numerical,
429+
),
430+
]
431+
model = ModelOut(
432+
uuid=uuid.uuid4(),
433+
name="talk model",
434+
description="description",
435+
model_type=ModelType.MULTI_CLASS,
436+
data_type=DataType.TABULAR,
437+
timestamp=timestamp,
438+
granularity=granularity,
439+
outputs=output,
440+
target=target,
441+
features=features,
442+
frameworks="framework",
443+
algorithm="algorithm",
444+
created_at=str(datetime.datetime.now()),
445+
updated_at=str(datetime.datetime.now()),
446+
)
447+
448+
raw_reference_dataset, raw_current_dataset = dataset_talk
449+
current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset)
450+
reference_dataset = ReferenceDataset(
451+
model=model, raw_dataframe=raw_reference_dataset
452+
)
453+
454+
drift = DriftCalculator.calculate_drift(
455+
spark_session=spark_fixture,
456+
current_dataset=current_dataset,
457+
reference_dataset=reference_dataset,
458+
)
459+
460+
metrics_service = CurrentMetricsMulticlassService(
461+
spark_session=spark_fixture,
462+
current=current_dataset,
463+
reference=reference_dataset,
464+
)
465+
466+
model_quality = metrics_service.calculate_model_quality()
467+
468+
percentages = PercentageCalculator.calculate_percentages(
469+
spark_session=spark_fixture,
470+
drift=drift,
471+
model_quality_current=model_quality,
472+
current_dataset=current_dataset,
473+
reference_dataset=reference_dataset,
474+
model=model,
475+
)
476+
477+
assert not deepdiff.DeepDiff(
478+
percentages,
479+
res.test_dataset_talk,
480+
ignore_order=True,
481+
significant_digits=6,
482+
)
483+
484+
485+
def test_percentages_dataset_demo(spark_fixture, dataset_demo):
486+
output = OutputType(
487+
prediction=ColumnDefinition(
488+
name="prediction",
489+
type=SupportedTypes.int,
490+
field_type=FieldTypes.categorical,
491+
),
492+
prediction_proba=None,
493+
output=[
494+
ColumnDefinition(
495+
name="prediction",
496+
type=SupportedTypes.int,
497+
field_type=FieldTypes.categorical,
498+
)
499+
],
500+
)
501+
target = ColumnDefinition(
502+
name="ground_truth", type=SupportedTypes.int, field_type=FieldTypes.categorical
503+
)
504+
timestamp = ColumnDefinition(
505+
name="timestamp", type=SupportedTypes.datetime, field_type=FieldTypes.datetime
506+
)
507+
granularity = Granularity.DAY
508+
features = [
509+
ColumnDefinition(
510+
name="age", type=SupportedTypes.int, field_type=FieldTypes.numerical
511+
)
512+
]
513+
model = ModelOut(
514+
uuid=uuid.uuid4(),
515+
name="talk model",
516+
description="description",
517+
model_type=ModelType.MULTI_CLASS,
518+
data_type=DataType.TABULAR,
519+
timestamp=timestamp,
520+
granularity=granularity,
521+
outputs=output,
522+
target=target,
523+
features=features,
524+
frameworks="framework",
525+
algorithm="algorithm",
526+
created_at=str(datetime.datetime.now()),
527+
updated_at=str(datetime.datetime.now()),
528+
)
529+
530+
raw_reference_dataset, raw_current_dataset = dataset_demo
531+
current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset)
532+
reference_dataset = ReferenceDataset(
533+
model=model, raw_dataframe=raw_reference_dataset
534+
)
535+
536+
drift = DriftCalculator.calculate_drift(
537+
spark_session=spark_fixture,
538+
current_dataset=current_dataset,
539+
reference_dataset=reference_dataset,
540+
)
541+
542+
metrics_service = CurrentMetricsMulticlassService(
543+
spark_session=spark_fixture,
544+
current=current_dataset,
545+
reference=reference_dataset,
546+
)
547+
548+
model_quality = metrics_service.calculate_model_quality()
549+
550+
percentages = PercentageCalculator.calculate_percentages(
551+
spark_session=spark_fixture,
552+
drift=drift,
553+
model_quality_current=model_quality,
554+
current_dataset=current_dataset,
555+
reference_dataset=reference_dataset,
556+
model=model,
557+
)
558+
559+
assert not deepdiff.DeepDiff(
560+
percentages,
561+
res.test_dataset_demo,
562+
ignore_order=True,
563+
significant_digits=6,
564+
)

0 commit comments

Comments
 (0)