Skip to content

Commit 7bffa33

Browse files
authored
Make seed independent of num.threads and add legacy option (#1447)
1 parent 8b08d83 commit 7bffa33

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2245
-2104
lines changed

REFERENCE.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,13 +403,11 @@ While the algorithm in `regression_forest` is very similar to that of classic ra
403403

404404
Overall, GRF is designed to produce the same estimates across platforms when using a consistent value for the random seed through the training option seed. However, there are still some cases where GRF can produce different estimates across platforms. When it comes to cross-platform predictions, the output of GRF will depend on a few factors beyond the forest seed.
405405

406-
One such factor is the compiler that was used to build GRF. Different compilers may have different default behavior around floating-point rounding, and these could lead to slightly different forest splits if the data requires numerical precision. Another factor is how the forest construction is distributed across different threads. Right now, our forest splitting algorithm can give different results depending on the number of threads that were used to build the forest.
406+
One such factor is the compiler that was used to build GRF. Different compilers may have different default behavior around floating-point behavior and instruction optimizations, and these could lead to slightly different forest splits if the data requires numerical precision. In addition to setting the seed argument, rounding all input data to at most 8 significant digits may help.
407407

408-
Therefore, in order to ensure consistent results, we provide the following recommendations.
409-
- Make sure arguments `seed` and `num.threads` are the same across platforms
410-
- Round data to 8 significant digits
408+
Even though the compiler is the same, different CPU architectures may produce slightly different output. One such example is GRF compiled with clang and run on x86 (Intel) vs. ARM (Apple Silicon).
411409

412-
Also, please note that we have not done extensive testing on Windows platforms, although we do not expect random number generation issues there to be different from Linux/Mac. Regardless of the platform, if results are still not consistent please help us by submitting a Github issue.
410+
Prior to GRF version 2.4.0, another factor was how the forest construction was distributed across different threads. In these versions, our forest splitting algorithm can give different results depending on the number of threads used to build the forest, meaning that the num.threads argument had to be the same for cross-platform reproducibility. To restore this behavior in current versions of GRF, you can set the global R option `options(grf.legacy.seed=TRUE)` and exactly recover results produced with past versions of the package.
413411

414412

415413
## References

core/src/forest/ForestOptions.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ ForestOptions::ForestOptions(uint num_trees,
3737
double imbalance_penalty,
3838
uint num_threads,
3939
uint random_seed,
40+
bool legacy_seed,
4041
const std::vector<size_t>& sample_clusters,
4142
uint samples_per_cluster):
4243
ci_group_size(ci_group_size),
4344
sample_fraction(sample_fraction),
4445
tree_options(mtry, min_node_size, honesty, honesty_fraction, honesty_prune_leaves, alpha, imbalance_penalty),
4546
sampling_options(samples_per_cluster, sample_clusters),
46-
random_seed(random_seed) {
47+
random_seed(random_seed),
48+
legacy_seed(legacy_seed) {
4749

4850
this->num_threads = validate_num_threads(num_threads);
4951

@@ -85,6 +87,10 @@ uint ForestOptions::get_random_seed() const {
8587
return random_seed;
8688
}
8789

90+
bool ForestOptions::get_legacy_seed() const {
91+
return legacy_seed;
92+
}
93+
8894
uint ForestOptions::validate_num_threads(uint num_threads) {
8995
if (num_threads == DEFAULT_NUM_THREADS) {
9096
return std::thread::hardware_concurrency();

core/src/forest/ForestOptions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class ForestOptions {
4141
double imbalance_penalty,
4242
uint num_threads,
4343
uint random_seed,
44+
bool legacy_seed,
4445
const std::vector<size_t>& sample_clusters,
4546
uint samples_per_cluster);
4647

@@ -55,6 +56,8 @@ class ForestOptions {
5556

5657
uint get_num_threads() const;
5758
uint get_random_seed() const;
59+
// Toggle between seed and num_threads dependence to reproduce behavior prior to grf 2.4.0.
60+
bool get_legacy_seed() const;
5861

5962
private:
6063
uint num_trees;
@@ -66,6 +69,7 @@ class ForestOptions {
6669

6770
uint num_threads;
6871
uint random_seed;
72+
bool legacy_seed;
6973
};
7074

7175
} // namespace grf

core/src/forest/ForestTrainer.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ std::vector<std::unique_ptr<Tree>> ForestTrainer::train_batch(
107107
trees.reserve(num_trees * ci_group_size);
108108

109109
for (size_t i = 0; i < num_trees; i++) {
110-
uint tree_seed = udist(random_number_generator);
110+
uint tree_seed;
111+
if (options.get_legacy_seed()) {
112+
tree_seed = udist(random_number_generator);
113+
} else {
114+
tree_seed = static_cast<uint>(options.get_random_seed() + start + i);
115+
}
111116
RandomSampler sampler(tree_seed, options.get_sampling_options());
112117

113118
if (ci_group_size == 1) {

core/test/forest/ForestSmokeTest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ TEST_CASE("forests don't crash when there are fewer trees than threads", "[fores
5252
uint samples_per_cluster = 0;
5353

5454
ForestOptions options(num_trees, ci_group_size, sample_fraction, mtry, min_node_size, honesty, honesty_fraction,
55-
prune, alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
55+
prune, alpha, imbalance_penalty, num_threads, seed, true, empty_clusters, samples_per_cluster);
5656

5757
Forest forest = trainer.train(data, options);
5858
ForestPredictor predictor = regression_predictor(4);

core/test/forest/LocalLinearForestTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ TEST_CASE("LLF gives reasonable prediction on friedman data", "[local linear], [
4949
ForestOptions options (
5050
num_trees, ci_group_size, sample_fraction,
5151
mtry, min_node_size, honesty, honesty_fraction, prune,
52-
alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
52+
alpha, imbalance_penalty, num_threads, seed, true, empty_clusters, samples_per_cluster);
5353
ForestTrainer trainer = regression_trainer();
5454
Forest forest = trainer.train(data, options);
5555

@@ -136,7 +136,7 @@ TEST_CASE("local linear forests give reasonable variance estimates", "[regressio
136136
ForestOptions options (
137137
num_trees, ci_group_size, sample_fraction,
138138
mtry, min_node_size, honesty, honesty_fraction, prune,
139-
alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
139+
alpha, imbalance_penalty, num_threads, seed, true, empty_clusters, samples_per_cluster);
140140
ForestTrainer trainer = regression_trainer();
141141
Forest forest = trainer.train(data, options);
142142

core/test/utilities/ForestTestUtilities.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ ForestOptions ForestTestUtilities::default_options(bool honesty,
4242
uint samples_per_cluster = 0;
4343
uint num_threads = 4;
4444
uint seed = 42;
45+
bool legacy_seed = true;
4546

4647
return ForestOptions(num_trees,
4748
ci_group_size, sample_fraction, mtry, min_node_size, honesty, honesty_fraction,
48-
prune, alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
49+
prune, alpha, imbalance_penalty, num_threads, seed, legacy_seed, empty_clusters, samples_per_cluster);
4950
}

r-package/grf/DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Imports:
2929
methods,
3030
Rcpp (>= 0.12.15),
3131
sandwich (>= 2.4-0)
32-
RoxygenNote: 7.2.3
32+
RoxygenNote: 7.3.2
3333
Suggests:
3434
DiagrammeR,
3535
MASS,

r-package/grf/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ export(get_leaf_node)
3838
export(get_sample_weights)
3939
export(get_scores)
4040
export(get_tree)
41+
export(grf_options)
4142
export(instrumental_forest)
4243
export(ll_regression_forest)
4344
export(lm_forest)

0 commit comments

Comments
 (0)