Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 72e0703

Browse files
Merge pull request #408 from nicolasvasilache/pr/fix-fbcode-issues
Fix fbcode issues
2 parents 9d11b7b + fa8ad51 commit 72e0703

File tree

9 files changed

+58
-52
lines changed

9 files changed

+58
-52
lines changed

benchmarks_python/caffe2_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def GetArgumentParser():
5151
parser.add_argument("--tuner_cache_file", type=str,
5252
default="tuner_cache",
5353
help="File to store tuned mapping options")
54-
parser.add_argument("--tuner_gpus", type=str,
54+
parser.add_argument("--tuner_devices", type=str,
5555
default="0",
5656
help="String representation of gpus to use for tuning (e.g. \"0,1\")")
5757
parser.add_argument("--tuner_threads", type=int, default=10,
@@ -70,7 +70,7 @@ def main():
7070
core.GlobalInit([
7171
'tc_bench',
7272
'--caffe2_logging_operator_dyno_sampling_rate=0',
73-
'--tuner_gpus=' + args.tuner_gpus,
73+
'--tuner_devices=' + args.tuner_devices,
7474
'--caffe2_simple_net_benchmark_run_whole_net=0',
7575
] + extra_args)
7676
mapping_options = tune(args)

tc/core/polyhedral/codegen_llvm.cc

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -614,31 +614,32 @@ IslCodegenRes codegenISL(const Scop& scop) {
614614
auto collectIteratorMaps =
615615
[](isl::ast_node node,
616616
isl::ast_build build,
617-
IteratorMapsType& iteratorMaps,
618-
const Scop& scop,
619-
StmtSubscriptExprMapType& stmtSubscripts) -> isl::ast_node {
617+
IteratorMapsType& iteratorMapsInFun,
618+
const Scop& scopInFun,
619+
StmtSubscriptExprMapType& stmtSubscriptsInFun) -> isl::ast_node {
620620
auto user = node.as<isl::ast_node_user>();
621621
CHECK(user);
622622
auto expr = user.get_expr().as<isl::ast_expr_op>();
623623
auto schedule = build.get_schedule();
624624
auto scheduleMap = isl::map::from_union_map(schedule);
625625

626626
auto stmtId = expr.get_arg(0).as<isl::ast_expr_id>().get_id();
627-
CHECK_EQ(0u, iteratorMaps.count(stmtId)) << "entry exists: " << stmtId;
627+
CHECK_EQ(0u, iteratorMapsInFun.count(stmtId))
628+
<< "entry exists: " << stmtId;
628629
auto iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
629-
auto iterators = scop.halide.iterators.at(stmtId);
630-
auto& stmtIteratorMap = iteratorMaps[stmtId];
630+
auto iterators = scopInFun.halide.iterators.at(stmtId);
631+
auto& stmtIteratorMap = iteratorMapsInFun[stmtId];
631632
for (size_t i = 0; i < iterators.size(); ++i) {
632633
auto expr = build.expr_from(iteratorMap.get_pw_aff(i));
633634
stmtIteratorMap.emplace(iterators[i], expr);
634635
}
635-
auto& subscripts = stmtSubscripts[stmtId];
636-
auto provide =
637-
scop.halide.statements.at(stmtId).as<Halide::Internal::Provide>();
636+
auto& subscripts = stmtSubscriptsInFun[stmtId];
637+
auto provide = scopInFun.halide.statements.at(stmtId)
638+
.as<Halide::Internal::Provide>();
638639
for (auto e : provide->args) {
639640
const auto& map = iteratorMap;
640641
auto space = map.get_space().params();
641-
auto aff = scop.makeIslAffFromStmtExpr(stmtId, space, e);
642+
auto aff = scopInFun.makeIslAffFromStmtExpr(stmtId, space, e);
642643
auto pulled = isl::pw_aff(aff).pullback(map);
643644
CHECK_EQ(pulled.n_piece(), 1);
644645
subscripts.push_back(build.expr_from(pulled));

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,16 @@ isl::union_set modifyMappingNames(
463463
space = space.set_dim_name(isl::dim_type::param, dim, name + suffix);
464464
}
465465
auto newSet = isl::union_set::empty(space);
466-
set.foreach_set([&newSet, &identifiers, &suffix](isl::set set) {
466+
set.foreach_set([&newSet, &identifiers, &suffix](isl::set setInFun) {
467467
for (auto id : identifiers) {
468468
auto name = id.get_name();
469-
auto dim = set.get_space().find_dim_by_name(isl::dim_type::param, name);
469+
auto dim =
470+
setInFun.get_space().find_dim_by_name(isl::dim_type::param, name);
470471
CHECK_LE(0, dim);
471-
set = set.set_dim_name(isl::dim_type::param, dim, name + suffix);
472+
setInFun =
473+
setInFun.set_dim_name(isl::dim_type::param, dim, name + suffix);
472474
}
473-
newSet = newSet.unite(set);
475+
newSet = newSet.unite(setInFun);
474476
});
475477
return newSet;
476478
}

tc/library/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
namespace tc {
1919

20-
std::string replaceString(
20+
inline std::string replaceString(
2121
std::string str,
2222
const std::string& search,
2323
const std::string& replace) {

test/caffe2/test_harness-inl.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@ namespace caffe2 {
1919

2020
namespace detail {
2121

22-
std::mutex& RNGMutex() {
23-
static std::mutex rng_mutex;
24-
return rng_mutex;
25-
}
22+
std::mutex& RNGMutex();
2623

2724
template <typename T>
2825
T* NewTensor(
@@ -97,9 +94,9 @@ at::Tensor MakeAtenTensor(
9794

9895
template <
9996
typename Backend,
100-
class IterableInputs = std::initializer_list<string>,
101-
class IterableOutputs = std::initializer_list<string>,
102-
class IterableArgs = std::initializer_list<Argument>>
97+
class IterableInputs,
98+
class IterableOutputs,
99+
class IterableArgs>
103100
OperatorDef MakeOperatorDef(
104101
std::string type,
105102
IterableInputs ins,

test/caffe2/test_harness.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@
1616
#include "test_harness.h"
1717

1818
namespace caffe2 {
19+
namespace detail {
20+
21+
std::mutex& RNGMutex() {
22+
static std::mutex rng_mutex;
23+
return rng_mutex;
24+
}
25+
26+
} // namespace detail
27+
28+
ReferenceImplementationBuilder MakeDefaultReferenceImplementationBuilder() {
29+
return [](const OperatorDef& op_def, NetDef* net_def) {
30+
caffe2::ReferenceImplementationRegistry::Append(net_def, op_def);
31+
};
32+
}
1933

2034
void CheckEqual(
2135
const caffe2::Tensor<caffe2::CPUContext>& Texpected,

test/caffe2/test_harness.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,7 @@ at::Tensor MakeAtenTensor(
8787
using ReferenceImplementationBuilder =
8888
std::function<void(const OperatorDef& op_def, NetDef* net_def)>;
8989

90-
ReferenceImplementationBuilder MakeDefaultReferenceImplementationBuilder() {
91-
return [](const OperatorDef& op_def, NetDef* net_def) {
92-
caffe2::ReferenceImplementationRegistry::Append(net_def, op_def);
93-
};
94-
}
90+
ReferenceImplementationBuilder MakeDefaultReferenceImplementationBuilder();
9591

9692
/// Creates an OperatorDef for a particular Backend
9793
/// op_name is the name of the operator (e.g. TcOp)

test/cuda/test_tc_mapper.cc

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -253,18 +253,17 @@ TEST_F(TcCudaMapperTest, BatchTripleHadamard) {
253253
at::Tensor V = at::CUDA(at::kFloat).rand({B, D});
254254
at::Tensor W = at::CUDA(at::kFloat).rand({B, D});
255255
std::vector<at::Tensor> inputs = {U, V, W};
256-
std::vector<at::Tensor> outputs;
257256

258257
static constexpr auto TC = R"TC(
259258
def batch_triple_hadamard(float(B, D) U, float(B, D) V, float(B, D) W) -> (Z) {
260259
Z(b, d) = U(b, d) * V(b, d) * W(b, d)
261260
}
262261
)TC";
263262

264-
auto checkFun = [=](const std::vector<at::Tensor>& inputs,
265-
std::vector<at::Tensor>& outputs) {
266-
at::Tensor diff = outputs[0].sub(inputs[0] * inputs[1] * inputs[2]);
267-
checkRtol(diff, inputs, D);
263+
auto checkFun = [=](const std::vector<at::Tensor>& ins,
264+
std::vector<at::Tensor>& outs) {
265+
at::Tensor diff = outs[0].sub(ins[0] * ins[1] * ins[2]);
266+
checkRtol(diff, ins, D);
268267
};
269268
Check(
270269
TC,
@@ -283,16 +282,15 @@ TEST_F(TcCudaMapperTest, TensorDot) {
283282
at::Tensor I0 = at::CUDA(at::kFloat).rand({N, C1, C2, H, W});
284283
at::Tensor I1 = at::CUDA(at::kFloat).rand({N, C2, C3, H, W});
285284
std::vector<at::Tensor> inputs = {I0, I1};
286-
std::vector<at::Tensor> outputs;
287285

288286
static constexpr auto TC = R"TC(
289287
def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) {
290288
O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w)
291289
}
292290
)TC";
293291
// No defaults for this case
294-
auto checkFun = [](const std::vector<at::Tensor>& inputs,
295-
std::vector<at::Tensor>& outputs) { return true; };
292+
auto checkFun = [](const std::vector<at::Tensor>& ins,
293+
std::vector<at::Tensor>& outs) { return true; };
296294
auto options = tc::CudaMappingOptions::makeNaiveMappingOptions();
297295
auto name = "tensordot";
298296
Check(TC, name, options, inputs, checkFun);
@@ -309,21 +307,20 @@ TEST_F(TcCudaMapperTest, LUT) {
309307
at::Tensor I =
310308
at::CUDA(at::kFloat).rand({B, N}).mul_(B).floor_().toType(at::kInt);
311309
std::vector<at::Tensor> inputs = {LUT, I};
312-
std::vector<at::Tensor> outputs;
313310

314311
static constexpr auto TC = R"TC(
315312
def fun(float(B, R) LUT, int32(B, N) I) -> (O) {
316313
O(b, n) +=! LUT(I(b, n), r_r)
317314
}
318315
)TC";
319316

320-
auto checkFun = [=](const std::vector<at::Tensor>& inputs,
321-
std::vector<at::Tensor>& outputs) {
322-
at::Tensor LUT = inputs[0].toBackend(at::kCPU);
323-
at::Tensor I = inputs[1].toBackend(at::kCPU);
324-
at::Tensor O = outputs[0].toBackend(at::kCPU);
325-
auto LUTAccessor = LUT.accessor<float, 2>();
326-
auto IAccessor = I.accessor<int, 2>();
317+
auto checkFun = [=](const std::vector<at::Tensor>& ins,
318+
std::vector<at::Tensor>& outs) {
319+
at::Tensor lut = ins[0].toBackend(at::kCPU);
320+
at::Tensor in = ins[1].toBackend(at::kCPU);
321+
at::Tensor O = outs[0].toBackend(at::kCPU);
322+
auto LUTAccessor = lut.accessor<float, 2>();
323+
auto IAccessor = in.accessor<int, 2>();
327324
auto OAccessor = O.accessor<float, 2>();
328325
for (int b = 0; b < B; b++) {
329326
for (int n = 0; n < N; n++) {
@@ -337,7 +334,7 @@ def fun(float(B, R) LUT, int32(B, N) I) -> (O) {
337334
}
338335
}
339336

340-
checkRtol(O, inputs, 5e-7);
337+
checkRtol(O, ins, 5e-7);
341338
};
342339
Check(
343340
TC,
@@ -361,7 +358,6 @@ TEST_F(TcCudaMapperTest, DISABLED_SpatialBatchNormalization) {
361358
at::Tensor rMeanIn = at::CUDA(at::kFloat).rand({C2});
362359
at::Tensor rVarIn = at::CUDA(at::kFloat).rand({C2});
363360
std::vector<at::Tensor> inputs = {momentum, eps, I, rMeanIn, rVarIn};
364-
std::vector<at::Tensor> outputs;
365361

366362
static constexpr auto TC = R"TC(
367363
def spatial_batch_norm(
@@ -382,8 +378,8 @@ def spatial_batch_norm(
382378
normalizedOut(n, c, h, w) = O(n, c, h, w)
383379
})TC";
384380

385-
auto checkFun = [=](const std::vector<at::Tensor>& inputs,
386-
std::vector<at::Tensor>& outputs) {
381+
auto checkFun = [=](const std::vector<at::Tensor>& ins,
382+
std::vector<at::Tensor>& outs) {
387383
TC_CUDA_RUNTIMEAPI_ENFORCE(cudaDeviceSynchronize());
388384
double prec = 3e-7;
389385
std::cout << "Checking expected output relative precision @" << prec;
@@ -400,8 +396,8 @@ def spatial_batch_norm(
400396
at::Scalar(momentum[0]).toFloat(),
401397
at::Scalar(eps[0]).toFloat(),
402398
true);
403-
auto diff = O.sub(outputs[0]);
404-
checkRtol(diff, inputs, N * H * W, prec);
399+
auto diff = O.sub(outs[0]);
400+
checkRtol(diff, ins, N * H * W, prec);
405401
};
406402

407403
auto name = "spatial_batch_norm";

test_python/test_c2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
MATMUL_LANG = """
3333
def matmul(float(M,N) A, float(N,K) B) -> (output) {
34-
output(m, n) +=! A(m, r_n) * B(r_n, k)
34+
output(m, k) +=! A(m, r_n) * B(r_n, k)
3535
}
3636
"""
3737

0 commit comments

Comments
 (0)