|
| 1 | +/** |
| 2 | + * Copyright (c) 2017-present, Facebook, Inc. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +#include <iostream> |
| 17 | +#include <string> |
| 18 | +#include <vector> |
| 19 | + |
| 20 | +#include <gflags/gflags.h> |
| 21 | +#include <glog/logging.h> |
| 22 | +#include <gtest/gtest.h> |
| 23 | + |
| 24 | +#include "tc/aten/aten.h" |
| 25 | + |
| 26 | +#include "tc/aten/aten_compiler.h" |
| 27 | +#include "tc/core/cuda/cuda_mapping_options.h" |
| 28 | + |
| 29 | +#include "../test/caffe2/cuda/test_harness.h" |
| 30 | +#include "../test/caffe2/test_harness.h" |
| 31 | +#include "../test/test_harness_aten_cuda.h" |
| 32 | +#include "benchmark_fixture.h" |
| 33 | + |
| 34 | +#include "tc/c2/context.h" |
| 35 | +#include "tc/core/cuda/cuda.h" |
| 36 | +#include "tc/core/flags.h" |
| 37 | + |
| 38 | +using namespace caffe2; |
| 39 | + |
| 40 | +DEFINE_uint32(N, 32, "Batch size (NCHW notation)"); |
| 41 | +DEFINE_uint32(C, 4, "Input channels (NCHW notation)"); |
| 42 | +DEFINE_uint32(F, 4, "Output filters (NCHW notation)"); |
| 43 | +DEFINE_uint32(H, 56, "Image width (NCHW notation)"); |
| 44 | +DEFINE_uint32(W, 56, "Image height (NCHW notation)"); |
| 45 | +DEFINE_uint32(KH, 3, "Kernel width (NCHW notation)"); |
| 46 | +DEFINE_uint32(KW, 3, "Kernel height (NCHW notation)"); |
| 47 | + |
| 48 | +class Convolution : public Benchmark { |
| 49 | + protected: |
| 50 | + uint32_t N, C, F, H, W, KH, KW; |
| 51 | + |
| 52 | + public: |
| 53 | + void Init( |
| 54 | + uint32_t n, |
| 55 | + uint32_t c, |
| 56 | + uint32_t f, |
| 57 | + uint32_t h, |
| 58 | + uint32_t w, |
| 59 | + uint32_t kh, |
| 60 | + uint32_t kw) { |
| 61 | + N = n; |
| 62 | + C = c; |
| 63 | + F = f; |
| 64 | + H = h; |
| 65 | + W = w; |
| 66 | + KH = kh; |
| 67 | + KW = kw; |
| 68 | + } |
| 69 | + void runConvolution(const tc::CudaMappingOptions& options); |
| 70 | + void runATenConvolution(); |
| 71 | + void runCaffe2Convolution(); |
| 72 | +}; |
| 73 | + |
| 74 | +void Convolution::runConvolution(const tc::CudaMappingOptions& options) { |
| 75 | + Workspace w; |
| 76 | + auto AddInput = AddDeterministicallyRandomInput<caffe2::CUDABackend, float>; |
| 77 | + AddInput(w, vector<TIndex>{N, C, H, W}, "I"); |
| 78 | + AddInput(w, vector<TIndex>{F, C, KH, KW}, "W"); |
| 79 | + AddInput(w, {F}, "B"); |
| 80 | + |
| 81 | + Argument kernel_h_arg = MakeArgument<int>("kernel_h", KH); |
| 82 | + Argument kernel_w_arg = MakeArgument<int>("kernel_w", KW); |
| 83 | + Argument group_arg = MakeArgument<int>("group", 1); |
| 84 | + OperatorDef op_def = MakeOperatorDef<caffe2::CUDABackend>( |
| 85 | + "Conv", {"I", "W", "B"}, {"O"}, {group_arg, kernel_h_arg, kernel_w_arg}); |
| 86 | + |
| 87 | + std::unique_ptr<OperatorBase> net(CreateOperator(op_def, &w)); |
| 88 | + ASSERT_TRUE(net.get()); |
| 89 | + net->Run(); |
| 90 | + caffe2::Tensor<caffe2::CUDAContext> expected_blob( |
| 91 | + w.GetBlob("O")->Get<caffe2::TensorCUDA>()); |
| 92 | + |
| 93 | + at::Tensor ref_output = |
| 94 | + MakeAtenTensor(expected_blob, at::Backend::CUDA, at::kFloat) |
| 95 | + .resize_({N, F, H - KH + 1, W - KW + 1}); |
| 96 | + |
| 97 | + auto check_fun = [&, ref_output]( |
| 98 | + const std::vector<at::Tensor>& inputs, |
| 99 | + const std::vector<at::Tensor>& outputs) { |
| 100 | + TC_CUDA_RUNTIMEAPI_ENFORCE(cudaDeviceSynchronize()); |
| 101 | + double prec = 1e-5; // relax precision to account for CUDNN Winograd kernels |
| 102 | + std::cout << "Checking expected output relative precision @" << prec; |
| 103 | + checkRtol(outputs[0].sub(ref_output), inputs, C * KH * KW, prec); |
| 104 | + return true; |
| 105 | + }; |
| 106 | + |
| 107 | + // Use the underlying C2 tensors CUDA pointers |
| 108 | + auto tI = GetNamedTensor<CUDABackend>(w, "I"); |
| 109 | + at::Tensor t_i = |
| 110 | + MakeAtenTensor(tI, at::Backend::CUDA, at::kFloat).resize_({N, C, H, W}); |
| 111 | + auto tW = GetNamedTensor<CUDABackend>(w, "W"); |
| 112 | + at::Tensor t_w = |
| 113 | + MakeAtenTensor(tW, at::Backend::CUDA, at::kFloat).resize_({F, C, KH, KW}); |
| 114 | + auto tB = GetNamedTensor<CUDABackend>(w, "B"); |
| 115 | + at::Tensor t_b = |
| 116 | + MakeAtenTensor(tB, at::Backend::CUDA, at::kFloat).resize_({F}); |
| 117 | + std::vector<at::Tensor> inputs = {t_i, t_w, t_b}; |
| 118 | + std::string tc = R"( |
| 119 | +def convolution(float(N,C,H,W) I, float(F,C,KH,KW) W1, float(F) B) |
| 120 | +-> (O) |
| 121 | +{ |
| 122 | + O(n, f, h, w) +=! |
| 123 | + I(n, r_c, h + r_kh, w + r_kw) * W1(f, r_c, r_kh, r_kw) |
| 124 | + O(n, f, h, w) = O(n, f, h, w) + B(f) |
| 125 | +} |
| 126 | +)"; |
| 127 | + |
| 128 | + std::string suffix = std::string("_N_") + std::to_string(FLAGS_N) + |
| 129 | + std::string("_C_") + std::to_string(FLAGS_C) + std::string("_F_") + |
| 130 | + std::to_string(FLAGS_F) + std::string("_W_") + std::to_string(FLAGS_W) + |
| 131 | + std::string("_H_") + std::to_string(FLAGS_H) + std::string("_KW_") + |
| 132 | + std::to_string(FLAGS_KW) + std::string("_KH_") + std::to_string(FLAGS_KH); |
| 133 | + std::vector<tc::CudaMappingOptions> bestOptions{options}; |
| 134 | + if (FLAGS_autotune) { |
| 135 | + autotune(tc, "convolution", inputs, options, check_fun); |
| 136 | + } |
| 137 | + Check(tc, "convolution", options, inputs, check_fun); |
| 138 | +} |
| 139 | + |
| 140 | +void Convolution::runATenConvolution() { |
| 141 | + Reference( |
| 142 | + [&]() { |
| 143 | + at::Tensor I = at::CUDA(at::kFloat).rand({N, C, W, H}); |
| 144 | + at::Tensor W = at::CUDA(at::kFloat).rand({F, C, KW, KH}); |
| 145 | + at::Tensor B = at::CUDA(at::kFloat).rand({F}); |
| 146 | + return std::vector<at::Tensor>{I, W, B}; |
| 147 | + }, |
| 148 | + [&](std::vector<at::Tensor>& inputs) { |
| 149 | + auto I = inputs[0]; |
| 150 | + auto W = inputs[1]; |
| 151 | + auto B = inputs[2]; |
| 152 | + return at::cudnn_convolution( |
| 153 | + I, W, B, {0, 0}, {1, 1}, {1, 1}, 1, true, false); |
| 154 | + }); |
| 155 | +} |
| 156 | + |
| 157 | +void Convolution::runCaffe2Convolution() { |
| 158 | + Workspace w; |
| 159 | + auto AddInput = AddDeterministicallyRandomInput<caffe2::CUDABackend, float>; |
| 160 | + AddInput(w, vector<TIndex>{N, C, W, H}, "I"); |
| 161 | + AddInput(w, vector<TIndex>{F, C, KW, KH}, "W"); |
| 162 | + AddInput(w, {F}, "B"); |
| 163 | + Argument kernel_h_arg = MakeArgument<int>("kernel_h", KH); |
| 164 | + Argument kernel_w_arg = MakeArgument<int>("kernel_w", KW); |
| 165 | + Argument group_arg = MakeArgument<int>("group", 1); |
| 166 | + OperatorDef ndef = MakeOperatorDef<caffe2::CUDABackend>( |
| 167 | + "Conv", {"I", "W", "B"}, {"O"}, {group_arg, kernel_h_arg, kernel_w_arg}); |
| 168 | + std::unique_ptr<OperatorBase> net(CreateOperator(ndef, &w)); |
| 169 | + Reference([&]() { return true; }, [&](bool flag) { net->Run(); }); |
| 170 | +} |
| 171 | + |
| 172 | +// Generic |
| 173 | +TEST_F(Convolution, Convolution) { |
| 174 | + Init(FLAGS_N, FLAGS_C, FLAGS_F, FLAGS_H, FLAGS_W, FLAGS_KH, FLAGS_KW); |
| 175 | + runConvolution(tc::CudaMappingOptions::makeNaiveMappingOptions()); |
| 176 | +} |
| 177 | + |
| 178 | +TEST_F(Convolution, Convolution_Caffe2) { |
| 179 | + Init(FLAGS_N, FLAGS_C, FLAGS_F, FLAGS_H, FLAGS_W, FLAGS_KH, FLAGS_KW); |
| 180 | + runCaffe2Convolution(); |
| 181 | +} |
| 182 | + |
| 183 | +TEST_F(Convolution, Convolution_ATen) { |
| 184 | + Init(FLAGS_N, FLAGS_C, FLAGS_F, FLAGS_H, FLAGS_W, FLAGS_KH, FLAGS_KW); |
| 185 | + runATenConvolution(); |
| 186 | +} |
| 187 | + |
| 188 | +int main(int argc, char** argv) { |
| 189 | + ::testing::InitGoogleTest(&argc, argv); |
| 190 | + ::gflags::ParseCommandLineFlags(&argc, &argv, true); |
| 191 | + ::google::InitGoogleLogging(argv[0]); |
| 192 | + tc::aten::setAtenSeed(tc::initRandomSeed(), at::Backend::CUDA); |
| 193 | + return RUN_ALL_TESTS(); |
| 194 | +} |
0 commit comments