Skip to content

Commit 1f13d2c

Browse files
authored
Merge pull request #2056 from Seanst98/sean/usm-normalized-fix
[CUDA][Bindless] Address USM normalized type image creation failure and functionality
2 parents 5276c53 + bcf2244 commit 1f13d2c

File tree

2 files changed

+50
-124
lines changed

2 files changed

+50
-124
lines changed

source/adapters/cuda/image.cpp

Lines changed: 45 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -65,101 +65,59 @@ ur_result_t urCalculateNumChannels(ur_image_channel_order_t order,
6565
/// format if not nullptr.
6666
/// /param return_pixel_size_bytes will be set to the pixel
6767
/// byte size if not nullptr.
68+
/// /param return_normalized_dtype_flag will be set if the
69+
/// data type is normalized if not nullptr.
6870
ur_result_t
6971
urToCudaImageChannelFormat(ur_image_channel_type_t image_channel_type,
7072
ur_image_channel_order_t image_channel_order,
7173
CUarray_format *return_cuda_format,
72-
size_t *return_pixel_size_bytes) {
74+
size_t *return_pixel_size_bytes,
75+
unsigned int *return_normalized_dtype_flag) {
7376

74-
CUarray_format cuda_format;
77+
CUarray_format cuda_format = CU_AD_FORMAT_UNSIGNED_INT8;
7578
size_t pixel_size_bytes = 0;
7679
unsigned int num_channels = 0;
80+
unsigned int normalized_dtype_flag = 0;
7781
UR_CHECK_ERROR(urCalculateNumChannels(image_channel_order, &num_channels));
7882

7983
switch (image_channel_type) {
80-
#define CASE(FROM, TO, SIZE) \
84+
#define CASE(FROM, TO, SIZE, NORM) \
8185
case FROM: { \
8286
cuda_format = TO; \
8387
pixel_size_bytes = SIZE * num_channels; \
88+
normalized_dtype_flag = NORM; \
8489
break; \
8590
}
8691

87-
CASE(UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8, CU_AD_FORMAT_UNSIGNED_INT8, 1)
88-
CASE(UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8, CU_AD_FORMAT_SIGNED_INT8, 1)
89-
CASE(UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16, CU_AD_FORMAT_UNSIGNED_INT16, 2)
90-
CASE(UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16, CU_AD_FORMAT_SIGNED_INT16, 2)
91-
CASE(UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT, CU_AD_FORMAT_HALF, 2)
92-
CASE(UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32, CU_AD_FORMAT_UNSIGNED_INT32, 4)
93-
CASE(UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32, CU_AD_FORMAT_SIGNED_INT32, 4)
94-
CASE(UR_IMAGE_CHANNEL_TYPE_FLOAT, CU_AD_FORMAT_FLOAT, 4)
92+
CASE(UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8, CU_AD_FORMAT_UNSIGNED_INT8, 1, 0)
93+
CASE(UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8, CU_AD_FORMAT_SIGNED_INT8, 1, 0)
94+
CASE(UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16, CU_AD_FORMAT_UNSIGNED_INT16, 2,
95+
0)
96+
CASE(UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16, CU_AD_FORMAT_SIGNED_INT16, 2, 0)
97+
CASE(UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT, CU_AD_FORMAT_HALF, 2, 0)
98+
CASE(UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32, CU_AD_FORMAT_UNSIGNED_INT32, 4,
99+
0)
100+
CASE(UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32, CU_AD_FORMAT_SIGNED_INT32, 4, 0)
101+
CASE(UR_IMAGE_CHANNEL_TYPE_FLOAT, CU_AD_FORMAT_FLOAT, 4, 0)
102+
CASE(UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, CU_AD_FORMAT_UNSIGNED_INT8, 1, 1)
103+
CASE(UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, CU_AD_FORMAT_SIGNED_INT8, 1, 1)
104+
CASE(UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, CU_AD_FORMAT_UNSIGNED_INT16, 2, 1)
105+
CASE(UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, CU_AD_FORMAT_SIGNED_INT16, 2, 1)
95106

96107
#undef CASE
97108
default:
98109
break;
99110
}
100111

101-
// These new formats were brought in in CUDA 11.5
102-
#if CUDA_VERSION >= 11050
103-
104-
// If none of the above channel types were passed, check those below
105-
if (pixel_size_bytes == 0) {
106-
107-
// We can't use a switch statement here because these single
108-
// UR_IMAGE_CHANNEL_TYPEs can correspond to multiple [u/s]norm CU_AD_FORMATs
109-
// depending on the number of channels. We use a std::map instead to
110-
// retrieve the correct CUDA format
111-
112-
// map < <channel type, num channels> , <CUDA format, data type byte size> >
113-
const std::map<std::pair<ur_image_channel_type_t, uint32_t>,
114-
std::pair<CUarray_format, uint32_t>>
115-
norm_channel_type_map{
116-
{{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 1},
117-
{CU_AD_FORMAT_UNORM_INT8X1, 1}},
118-
{{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 2},
119-
{CU_AD_FORMAT_UNORM_INT8X2, 2}},
120-
{{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 4},
121-
{CU_AD_FORMAT_UNORM_INT8X4, 4}},
122-
123-
{{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 1},
124-
{CU_AD_FORMAT_SNORM_INT8X1, 1}},
125-
{{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 2},
126-
{CU_AD_FORMAT_SNORM_INT8X2, 2}},
127-
{{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 4},
128-
{CU_AD_FORMAT_SNORM_INT8X4, 4}},
129-
130-
{{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 1},
131-
{CU_AD_FORMAT_UNORM_INT16X1, 2}},
132-
{{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 2},
133-
{CU_AD_FORMAT_UNORM_INT16X2, 4}},
134-
{{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 4},
135-
{CU_AD_FORMAT_UNORM_INT16X4, 8}},
136-
137-
{{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 1},
138-
{CU_AD_FORMAT_SNORM_INT16X1, 2}},
139-
{{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 2},
140-
{CU_AD_FORMAT_SNORM_INT16X2, 4}},
141-
{{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 4},
142-
{CU_AD_FORMAT_SNORM_INT16X4, 8}},
143-
};
144-
145-
try {
146-
auto cuda_format_and_size = norm_channel_type_map.at(
147-
std::make_pair(image_channel_type, num_channels));
148-
cuda_format = cuda_format_and_size.first;
149-
pixel_size_bytes = cuda_format_and_size.second;
150-
} catch (const std::out_of_range &) {
151-
return UR_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT;
152-
}
153-
}
154-
155-
#endif
156-
157112
if (return_cuda_format) {
158113
*return_cuda_format = cuda_format;
159114
}
160115
if (return_pixel_size_bytes) {
161116
*return_pixel_size_bytes = pixel_size_bytes;
162117
}
118+
if (return_normalized_dtype_flag) {
119+
*return_normalized_dtype_flag = normalized_dtype_flag;
120+
}
163121
return UR_RESULT_SUCCESS;
164122
}
165123

@@ -189,53 +147,17 @@ cudaToUrImageChannelFormat(CUarray_format cuda_format,
189147
UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT);
190148
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_FLOAT,
191149
UR_IMAGE_CHANNEL_TYPE_FLOAT);
192-
#if CUDA_VERSION >= 11050
193-
194-
// Note that the CUDA UNORM and SNORM formats also encode the number of
195-
// channels.
196-
// Since UR does not encode this, we map different CUDA formats to the same
197-
// UR channel type.
198-
// Since this function is only called from `urBindlessImagesImageGetInfoExp`
199-
// which has access to `CUDA_ARRAY3D_DESCRIPTOR`, we can determine the
200-
// number of channels in the calling function.
201-
202-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_UNORM_INT8X1,
203-
UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
204-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_UNORM_INT8X2,
205-
UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
206-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_UNORM_INT8X4,
207-
UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
208-
209-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_UNORM_INT16X1,
210-
UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
211-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_UNORM_INT16X2,
212-
UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
213-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_UNORM_INT16X4,
214-
UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
215-
216-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_SNORM_INT8X1,
217-
UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
218-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_SNORM_INT8X2,
219-
UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
220-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_SNORM_INT8X4,
221-
UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
222-
223-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_SNORM_INT16X1,
224-
UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
225-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_SNORM_INT16X2,
226-
UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
227-
CUDA_TO_UR_IMAGE_CHANNEL_TYPE(CU_AD_FORMAT_SNORM_INT16X4,
228-
UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
229-
#endif
230-
#undef MAP
231150
default:
151+
// Default invalid enum
152+
*return_image_channel_type = UR_IMAGE_CHANNEL_TYPE_FORCE_UINT32;
232153
return UR_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT;
233154
}
234155
}
235156

