@@ -272,32 +272,6 @@ void Scop::promoteEverythingAt(std::vector<size_t> pos) {
272
272
insertSyncsAroundCopies (tree);
273
273
}
274
274
275
- namespace {
276
- typedef std::unordered_map<isl::id, long , isl::IslIdIslHash> IslParamValueMap;
277
-
278
- // Extract the fixed values of the parameters from the given (context) set.
279
- IslParamValueMap extractParamValueMap (isl::set set) {
280
- CHECK (set.is_singleton ()) << " set must be singleton to extract fixed values" ;
281
-
282
- auto ctx = set.get_ctx ();
283
- auto longMax = isl::val (ctx, std::numeric_limits<long >::max ());
284
- auto p = set.sample_point ();
285
- auto space = p.get_space ();
286
-
287
- IslParamValueMap paramValueMap;
288
- int i = 0 ;
289
- for (auto id : isl::DimIds<isl::space, isl::dim_type::param>(space)) {
290
- auto val = p.get_coordinate_val (isl::dim_type::param, i);
291
- CHECK_EQ (val.get_den_si (), 1 ) << " fractional parameters unsupported" ;
292
- CHECK (val.le (longMax)) << " parameter value overflows long" ;
293
- paramValueMap[id] = val.get_num_si ();
294
- ++i;
295
- }
296
-
297
- return paramValueMap;
298
- }
299
- } // namespace
300
-
301
275
// Compute the values of parameters based on the effective sizes of the
302
276
// tensors provided as arguments and their parametric expressions stored in
303
277
// Halide InputImage. We only know input sizes, output sizes are inferred.
@@ -325,24 +299,24 @@ isl::set Scop::makeContextFromInputs(
325
299
}
326
300
327
301
std::vector<long > Scop::getParameterValues (isl::set context) const {
328
- IslParamValueMap pvm = extractParamValueMap (context);
302
+ auto ctx = context.get_ctx ();
303
+ auto longMax = isl::val (ctx, std::numeric_limits<long >::max ());
304
+ auto space = context.get_space ();
305
+ auto p = context.sample_point ();
306
+ CHECK (context.is_equal (p));
329
307
330
308
// Scop holds a vector of Variables.
331
309
// Iterate over parameters in order, checking if the
332
- // ParamValueMap contains an id whose name corresponds to that
310
+ // context contains a parameter whose name corresponds to that
333
311
// Variable and push respective parameter values.
334
312
std::vector<long > paramValues;
335
313
for (auto const & param : halide.params ) {
336
- size_t previousSize = paramValues.size ();
337
- for (auto p : pvm) {
338
- isl::id id = p.first ;
339
- if (id.get_name () == param.name ()) {
340
- paramValues.push_back (p.second );
341
- }
342
- }
343
- CHECK_EQ (previousSize + 1 , paramValues.size ())
344
- << " parameter " << param.name () << " is not present in the context "
345
- << context << " ; mind identical names in Halide." ;
314
+ isl::id id (ctx, param.name ());
315
+ CHECK (context.involves_param (id));
316
+ auto val = isl::aff::param_on_domain_space (space, id).eval (p);
317
+ CHECK (val.is_int ()) << " fractional parameters unsupported" ;
318
+ CHECK (val.le (longMax)) << " parameter value overflows long" ;
319
+ paramValues.push_back (val.get_num_si ());
346
320
}
347
321
return paramValues;
348
322
}
0 commit comments