@@ -64,6 +64,34 @@ def dataset_perfect_classes(spark_fixture, test_data_dir):
64
64
)
65
65
66
66
67
+ @pytest .fixture ()
68
+ def dataset_talk (spark_fixture , test_data_dir ):
69
+ yield (
70
+ spark_fixture .read .csv (
71
+ f"{ test_data_dir } /reference/multiclass/reference_sentiment_analysis_talk.csv" ,
72
+ header = True ,
73
+ ),
74
+ spark_fixture .read .csv (
75
+ f"{ test_data_dir } /current/multiclass/current_sentiment_analysis_talk.csv" ,
76
+ header = True ,
77
+ ),
78
+ )
79
+
80
+
81
+ @pytest .fixture ()
82
+ def dataset_demo (spark_fixture , test_data_dir ):
83
+ yield (
84
+ spark_fixture .read .csv (
85
+ f"{ test_data_dir } /reference/multiclass/3_classes_reference.csv" ,
86
+ header = True ,
87
+ ),
88
+ spark_fixture .read .csv (
89
+ f"{ test_data_dir } /current/multiclass/3_classes_current1.csv" ,
90
+ header = True ,
91
+ ),
92
+ )
93
+
94
+
67
95
def test_calculation_dataset_perfect_classes (spark_fixture , dataset_perfect_classes ):
68
96
output = OutputType (
69
97
prediction = ColumnDefinition (
@@ -363,3 +391,174 @@ def test_percentages_abalone(spark_fixture, test_dataset_abalone):
363
391
ignore_order = True ,
364
392
significant_digits = 6 ,
365
393
)
394
+
395
+
396
+ def test_percentages_dataset_talk (spark_fixture , dataset_talk ):
397
+ output = OutputType (
398
+ prediction = ColumnDefinition (
399
+ name = "content" , type = SupportedTypes .int , field_type = FieldTypes .categorical
400
+ ),
401
+ prediction_proba = None ,
402
+ output = [
403
+ ColumnDefinition (
404
+ name = "content" ,
405
+ type = SupportedTypes .int ,
406
+ field_type = FieldTypes .categorical ,
407
+ )
408
+ ],
409
+ )
410
+ target = ColumnDefinition (
411
+ name = "label" , type = SupportedTypes .int , field_type = FieldTypes .categorical
412
+ )
413
+ timestamp = ColumnDefinition (
414
+ name = "rbit_prediction_ts" ,
415
+ type = SupportedTypes .datetime ,
416
+ field_type = FieldTypes .datetime ,
417
+ )
418
+ granularity = Granularity .HOUR
419
+ features = [
420
+ ColumnDefinition (
421
+ name = "total_tokens" ,
422
+ type = SupportedTypes .int ,
423
+ field_type = FieldTypes .numerical ,
424
+ ),
425
+ ColumnDefinition (
426
+ name = "prompt_tokens" ,
427
+ type = SupportedTypes .int ,
428
+ field_type = FieldTypes .numerical ,
429
+ ),
430
+ ]
431
+ model = ModelOut (
432
+ uuid = uuid .uuid4 (),
433
+ name = "talk model" ,
434
+ description = "description" ,
435
+ model_type = ModelType .MULTI_CLASS ,
436
+ data_type = DataType .TABULAR ,
437
+ timestamp = timestamp ,
438
+ granularity = granularity ,
439
+ outputs = output ,
440
+ target = target ,
441
+ features = features ,
442
+ frameworks = "framework" ,
443
+ algorithm = "algorithm" ,
444
+ created_at = str (datetime .datetime .now ()),
445
+ updated_at = str (datetime .datetime .now ()),
446
+ )
447
+
448
+ raw_reference_dataset , raw_current_dataset = dataset_talk
449
+ current_dataset = CurrentDataset (model = model , raw_dataframe = raw_current_dataset )
450
+ reference_dataset = ReferenceDataset (
451
+ model = model , raw_dataframe = raw_reference_dataset
452
+ )
453
+
454
+ drift = DriftCalculator .calculate_drift (
455
+ spark_session = spark_fixture ,
456
+ current_dataset = current_dataset ,
457
+ reference_dataset = reference_dataset ,
458
+ )
459
+
460
+ metrics_service = CurrentMetricsMulticlassService (
461
+ spark_session = spark_fixture ,
462
+ current = current_dataset ,
463
+ reference = reference_dataset ,
464
+ )
465
+
466
+ model_quality = metrics_service .calculate_model_quality ()
467
+
468
+ percentages = PercentageCalculator .calculate_percentages (
469
+ spark_session = spark_fixture ,
470
+ drift = drift ,
471
+ model_quality_current = model_quality ,
472
+ current_dataset = current_dataset ,
473
+ reference_dataset = reference_dataset ,
474
+ model = model ,
475
+ )
476
+
477
+ assert not deepdiff .DeepDiff (
478
+ percentages ,
479
+ res .test_dataset_talk ,
480
+ ignore_order = True ,
481
+ significant_digits = 6 ,
482
+ )
483
+
484
+
485
+ def test_percentages_dataset_demo (spark_fixture , dataset_demo ):
486
+ output = OutputType (
487
+ prediction = ColumnDefinition (
488
+ name = "prediction" ,
489
+ type = SupportedTypes .int ,
490
+ field_type = FieldTypes .categorical ,
491
+ ),
492
+ prediction_proba = None ,
493
+ output = [
494
+ ColumnDefinition (
495
+ name = "prediction" ,
496
+ type = SupportedTypes .int ,
497
+ field_type = FieldTypes .categorical ,
498
+ )
499
+ ],
500
+ )
501
+ target = ColumnDefinition (
502
+ name = "ground_truth" , type = SupportedTypes .int , field_type = FieldTypes .categorical
503
+ )
504
+ timestamp = ColumnDefinition (
505
+ name = "timestamp" , type = SupportedTypes .datetime , field_type = FieldTypes .datetime
506
+ )
507
+ granularity = Granularity .DAY
508
+ features = [
509
+ ColumnDefinition (
510
+ name = "age" , type = SupportedTypes .int , field_type = FieldTypes .numerical
511
+ )
512
+ ]
513
+ model = ModelOut (
514
+ uuid = uuid .uuid4 (),
515
+ name = "talk model" ,
516
+ description = "description" ,
517
+ model_type = ModelType .MULTI_CLASS ,
518
+ data_type = DataType .TABULAR ,
519
+ timestamp = timestamp ,
520
+ granularity = granularity ,
521
+ outputs = output ,
522
+ target = target ,
523
+ features = features ,
524
+ frameworks = "framework" ,
525
+ algorithm = "algorithm" ,
526
+ created_at = str (datetime .datetime .now ()),
527
+ updated_at = str (datetime .datetime .now ()),
528
+ )
529
+
530
+ raw_reference_dataset , raw_current_dataset = dataset_demo
531
+ current_dataset = CurrentDataset (model = model , raw_dataframe = raw_current_dataset )
532
+ reference_dataset = ReferenceDataset (
533
+ model = model , raw_dataframe = raw_reference_dataset
534
+ )
535
+
536
+ drift = DriftCalculator .calculate_drift (
537
+ spark_session = spark_fixture ,
538
+ current_dataset = current_dataset ,
539
+ reference_dataset = reference_dataset ,
540
+ )
541
+
542
+ metrics_service = CurrentMetricsMulticlassService (
543
+ spark_session = spark_fixture ,
544
+ current = current_dataset ,
545
+ reference = reference_dataset ,
546
+ )
547
+
548
+ model_quality = metrics_service .calculate_model_quality ()
549
+
550
+ percentages = PercentageCalculator .calculate_percentages (
551
+ spark_session = spark_fixture ,
552
+ drift = drift ,
553
+ model_quality_current = model_quality ,
554
+ current_dataset = current_dataset ,
555
+ reference_dataset = reference_dataset ,
556
+ model = model ,
557
+ )
558
+
559
+ assert not deepdiff .DeepDiff (
560
+ percentages ,
561
+ res .test_dataset_demo ,
562
+ ignore_order = True ,
563
+ significant_digits = 6 ,
564
+ )
0 commit comments