@@ -891,6 +891,7 @@ const Elf64_Sym *elf_lookup(Elf *elf, char *base, Elf64_Shdr *section_hash,
891
891
typedef struct {
892
892
void *addr = nullptr ;
893
893
uint32_t size = UINT32_MAX;
894
+ uint32_t sh_type = SHT_NULL;
894
895
} symbol_info;
895
896
896
897
int get_symbol_info_without_loading (Elf *elf, char *base, const char *symname,
@@ -913,8 +914,23 @@ int get_symbol_info_without_loading(Elf *elf, char *base, const char *symname,
913
914
return 1 ;
914
915
}
915
916
916
- res->size = static_cast <uint32_t >(sym->st_size );
917
+ if (sym->st_shndx == SHN_UNDEF) {
918
+ return 1 ;
919
+ }
920
+
921
+ Elf_Scn *section = elf_getscn (elf, sym->st_shndx );
922
+ if (!section) {
923
+ return 1 ;
924
+ }
925
+
926
+ Elf64_Shdr *header = elf64_getshdr (section);
927
+ if (!header) {
928
+ return 1 ;
929
+ }
930
+
917
931
res->addr = sym->st_value + base;
932
+ res->size = static_cast <uint32_t >(sym->st_size );
933
+ res->sh_type = header->sh_type ;
918
934
return 0 ;
919
935
}
920
936
@@ -992,6 +1008,99 @@ __tgt_target_table *__tgt_rtl_load_binary(int32_t device_id,
992
1008
return res;
993
1009
}
994
1010
1011
+ struct device_environment {
1012
+ // initialise an omptarget_device_environmentTy in the deviceRTL
1013
+ // patches around differences in the deviceRTL between trunk, aomp,
1014
+ // rocmcc. Over time these differences will tend to zero and this class
1015
+ // simplified.
1016
+ // Symbol may be in .data or .bss, and may be missing fields:
1017
+ // - aomp has debug_level, num_devices, device_num
1018
+ // - trunk has debug_level
1019
+ // - under review in trunk is debug_level, device_num
1020
+ // - rocmcc matches aomp, patch to swap num_devices and device_num
1021
+
1022
+ // If the symbol is in .data (aomp, rocm) it can be written directly.
1023
+ // If it is in .bss, we must wait for it to be allocated space on the
1024
+ // gpu (trunk) and initialize after loading.
1025
+ const char *sym () { return " omptarget_device_environment" ; }
1026
+
1027
+ omptarget_device_environmentTy host_device_env;
1028
+ symbol_info si;
1029
+ bool valid = false ;
1030
+
1031
+ __tgt_device_image *image;
1032
+ const size_t img_size;
1033
+
1034
+ device_environment (int device_id, int number_devices,
1035
+ __tgt_device_image *image, const size_t img_size)
1036
+ : image(image), img_size(img_size) {
1037
+
1038
+ host_device_env.num_devices = number_devices;
1039
+ host_device_env.device_num = device_id;
1040
+ host_device_env.debug_level = 0 ;
1041
+ #ifdef OMPTARGET_DEBUG
1042
+ if (char *envStr = getenv (" LIBOMPTARGET_DEVICE_RTL_DEBUG" )) {
1043
+ host_device_env.debug_level = std::stoi (envStr);
1044
+ }
1045
+ #endif
1046
+
1047
+ int rc = get_symbol_info_without_loading ((char *)image->ImageStart ,
1048
+ img_size, sym (), &si);
1049
+ if (rc != 0 ) {
1050
+ DP (" Finding global device environment '%s' - symbol missing.\n " , sym ());
1051
+ return ;
1052
+ }
1053
+
1054
+ if (si.size > sizeof (host_device_env)) {
1055
+ DP (" Symbol '%s' has size %u, expected at most %zu.\n " , sym (), si.size ,
1056
+ sizeof (host_device_env));
1057
+ return ;
1058
+ }
1059
+
1060
+ valid = true ;
1061
+ }
1062
+
1063
+ bool in_image () { return si.sh_type != SHT_NOBITS; }
1064
+
1065
+ atmi_status_t before_loading (void *data, size_t size) {
1066
+ assert (valid);
1067
+ if (in_image ()) {
1068
+ DP (" Setting global device environment before load (%u bytes)\n " , si.size );
1069
+ uint64_t offset = (char *)si.addr - (char *)image->ImageStart ;
1070
+ void *pos = (char *)data + offset;
1071
+ memcpy (pos, &host_device_env, si.size );
1072
+ }
1073
+ return ATMI_STATUS_SUCCESS;
1074
+ }
1075
+
1076
+ atmi_status_t after_loading () {
1077
+ assert (valid);
1078
+ if (!in_image ()) {
1079
+ DP (" Setting global device environment after load (%u bytes)\n " , si.size );
1080
+ int device_id = host_device_env.device_num ;
1081
+
1082
+ void *state_ptr;
1083
+ uint32_t state_ptr_size;
1084
+ atmi_status_t err = atmi_interop_hsa_get_symbol_info (
1085
+ get_gpu_mem_place (device_id), sym (), &state_ptr, &state_ptr_size);
1086
+ if (err != ATMI_STATUS_SUCCESS) {
1087
+ DP (" failed to find %s in loaded image\n " , sym ());
1088
+ return err;
1089
+ }
1090
+
1091
+ if (state_ptr_size != si.size ) {
1092
+ DP (" Symbol had size %u before loading, %u after\n " , state_ptr_size,
1093
+ si.size );
1094
+ return ATMI_STATUS_ERROR;
1095
+ }
1096
+
1097
+ return DeviceInfo.freesignalpool_memcpy_h2d (state_ptr, &host_device_env,
1098
+ state_ptr_size, device_id);
1099
+ }
1100
+ return ATMI_STATUS_SUCCESS;
1101
+ }
1102
+ };
1103
+
995
1104
static atmi_status_t atmi_calloc (void **ret_ptr, size_t size,
996
1105
atmi_mem_place_t place) {
997
1106
uint64_t rounded = 4 * ((size + 3 ) / 4 );
@@ -1047,41 +1156,18 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id,
1047
1156
return NULL ;
1048
1157
}
1049
1158
1050
- omptarget_device_environmentTy host_device_env;
1051
- host_device_env.num_devices = DeviceInfo.NumberOfDevices ;
1052
- host_device_env.device_num = device_id;
1053
- host_device_env.debug_level = 0 ;
1054
- #ifdef OMPTARGET_DEBUG
1055
- if (char *envStr = getenv (" LIBOMPTARGET_DEVICE_RTL_DEBUG" )) {
1056
- host_device_env.debug_level = std::stoi (envStr);
1057
- }
1058
- #endif
1059
-
1060
- auto on_deserialized_data = [&](void *data, size_t size) -> atmi_status_t {
1061
- const char *device_env_Name = " omptarget_device_environment" ;
1062
- symbol_info si;
1063
- int rc = get_symbol_info_without_loading ((char *)image->ImageStart ,
1064
- img_size, device_env_Name, &si);
1065
- if (rc != 0 ) {
1066
- DP (" Finding global device environment '%s' - symbol missing.\n " ,
1067
- device_env_Name);
1068
- // no need to return FAIL, consider this is a not a device debug build.
1069
- return ATMI_STATUS_SUCCESS;
1070
- }
1071
- if (si.size != sizeof (host_device_env)) {
1072
- return ATMI_STATUS_ERROR;
1159
+ {
1160
+ auto env = device_environment (device_id, DeviceInfo.NumberOfDevices , image,
1161
+ img_size);
1162
+ if (!env.valid ) {
1163
+ return NULL ;
1073
1164
}
1074
- DP (" Setting global device environment %u bytes\n " , si.size );
1075
- uint64_t offset = (char *)si.addr - (char *)image->ImageStart ;
1076
- void *pos = (char *)data + offset;
1077
- memcpy (pos, &host_device_env, sizeof (host_device_env));
1078
- return ATMI_STATUS_SUCCESS;
1079
- };
1080
1165
1081
- {
1082
1166
atmi_status_t err = module_register_from_memory_to_place (
1083
1167
(void *)image->ImageStart , img_size, get_gpu_place (device_id),
1084
- on_deserialized_data);
1168
+ [&](void *data, size_t size) {
1169
+ return env.before_loading (data, size);
1170
+ });
1085
1171
1086
1172
check (" Module registering" , err);
1087
1173
if (err != ATMI_STATUS_SUCCESS) {
@@ -1092,6 +1178,11 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id,
1092
1178
get_elf_mach_gfx_name (elf_e_flags (image)));
1093
1179
return NULL ;
1094
1180
}
1181
+
1182
+ err = env.after_loading ();
1183
+ if (err != ATMI_STATUS_SUCCESS) {
1184
+ return NULL ;
1185
+ }
1095
1186
}
1096
1187
1097
1188
DP (" ATMI module successfully loaded!\n " );
0 commit comments