Skip to content

Commit 1be6e9c

Browse files
authored
[DeviceMSAN] Fix clean shadow is polluted (#19207)
1 parent b5b9813 commit 1be6e9c

File tree

1 file changed

+85
-52
lines changed

1 file changed

+85
-52
lines changed

libdevice/sanitizer/msan_rtl.cpp

Lines changed: 85 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@ void CopyOrigin(uptr dst, uint32_t dst_as, uptr src, uint32_t src_as,
329329
inline void CopyShadowAndOrigin(uptr dst, uint32_t dst_as, uptr src,
330330
uint32_t src_as, size_t size) {
331331
auto *shadow_dst = (__SYCL_GLOBAL__ char *)MemToShadow(dst, dst_as);
332+
if ((uptr)shadow_dst == GetMsanLaunchInfo->CleanShadow) {
333+
return;
334+
}
332335
auto *shadow_src = (__SYCL_GLOBAL__ char *)MemToShadow(src, src_as);
333336
Memcpy(shadow_dst, shadow_src, size);
334337
CopyOrigin(dst, dst_as, src, src_as, size);
@@ -344,10 +347,13 @@ static __SYCL_CONSTANT__ const char __msan_print_move_shadow[] =
344347
// FIXME: The original implemention only moves the origin of poisoned memories
345348
void MoveOrigin(uptr dst, uint32_t dst_as, uptr src, uint32_t src_as,
346349
uptr size) {
350+
auto *dst_beg = (__SYCL_GLOBAL__ char *)MemToOrigin(dst, dst_as);
351+
if ((uptr)dst_beg == GetMsanLaunchInfo->CleanShadow) {
352+
return;
353+
}
347354
auto *src_beg = (__SYCL_GLOBAL__ char *)MemToOrigin(src, src_as);
348355
auto *src_end = (__SYCL_GLOBAL__ char *)MemToOrigin(src + size - 1, src_as) +
349356
MSAN_ORIGIN_GRANULARITY;
350-
auto *dst_beg = (__SYCL_GLOBAL__ char *)MemToOrigin(dst, dst_as);
351357
Memmove(dst_beg, src_beg, src_end - src_beg);
352358
}
353359

@@ -365,9 +371,19 @@ inline void MoveShadowAndOrigin(uptr dst, uint32_t dst_as, uptr src,
365371

366372
inline void UnpoisonShadow(uptr addr, uint32_t as, size_t size) {
367373
auto *shadow_ptr = (__SYCL_GLOBAL__ char *)MemToShadow(addr, as);
374+
if ((uptr)shadow_ptr == GetMsanLaunchInfo->CleanShadow) {
375+
return;
376+
}
368377
Memset(shadow_ptr, 0, size);
369378
}
370379

380+
// Check if the current work item is the first one in the work group
381+
inline bool IsFirstWorkItemWthinWorkGroup() {
382+
return __spirv_LocalInvocationId_x() + __spirv_LocalInvocationId_y() +
383+
__spirv_LocalInvocationId_z() ==
384+
0;
385+
}
386+
371387
} // namespace
372388

373389
#define MSAN_MAYBE_WARNING(type, size) \
@@ -525,41 +541,40 @@ static __SYCL_CONSTANT__ const char __mem_set_shadow_local[] =
525541
DEVICE_EXTERN_C_NOINLINE void __msan_poison_shadow_static_local(uptr ptr,
526542
size_t size) {
527543
// Update shadow memory of local memory only on first work-item
528-
if (__spirv_LocalInvocationId_x() + __spirv_LocalInvocationId_y() +
529-
__spirv_LocalInvocationId_z() ==
530-
0) {
531-
if (!GetMsanLaunchInfo)
532-
return;
533-
534-
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg,
535-
"__msan_poison_shadow_static_local"));
536-
537-
auto shadow_address = MemToShadow(ptr, ADDRESS_SPACE_LOCAL);
538-
if (shadow_address == GetMsanLaunchInfo->CleanShadow)
539-
return;
540-
Memset((__SYCL_GLOBAL__ char *)shadow_address, 0xff, size);
544+
if (!IsFirstWorkItemWthinWorkGroup())
545+
return;
546+
547+
if (!GetMsanLaunchInfo || GetMsanLaunchInfo->LocalShadowOffset == 0)
548+
return;
549+
550+
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg,
551+
"__msan_poison_shadow_static_local"));
541552

553+
auto shadow_address = MemToShadow(ptr, ADDRESS_SPACE_LOCAL);
554+
if (shadow_address != GetMsanLaunchInfo->CleanShadow) {
555+
Memset((__SYCL_GLOBAL__ char *)shadow_address, 0xff, size);
542556
MSAN_DEBUG(__spirv_ocl_printf(__mem_set_shadow_local, shadow_address,
543557
shadow_address + size, 0xff));
544-
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end,
545-
"__msan_poison_shadow_static_local"));
546558
}
559+
560+
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end,
561+
"__msan_poison_shadow_static_local"));
547562
}
548563

