diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp index 832c31c43b5d2..12c7cc62905c9 100644 --- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp @@ -3089,15 +3089,16 @@ struct AMDGPUGlobalHandlerTy final : public GenericGlobalHandlerTy { } // Check the size of the symbol. - if (SymbolSize != DeviceGlobal.getSize()) + if (DeviceGlobal.getSize() && SymbolSize != DeviceGlobal.getSize()) return Plugin::error( ErrorCode::INVALID_BINARY, "failed to load global '%s' due to size mismatch (%zu != %zu)", DeviceGlobal.getName().data(), SymbolSize, (size_t)DeviceGlobal.getSize()); - // Store the symbol address on the device global metadata. + // Store the symbol address and size on the device global metadata. DeviceGlobal.setPtr(reinterpret_cast(SymbolAddr)); + DeviceGlobal.setSize(SymbolSize); return Plugin::success(); } diff --git a/offload/plugins-nextgen/common/include/GlobalHandler.h b/offload/plugins-nextgen/common/include/GlobalHandler.h index 5d6109df49da5..af7dac66ca85d 100644 --- a/offload/plugins-nextgen/common/include/GlobalHandler.h +++ b/offload/plugins-nextgen/common/include/GlobalHandler.h @@ -37,6 +37,8 @@ using namespace llvm::object; /// Common abstraction for globals that live on the host and device. /// It simply encapsulates the symbol name, symbol size, and symbol address /// (which might be host or device depending on the context). +/// Both size and address may be absent (signified by 0/nullptr), and can be +/// populated with getGlobalMetadataFromDevice/Image. class GlobalTy { // NOTE: Maybe we can have a pointer to the offload entry name instead of // holding a private copy of the name as a std::string. @@ -45,7 +47,7 @@ class GlobalTy { void *Ptr; public: - GlobalTy(const std::string &Name, uint32_t Size, void *Ptr = nullptr) + GlobalTy(const std::string &Name, uint32_t Size = 0, void *Ptr = nullptr) : Name(Name), Size(Size), Ptr(Ptr) {} const std::string &getName() const { return Name; } @@ -139,8 +141,11 @@ class GenericGlobalHandlerTy { bool isSymbolInImage(GenericDeviceTy &Device, DeviceImageTy &Image, StringRef SymName); - /// Get the address and size of a global in the image. Address and size are - /// return in \p ImageGlobal, the global name is passed in \p ImageGlobal. + /// Get the address and size of a global in the image. Address is + /// returned in \p ImageGlobal and the global name is passed in \p + /// ImageGlobal. If no size is present in \p ImageGlobal, then the size of the + /// global will be stored there. If it is present, it will be validated + /// against the real size of the global. Error getGlobalMetadataFromImage(GenericDeviceTy &Device, DeviceImageTy &Image, GlobalTy &ImageGlobal); @@ -149,9 +154,11 @@ class GenericGlobalHandlerTy { Error readGlobalFromImage(GenericDeviceTy &Device, DeviceImageTy &Image, const GlobalTy &HostGlobal); - /// Get the address and size of a global from the device. Address is return in - /// \p DeviceGlobal, the global name and expected size are passed in - /// \p DeviceGlobal. + /// Get the address and size of a global from the device. Address is + /// returned in \p ImageGlobal and the global name is passed in \p + /// ImageGlobal. If no size is present in \p ImageGlobal, then the size of the + /// global will be stored there. If it is present, it will be validated + /// against the real size of the global. virtual Error getGlobalMetadataFromDevice(GenericDeviceTy &Device, DeviceImageTy &Image, GlobalTy &DeviceGlobal) = 0; diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp index 53089df2d0f0d..15193de6ae430 100644 --- a/offload/plugins-nextgen/cuda/src/rtl.cpp +++ b/offload/plugins-nextgen/cuda/src/rtl.cpp @@ -1355,13 +1355,15 @@ class CUDAGlobalHandlerTy final : public GenericGlobalHandlerTy { GlobalName)) return Err; - if (CUSize != DeviceGlobal.getSize()) + if (DeviceGlobal.getSize() && CUSize != DeviceGlobal.getSize()) return Plugin::error( ErrorCode::INVALID_BINARY, "failed to load global '%s' due to size mismatch (%zu != %zu)", GlobalName, CUSize, (size_t)DeviceGlobal.getSize()); DeviceGlobal.setPtr(reinterpret_cast(CUPtr)); + DeviceGlobal.setSize(CUSize); + return Plugin::success(); } };