@@ -51,6 +51,48 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
51
51
return UR_RESULT_SUCCESS;
52
52
}
53
53
54
+ ur_result_t urEnqueueUSMFill2DFallback (ur_queue_handle_t hQueue, void *pMem,
55
+ size_t pitch, size_t patternSize,
56
+ const void *pPattern, size_t width,
57
+ size_t height,
58
+ uint32_t numEventsInWaitList,
59
+ const ur_event_handle_t *phEventWaitList,
60
+ ur_event_handle_t *phEvent) {
61
+ ur_result_t Result = getContext ()->urDdiTable .Enqueue .pfnUSMFill2D (
62
+ hQueue, pMem, pitch, patternSize, pPattern, width, height,
63
+ numEventsInWaitList, phEventWaitList, phEvent);
64
+ if (Result == UR_RESULT_SUCCESS ||
65
+ Result != UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
66
+ return Result;
67
+ }
68
+
69
+ // fallback code
70
+ auto pfnUSMFill = getContext ()->urDdiTable .Enqueue .pfnUSMFill ;
71
+
72
+ std::vector<ur_event_handle_t > WaitEvents (numEventsInWaitList);
73
+
74
+ for (size_t HeightIndex = 0 ; HeightIndex < height; HeightIndex++) {
75
+ ur_event_handle_t Event = nullptr ;
76
+
77
+ UR_CALL (pfnUSMFill (hQueue, (void *)((char *)pMem + pitch * HeightIndex),
78
+ patternSize, pPattern, width, WaitEvents.size (),
79
+ WaitEvents.data (), &Event));
80
+
81
+ WaitEvents.push_back (Event);
82
+ }
83
+
84
+ if (phEvent) {
85
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnEventsWait (
86
+ hQueue, WaitEvents.size (), WaitEvents.data (), phEvent));
87
+ }
88
+
89
+ for (const auto Event : WaitEvents) {
90
+ UR_CALL (getContext ()->urDdiTable .Event .pfnRelease (Event));
91
+ }
92
+
93
+ return UR_RESULT_SUCCESS;
94
+ }
95
+
54
96
} // namespace
55
97
56
98
// /////////////////////////////////////////////////////////////////////////////
@@ -1726,11 +1768,6 @@ ur_result_t urEnqueueUSMMemcpy2D(
1726
1768
{
1727
1769
auto pfnUSMMemcpy = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy ;
1728
1770
1729
- std::vector<ur_event_handle_t > WaitEvents (numEventsInWaitList);
1730
- for (uint32_t i = 0 ; i < numEventsInWaitList; i++) {
1731
- WaitEvents[i] = phEventWaitList[i];
1732
- }
1733
-
1734
1771
for (size_t HeightIndex = 0 ; HeightIndex < height; HeightIndex++) {
1735
1772
ur_event_handle_t Event = nullptr ;
1736
1773
const auto DstOrigin =
@@ -1742,8 +1779,8 @@ ur_result_t urEnqueueUSMMemcpy2D(
1742
1779
width - 1 ) +
1743
1780
MSAN_ORIGIN_GRANULARITY;
1744
1781
pfnUSMMemcpy (hQueue, false , (void *)DstOrigin, (void *)SrcOrigin,
1745
- SrcOriginEnd - SrcOrigin, WaitEvents. size () ,
1746
- WaitEvents. data () , &Event);
1782
+ SrcOriginEnd - SrcOrigin, numEventsInWaitList ,
1783
+ phEventWaitList , &Event);
1747
1784
Events.push_back (Event);
1748
1785
}
1749
1786
}
@@ -1756,9 +1793,9 @@ ur_result_t urEnqueueUSMMemcpy2D(
1756
1793
const auto DstShadow = DstDI->Shadow ->MemToShadow ((uptr)pDst);
1757
1794
const char Pattern = 0 ;
1758
1795
ur_event_handle_t Event = nullptr ;
1759
- UR_CALL (getContext ()-> urDdiTable . Enqueue . pfnUSMFill2D (
1760
- hQueue, ( void *)DstShadow, dstPitch, 1 , &Pattern, width, height, 0 ,
1761
- nullptr , &Event));
1796
+ UR_CALL (urEnqueueUSMFill2DFallback (hQueue, ( void *)DstShadow, dstPitch, 1 ,
1797
+ &Pattern, width, height, 0 , nullptr ,
1798
+ &Event));
1762
1799
Events.push_back (Event);
1763
1800
}
1764
1801
@@ -1767,7 +1804,7 @@ ur_result_t urEnqueueUSMMemcpy2D(
1767
1804
hQueue, Events.size (), Events.data (), phEvent));
1768
1805
}
1769
1806
1770
- for (const auto & E : Events)
1807
+ for (const auto E : Events)
1771
1808
UR_CALL (getContext ()->urDdiTable .Event .pfnRelease (E));
1772
1809
1773
1810
return UR_RESULT_SUCCESS;
0 commit comments