Skip to content

Commit da0cd9b

Browse files
authored
Merge pull request #3 from OskNum/unit-tests
Unit tests
2 parents 2073672 + f411670 commit da0cd9b

16 files changed

+123
-13
lines changed

.github/CODE_OF_CONDUCT.md

100644100755
File mode changed.

.github/CONTRIBUTING.md

100644100755
File mode changed.

.github/ISSUE_TEMPLATE/BUG_REPORT.md

100644100755
File mode changed.

.github/ISSUE_TEMPLATE/FEATURE_REQUEST.md

100644100755
File mode changed.

.github/PULL_REQUEST_TEMPLATE.md

100644100755
File mode changed.

.github/workflows/r_new.yml

100644100755
File mode changed.

.gitignore

100644100755
File mode changed.

LICENSE

100644100755
File mode changed.

LICENSE.md

100644100755
File mode changed.

R/preprocess.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,13 @@ preprocess <- function(df,
6060
function_vector <- strings_to_functions(numeric_operation_list)
6161
names(function_vector) <- numeric_operation_list
6262

63-
numeric_df <- df %>%
64-
select(-target) %>%
63+
if (!is.na(target)) {
64+
numeric_df <- select(df, -target)
65+
} else {
66+
numeric_df <- df
67+
}
68+
69+
numeric_df <- numeric_df %>%
6570
group_by(.data$customerid) %>%
6671
summarise_if(is.numeric, function_vector) %>%
6772
ungroup()
@@ -71,18 +76,19 @@ preprocess <- function(df,
7176
} else {
7277
evaluated_columns <- names(df)[sapply(df, is.numeric) & names(df) != 'customerid' & names(df) != target]
7378
}
79+
7480

7581
if (length(evaluated_columns) == 1) {
7682
adjusted_name <- paste0(evaluated_columns, '_', names(numeric_df)[!(names(numeric_df) %in% c('customerid', target))])
7783
names(numeric_df) <- c('customerid', adjusted_name)
7884
}
79-
85+
8086
# Filters categorical columns and grabs the top n category for each
8187
# categorical column
8288
final_df <- inner_join(final_df, numeric_df, by = 'customerid')
8389
}
8490

