Skip to content

Commit bd0a132

Browse files
committed
fix parseDisjointPoolConfig and add tests
1 parent e7366f9 commit bd0a132

File tree

3 files changed

+86
-146
lines changed

3 files changed

+86
-146
lines changed

source/common/umf_pools/disjoint_pool_config_parser.cpp

Lines changed: 21 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -72,151 +72,35 @@ DisjointPoolAllConfigs::DisjointPoolAllConfigs(int trace) {
7272

7373
DisjointPoolAllConfigs parseDisjointPoolConfig(const std::string &config,
7474
int trace) {
75+
// TODO: avoid creating a copy of 'config'
7576
DisjointPoolAllConfigs AllConfigs;
7677

77-
// TODO: replace with UR ENV var parser and avoid creating a copy of 'config'
78-
auto GetValue = [](std::string &Param, size_t Length, size_t &Setting) {
79-
size_t Multiplier = 1;
80-
if (tolower(Param[Length - 1]) == 'k') {
81-
Length--;
82-
Multiplier = 1_KB;
83-
}
84-
if (tolower(Param[Length - 1]) == 'm') {
85-
Length--;
86-
Multiplier = 1_MB;
87-
}
88-
if (tolower(Param[Length - 1]) == 'g') {
89-
Length--;
90-
Multiplier = 1_GB;
91-
}
92-
std::string TheNumber = Param.substr(0, Length);
93-
if (TheNumber.find_first_not_of("0123456789") == std::string::npos) {
94-
Setting = std::stoi(TheNumber) * Multiplier;
95-
}
96-
};
78+
size_t MaxSize = (std::numeric_limits<size_t>::max)();
79+
size_t EnableBuffers = 1;
9780

98-
auto ParamParser = [GetValue](std::string &Params, size_t &Setting,
99-
bool &ParamWasSet) {
100-
bool More;
101-
if (Params.size() == 0) {
102-
ParamWasSet = false;
103-
return false;
104-
}
105-
size_t Pos = Params.find(',');
106-
if (Pos != std::string::npos) {
107-
if (Pos > 0) {
108-
GetValue(Params, Pos, Setting);
109-
ParamWasSet = true;
110-
}
111-
Params.erase(0, Pos + 1);
112-
More = true;
81+
// Update pool settings if specified in environment.
82+
bool EnableBuffersSet = false;
83+
bool MaxSizeSet = false;
84+
size_t Start = 0;
85+
size_t End = config.find(';');
86+
while (true) {
87+
std::string Param = config.substr(Start, End - Start);
88+
if (!EnableBuffersSet && isdigit(Param[0])) {
89+
GetValue(Param, Param.size(), EnableBuffers);
90+
EnableBuffersSet = true;
91+
} else if (!MaxSizeSet && isdigit(Param[0])) {
92+
GetValue(Param, Param.size(), MaxSize);
93+
MaxSizeSet = true;
11394
} else {
114-
GetValue(Params, Params.size(), Setting);
115-
ParamWasSet = true;
116-
More = false;
95+
EnvVarMap map = string_to_map
11796
}
118-
return More;
119-
};
12097

121-
auto MemParser = [&AllConfigs, ParamParser](std::string &Params,
122-
DisjointPoolMemType memType =
123-
DisjointPoolMemType::All) {
124-
bool ParamWasSet;
125-
DisjointPoolMemType LM = memType;
126-
if (memType == DisjointPoolMemType::All) {
127-
LM = DisjointPoolMemType::Host;
98+
if (End == std::string::npos) {
99+
break;
128100
}
129101

130-
bool More = ParamParser(Params, AllConfigs.Configs[LM].MaxPoolableSize,
131-
ParamWasSet);
132-
if (ParamWasSet && memType == DisjointPoolMemType::All) {
133-
for (auto &Config : AllConfigs.Configs) {
134-
Config.MaxPoolableSize = AllConfigs.Configs[LM].MaxPoolableSize;
135-
}
136-
}
137-
if (More) {
138-
More = ParamParser(Params, AllConfigs.Configs[LM].Capacity,
139-
ParamWasSet);
140-
if (ParamWasSet && memType == DisjointPoolMemType::All) {
141-
for (auto &Config : AllConfigs.Configs) {
142-
Config.Capacity = AllConfigs.Configs[LM].Capacity;
143-
}
144-
}
145-
}
146-
if (More) {
147-
ParamParser(Params, AllConfigs.Configs[LM].SlabMinSize,
148-
ParamWasSet);
149-
if (ParamWasSet && memType == DisjointPoolMemType::All) {
150-
for (auto &Config : AllConfigs.Configs) {
151-
Config.SlabMinSize = AllConfigs.Configs[LM].SlabMinSize;
152-
}
153-
}
154-
}
155-
};
156-
157-
auto MemTypeParser = [MemParser](std::string &Params) {
158-
int Pos = 0;
159-
DisjointPoolMemType M(DisjointPoolMemType::All);
160-
if (Params.compare(0, 5, "host:") == 0) {
161-
Pos = 5;
162-
M = DisjointPoolMemType::Host;
163-
} else if (Params.compare(0, 7, "device:") == 0) {
164-
Pos = 7;
165-
M = DisjointPoolMemType::Device;
166-
} else if (Params.compare(0, 7, "shared:") == 0) {
167-
Pos = 7;
168-
M = DisjointPoolMemType::Shared;
169-
} else if (Params.compare(0, 17, "read_only_shared:") == 0) {
170-
Pos = 17;
171-
M = DisjointPoolMemType::SharedReadOnly;
172-
}
173-
if (Pos > 0) {
174-
Params.erase(0, Pos);
175-
}
176-
MemParser(Params, M);
177-
};
178-
179-
size_t MaxSize = (std::numeric_limits<size_t>::max)();
180-
181-
// Update pool settings if specified in environment.
182-
size_t EnableBuffers = 1;
183-
if (config != "") {
184-
std::string Params = config;
185-
size_t Pos = Params.find(';');
186-
if (Pos != std::string::npos) {
187-
if (Pos > 0) {
188-
GetValue(Params, Pos, EnableBuffers);
189-
}
190-
Params.erase(0, Pos + 1);
191-
size_t Pos = Params.find(';');
192-
if (Pos != std::string::npos) {
193-
if (Pos > 0) {
194-
GetValue(Params, Pos, MaxSize);
195-
}
196-
Params.erase(0, Pos + 1);
197-
do {
198-
size_t Pos = Params.find(';');
199-
if (Pos != std::string::npos) {
200-
if (Pos > 0) {
201-
std::string MemParams = Params.substr(0, Pos);
202-
MemTypeParser(MemParams);
203-
}
204-
Params.erase(0, Pos + 1);
205-
if (Params.size() == 0) {
206-
break;
207-
}
208-
} else {
209-
MemTypeParser(Params);
210-
break;
211-
}
212-
} while (true);
213-
} else {
214-
// set MaxPoolSize for all configs
215-
GetValue(Params, Params.size(), MaxSize);
216-
}
217-
} else {
218-
GetValue(Params, Params.size(), EnableBuffers);
219-
}
102+
Start = End + 1;
103+
End = config.find(';', Start);
220104
}
221105

222106
AllConfigs.EnableBuffers = EnableBuffers;

source/common/ur_util.hpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,27 @@ using EnvVarMap = std::map<std::string, std::vector<std::string>>;
209209
/// Otherwise, optional is set to std::nullopt when the environment variable
210210
/// is not set or is empty.
211211
/// @throws std::invalid_argument() when the parsed environment variable has wrong format
212+
212213
inline std::optional<EnvVarMap> getenv_to_map(const char *env_var_name,
213214
bool reject_empty = true) {
214-
char main_delim = ';';
215-
char key_value_delim = ':';
216-
char values_delim = ',';
217-
EnvVarMap map;
218-
219215
auto env_var = ur_getenv(env_var_name);
220216
if (!env_var.has_value()) {
221217
return std::nullopt;
222218
}
223219

220+
try {
221+
auto map = string_to_map(*env_var, reject_empty);
222+
} catch (...) {
223+
throw_wrong_format_map(env_var_name, env_var);
224+
}
225+
}
226+
227+
inline EnvVarMap string_to_map(std::string input, bool reject_empty = true) {
228+
char main_delim = ';';
229+
char key_value_delim = ':';
230+
char values_delim = ',';
231+
EnvVarMap map;
232+
224233
auto is_quoted = [](std::string &str) {
225234
return (str.front() == '\'' && str.back() == '\'') ||
226235
(str.front() == '"' && str.back() == '"');
@@ -229,30 +238,30 @@ inline std::optional<EnvVarMap> getenv_to_map(const char *env_var_name,
229238
return str.find(':') != std::string::npos;
230239
};
231240

232-
std::stringstream ss(*env_var);
241+
std::stringstream ss(input);
233242
std::string key_value;
234243
while (std::getline(ss, key_value, main_delim)) {
235244
std::string key;
236245
std::string values;
237246
std::stringstream kv_ss(key_value);
238247

239248
if (reject_empty && !has_colon(key_value)) {
240-
throw_wrong_format_map(env_var_name, *env_var);
249+
throw_wrong_format_map(env_var_name, input);
241250
}
242251

243252
std::getline(kv_ss, key, key_value_delim);
244253
std::getline(kv_ss, values);
245254
if (key.empty() || (reject_empty && values.empty()) ||
246255
map.find(key) != map.end()) {
247-
throw_wrong_format_map(env_var_name, *env_var);
256+
throw;
248257
}
249258

250259
std::vector<std::string> values_vec;
251260
std::stringstream values_ss(values);
252261
std::string value;
253262
while (std::getline(values_ss, value, values_delim)) {
254263
if (value.empty() || (has_colon(value) && !is_quoted(value))) {
255-
throw_wrong_format_map(env_var_name, *env_var);
264+
throw;
256265
}
257266
if (is_quoted(value)) {
258267
value.erase(value.cbegin());

test/usm/usmPoolManager.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See LICENSE.TXT
44
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55

6+
#include "umf_pools/disjoint_pool_config_parser.hpp"
67
#include "ur_pool_manager.hpp"
78

89
#include <uur/fixtures.h>
@@ -18,6 +19,27 @@ auto createMockPoolHandle() {
1819
[](umf_memory_pool_t *) {});
1920
}
2021

22+
bool compareConfig(const usm::umf_disjoint_pool_config_t &left,
23+
usm::umf_disjoint_pool_config_t &right) {
24+
return left.MaxPoolableSize == right.MaxPoolableSize &&
25+
left.Capacity == right.Capacity &&
26+
left.SlabMinSize == right.SlabMinSize;
27+
}
28+
29+
bool compareConfigs(const usm::DisjointPoolAllConfigs &left,
30+
usm::DisjointPoolAllConfigs &right) {
31+
return left.EnableBuffers == right.EnableBuffers &&
32+
compareConfig(left.Configs[usm::DisjointPoolMemType::Host],
33+
right.Configs[usm::DisjointPoolMemType::Host]) &&
34+
compareConfig(left.Configs[usm::DisjointPoolMemType::Device],
35+
right.Configs[usm::DisjointPoolMemType::Device]) &&
36+
compareConfig(left.Configs[usm::DisjointPoolMemType::Shared],
37+
right.Configs[usm::DisjointPoolMemType::Shared]) &&
38+
compareConfig(
39+
left.Configs[usm::DisjointPoolMemType::SharedReadOnly],
40+
right.Configs[usm::DisjointPoolMemType::SharedReadOnly]);
41+
}
42+
2143
TEST_P(urUsmPoolDescriptorTest, poolIsPerContextTypeAndDevice) {
2244
auto &devices = uur::DevicesEnvironment::instance->devices;
2345
auto poolHandle = this->GetParam();
@@ -115,4 +137,29 @@ TEST_P(urUsmPoolManagerTest, poolManagerGetNonexistant) {
115137
}
116138
}
117139

140+
TEST_P(urUsmPoolManagerTest, config) {
141+
// Check default config
142+
usm::DisjointPoolAllConfigs def;
143+
usm::DisjointPoolAllConfigs parsed1 =
144+
usm::parseDisjointPoolConfig("1;host:2M,4,64K;device:4M,4,64K;"
145+
"shared:0,0,2M;read_only_shared:4M,4,2M",
146+
0);
147+
ASSERT_EQ(compareConfigs(def, parsed1), true);
148+
149+
// Check partially set config
150+
usm::DisjointPoolAllConfigs parsed2 =
151+
usm::parseDisjointPoolConfig("1;device:4M;shared:0,0,2M", 0);
152+
ASSERT_EQ(compareConfigs(def, parsed2), true);
153+
154+
// Check non-default config
155+
usm::DisjointPoolAllConfigs test(def);
156+
test.Configs[usm::DisjointPoolMemType::Shared].MaxPoolableSize = 128 * 1024;
157+
test.Configs[usm::DisjointPoolMemType::Shared].Capacity = 4;
158+
test.Configs[usm::DisjointPoolMemType::Shared].SlabMinSize = 64 * 1024;
159+
160+
usm::DisjointPoolAllConfigs parsed3 =
161+
usm::parseDisjointPoolConfig("1;shared:128K,4,64K", 0);
162+
ASSERT_EQ(compareConfigs(test, parsed3), true);
163+
}
164+
118165
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urUsmPoolManagerTest);

0 commit comments

Comments
 (0)