Skip to content

Commit bf62e0b

Browse files
authored
[DeviceMSAN] Fix urEnqueueUSMMemcpy2D return UR_RESULT_ERROR_UNSUPPORTED_FEATURE after enabling MSAN (#19286)
urEnqueueUSMMemcpy2D return UR_RESULT_ERROR_UNSUPPORTED_FEATURE after enabling msan layer due to most adapters haven't implementing urEnqueueUSMFill2D. I added a quick fallback implementation in msan layer just for urEnqueueUSMMemcpy2D.
1 parent e74d3a7 commit bf62e0b

File tree

1 file changed

+48
-11
lines changed
  • unified-runtime/source/loader/layers/sanitizer/msan

1 file changed

+48
-11
lines changed

unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,48 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
5151
return UR_RESULT_SUCCESS;
5252
}
5353

54+
ur_result_t urEnqueueUSMFill2DFallback(ur_queue_handle_t hQueue, void *pMem,
55+
size_t pitch, size_t patternSize,
56+
const void *pPattern, size_t width,
57+
size_t height,
58+
uint32_t numEventsInWaitList,
59+
const ur_event_handle_t *phEventWaitList,
60+
ur_event_handle_t *phEvent) {
61+
ur_result_t Result = getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
62+
hQueue, pMem, pitch, patternSize, pPattern, width, height,
63+
numEventsInWaitList, phEventWaitList, phEvent);
64+
if (Result == UR_RESULT_SUCCESS ||
65+
Result != UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
66+
return Result;
67+
}
68+
69+
// fallback code
70+
auto pfnUSMFill = getContext()->urDdiTable.Enqueue.pfnUSMFill;
71+
72+
std::vector<ur_event_handle_t> WaitEvents(numEventsInWaitList);
73+
74+
for (size_t HeightIndex = 0; HeightIndex < height; HeightIndex++) {
75+
ur_event_handle_t Event = nullptr;
76+
77+
UR_CALL(pfnUSMFill(hQueue, (void *)((char *)pMem + pitch * HeightIndex),
78+
patternSize, pPattern, width, WaitEvents.size(),
79+
WaitEvents.data(), &Event));
80+
81+
WaitEvents.push_back(Event);
82+
}
83+
84+
if (phEvent) {
85+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
86+
hQueue, WaitEvents.size(), WaitEvents.data(), phEvent));
87+
}
88+
89+
for (const auto Event : WaitEvents) {
90+
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(Event));
91+
}
92+
93+
return UR_RESULT_SUCCESS;
94+
}
95+
5496
} // namespace
5597

5698
///////////////////////////////////////////////////////////////////////////////
@@ -1726,11 +1768,6 @@ ur_result_t urEnqueueUSMMemcpy2D(
17261768
{
17271769
auto pfnUSMMemcpy = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy;
17281770

1729-
std::vector<ur_event_handle_t> WaitEvents(numEventsInWaitList);
1730-
for (uint32_t i = 0; i < numEventsInWaitList; i++) {
1731-
WaitEvents[i] = phEventWaitList[i];
1732-
}
1733-
17341771
for (size_t HeightIndex = 0; HeightIndex < height; HeightIndex++) {
17351772
ur_event_handle_t Event = nullptr;
17361773
const auto DstOrigin =
@@ -1742,8 +1779,8 @@ ur_result_t urEnqueueUSMMemcpy2D(
17421779
width - 1) +
17431780
MSAN_ORIGIN_GRANULARITY;
17441781
pfnUSMMemcpy(hQueue, false, (void *)DstOrigin, (void *)SrcOrigin,
1745-
SrcOriginEnd - SrcOrigin, WaitEvents.size(),
1746-
WaitEvents.data(), &Event);
1782+
SrcOriginEnd - SrcOrigin, numEventsInWaitList,
1783+
phEventWaitList, &Event);
17471784
Events.push_back(Event);
17481785
}
17491786
}
@@ -1756,9 +1793,9 @@ ur_result_t urEnqueueUSMMemcpy2D(
17561793
const auto DstShadow = DstDI->Shadow->MemToShadow((uptr)pDst);
17571794
const char Pattern = 0;
17581795
ur_event_handle_t Event = nullptr;
1759-
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
1760-
hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 0,
1761-
nullptr, &Event));
1796+
UR_CALL(urEnqueueUSMFill2DFallback(hQueue, (void *)DstShadow, dstPitch, 1,
1797+
&Pattern, width, height, 0, nullptr,
1798+
&Event));
17621799
Events.push_back(Event);
17631800
}
17641801

@@ -1767,7 +1804,7 @@ ur_result_t urEnqueueUSMMemcpy2D(
17671804
hQueue, Events.size(), Events.data(), phEvent));
17681805
}
17691806

1770-
for (const auto &E : Events)
1807+
for (const auto E : Events)
17711808
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));
17721809

17731810
return UR_RESULT_SUCCESS;

0 commit comments

Comments
 (0)