|
21 | 21 |
|
22 | 22 | #include <nvbench/nvbench.cuh>
|
23 | 23 |
|
24 |
| -#include <map> |
25 |
| -#include <stdexcept> |
26 |
| -#include <tuple> |
27 |
| - |
28 |
| -inline static std::tuple<NVCVColorConversionCode, NVCVImageFormat, NVCVImageFormat> StringToFormats( |
29 |
| - const std::string &str) |
30 |
| -{ |
31 |
| - // clang-format off |
32 |
| - static const std::map<std::string, std::tuple<NVCVColorConversionCode, NVCVImageFormat, NVCVImageFormat>> codeMap{ |
33 |
| - { "RGB2BGR", {NVCV_COLOR_RGB2BGR, NVCV_IMAGE_FORMAT_RGB8, NVCV_IMAGE_FORMAT_BGR8 }}, |
34 |
| - { "RGB2RGBA", {NVCV_COLOR_RGB2RGBA, NVCV_IMAGE_FORMAT_RGB8, NVCV_IMAGE_FORMAT_RGBA8}}, |
35 |
| - { "RGBA2RGB", {NVCV_COLOR_RGBA2RGB, NVCV_IMAGE_FORMAT_RGBA8, NVCV_IMAGE_FORMAT_RGB8 }}, |
36 |
| - { "RGB2GRAY", {NVCV_COLOR_RGB2GRAY, NVCV_IMAGE_FORMAT_RGB8, NVCV_IMAGE_FORMAT_Y8 }}, |
37 |
| - { "GRAY2RGB", {NVCV_COLOR_GRAY2RGB, NVCV_IMAGE_FORMAT_Y8, NVCV_IMAGE_FORMAT_RGB8 }}, |
38 |
| - { "RGB2HSV", {NVCV_COLOR_RGB2HSV, NVCV_IMAGE_FORMAT_RGB8, NVCV_IMAGE_FORMAT_HSV8 }}, |
39 |
| - { "HSV2RGB", {NVCV_COLOR_HSV2RGB, NVCV_IMAGE_FORMAT_HSV8, NVCV_IMAGE_FORMAT_RGB8 }}, |
40 |
| - { "RGB2YUV", {NVCV_COLOR_RGB2YUV, NVCV_IMAGE_FORMAT_RGB8, NVCV_IMAGE_FORMAT_YUV8 }}, |
41 |
| - { "YUV2RGB", {NVCV_COLOR_YUV2RGB, NVCV_IMAGE_FORMAT_YUV8, NVCV_IMAGE_FORMAT_RGB8 }}, |
42 |
| - {"RGB2YUV_NV12", {NVCV_COLOR_RGB2YUV_NV12, NVCV_IMAGE_FORMAT_RGB8, NVCV_IMAGE_FORMAT_NV12 }}, |
43 |
| - {"YUV2RGB_NV12", {NVCV_COLOR_YUV2RGB_NV12, NVCV_IMAGE_FORMAT_NV12, NVCV_IMAGE_FORMAT_RGB8 }}, |
44 |
| - }; |
45 |
| - // clang-format on |
46 |
| - |
47 |
| - if (auto it = codeMap.find(str); it != codeMap.end()) |
48 |
| - { |
49 |
| - return it->second; |
50 |
| - } |
51 |
| - else |
52 |
| - { |
53 |
| - throw std::invalid_argument("Unrecognized color code"); |
54 |
| - } |
55 |
| -} |
56 |
| - |
57 |
| -template<typename BT> |
58 |
| -inline float BytesPerPixel(NVCVImageFormat imgFormat) |
59 |
| -{ |
60 |
| - switch (imgFormat) |
61 |
| - { |
62 |
| -#define CVCUDA_BYTES_PER_PIXEL_CASE(FORMAT, BYTES) \ |
63 |
| - case FORMAT: \ |
64 |
| - return BYTES * sizeof(BT) |
65 |
| - CVCUDA_BYTES_PER_PIXEL_CASE(NVCV_IMAGE_FORMAT_RGB8, 3); |
66 |
| - CVCUDA_BYTES_PER_PIXEL_CASE(NVCV_IMAGE_FORMAT_BGR8, 3); |
67 |
| - CVCUDA_BYTES_PER_PIXEL_CASE(NVCV_IMAGE_FORMAT_HSV8, 3); |
68 |
| - CVCUDA_BYTES_PER_PIXEL_CASE(NVCV_IMAGE_FORMAT_RGBA8, 4); |
69 |
| - CVCUDA_BYTES_PER_PIXEL_CASE(NVCV_IMAGE_FORMAT_YUV8, 3); |
70 |
| - CVCUDA_BYTES_PER_PIXEL_CASE(NVCV_IMAGE_FORMAT_NV12, 1.5f); |
71 |
| - CVCUDA_BYTES_PER_PIXEL_CASE(NVCV_IMAGE_FORMAT_Y8, 1); |
72 |
| -#undef CVCUDA_BYTES_PER_PIXEL_CASE |
73 |
| - default: |
74 |
| - throw std::invalid_argument("Unrecognized format"); |
75 |
| - } |
76 |
| -} |
77 |
| - |
78 |
| -// Adapted from src/util/TensorDataUtils.hpp |
79 |
| -inline static nvcv::Tensor CreateTensor(int numImages, int imgWidth, int imgHeight, const nvcv::ImageFormat &imgFormat) |
80 |
| -{ |
81 |
| - if (imgFormat == NVCV_IMAGE_FORMAT_NV12 || imgFormat == NVCV_IMAGE_FORMAT_NV12_ER |
82 |
| - || imgFormat == NVCV_IMAGE_FORMAT_NV21 || imgFormat == NVCV_IMAGE_FORMAT_NV21_ER) |
83 |
| - { |
84 |
| - int height420 = (imgHeight * 3) / 2; |
85 |
| - if (height420 % 3 != 0 || imgWidth % 2 != 0) |
86 |
| - { |
87 |
| - throw std::invalid_argument("Invalid height"); |
88 |
| - } |
89 |
| - |
90 |
| - return nvcv::Tensor(numImages, {imgWidth, height420}, nvcv::ImageFormat(NVCV_IMAGE_FORMAT_Y8)); |
91 |
| - } |
92 |
| - else |
93 |
| - { |
94 |
| - return nvcv::Tensor(numImages, {imgWidth, imgHeight}, imgFormat); |
95 |
| - } |
96 |
| -} |
97 |
| - |
98 |
| -template<typename BT> |
99 |
| -inline void CvtColor(nvbench::state &state, nvbench::type_list<BT>) |
| 24 | +template<typename T> |
| 25 | +inline void CvtColor(nvbench::state &state, nvbench::type_list<T>) |
100 | 26 | try
|
101 | 27 | {
|
102 | 28 | long3 shape = benchutils::GetShape<3>(state.get_string("shape"));
|
103 | 29 | long varShape = state.get_int64("varShape");
|
104 |
| - std::tuple<NVCVColorConversionCode, NVCVImageFormat, NVCVImageFormat> formats |
105 |
| - = StringToFormats(state.get_string("code")); |
106 | 30 |
|
107 |
| - NVCVColorConversionCode code = std::get<0>(formats); |
108 |
| - nvcv::ImageFormat inFormat{std::get<1>(formats)}; |
109 |
| - nvcv::ImageFormat outFormat{std::get<2>(formats)}; |
| 31 | + using BT = typename nvcv::cuda::BaseType<T>; |
| 32 | + |
| 33 | + int ch = nvcv::cuda::NumElements<T>; |
110 | 34 |
|
111 |
| - state.add_global_memory_reads(shape.x * shape.y * shape.z * BytesPerPixel<BT>(inFormat)); |
112 |
| - state.add_global_memory_writes(shape.x * shape.y * shape.z * BytesPerPixel<BT>(outFormat)); |
| 35 | + NVCVColorConversionCode code = ch == 3 ? NVCV_COLOR_BGR2RGB : NVCV_COLOR_BGRA2RGBA; |
| 36 | + |
| 37 | + state.add_global_memory_reads(shape.x * shape.y * shape.z * sizeof(T)); |
| 38 | + state.add_global_memory_writes(shape.x * shape.y * shape.z * sizeof(T)); |
113 | 39 |
|
114 | 40 | cvcuda::CvtColor op;
|
115 | 41 |
|
| 42 | + // clang-format off |
| 43 | + |
116 | 44 | if (varShape < 0) // negative var shape means use Tensor
|
117 | 45 | {
|
118 |
| - nvcv::Tensor src = CreateTensor(shape.x, shape.z, shape.y, inFormat); |
119 |
| - nvcv::Tensor dst = CreateTensor(shape.x, shape.z, shape.y, outFormat); |
| 46 | + nvcv::Tensor src({{shape.x, shape.y, shape.z, ch}, "NHWC"}, benchutils::GetDataType<BT>()); |
| 47 | + nvcv::Tensor dst({{shape.x, shape.y, shape.z, ch}, "NHWC"}, benchutils::GetDataType<BT>()); |
120 | 48 |
|
121 | 49 | benchutils::FillTensor<BT>(src, benchutils::RandomValues<BT>());
|
122 | 50 |
|
123 |
| - state.exec(nvbench::exec_tag::sync, |
124 |
| - [&op, &src, &dst, &code](nvbench::launch &launch) { op(launch.get_stream(), src, dst, code); }); |
| 51 | + state.exec(nvbench::exec_tag::sync, [&op, &src, &dst, &code](nvbench::launch &launch) |
| 52 | + { |
| 53 | + op(launch.get_stream(), src, dst, code); |
| 54 | + }); |
125 | 55 | }
|
126 | 56 | else // zero and positive var shape means use ImageBatchVarShape
|
127 | 57 | {
|
128 |
| - if (inFormat.chromaSubsampling() != nvcv::ChromaSubsampling::CSS_444 |
129 |
| - || outFormat.chromaSubsampling() != nvcv::ChromaSubsampling::CSS_444) |
130 |
| - { |
131 |
| - state.skip("Skipping formats that have subsampled planes for the varshape benchmark"); |
132 |
| - } |
133 |
| - |
134 |
| - std::vector<nvcv::Image> imgSrc; |
135 |
| - std::vector<nvcv::Image> imgDst; |
136 |
| - nvcv::ImageBatchVarShape src(shape.x); |
137 |
| - nvcv::ImageBatchVarShape dst(shape.x); |
138 |
| - std::vector<std::vector<uint8_t>> srcVec(shape.x); |
| 58 | + nvcv::ImageBatchVarShape src(shape.x); |
| 59 | + nvcv::ImageBatchVarShape dst(shape.x); |
139 | 60 |
|
140 |
| - auto randomValuesU8 = benchutils::RandomValues<uint8_t>(); |
| 61 | + benchutils::FillImageBatch<T>(src, long2{shape.z, shape.y}, long2{varShape, varShape}, |
| 62 | + benchutils::RandomValues<T>()); |
| 63 | + dst.pushBack(src.begin(), src.end()); |
141 | 64 |
|
142 |
| - for (int i = 0; i < shape.x; i++) |
| 65 | + state.exec(nvbench::exec_tag::sync, [&op, &src, &dst, &code](nvbench::launch &launch) |
143 | 66 | {
|
144 |
| - imgSrc.emplace_back(nvcv::Size2D{(int)shape.z, (int)shape.y}, inFormat); |
145 |
| - imgDst.emplace_back(nvcv::Size2D{(int)shape.z, (int)shape.y}, outFormat); |
146 |
| - |
147 |
| - int srcRowStride = imgSrc[i].size().w * inFormat.planePixelStrideBytes(0); |
148 |
| - int srcBufSize = imgSrc[i].size().h * srcRowStride; |
149 |
| - srcVec[i].resize(srcBufSize); |
150 |
| - for (int idx = 0; idx < srcBufSize; idx++) |
151 |
| - { |
152 |
| - srcVec[i][idx] = randomValuesU8(); |
153 |
| - } |
154 |
| - |
155 |
| - auto imgData = imgSrc[i].exportData<nvcv::ImageDataStridedCuda>(); |
156 |
| - CUDA_CHECK_ERROR(cudaMemcpy2D(imgData->plane(0).basePtr, imgData->plane(0).rowStride, srcVec[i].data(), |
157 |
| - srcRowStride, srcRowStride, imgSrc[i].size().h, cudaMemcpyHostToDevice)); |
158 |
| - } |
159 |
| - src.pushBack(imgSrc.begin(), imgSrc.end()); |
160 |
| - dst.pushBack(imgDst.begin(), imgDst.end()); |
161 |
| - |
162 |
| - state.exec(nvbench::exec_tag::sync, |
163 |
| - [&op, &src, &dst, &code](nvbench::launch &launch) { op(launch.get_stream(), src, dst, code); }); |
| 67 | + op(launch.get_stream(), src, dst, code); |
| 68 | + }); |
164 | 69 | }
|
165 | 70 | }
|
166 | 71 | catch (const std::exception &err)
|
167 | 72 | {
|
168 | 73 | state.skip(err.what());
|
169 | 74 | }
|
170 | 75 |
|
171 |
| -using BaseTypes = nvbench::type_list<uint8_t>; |
| 76 | +// clang-format on |
| 77 | + |
| 78 | +using CvtColorTypes = nvbench::type_list<uchar3, uchar4>; |
172 | 79 |
|
173 |
| -NVBENCH_BENCH_TYPES(CvtColor, NVBENCH_TYPE_AXES(BaseTypes)) |
174 |
| - .set_type_axes_names({"BaseType"}) |
175 |
| - .add_string_axis("shape", {"1x1080x1920", "64x720x1280"}) |
176 |
| - .add_string_axis("code", {"RGB2BGR", "RGB2RGBA", "RGBA2RGB", "RGB2GRAY", "GRAY2RGB", "RGB2HSV", "HSV2RGB", |
177 |
| - "RGB2YUV", "YUV2RGB", "RGB2YUV_NV12", "YUV2RGB_NV12"}) |
| 80 | +NVBENCH_BENCH_TYPES(CvtColor, NVBENCH_TYPE_AXES(CvtColorTypes)) |
| 81 | + .set_type_axes_names({"InOutDataType"}) |
| 82 | + .add_string_axis("shape", {"1x1080x1920"}) |
178 | 83 | .add_int64_axis("varShape", {-1, 0});
|
0 commit comments