85-
91+
8692
if (!is.null(categories)) {
8793
for (col_name in categories) {
8894

@@ -118,7 +124,6 @@ preprocess <- function(df,
118124

119125
return(final_df)
120126
}
121-
122127
strings_to_functions <- function(string_vector) {
123128
function_vector <- c()
124129
for (obj_name in string_vector) {

R/validate.R

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55
#' @param supervised logical, TRUE for supervised learning, FALSE for unsupervised
66
#' @importFrom dplyr n_distinct
77
#' @export
8-
validate <- function(df, supervised = TRUE) {
8+
validate <- function(df, supervised = TRUE, hyperparameters = NULL) {
99
missing_columns <- c()
1010
other_errors <- c()
1111
toomanylevels_columns <- c()
12-
categorical_columns <- df %>% select(-customerid) %>% select_if(is.character) %>% summarise_all(n_distinct)
12+
categorical_columns <- df[,names(df) != 'customerid'] %>% select_if(is.character) %>% summarise_all(n_distinct)
13+
14+
15+
16+
if (!is.null(hyperparameters$segmentation_variables)) {
17+
df <- df[,names(df) %in% hyperparameters$segmentation_variables]
18+
}
1319

1420
if (!('response' %in% names(df)) & (supervised == TRUE)) {
1521
missing_columns <- c(missing_columns, 'response')

README.md

100644100755
File mode changed.

tests/testthat/test_output_table.R

100644100755
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ test_that("Number of Columns", {
1919
expect_equal(ncol(output), 16)
2020
})
2121

22-
2322
test_that("Number of Rows", {
2423
expect_equal(nrow(output), 6)
2524
})
2625

2726
test_that("No nulls", {
2827
expect_equal(ncol(output[complete.cases(output), ]), 16)
29-
})
28+
})
29+
30+
31+

tests/testthat/test_preprocess.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ library(dplyr)
55
transactional_data <- citrus::transactional_data
66
data <- transactional_data %>% select(c('transactionid', 'transactionvalue', 'customerid', 'orderdate'))
77
output_preprocess <- citrus::preprocess(data, numeric_operation_list = 'mean')
8+
output_with_response <- citrus::preprocess(data, numeric_operation_list = 'mean', target = 'transactionvalue')
89
output_preprocess_na <- citrus::preprocess(data, numeric_operation_list = NULL)
910

1011
test_that("Number of Columns", {
1112
expect_equal(ncol(output_preprocess), 5)
1213
})
1314

14-
1515
test_that("String Customerid Check", {
1616
expect_true(is.character(typeof(output_preprocess$customerid)))
1717
})
@@ -21,7 +21,10 @@ test_that("Passing NA to numeric_operations_list defaults to RFM", {
2121
expect_equal(sort(colnames(output_preprocess_na)), sort(c('customerid', 'recency', 'frequency', 'monetary')))
2222
})
2323

24-
2524
test_that("Correct Labelling", {
2625
expect_equal(colnames(output_preprocess), c('customerid', 'recency', 'frequency', 'monetary', 'transactionvalue_mean'))
2726
})
27+
28+
test_that("Custom target included", {
29+
expect_true('response' %in% colnames(output_with_response))
30+
})

tests/testthat/test_segment.R

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
library(citrus)
2+
library(testthat)
3+
library(dplyr)
4+
5+
transactional_data <- citrus::transactional_data
6+
data <- transactional_data %>% select(c('transactionid', 'transactionvalue', 'customerid', 'orderdate'))
7+
8+
test_that("Supervised default output object check", {
9+
output_supervised <- segment(data, modeltype = 'tree')
10+
11+
# Should contain all the right variables
12+
expect_true(all(names(output_supervised) %in% c('OutputTable', 'segments', 'CitrusModel')))
13+
# Output table should be 6 rows
14+
expect_equal(nrow(output_supervised$OutputTable), 6)
15+
# No nulls allowed in the output table
16+
expect_equal(ncol(output_supervised$OutputTable[complete.cases(output_supervised$OutputTable), ]), 10)
17+
# The segment lookup should not contain any NAs
18+
expect_true(all(!is.na(output_supervised$segments$persona)))
19+
expect_true(all(!is.na(output_supervised$segments$customerid)))
20+
})
21+
22+
test_that("Supervised custom output object check", {
23+
24+
output_supervised <- segment(data, modeltype = 'tree', hyperparameters = list(dependent_variable = 'response',
25+
min_segmentation_fraction = 0.1,
26+
print_safety_check = 20,
27+
number_of_personas = 4,
28+
print_plot = FALSE))
29+
30+
# Should contain all the right variables
31+
expect_true(all(names(output_supervised) %in% c('OutputTable', 'segments', 'CitrusModel')))
32+
# Output table should be 4 rows
33+
expect_equal(nrow(output_supervised$OutputTable), 4)
34+
# No nulls allowed in the output table
35+
expect_equal(ncol(output_supervised$OutputTable[complete.cases(output_supervised$OutputTable), ]), 10)
36+
# There shouldn't be segments smaller than 10% of the total population
37+
expect_true(all(output_supervised$OutputTable$percentage >= 0.1))
38+
# The segment lookup should not contain any NAs
39+
expect_true(all(!is.na(output_supervised$segments$persona)))
40+
expect_true(all(!is.na(output_supervised$segments$customerid)))
41+
# The model hyperparameters should agree with the custom ones
42+
expect_true(output_supervised$CitrusModel$model_hyperparameters$dependent_variable == 'response')
43+
expect_true(output_supervised$CitrusModel$model_hyperparameters$min_segmentation_fraction == 0.1)
44+
expect_true(output_supervised$CitrusModel$model_hyperparameters$number_of_personas == 4)
45+
})
46+
47+
test_that("Unsupervised default output object check", {
48+
output_unsupervised <- segment(data, modeltype = 'unsupervised')
49+
50+
# Should contain all the right variables
51+
expect_true(all(names(output_unsupervised) %in% c('OutputTable', 'segments', 'CitrusModel')))
52+
# Output table should be 3 rows
53+
expect_equal(nrow(output_unsupervised$OutputTable), 3)
54+
# No nulls allowed in the output table
55+
expect_equal(ncol(output_unsupervised$OutputTable[complete.cases(output_unsupervised$OutputTable), ]), 3)
56+
# The segment lookup should not contain any NAs
57+
expect_true(all(!is.na(output_unsupervised$segments$persona)))
58+
expect_true(all(!is.na(output_unsupervised$segments$customerid)))
59+
})
60+
61+
test_that("Supervised custom output object check", {
62+
63+
output_unsupervised <- segment(data, modeltype = 'unsupervised', hyperparameters = list(centers = 'auto',
64+
iter_max = 35,
65+
nstart = 2,
66+
max_centers = 3,
67+
segmentation_variables = NULL,
68+
standardize = TRUE))
69+
70+
# Should contain all the right variables
71+
expect_true(all(names(output_unsupervised) %in% c('OutputTable', 'segments', 'CitrusModel')))
72+
# Output table should be 2 rows
73+
expect_equal(nrow(output_unsupervised$OutputTable), 2)
74+
# No nulls allowed in the output table
75+
expect_equal(ncol(output_unsupervised$OutputTable[complete.cases(output_unsupervised$OutputTable), ]), 3)
76+
# The segment lookup should not contain any NAs
77+
expect_true(all(!is.na(output_unsupervised$segments$persona)))
78+
expect_true(all(!is.na(output_unsupervised$segments$customerid)))
79+
# The model hyperparameters should agree with the custom ones
80+
expect_true(output_unsupervised$CitrusModel$model_hyperparameters$centers == 'auto')
81+
expect_true(output_unsupervised$CitrusModel$model_hyperparameters$iter_max == 35)
82+
expect_true(output_unsupervised$CitrusModel$model_hyperparameters$nstart == 2)
83+
})

tests/testthat/test_validate.R

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,23 @@ library(dplyr)
55
transactional_data <- citrus::transactional_data
66
data <- transactional_data %>% select(c('transactionid', 'transactionvalue', 'customerid', 'orderdate'))
77
output_preprocess <- citrus::preprocess(data, numeric_operation_list = 'mean')
8+
preprocess_too_many_categories <- citrus::preprocessed_data %>%
9+
mutate(faulty_feature = as.character(customerid))
10+
preprocess_no_customerid <- citrus::preprocessed_data %>%
11+
select(-customerid)
812

913
test_that("Supervised without response variable", {
10-
expect_error(validate(output_preprocess))
14+
expect_error(citrus::validate(output_preprocess), regexp = "Columns missing: response")
1115
})
1216

17+
test_that("Correct error when customerid is missing.", {
18+
expect_error(citrus::validate(preprocess_no_customerid), regexp = "Columns missing: customerid")
19+
})
20+
21+
test_that("Throw error when too many categorical levels", {
22+
expect_error(citrus::validate(preprocess_too_many_categories), regexp = "Categorical Columns have too many levels: faulty_feature")
23+
})
1324

1425
test_that("Unique customer count", {
1526
expect_equal(nrow(output_preprocess), length(unique(output_preprocess[["customerid"]])))
16-
})
27+
})

0 commit comments

Comments
 (0)