@@ -356,26 +356,61 @@ size_t largestDim(const std::vector<const DLConstTensor*>& inputs) {
356
356
return (*maxElement)->ndim ;
357
357
}
358
358
359
- // Creates well-chosen parameter sizes to match the input shapes.
360
- void setupTuningParameters (
359
+ // Creates well-chosen generic parameter sizes to match the input shapes.
360
+ template <typename MappingOptionsType>
361
+ inline std::pair<TuningConfiguration, std::vector<size_t >>
362
+ setupGenericTuningParametersAndGetRange (
361
363
const std::vector<const DLConstTensor*>& inputs,
362
- TuningConfiguration& configuration ) {
364
+ const std::vector<MappingOptionsType>& baseMappings ) {
363
365
TC_CHECK_GE (inputs.size (), 1u );
364
366
auto range = inputDivisorsAndPowers2 (inputs);
365
367
// 0 is a valid tiling annotation and signals no tiling of that dimension
366
368
// 0 is not a valid block / grid annotation
367
369
auto nTilesDim = largestDim (inputs) + 1 ;
368
370
auto tileRange = range;
369
371
tileRange.push_back (0 );
372
+
373
+ TuningConfiguration configuration;
370
374
configuration.tilingParams .setRange (nTilesDim, tileRange);
371
- configuration.blockParams .setRange (range, " b" );
372
- configuration.gridParams .setRange (range, " g" );
373
375
configuration.unrollFactor =
374
376
RangeParameter (powers2 (FLAGS_tuner_max_unroll_size), " unroll" );
377
+
378
+ return {configuration, range};
379
+ }
380
+
381
+ // Creates well-chosen parameter sizes to match the input shapes.
382
+ inline TuningConfiguration setupTuningParameters (
383
+ const std::vector<const DLConstTensor*>& inputs,
384
+ const std::vector<CudaMappingOptions>& baseMappings) {
385
+ std::vector<size_t > range;
386
+ TuningConfiguration configuration;
387
+ std::tie (configuration, range) =
388
+ setupGenericTuningParametersAndGetRange (inputs, baseMappings);
389
+ auto blockRange = range;
390
+ auto gridRange = range;
391
+
392
+ for (const auto & baseMapping : baseMappings) {
393
+ blockRange =
394
+ mergeVectors (std::move (blockRange), baseMapping.block .extractVector ());
395
+ gridRange =
396
+ mergeVectors (std::move (gridRange), baseMapping.grid .extractVector ());
397
+ }
398
+
399
+ configuration.blockParams .setRange (blockRange, " b" );
400
+ configuration.gridParams .setRange (gridRange, " g" );
375
401
configuration.privateDepth =
376
402
RangeParameter ({0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 }, " pdepth" );
377
403
configuration.sharedDepth =
378
404
RangeParameter ({0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 }, " sdepth" );
405
+
406
+ return configuration;
407
+ }
408
+
409
+ // Creates well-chosen parameter sizes to match the input shapes.
410
+ inline TuningConfiguration setupTuningParameters (
411
+ const std::vector<const DLConstTensor*>& inputs,
412
+ const std::vector<CpuMappingOptions>& baseMappings) {
413
+ return setupGenericTuningParametersAndGetRange (inputs, baseMappings).first ;
379
414
}
380
415
} // namespace
381
416
@@ -397,9 +432,9 @@ Autotuner<Backend, SearchStrategy>::tune(
397
432
<< " Error looking up " << tcEntryPoint;
398
433
399
434
// Initialize a model configuration
400
- TuningConfiguration modelConfiguration;
401
435
TC_CHECK_GE (inputs.size (), 1u );
402
- setupTuningParameters (inputs.begin ()->second , modelConfiguration);
436
+ auto modelConfiguration =
437
+ setupTuningParameters (inputs.begin ()->second , baseMappings);
403
438
modelConfiguration.fixParameters (fixedParams);
404
439
405
440
// Create initial configs based on options + model configuration
0 commit comments