Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 52cc55b

Browse files
[C++ API] Graduate set/getAtenSeed
These functions were confined to a test harness but they are more widely useful. Graduate them to aten.h
1 parent 674ed01 commit 52cc55b

14 files changed

+24
-21
lines changed

tc/aten/aten-inl.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,15 @@ inline std::vector<DLConstTensorUPtr> makeDLConstTensors(
4646
}
4747
return dlTensors;
4848
}
49+
50+
inline void setAtenSeed(uint64_t seed, at::Backend backend) {
51+
at::Generator& gen = at::globalContext().defaultGenerator(backend);
52+
gen.manualSeed(seed);
53+
}
54+
55+
inline uint64_t getAtenSeed(at::Backend backend) {
56+
at::Generator& gen = at::globalContext().defaultGenerator(backend);
57+
return gen.seed();
58+
}
4959
} // namespace aten
5060
} // namespace tc

tc/aten/aten.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ inline std::vector<DLTensorUPtr> makeDLTensors(
3131
inline std::vector<DLConstTensorUPtr> makeDLConstTensors(
3232
const std::vector<at::Tensor>& tensors);
3333

34+
inline void setAtenSeed(uint64_t seed, at::Backend backend);
35+
inline uint64_t getAtenSeed(at::Backend backend);
36+
3437
} // namespace aten
3538
} // namespace tc
3639

tc/benchmarks/MLP_model.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,6 @@ int main(int argc, char** argv) {
10011001
::testing::InitGoogleTest(&argc, argv);
10021002
::gflags::ParseCommandLineFlags(&argc, &argv, true);
10031003
::google::InitGoogleLogging(argv[0]);
1004-
setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
1004+
tc::aten::setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
10051005
return RUN_ALL_TESTS();
10061006
}

tc/benchmarks/batchmatmul.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,6 @@ int main(int argc, char** argv) {
182182
::testing::InitGoogleTest(&argc, argv);
183183
::gflags::ParseCommandLineFlags(&argc, &argv, true);
184184
::google::InitGoogleLogging(argv[0]);
185-
setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
185+
tc::aten::setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
186186
return RUN_ALL_TESTS();
187187
}

tc/benchmarks/group_convolution.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,6 @@ int main(int argc, char** argv) {
380380
::testing::InitGoogleTest(&argc, argv);
381381
::gflags::ParseCommandLineFlags(&argc, &argv, true);
382382
::google::InitGoogleLogging(argv[0]);
383-
setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
383+
tc::aten::setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
384384
return RUN_ALL_TESTS();
385385
}

tc/benchmarks/tmm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,6 @@ int main(int argc, char** argv) {
231231
::testing::InitGoogleTest(&argc, argv);
232232
::gflags::ParseCommandLineFlags(&argc, &argv, true);
233233
::google::InitGoogleLogging(argv[0]);
234-
setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
234+
tc::aten::setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
235235
return RUN_ALL_TESTS();
236236
}

tc/examples/blockdiagperm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,6 @@ int main(int argc, char** argv) {
133133
::testing::InitGoogleTest(&argc, argv);
134134
::gflags::ParseCommandLineFlags(&argc, &argv, true);
135135
::google::InitGoogleLogging(argv[0]);
136-
setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
136+
tc::aten::setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
137137
return RUN_ALL_TESTS();
138138
}

tc/examples/tensordot.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,6 @@ int main(int argc, char** argv) {
124124
::testing::InitGoogleTest(&argc, argv);
125125
::gflags::ParseCommandLineFlags(&argc, &argv, true);
126126
::google::InitGoogleLogging(argv[0]);
127-
setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
127+
tc::aten::setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
128128
return RUN_ALL_TESTS();
129129
}

tc/examples/wavenet.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,6 @@ int main(int argc, char** argv) {
174174
::testing::InitGoogleTest(&argc, argv);
175175
::gflags::ParseCommandLineFlags(&argc, &argv, true);
176176
::google::InitGoogleLogging(argv[0]);
177-
setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
177+
tc::aten::setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
178178
return RUN_ALL_TESTS();
179179
}

test/cuda/test_compile_and_run.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,6 @@ int main(int argc, char** argv) {
278278
::testing::InitGoogleTest(&argc, argv);
279279
::gflags::ParseCommandLineFlags(&argc, &argv, true);
280280
::google::InitGoogleLogging(argv[0]);
281-
setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
281+
tc::aten::setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
282282
return RUN_ALL_TESTS();
283283
}

0 commit comments

Comments
 (0)