Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 49a2965

Browse files
authored
Merge pull request #516 from facebookresearch/convolution-benchmark
add convolution benchmark
2 parents f6265e0 + fcdfa02 commit 49a2965

File tree

2 files changed

+195
-0
lines changed

2 files changed

+195
-0
lines changed

tc/benchmarks/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ find_library(CUDA_CUDNN_LIBRARIES cudnn
1717
################################################################################
1818
set(BENCHMARKS
1919
batchmatmul
20+
convolution
2021
group_convolution
2122
group_normalization
2223
kronecker

tc/benchmarks/convolution.cc

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <iostream>
17+
#include <string>
18+
#include <vector>
19+
20+
#include <gflags/gflags.h>
21+
#include <glog/logging.h>
22+
#include <gtest/gtest.h>
23+
24+
#include "tc/aten/aten.h"
25+
26+
#include "tc/aten/aten_compiler.h"
27+
#include "tc/core/cuda/cuda_mapping_options.h"
28+
29+
#include "../test/caffe2/cuda/test_harness.h"
30+
#include "../test/caffe2/test_harness.h"
31+
#include "../test/test_harness_aten_cuda.h"
32+
#include "benchmark_fixture.h"
33+
34+
#include "tc/c2/context.h"
35+
#include "tc/core/cuda/cuda.h"
36+
#include "tc/core/flags.h"
37+
38+
using namespace caffe2;
39+
40+
DEFINE_uint32(N, 32, "Batch size (NCHW notation)");
41+
DEFINE_uint32(C, 4, "Input channels (NCHW notation)");
42+
DEFINE_uint32(F, 4, "Output filters (NCHW notation)");
43+
DEFINE_uint32(H, 56, "Image width (NCHW notation)");
44+
DEFINE_uint32(W, 56, "Image height (NCHW notation)");
45+
DEFINE_uint32(KH, 3, "Kernel width (NCHW notation)");
46+
DEFINE_uint32(KW, 3, "Kernel height (NCHW notation)");
47+
48+
class Convolution : public Benchmark {
49+
protected:
50+
uint32_t N, C, F, H, W, KH, KW;
51+
52+
public:
53+
void Init(
54+
uint32_t n,
55+
uint32_t c,
56+
uint32_t f,
57+
uint32_t h,
58+
uint32_t w,
59+
uint32_t kh,
60+
uint32_t kw) {
61+
N = n;
62+
C = c;
63+
F = f;
64+
H = h;
65+
W = w;
66+
KH = kh;
67+
KW = kw;
68+
}
69+
void runConvolution(const tc::CudaMappingOptions& options);
70+
void runATenConvolution();
71+
void runCaffe2Convolution();
72+
};
73+
74+
void Convolution::runConvolution(const tc::CudaMappingOptions& options) {
75+
Workspace w;
76+
auto AddInput = AddDeterministicallyRandomInput<caffe2::CUDABackend, float>;
77+
AddInput(w, vector<TIndex>{N, C, H, W}, "I");
78+
AddInput(w, vector<TIndex>{F, C, KH, KW}, "W");
79+
AddInput(w, {F}, "B");
80+
81+
Argument kernel_h_arg = MakeArgument<int>("kernel_h", KH);
82+
Argument kernel_w_arg = MakeArgument<int>("kernel_w", KW);
83+
Argument group_arg = MakeArgument<int>("group", 1);
84+
OperatorDef op_def = MakeOperatorDef<caffe2::CUDABackend>(
85+
"Conv", {"I", "W", "B"}, {"O"}, {group_arg, kernel_h_arg, kernel_w_arg});
86+
87+
std::unique_ptr<OperatorBase> net(CreateOperator(op_def, &w));
88+
ASSERT_TRUE(net.get());
89+
net->Run();
90+
caffe2::Tensor<caffe2::CUDAContext> expected_blob(
91+
w.GetBlob("O")->Get<caffe2::TensorCUDA>());
92+
93+
at::Tensor ref_output =
94+
MakeAtenTensor(expected_blob, at::Backend::CUDA, at::kFloat)
95+
.resize_({N, F, H - KH + 1, W - KW + 1});
96+
97+
auto check_fun = [&, ref_output](
98+
const std::vector<at::Tensor>& inputs,
99+
const std::vector<at::Tensor>& outputs) {
100+
TC_CUDA_RUNTIMEAPI_ENFORCE(cudaDeviceSynchronize());
101+
double prec = 1e-5; // relax precision to account for CUDNN Winograd kernels
102+
std::cout << "Checking expected output relative precision @" << prec;
103+
checkRtol(outputs[0].sub(ref_output), inputs, C * KH * KW, prec);
104+
return true;
105+
};
106+
107+
// Use the underlying C2 tensors CUDA pointers
108+
auto tI = GetNamedTensor<CUDABackend>(w, "I");
109+
at::Tensor t_i =
110+
MakeAtenTensor(tI, at::Backend::CUDA, at::kFloat).resize_({N, C, H, W});
111+
auto tW = GetNamedTensor<CUDABackend>(w, "W");
112+
at::Tensor t_w =
113+
MakeAtenTensor(tW, at::Backend::CUDA, at::kFloat).resize_({F, C, KH, KW});
114+
auto tB = GetNamedTensor<CUDABackend>(w, "B");
115+
at::Tensor t_b =
116+
MakeAtenTensor(tB, at::Backend::CUDA, at::kFloat).resize_({F});
117+
std::vector<at::Tensor> inputs = {t_i, t_w, t_b};
118+
std::string tc = R"(
119+
def convolution(float(N,C,H,W) I, float(F,C,KH,KW) W1, float(F) B)
120+
-> (O)
121+
{
122+
O(n, f, h, w) +=!
123+
I(n, r_c, h + r_kh, w + r_kw) * W1(f, r_c, r_kh, r_kw)
124+
O(n, f, h, w) = O(n, f, h, w) + B(f)
125+
}
126+
)";
127+
128+
std::string suffix = std::string("_N_") + std::to_string(FLAGS_N) +
129+
std::string("_C_") + std::to_string(FLAGS_C) + std::string("_F_") +
130+
std::to_string(FLAGS_F) + std::string("_W_") + std::to_string(FLAGS_W) +
131+
std::string("_H_") + std::to_string(FLAGS_H) + std::string("_KW_") +
132+
std::to_string(FLAGS_KW) + std::string("_KH_") + std::to_string(FLAGS_KH);
133+
std::vector<tc::CudaMappingOptions> bestOptions{options};
134+
if (FLAGS_autotune) {
135+
autotune(tc, "convolution", inputs, options, check_fun);
136+
}
137+
Check(tc, "convolution", options, inputs, check_fun);
138+
}
139+
140+
void Convolution::runATenConvolution() {
141+
Reference(
142+
[&]() {
143+
at::Tensor I = at::CUDA(at::kFloat).rand({N, C, W, H});
144+
at::Tensor W = at::CUDA(at::kFloat).rand({F, C, KW, KH});
145+
at::Tensor B = at::CUDA(at::kFloat).rand({F});
146+
return std::vector<at::Tensor>{I, W, B};
147+
},
148+
[&](std::vector<at::Tensor>& inputs) {
149+
auto I = inputs[0];
150+
auto W = inputs[1];
151+
auto B = inputs[2];
152+
return at::cudnn_convolution(
153+
I, W, B, {0, 0}, {1, 1}, {1, 1}, 1, true, false);
154+
});
155+
}
156+
157+
void Convolution::runCaffe2Convolution() {
158+
Workspace w;
159+
auto AddInput = AddDeterministicallyRandomInput<caffe2::CUDABackend, float>;
160+
AddInput(w, vector<TIndex>{N, C, W, H}, "I");
161+
AddInput(w, vector<TIndex>{F, C, KW, KH}, "W");
162+
AddInput(w, {F}, "B");
163+
Argument kernel_h_arg = MakeArgument<int>("kernel_h", KH);
164+
Argument kernel_w_arg = MakeArgument<int>("kernel_w", KW);
165+
Argument group_arg = MakeArgument<int>("group", 1);
166+
OperatorDef ndef = MakeOperatorDef<caffe2::CUDABackend>(
167+
"Conv", {"I", "W", "B"}, {"O"}, {group_arg, kernel_h_arg, kernel_w_arg});
168+
std::unique_ptr<OperatorBase> net(CreateOperator(ndef, &w));
169+
Reference([&]() { return true; }, [&](bool flag) { net->Run(); });
170+
}
171+
172+
// Generic
173+
TEST_F(Convolution, Convolution) {
174+
Init(FLAGS_N, FLAGS_C, FLAGS_F, FLAGS_H, FLAGS_W, FLAGS_KH, FLAGS_KW);
175+
runConvolution(tc::CudaMappingOptions::makeNaiveMappingOptions());
176+
}
177+
178+
TEST_F(Convolution, Convolution_Caffe2) {
179+
Init(FLAGS_N, FLAGS_C, FLAGS_F, FLAGS_H, FLAGS_W, FLAGS_KH, FLAGS_KW);
180+
runCaffe2Convolution();
181+
}
182+
183+
TEST_F(Convolution, Convolution_ATen) {
184+
Init(FLAGS_N, FLAGS_C, FLAGS_F, FLAGS_H, FLAGS_W, FLAGS_KH, FLAGS_KW);
185+
runATenConvolution();
186+
}
187+
188+
int main(int argc, char** argv) {
189+
::testing::InitGoogleTest(&argc, argv);
190+
::gflags::ParseCommandLineFlags(&argc, &argv, true);
191+
::google::InitGoogleLogging(argv[0]);
192+
tc::aten::setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA);
193+
return RUN_ALL_TESTS();
194+
}

0 commit comments

Comments
 (0)