Skip to content

Commit f3d858d

Browse files
authored
Use parent forest num.threads in auxiliary forest (#1437)
1 parent 49aac3d commit f3d858d

11 files changed

+16
-3
lines changed

r-package/grf/R/causal_forest.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ causal_forest <- function(X, Y, W,
278278
forest <- do.call.rcpp(causal_train, c(data, args))
279279
class(forest) <- c("causal_forest", "grf")
280280
forest[["seed"]] <- seed
281+
forest[["num.threads"]] <- num.threads
281282
forest[["ci.group.size"]] <- ci.group.size
282283
forest[["X.orig"]] <- X
283284
forest[["Y.orig"]] <- Y

r-package/grf/R/causal_survival_forest.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ causal_survival_forest <- function(X, Y, W, D,
378378
forest <- do.call.rcpp(causal_survival_train, c(data, args))
379379
class(forest) <- c("causal_survival_forest", "grf")
380380
forest[["seed"]] <- seed
381+
forest[["num.threads"]] <- num.threads
381382
forest[["_psi"]] <- psi
382383
forest[["X.orig"]] <- X
383384
forest[["Y.orig"]] <- Y

r-package/grf/R/get_scores.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ get_scores.causal_forest <- function(forest,
7979
sample.weights = forest$sample.weights,
8080
num.trees = num.trees.for.weights,
8181
ci.group.size = 1,
82-
seed = forest$seed)
82+
seed = forest$seed,
83+
num.threads = forest$num.threads)
8384
V.hat <- predict(variance_forest)$predictions
8485
debiasing.weights.all <- (forest$W.orig - forest$W.hat) / V.hat
8586
debiasing.weights <- debiasing.weights.all[subset]
@@ -178,7 +179,8 @@ get_scores.instrumental_forest <- function(forest,
178179
sample.weights = forest$sample.weights,
179180
clusters = clusters,
180181
num.trees = num.trees.for.weights,
181-
seed = forest$seed)
182+
seed = forest$seed,
183+
num.threads = forest$num.threads)
182184
compliance.score <- predict(compliance.forest)$predictions
183185
compliance.score <- compliance.score[subset]
184186
} else if (length(compliance.score) == length(forest$Y.orig)) {
@@ -342,7 +344,8 @@ get_scores.causal_survival_forest <- function(forest,
342344
sample.weights = forest$sample.weights,
343345
num.trees = num.trees.for.weights,
344346
ci.group.size = 1,
345-
seed = forest$seed)
347+
seed = forest$seed,
348+
num.threads = forest$num.threads)
346349
V.hat <- predict(variance_forest)$predictions[subset]
347350
}
348351

r-package/grf/R/instrumental_forest.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ instrumental_forest <- function(X, Y, W, Z,
245245
forest <- do.call.rcpp(instrumental_train, c(data, args))
246246
class(forest) <- c("instrumental_forest", "grf")
247247
forest[["seed"]] <- seed
248+
forest[["num.threads"]] <- num.threads
248249
forest[["ci.group.size"]] <- ci.group.size
249250
forest[["X.orig"]] <- X
250251
forest[["Y.orig"]] <- Y

r-package/grf/R/ll_regression_forest.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ ll_regression_forest <- function(X, Y,
201201

202202
class(forest) <- c("ll_regression_forest", "grf")
203203
forest[["seed"]] <- seed
204+
forest[["num.threads"]] <- num.threads
204205
forest[["ci.group.size"]] <- ci.group.size
205206
forest[["X.orig"]] <- X
206207
forest[["Y.orig"]] <- Y

r-package/grf/R/multi_arm_causal_forest.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ multi_arm_causal_forest <- function(X, Y, W,
289289
forest <- do.call.rcpp(multi_causal_train, c(data, args))
290290
class(forest) <- c("multi_arm_causal_forest", "grf")
291291
forest[["seed"]] <- seed
292+
forest[["num.threads"]] <- num.threads
292293
forest[["ci.group.size"]] <- ci.group.size
293294
forest[["X.orig"]] <- X
294295
forest[["Y.orig"]] <- Y

r-package/grf/R/multi_regression_forest.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ multi_regression_forest <- function(X, Y,
110110
forest <- do.call.rcpp(multi_regression_train, c(data, args))
111111
class(forest) <- c("multi_regression_forest", "grf")
112112
forest[["seed"]] <- seed
113+
forest[["num.threads"]] <- num.threads
113114
forest[["X.orig"]] <- X
114115
forest[["Y.orig"]] <- Y
115116
forest[["sample.weights"]] <- sample.weights

r-package/grf/R/probability_forest.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ probability_forest <- function(X, Y,
137137
forest <- do.call.rcpp(probability_train, c(data, args))
138138
class(forest) <- c("probability_forest", "grf")
139139
forest[["seed"]] <- seed
140+
forest[["num.threads"]] <- num.threads
140141
forest[["X.orig"]] <- X
141142
forest[["Y.orig"]] <- Y
142143
forest[["Y.relabeled"]] <- Y.relabeled

r-package/grf/R/quantile_forest.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ quantile_forest <- function(X, Y,
132132
forest <- do.call.rcpp(quantile_train, c(data, args))
133133
class(forest) <- c("quantile_forest", "grf")
134134
forest[["seed"]] <- seed
135+
forest[["num.threads"]] <- num.threads
135136
forest[["X.orig"]] <- X
136137
forest[["Y.orig"]] <- Y
137138
forest[["quantiles.orig"]] <- quantiles

r-package/grf/R/regression_forest.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ regression_forest <- function(X, Y,
168168
forest <- do.call.rcpp(regression_train, c(data, args))
169169
class(forest) <- c("regression_forest", "grf")
170170
forest[["seed"]] <- seed
171+
forest[["num.threads"]] <- num.threads
171172
forest[["ci.group.size"]] <- ci.group.size
172173
forest[["X.orig"]] <- X
173174
forest[["Y.orig"]] <- Y

r-package/grf/R/survival_forest.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ survival_forest <- function(X, Y, D,
173173
forest <- do.call.rcpp(survival_train, c(data, args))
174174
class(forest) <- c("survival_forest", "grf")
175175
forest[["seed"]] <- seed
176+
forest[["num.threads"]] <- num.threads
176177
forest[["X.orig"]] <- X
177178
forest[["Y.orig"]] <- Y
178179
forest[["Y.relabeled"]] <- Y.relabeled

0 commit comments

Comments
 (0)