7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include < cuda.h>
10
+ #include < map>
11
+ #include < utility>
10
12
11
13
#include " common.hpp"
12
14
#include " context.hpp"
@@ -52,30 +54,33 @@ ur_result_t urCalculateNumChannels(ur_image_channel_order_t order,
52
54
// / Convert a UR image format to a CUDA image format and
53
55
// / get the pixel size in bytes.
54
56
// / /param image_channel_type is the ur_image_channel_type_t.
57
+ // / /param image_channel_order is the ur_image_channel_order_t.
58
+ // / this is used for normalized channel formats, as CUDA
59
+ // / combines the channel format and order for normalized
60
+ // / channel types.
55
61
// / /param return_cuda_format will be set to the equivalent cuda
56
- // / format if not nullptr.
57
- // / /param return_pixel_types_size_bytes will be set to the pixel
58
- // / byte size if not nullptr.
62
+ // / format if not nullptr.
63
+ // / /param return_pixel_size_bytes will be set to the pixel
64
+ // / byte size if not nullptr.
59
65
ur_result_t
60
66
urToCudaImageChannelFormat (ur_image_channel_type_t image_channel_type,
67
+ ur_image_channel_order_t image_channel_order,
61
68
CUarray_format *return_cuda_format,
62
- size_t *return_pixel_types_size_bytes ) {
69
+ size_t *return_pixel_size_bytes ) {
63
70
64
71
CUarray_format cuda_format;
65
- size_t PixelTypeSizeBytes;
72
+ size_t pixel_size_bytes = 0 ;
73
+ unsigned int num_channels = 0 ;
74
+ UR_CHECK_ERROR (urCalculateNumChannels (image_channel_order, &num_channels));
66
75
67
76
switch (image_channel_type) {
68
77
#define CASE (FROM, TO, SIZE ) \
69
78
case FROM: { \
70
79
cuda_format = TO; \
71
- PixelTypeSizeBytes = SIZE; \
80
+ pixel_size_bytes = SIZE * num_channels; \
72
81
break ; \
73
82
}
74
- // These new formats were brought in in CUDA 11.5
75
- #if CUDA_VERSION >= 11050
76
- CASE (UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, CU_AD_FORMAT_UNORM_INT8X1, 1 )
77
- CASE (UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, CU_AD_FORMAT_UNORM_INT16X1, 2 )
78
- #endif
83
+
79
84
CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8, CU_AD_FORMAT_UNSIGNED_INT8, 1 )
80
85
CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8, CU_AD_FORMAT_SIGNED_INT8, 1 )
81
86
CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16, CU_AD_FORMAT_UNSIGNED_INT16, 2 )
@@ -84,16 +89,73 @@ urToCudaImageChannelFormat(ur_image_channel_type_t image_channel_type,
84
89
CASE (UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32, CU_AD_FORMAT_UNSIGNED_INT32, 4 )
85
90
CASE (UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32, CU_AD_FORMAT_SIGNED_INT32, 4 )
86
91
CASE (UR_IMAGE_CHANNEL_TYPE_FLOAT, CU_AD_FORMAT_FLOAT, 4 )
92
+
87
93
#undef CASE
88
94
default :
89
- return UR_RESULT_ERROR_IMAGE_FORMAT_NOT_SUPPORTED;
95
+ break ;
96
+ }
97
+
98
+ // These new formats were brought in in CUDA 11.5
99
+ #if CUDA_VERSION >= 11050
100
+
101
+ // If none of the above channel types were passed, check those below
102
+ if (pixel_size_bytes == 0 ) {
103
+
104
+ // We can't use a switch statement here because these single
105
+ // UR_IMAGE_CHANNEL_TYPEs can correspond to multiple [u/s]norm CU_AD_FORMATs
106
+ // depending on the number of channels. We use a std::map instead to
107
+ // retrieve the correct CUDA format
108
+
109
+ // map < <channel type, num channels> , <CUDA format, data type byte size> >
110
+ const std::map<std::pair<ur_image_channel_type_t , uint32_t >,
111
+ std::pair<CUarray_format, uint32_t >>
112
+ norm_channel_type_map{
113
+ {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 1 },
114
+ {CU_AD_FORMAT_UNORM_INT8X1, 1 }},
115
+ {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 2 },
116
+ {CU_AD_FORMAT_UNORM_INT8X2, 2 }},
117
+ {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT8, 4 },
118
+ {CU_AD_FORMAT_UNORM_INT8X4, 4 }},
119
+
120
+ {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 1 },
121
+ {CU_AD_FORMAT_SNORM_INT8X1, 1 }},
122
+ {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 2 },
123
+ {CU_AD_FORMAT_SNORM_INT8X2, 2 }},
124
+ {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT8, 4 },
125
+ {CU_AD_FORMAT_SNORM_INT8X4, 4 }},
126
+
127
+ {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 1 },
128
+ {CU_AD_FORMAT_UNORM_INT16X1, 2 }},
129
+ {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 2 },
130
+ {CU_AD_FORMAT_UNORM_INT16X2, 4 }},
131
+ {{UR_IMAGE_CHANNEL_TYPE_UNORM_INT16, 4 },
132
+ {CU_AD_FORMAT_UNORM_INT16X4, 8 }},
133
+
134
+ {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 1 },
135
+ {CU_AD_FORMAT_SNORM_INT16X1, 2 }},
136
+ {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 2 },
137
+ {CU_AD_FORMAT_SNORM_INT16X2, 4 }},
138
+ {{UR_IMAGE_CHANNEL_TYPE_SNORM_INT16, 4 },
139
+ {CU_AD_FORMAT_SNORM_INT16X4, 8 }},
140
+ };
141
+
142
+ try {
143
+ auto cuda_format_and_size = norm_channel_type_map.at (
144
+ std::make_pair (image_channel_type, num_channels));
145
+ cuda_format = cuda_format_and_size.first ;
146
+ pixel_size_bytes = cuda_format_and_size.second ;
147
+ } catch (std::out_of_range &e) {
148
+ return UR_RESULT_ERROR_IMAGE_FORMAT_NOT_SUPPORTED;
149
+ }
90
150
}
91
151
152
+ #endif
153
+
92
154
if (return_cuda_format) {
93
155
*return_cuda_format = cuda_format;
94
156
}
95
- if (return_pixel_types_size_bytes ) {
96
- *return_pixel_types_size_bytes = PixelTypeSizeBytes ;
157
+ if (return_pixel_size_bytes ) {
158
+ *return_pixel_size_bytes = pixel_size_bytes ;
97
159
}
98
160
return UR_RESULT_SUCCESS;
99
161
}
@@ -125,10 +187,42 @@ cudaToUrImageChannelFormat(CUarray_format cuda_format,
125
187
CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_FLOAT,
126
188
UR_IMAGE_CHANNEL_TYPE_FLOAT);
127
189
#if CUDA_VERSION >= 11050
190
+
191
+ // Note that the CUDA UNORM and SNORM formats also encode the number of
192
+ // channels.
193
+ // Since UR does not encode this, we map different CUDA formats to the same
194
+ // UR channel type.
195
+ // Since this function is only called from `urBindlessImagesImageGetInfoExp`
196
+ // which has access to `CUDA_ARRAY3D_DESCRIPTOR`, we can determine the
197
+ // number of channels in the calling function.
198
+
128
199
CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT8X1,
129
200
UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
201
+ CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT8X2,
202
+ UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
203
+ CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT8X4,
204
+ UR_IMAGE_CHANNEL_TYPE_UNORM_INT8);
205
+
130
206
CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT16X1,
131
207
UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
208
+ CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT16X2,
209
+ UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
210
+ CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_UNORM_INT16X4,
211
+ UR_IMAGE_CHANNEL_TYPE_UNORM_INT16);
212
+
213
+ CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT8X1,
214
+ UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
215
+ CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT8X2,
216
+ UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
217
+ CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT8X4,
218
+ UR_IMAGE_CHANNEL_TYPE_SNORM_INT8);
219
+
220
+ CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT16X1,
221
+ UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
222
+ CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT16X2,
223
+ UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
224
+ CUDA_TO_UR_IMAGE_CHANNEL_TYPE (CU_AD_FORMAT_SNORM_INT16X4,
225
+ UR_IMAGE_CHANNEL_TYPE_SNORM_INT16);
132
226
#endif
133
227
#undef MAP
134
228
default :
@@ -283,6 +377,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageAllocateExp(
283
377
&array_desc.NumChannels ));
284
378
285
379
UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType ,
380
+ pImageFormat->channelOrder ,
286
381
&array_desc.Format , nullptr ));
287
382
288
383
array_desc.Flags = 0 ; // No flags required
@@ -365,9 +460,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp(
365
460
urCalculateNumChannels (pImageFormat->channelOrder , &NumChannels));
366
461
367
462
CUarray_format format;
368
- size_t PixelTypeSizeBytes;
369
- UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType , &format,
370
- &PixelTypeSizeBytes));
463
+ size_t PixelSizeBytes;
464
+ UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType ,
465
+ pImageFormat->channelOrder , &format,
466
+ &PixelSizeBytes));
371
467
372
468
try {
373
469
@@ -418,9 +514,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp(
418
514
urCalculateNumChannels (pImageFormat->channelOrder , &NumChannels));
419
515
420
516
CUarray_format format;
421
- size_t PixelTypeSizeBytes;
422
- UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType , &format,
423
- &PixelTypeSizeBytes));
517
+ size_t PixelSizeBytes;
518
+ UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType ,
519
+ pImageFormat->channelOrder , &format,
520
+ &PixelSizeBytes));
424
521
425
522
try {
426
523
CUDA_RESOURCE_DESC image_res_desc = {};
@@ -451,7 +548,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp(
451
548
image_res_desc.res .linear .format = format;
452
549
image_res_desc.res .linear .numChannels = NumChannels;
453
550
image_res_desc.res .linear .sizeInBytes =
454
- pImageDesc->width * PixelTypeSizeBytes * NumChannels ;
551
+ pImageDesc->width * PixelSizeBytes ;
455
552
} else if (pImageDesc->type == UR_MEM_TYPE_IMAGE2D) {
456
553
image_res_desc.resType = CU_RESOURCE_TYPE_PITCH2D;
457
554
image_res_desc.res .pitch2D .devPtr = (CUdeviceptr)hImageMem;
@@ -503,17 +600,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesImageCopyExp(
503
600
UR_RESULT_ERROR_INVALID_VALUE);
504
601
505
602
unsigned int NumChannels = 0 ;
506
- size_t PixelTypeSizeBytes = 0 ;
603
+ size_t PixelSizeBytes = 0 ;
507
604
508
605
UR_CHECK_ERROR (
509
606
urCalculateNumChannels (pImageFormat->channelOrder , &NumChannels));
510
607
511
608
// We need to get this now in bytes for calculating the total image size
512
609
// later.
513
- UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType , nullptr ,
514
- &PixelTypeSizeBytes));
515
-
516
- size_t PixelSizeBytes = PixelTypeSizeBytes * NumChannels;
610
+ UR_CHECK_ERROR (urToCudaImageChannelFormat (pImageFormat->channelType ,
611
+ pImageFormat->channelOrder , nullptr ,
612
+ &PixelSizeBytes));
517
613
518
614
try {
519
615
ScopedContext Active (hQueue->getContext ());
@@ -789,8 +885,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesMapExternalArrayExp(
789
885
urCalculateNumChannels (pImageFormat->channelOrder , &NumChannels));
790
886
791
887
CUarray_format format;
792
- UR_CHECK_ERROR (
793
- urToCudaImageChannelFormat ( pImageFormat->channelType , &format, nullptr ));
888
+ UR_CHECK_ERROR (urToCudaImageChannelFormat (
889
+ pImageFormat->channelType , pImageFormat-> channelOrder , &format, nullptr ));
794
890
795
891
try {
796
892
ScopedContext Active (hDevice->getContext ());
0 commit comments