549564
DEVICE_EXTERN_C_NOINLINE void __msan_unpoison_shadow_static_local(uptr ptr,
550565
size_t size) {
551566
// Update shadow memory of local memory only on first work-item
552-
if (__spirv_LocalInvocationId_x() + __spirv_LocalInvocationId_y() +
553-
__spirv_LocalInvocationId_z() ==
554-
0) {
555-
if (!GetMsanLaunchInfo || GetMsanLaunchInfo->LocalShadowOffset == 0)
556-
return;
557-
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg,
558-
"__msan_unpoison_shadow_static_local"));
559-
UnpoisonShadow(ptr, ADDRESS_SPACE_LOCAL, size);
560-
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end,
561-
"__msan_unpoison_shadow_static_local"));
562-
}
567+
if (!IsFirstWorkItemWthinWorkGroup())
568+
return;
569+
570+
if (!GetMsanLaunchInfo || GetMsanLaunchInfo->LocalShadowOffset == 0)
571+
return;
572+
573+
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg,
574+
"__msan_unpoison_shadow_static_local"));
575+
UnpoisonShadow(ptr, ADDRESS_SPACE_LOCAL, size);
576+
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end,
577+
"__msan_unpoison_shadow_static_local"));
563578
}
564579

565580
DEVICE_EXTERN_C_INLINE void __msan_barrier() {
@@ -583,7 +598,11 @@ static __SYCL_CONSTANT__ const char __msan_print_report_arg_count_incorrect[] =
583598

584599
DEVICE_EXTERN_C_NOINLINE void
585600
__msan_poison_shadow_dynamic_local(uptr ptr, uint32_t num_args) {
586-
if (!GetMsanLaunchInfo)
601+
// Update shadow memory of local memory only on first work-item
602+
if (!IsFirstWorkItemWthinWorkGroup())
603+
return;
604+
605+
if (!GetMsanLaunchInfo || GetMsanLaunchInfo->LocalShadowOffset == 0)
587606
return;
588607

589608
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg,
@@ -601,7 +620,12 @@ __msan_poison_shadow_dynamic_local(uptr ptr, uint32_t num_args) {
601620
auto *local_arg = &GetMsanLaunchInfo->LocalArgs[i];
602621
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_local_arg, i, local_arg->Size));
603622

604-
__msan_poison_shadow_static_local(args[i], local_arg->Size);
623+
auto shadow_address = MemToShadow(args[i], ADDRESS_SPACE_LOCAL);
624+
if (shadow_address != GetMsanLaunchInfo->CleanShadow) {
625+
Memset((__SYCL_GLOBAL__ char *)shadow_address, 0xff, local_arg->Size);
626+
MSAN_DEBUG(__spirv_ocl_printf(__mem_set_shadow_local, shadow_address,
627+
shadow_address + local_arg->Size, 0xff));
628+
}
605629
}
606630

607631
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end,
@@ -616,15 +640,17 @@ static __SYCL_CONSTANT__ const char __mem_unpoison_shadow_dynamic_local_end[] =
616640

617641
DEVICE_EXTERN_C_NOINLINE void
618642
__msan_unpoison_shadow_dynamic_local(uptr ptr, uint32_t num_args) {
619-
if (!GetMsanLaunchInfo)
643+
// Update shadow memory of local memory only on first work-item
644+
if (!IsFirstWorkItemWthinWorkGroup())
645+
return;
646+
647+
if (!GetMsanLaunchInfo || GetMsanLaunchInfo->LocalShadowOffset == 0)
620648
return;
621649

622650
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg,
623651
"__msan_unpoison_shadow_dynamic_local"));
624652