236157
ur_result_t urTextureCreate(ur_sampler_handle_t hSampler,
237158
const ur_image_desc_t *pImageDesc,
238159
const CUDA_RESOURCE_DESC &ResourceDesc,
160+
const unsigned int normalized_dtype_flag,
239161
ur_exp_image_native_handle_t *phRetImage) {
240162

241163
try {
@@ -306,8 +228,9 @@ ur_result_t urTextureCreate(ur_sampler_handle_t hSampler,
306228

307229
// CUDA default promotes 8-bit and 16-bit integers to float between [0,1]
308230
// This flag prevents this behaviour.
309-
ImageTexDesc.flags |= CU_TRSF_READ_AS_INTEGER;
310-
231+
if (!normalized_dtype_flag) {
232+
ImageTexDesc.flags |= CU_TRSF_READ_AS_INTEGER;
233+
}
311234
// Cubemap attributes
312235
ur_exp_sampler_cubemap_filter_mode_t CubemapFilterModeProp =
313236
hSampler->getCubemapFilterMode();
@@ -413,9 +336,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageAllocateExp(
413336
UR_CHECK_ERROR(urCalculateNumChannels(pImageFormat->channelOrder,
414337
&array_desc.NumChannels));
415338

416-
UR_CHECK_ERROR(urToCudaImageChannelFormat(pImageFormat->channelType,
417-
pImageFormat->channelOrder,
418-
&array_desc.Format, nullptr));
339+
UR_CHECK_ERROR(urToCudaImageChannelFormat(
340+
pImageFormat->channelType, pImageFormat->channelOrder, &array_desc.Format,
341+
nullptr, nullptr));
419342

420343
array_desc.Flags = 0; // No flags required
421344
array_desc.Width = pImageDesc->width;
@@ -534,7 +457,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp(
534457
size_t PixelSizeBytes;
535458
UR_CHECK_ERROR(urToCudaImageChannelFormat(pImageFormat->channelType,
536459
pImageFormat->channelOrder, &format,
537-
&PixelSizeBytes));
460+
&PixelSizeBytes, nullptr));
538461

539462
try {
540463

@@ -579,9 +502,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp(
579502

580503
CUarray_format format;
581504
size_t PixelSizeBytes;
582-
UR_CHECK_ERROR(urToCudaImageChannelFormat(pImageFormat->channelType,
583-
pImageFormat->channelOrder, &format,
584-
&PixelSizeBytes));
505+
unsigned int normalized_dtype_flag;
506+
UR_CHECK_ERROR(urToCudaImageChannelFormat(
507+
pImageFormat->channelType, pImageFormat->channelOrder, &format,
508+
&PixelSizeBytes, &normalized_dtype_flag));
585509

586510
try {
587511
CUDA_RESOURCE_DESC image_res_desc = {};
@@ -630,8 +554,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp(
630554
return UR_RESULT_ERROR_INVALID_VALUE;
631555
}
632556

633-
UR_CHECK_ERROR(
634-
urTextureCreate(hSampler, pImageDesc, image_res_desc, phImage));
557+
UR_CHECK_ERROR(urTextureCreate(hSampler, pImageDesc, image_res_desc,
558+
normalized_dtype_flag, phImage));
635559

636560
} catch (ur_result_t Err) {
637561
return Err;
@@ -671,7 +595,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageCopyExp(
671595
// later.
672596
UR_CHECK_ERROR(urToCudaImageChannelFormat(pSrcImageFormat->channelType,
673597
pSrcImageFormat->channelOrder,
674-
nullptr, &PixelSizeBytes));
598+
nullptr, &PixelSizeBytes, nullptr));
675599

676600
try {
677601
ScopedContext Active(hQueue->getDevice());
@@ -1150,8 +1074,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesMapExternalArrayExp(
11501074
urCalculateNumChannels(pImageFormat->channelOrder, &NumChannels));
11511075

11521076
CUarray_format format;
1153-
UR_CHECK_ERROR(urToCudaImageChannelFormat(
1154-
pImageFormat->channelType, pImageFormat->channelOrder, &format, nullptr));
1077+
UR_CHECK_ERROR(urToCudaImageChannelFormat(pImageFormat->channelType,
1078+
pImageFormat->channelOrder, &format,
1079+
nullptr, nullptr));
11551080

11561081
try {
11571082
ScopedContext Active(hDevice);

source/adapters/cuda/image.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ ur_result_t
2121
urToCudaImageChannelFormat(ur_image_channel_type_t image_channel_type,
2222
ur_image_channel_order_t image_channel_order,
2323
CUarray_format *return_cuda_format,
24-
size_t *return_pixel_types_size_bytes);
24+
size_t *return_pixel_types_size_bytes,
25+
unsigned int *return_normalized_dtype_flag);
2526

2627
ur_result_t
2728
cudaToUrImageChannelFormat(CUarray_format cuda_format,
2829
ur_image_channel_type_t *return_image_channel_type);
2930

30-
ur_result_t urTextureCreate(ur_context_handle_t hContext,
31-
ur_sampler_desc_t SamplerDesc,
31+
ur_result_t urTextureCreate(ur_sampler_handle_t hSampler,
3232
const ur_image_desc_t *pImageDesc,
33-
CUDA_RESOURCE_DESC ResourceDesc,
33+
const CUDA_RESOURCE_DESC &ResourceDesc,
34+
const unsigned int normalized_dtype_flag,
3435
ur_exp_image_native_handle_t *phRetImage);

0 commit comments

Comments
 (0)