@@ -65,101 +65,59 @@ ur_result_t urCalculateNumChannels(ur_image_channel_order_t order,
65
65
// / format if not nullptr.
66
66
// / /param return_pixel_size_bytes will be set to the pixel
67
67
// / byte size if not nullptr.
68
+ // / /param return_normalized_dtype_flag will be set if the
69
+ // / data type is normalized if not nullptr.
68
70
ur_result_t
69
71
urToCudaImageChannelFormat (ur_image_channel_type_t image_channel_type,
70
72
ur_image_channel_order_t image_channel_order,
71
73
CUarray_format *return_cuda_format,
72
- size_t *return_pixel_size_bytes) {
74
+ size_t *return_pixel_size_bytes,
75
+ unsigned int *return_normalized_dtype_flag) {
73
76
74
- CUarray_format cuda_format;
77
+ CUarray_format cuda_format = CU_AD_FORMAT_UNSIGNED_INT8 ;
75
78
size_t pixel_size_bytes = 0 ;
76
79
unsigned int num_channels = 0 ;
80
+ unsigned int normalized_dtype_flag = 0 ;
77
81
UR_CHECK_ERROR (urCalculateNumChannels (image_channel_order, &num_channels));
78
82
79
83
switch (image_channel_type) {
80
- #define CASE (FROM, TO, SIZE ) \
84
+ #define CASE (FROM, TO, SIZE, NORM ) \
81
85
case FROM: { \
82
86
cuda_format = TO; \
83
87
pixel_size_bytes = SIZE * num_channels; \
88
+ normalized_dtype_flag = NORM; \
84
89
break ; \
85
90
}
86
91
87
- CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8, CU_AD_FORMAT_UNSIGNED_INT8, 1 )
88
- CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8, CU_AD_FORMAT_SIGNED_INT8, 1 )
89
- CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16, CU_AD_FORMAT_UNSIGNED_INT16, 2 )
90
- CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16, CU_AD_FORMAT_SIGNED_INT16, 2 )
91
- CASE (UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT, CU_AD_FORMAT_HALF, 2 )
92
- CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32, CU_AD_FORMAT_UNSIGNED_INT32, 4 )
93
- CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32, CU_AD_FORMAT_SIGNED_INT32, 4 )
94
- CASE (UR_IMAGE_CHANNEL_TYPE_FLOAT, CU_AD_FORMAT_FLOAT, 4 )
92
+ CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8, CU_AD_FORMAT_UNSIGNED_INT8, 1 , 0 )
93
+ CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8, CU_AD_FORMAT_SIGNED_INT8, 1 , 0 )
94
+ CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16, CU_AD_FORMAT_UNSIGNED_INT16, 2 ,
95
+ 0 )
96
+ CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16, CU_AD_FORMAT_SIGNED_INT16, 2 , 0 )
97
+ CASE (UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT, CU_AD_FORMAT_HALF, 2 , 0 )
98
+ CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32, CU_AD_FORMAT_UNSIGNED_INT32, 4 ,
99
+ 0 )
100
+ CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32, CU_AD_FORMAT_SIGNED_INT32, 4 , 0 )
101
+ CASE (UR_IMAGE_CHANNEL_TYPE_FLOAT, CU_AD_FORMAT_FLOAT, 4 , 0 )
102
+ CASE (UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, CU_AD_FORMAT_UNSIGNED_INT8, 1 , 1 )
103
+ CASE (UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, CU_AD_FORMAT_SIGNED_INT8, 1 , 1 )
104
+ CASE (UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, CU_AD_FORMAT_UNSIGNED_INT16, 2 , 1 )
105
+ CASE (UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, CU_AD_FORMAT_SIGNED_INT16, 2 , 1 )
95
106
96
107
#undef CASE
97
108
default :
98
109
break ;
99
110
}
100
111
101
- // These new formats were brought in in CUDA 11.5
102
- #if CUDA_VERSION >= 11050
103
-
104
- // If none of the above channel types were passed, check those below
105
- if (pixel_size_bytes == 0 ) {
106
-
107
- // We can't use a switch statement here because these single
108
- // UR_IMAGE_CHANNEL_TYPEs can correspond to multiple [u/s]norm CU_AD_FORMATs
109
- // depending on the number of channels. We use a std::map instead to
110
- // retrieve the correct CUDA format
111
-
112
- // map < <channel type, num channels> , <CUDA format, data type byte size> >
113
- const std::map<std::pair<ur_image_channel_type_t , uint32_t >,
114
- std::pair<CUarray_format, uint32_t >>
115
- norm_channel_type_map{
116
- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 1 },
117
- {CU_AD_FORMAT_UNORM_INT8X1, 1 }},
118
- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 2 },
119
- {CU_AD_FORMAT_UNORM_INT8X2, 2 }},
120
- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 4 },
121
- {CU_AD_FORMAT_UNORM_INT8X4, 4 }},
122
-
123
- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 1 },
124
- {CU_AD_FORMAT_SNORM_INT8X1, 1 }},
125
- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 2 },
126
- {CU_AD_FORMAT_SNORM_INT8X2, 2 }},
127
- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 4 },
128
- {CU_AD_FORMAT_SNORM_INT8X4, 4 }},
129
-
130
- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 1 },
131
- {CU_AD_FORMAT_UNORM_INT16X1, 2 }},
132
- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 2 },
133
- {CU_AD_FORMAT_UNORM_INT16X2, 4 }},
134
- {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 4 },
135
- {CU_AD_FORMAT_UNORM_INT16X4, 8 }},
136
-
137
- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 1 },
138
- {CU_AD_FORMAT_SNORM_INT16X1, 2 }},
139
- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 2 },
140
- {CU_AD_FORMAT_SNORM_INT16X2, 4 }},
141
- {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 4 },
142
- {CU_AD_FORMAT_SNORM_INT16X4, 8 }},
143
- };
144
-
145
- try {
146
- auto cuda_format_and_size = norm_channel_type_map.at (
147
- std::make_pair (image_channel_type, num_channels));
148
- cuda_format = cuda_format_and_size.first ;
149
- pixel_size_bytes = cuda_format_and_size.second ;
150
- } catch (const std::out_of_range &) {
151
- return UR_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT;
152
- }
153
- }
154
-
155
- #endif
156
-
157
112
if (return_cuda_format) {
158
113
*return_cuda_format = cuda_format;
159
114
}
160
115
if (return_pixel_size_bytes) {
161
116
*return_pixel_size_bytes = pixel_size_bytes;
162
117
}
118
+ if (return_normalized_dtype_flag) {
119
+ *return_normalized_dtype_flag = normalized_dtype_flag;
120
+ }
163
121
return UR_RESULT_SUCCESS;
164
122
}
165
123
@@ -189,53 +147,17 @@ cudaToUrImageChannelFormat(CUarray_format cuda_format,
189
147
UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT);
190
148
CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_FLOAT,
191
149
UR_IMAGE_CHANNEL_TYPE_FLOAT);
192
- #if CUDA_VERSION >= 11050
193
-
194
- // Note that the CUDA UNORM and SNORM formats also encode the number of
195
- // channels.
196
- // Since UR does not encode this, we map different CUDA formats to the same
197
- // UR channel type.
198
- // Since this function is only called from `urBindlessImagesImageGetInfoExp`
199
- // which has access to `CUDA_ARRAY3D_DESCRIPTOR`, we can determine the
200
- // number of channels in the calling function.
201
-
202
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT8X1,
203
- UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
204
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT8X2,
205
- UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
206
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT8X4,
207
- UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
208
-
209
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT16X1,
210
- UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
211
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT16X2,
212
- UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
213
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT16X4,
214
- UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
215
-
216
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT8X1,
217
- UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
218
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT8X2,
219
- UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
220
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT8X4,
221
- UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
222
-
223
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT16X1,
224
- UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
225
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT16X2,
226
- UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
227
- CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT16X4,
228
- UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
229
- #endif
230
- #undef MAP
231
150
default :
151
+ // Default invalid enum
152
+ *return_image_channel_type = UR_IMAGE_CHANNEL_TYPE_FORCE_UINT32;
232
153
return UR_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT;
233
154
}
234
155
}
235
156
236
157
ur_result_t urTextureCreate (ur_sampler_handle_t hSampler,
237
158
const ur_image_desc_t *pImageDesc,
238
159
const CUDA_RESOURCE_DESC &ResourceDesc,
160
+ const unsigned int normalized_dtype_flag,
239
161
ur_exp_image_native_handle_t *phRetImage) {
240
162
241
163
try {
@@ -306,8 +228,9 @@ ur_result_t urTextureCreate(ur_sampler_handle_t hSampler,
306
228
307
229
// CUDA default promotes 8-bit and 16-bit integers to float between [0,1]
308
230
// This flag prevents this behaviour.
309
- ImageTexDesc.flags |= CU_TRSF_READ_AS_INTEGER;
310
-
231
+ if (!normalized_dtype_flag) {
232
+ ImageTexDesc.flags |= CU_TRSF_READ_AS_INTEGER;
233
+ }
311
234
// Cubemap attributes
312
235
ur_exp_sampler_cubemap_filter_mode_t CubemapFilterModeProp =
313
236
hSampler->getCubemapFilterMode ();
@@ -413,9 +336,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageAllocateExp(
413
336
UR_CHECK_ERROR (urCalculateNumChannels (pImageFormat->channelOrder ,
414
337
&array_desc.NumChannels ));
415
338
416
- UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat-> channelType ,
417
- pImageFormat->channelOrder ,
418
- &array_desc. Format , nullptr ));
339
+ UR_CHECK_ERROR (urToCudaImageChannelFormat (
340
+ pImageFormat-> channelType , pImageFormat->channelOrder , &array_desc. Format ,
341
+ nullptr , nullptr ));
419
342
420
343
array_desc.Flags = 0 ; // No flags required
421
344
array_desc.Width = pImageDesc->width ;
@@ -534,7 +457,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp(
534
457
size_t PixelSizeBytes;
535
458
UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType ,
536
459
pImageFormat->channelOrder , &format,
537
- &PixelSizeBytes));
460
+ &PixelSizeBytes, nullptr ));
538
461
539
462
try {
540
463
@@ -579,9 +502,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp(
579
502
580
503
CUarray_format format;
581
504
size_t PixelSizeBytes;
582
- UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType ,
583
- pImageFormat->channelOrder , &format,
584
- &PixelSizeBytes));
505
+ unsigned int normalized_dtype_flag;
506
+ UR_CHECK_ERROR (urToCudaImageChannelFormat (
507
+ pImageFormat->channelType , pImageFormat->channelOrder , &format,
508
+ &PixelSizeBytes, &normalized_dtype_flag));
585
509
586
510
try {
587
511
CUDA_RESOURCE_DESC image_res_desc = {};
@@ -630,8 +554,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp(
630
554
return UR_RESULT_ERROR_INVALID_VALUE;
631
555
}
632
556
633
- UR_CHECK_ERROR (
634
- urTextureCreate (hSampler, pImageDesc, image_res_desc , phImage));
557
+ UR_CHECK_ERROR (urTextureCreate (hSampler, pImageDesc, image_res_desc,
558
+ normalized_dtype_flag , phImage));
635
559
636
560
} catch (ur_result_t Err) {
637
561
return Err;
@@ -671,7 +595,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageCopyExp(
671
595
// later.
672
596
UR_CHECK_ERROR (urToCudaImageChannelFormat (pSrcImageFormat->channelType ,
673
597
pSrcImageFormat->channelOrder ,
674
- nullptr , &PixelSizeBytes));
598
+ nullptr , &PixelSizeBytes, nullptr ));
675
599
676
600
try {
677
601
ScopedContext Active (hQueue->getDevice ());
@@ -1150,8 +1074,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesMapExternalArrayExp(
1150
1074
urCalculateNumChannels (pImageFormat->channelOrder , &NumChannels));
1151
1075
1152
1076
CUarray_format format;
1153
- UR_CHECK_ERROR (urToCudaImageChannelFormat (
1154
- pImageFormat->channelType , pImageFormat->channelOrder , &format, nullptr ));
1077
+ UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType ,
1078
+ pImageFormat->channelOrder , &format,
1079
+ nullptr , nullptr ));
1155
1080
1156
1081
try {
1157
1082
ScopedContext Active (hDevice);
0 commit comments