Skip to content

Commit 1b5a6c6

Browse files
committed
test fixed
1 parent 16476a9 commit 1b5a6c6

File tree

1 file changed

+26
-42
lines changed

1 file changed

+26
-42
lines changed

test/DecisionTreeExt/tree.jl

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -81,46 +81,30 @@ printmodel.(sort(interesting_rules, by = readmetrics); show_metrics = (; round_d
8181
# Data Validation #
8282
# ---------------------------------------------------------------------------- #
8383
@testset "data validation" begin
84-
Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree
85-
86-
for train_ratio in 0.5:0.1:0.9
87-
for seed in 1:40
88-
train, test = partition(eachindex(y), train_ratio; shuffle=true, rng=Xoshiro(seed))
89-
X_train, y_train = X[train, :], y[train]
90-
X_test, y_test = X[test, :], y[test]
91-
92-
for max_depth in 2:1:6
93-
# solemodel
94-
model = Tree(; max_depth, rng=Xoshiro(seed))
95-
mach = machine(model, X_train, y_train)
96-
fit!(mach, verbosity=0)
97-
solem = solemodel(MLJ.fitted_params(mach).tree)
98-
preds = apply!(solem, X_test, y_test)
99-
100-
# decisiontree
101-
dt_model = DT.build_tree(y_train, Matrix(X_train), 0, max_depth; rng=Xoshiro(seed))
102-
dt_preds = DT.apply_tree(dt_model, Matrix(X_test))
103-
104-
@test preds == dt_preds
105-
end
106-
end
107-
end
84+
Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree
85+
86+
for train_ratio in 0.5:0.1:0.9
87+
for seed in 1:40
88+
train, test = partition(eachindex(y), train_ratio; shuffle=true, rng=Xoshiro(seed))
89+
X_train, y_train = X[train, :], y[train]
90+
X_test, y_test = X[test, :], y[test]
91+
92+
for max_depth in 2:1:6
93+
# solemodel
94+
model = Tree(; max_depth, rng=Xoshiro(seed))
95+
mach = machine(model, X_train, y_train)
96+
fit!(mach, verbosity=0)
97+
solem = solemodel(MLJ.fitted_params(mach).tree)
98+
preds = apply!(solem, X_test, y_test)
99+
100+
# decisiontree
101+
y_coded_train = @. CategoricalArrays.levelcode(y_train)
102+
dt_model = DT.build_tree(y_coded_train, Matrix(X_train), 0, max_depth; rng=Xoshiro(seed))
103+
dt_preds = DT.apply_tree(dt_model, Matrix(X_test))
104+
105+
preds_coded = CategoricalArrays.levelcode.(CategoricalArray(preds))
106+
@test preds_coded == dt_preds
107+
end
108+
end
109+
end
108110
end
109-
110-
### the problem rises in fit! in MLJDecisionTreeInterface
111-
Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree
112-
seed = 1
113-
max_depth = 3
114-
train_ratio = 0.5
115-
train, test = partition(eachindex(y), train_ratio; shuffle=true, rng=Xoshiro(seed))
116-
X_train, y_train = X[train, :], y[train]
117-
X_test, y_test = X[test, :], y[test]
118-
119-
model = Tree(; max_depth, rng=Xoshiro(seed))
120-
mach = machine(model, X_train, y_train)
121-
fit!(mach, verbosity=0)
122-
solem = solemodel(MLJ.fitted_params(mach).tree)
123-
preds = apply!(solem, X_test, y_test)
124-
125-
dt_model = DT.build_tree(y_train, Matrix(X_train), 0, max_depth; rng=Xoshiro(seed))
126-
dt_preds = DT.apply_tree(dt_model, Matrix(X_test))

0 commit comments

Comments
 (0)