625653
if (num_args != GetMsanLaunchInfo->NumLocalArgs) {
626-
__spirv_ocl_printf(__msan_print_report_arg_count_incorrect, num_args,
627-
GetMsanLaunchInfo->NumLocalArgs);
628654
return;
629655
}
630656

@@ -634,7 +660,7 @@ __msan_unpoison_shadow_dynamic_local(uptr ptr, uint32_t num_args) {
634660
auto *local_arg = &GetMsanLaunchInfo->LocalArgs[i];
635661
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_local_arg, i, local_arg->Size));
636662

637-
__msan_unpoison_shadow_static_local(args[i], local_arg->Size);
663+
UnpoisonShadow(args[i], ADDRESS_SPACE_LOCAL, local_arg->Size);
638664
}
639665

640666
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end,
@@ -658,7 +684,9 @@ DEVICE_EXTERN_C_NOINLINE void __msan_poison_stack(__SYCL_PRIVATE__ void *ptr,
658684
(void *)shadow_address,
659685
(void *)(shadow_address + size), 0xff));
660686

661-
Memset((__SYCL_GLOBAL__ char *)shadow_address, 0xff, size);
687+
if (shadow_address != GetMsanLaunchInfo->CleanShadow) {
688+
Memset((__SYCL_GLOBAL__ char *)shadow_address, 0xff, size);
689+
}
662690

663691
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end, "__msan_poison_stack"));
664692
}
@@ -676,7 +704,9 @@ DEVICE_EXTERN_C_NOINLINE void __msan_unpoison_stack(__SYCL_PRIVATE__ void *ptr,
676704
(void *)shadow_address,
677705
(void *)(shadow_address + size), 0x0));
678706

679-
Memset((__SYCL_GLOBAL__ char *)shadow_address, 0, size);
707+
if (shadow_address != GetMsanLaunchInfo->CleanShadow) {
708+
Memset((__SYCL_GLOBAL__ char *)shadow_address, 0, size);
709+
}
680710

681711
MSAN_DEBUG(
682712
__spirv_ocl_printf(__msan_print_func_end, "__msan_unpoison_stack"));
@@ -713,23 +743,26 @@ __msan_unpoison_strided_copy(uptr dest, uint32_t dest_as, uptr src,
713743
"__msan_unpoison_strided_copy"));
714744

715745
uptr shadow_dest = (uptr)__msan_get_shadow(dest, dest_as);
716-
uptr shadow_src = (uptr)__msan_get_shadow(src, src_as);
717-
718-
switch (element_size) {
719-
case 1:
720-
GroupAsyncCopy<int8_t>(shadow_dest, shadow_src, counts, stride);
721-
break;
722-
case 2:
723-
GroupAsyncCopy<int16_t>(shadow_dest, shadow_src, counts, stride);
724-
break;
725-
case 4:
726-
GroupAsyncCopy<int32_t>(shadow_dest, shadow_src, counts, stride);
727-
break;
728-
case 8:
729-
GroupAsyncCopy<int64_t>(shadow_dest, shadow_src, counts, stride);
730-
break;
731-
default:
732-
__spirv_ocl_printf(__msan_print_strided_copy_unsupport_type, element_size);
746+
if (shadow_dest != GetMsanLaunchInfo->CleanShadow) {
747+
uptr shadow_src = (uptr)__msan_get_shadow(src, src_as);
748+
749+
switch (element_size) {
750+
case 1:
751+
GroupAsyncCopy<int8_t>(shadow_dest, shadow_src, counts, stride);
752+
break;
753+
case 2:
754+
GroupAsyncCopy<int16_t>(shadow_dest, shadow_src, counts, stride);
755+
break;
756+
case 4:
757+
GroupAsyncCopy<int32_t>(shadow_dest, shadow_src, counts, stride);
758+
break;
759+
case 8:
760+
GroupAsyncCopy<int64_t>(shadow_dest, shadow_src, counts, stride);
761+
break;
762+
default:
763+
__spirv_ocl_printf(__msan_print_strided_copy_unsupport_type,
764+
element_size);
765+
}
733766
}
734767

735768
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end,

0 commit comments

Comments
 (0)