Skip to content

Commit be9f7da

Browse files
authored
Merge pull request #1902 from nakajee/release/rocm-rel-6.1
Hotfix: Fix WorkspaceCheck implementation when used in rocBLAS
2 parents d0314ce + e61b297 commit be9f7da

31 files changed

+313
-118
lines changed

HostLibraryTests/LibYamlToMsgpack.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/usr/bin/env python3
2+
3+
################################################################################
4+
#
5+
# Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
6+
#
7+
# Permission is hereby granted, free of charge, to any person obtaining a copy
8+
# of this software and associated documentation files (the "Software"), to deal
9+
# in the Software without restriction, including without limitation the rights
10+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
# copies of the Software, and to permit persons to whom the Software is
12+
# furnished to do so, subject to the following conditions:
13+
#
14+
# The above copyright notice and this permission notice shall be included in
15+
# all copies or substantial portions of the Software.
16+
#
17+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
# SOFTWARE.
24+
#
25+
################################################################################
26+
27+
import sys
28+
import yaml
29+
import msgpack
30+
31+
if __name__ == "__main__":
32+
args = sys.argv[1:]
33+
infile = args[0]
34+
outfile = args[1]
35+
with open(infile) as f:
36+
data = yaml.load(f)
37+
with open(outfile, 'wb') as f:
38+
msgpack.dump(data, f)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
Sample libraries can be rebuilt using TensileCreateLibrary and rocBLAS build.
2+
3+
To rebuild rocBLAS_Full, run rocBLAS build script.
4+
Sample library currently includes gfx803, gfx900, gfx906, and gfx908.
5+
To build yaml version, include the --no-msgpack flag.
6+
7+
./install.sh -dc -t ~/tensile -a "gfx803;gfx900;gfx906;gfx908" --merge-architectures --no-lazy-library-loading
8+
./install.sh -dc -t ~/tensile -a "gfx803;gfx900;gfx906;gfx908" --merge-architectures --no-lazy-library-loading --no-msgpack
9+
10+
SampleTensileKernels are small samples written manually.
11+
To update, make any required updates to SampleTensileKernels.yaml and call the script to convert to msgpack
12+
13+
cd HostLibraryTests
14+
./LibYamlToMsgpack.py configs/SolutionLibraries/SampleTensileKernels.yaml configs/SolutionLibraries/SampleTensileKernels.dat
15+
16+
Other libs can be rebuilt by calling TensileCreateLibrary.
17+
18+
KernelsLite:
19+
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=yaml ../HostLibraryTests/configs/lite_configs/ . HIP
20+
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=msgpack ../HostLibraryTests/configs/lite_configs/ . HIP
21+
KernelsLiteMixed:
22+
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=yaml ../HostLibraryTests/configs/lite_configs_mixed/ . HIP
23+
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=msgpack ../HostLibraryTests/configs/lite_configs_mixed/ . HIP
24+
KernelsLiteNavi:
25+
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=yaml ../Tensile/Source/lib/configs/lite_configs/ . HIP
26+
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=msgpack ../Tensile/Source/lib/configs/lite_configs/ . HIP
27+
KernelsTileLite:
28+
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=yaml ../HostLibraryTests/configs/tile_aware_selection/ . HIP
29+
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=msgpack ../HostLibraryTests/configs/tile_aware_selection/ . HIP
30+
31+
All libraries are checked in as .gz to reduce checkout size
Binary file not shown.
Binary file not shown.

HostLibraryTests/llvm/LLVMYAMLContraction_test.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ TEST(LLVMYAMLContractionTest, Simple)
5252
"index: 0\n"
5353
"hardwarePredicate: { type: TruePred }\n"
5454
"problemPredicate: { type: TruePred }\n"
55+
"taskPredicate: { type: TruePred }\n"
5556
"debugKernel: false\n"
5657
"problemType:\n"
5758
" operationIdentifier: foo\n"
@@ -119,6 +120,7 @@ TEST(LLVMYAMLContractionTest, ContractionLibrary)
119120
" index: 0\n"
120121
" hardwarePredicate: { type: TruePred }\n"
121122
" problemPredicate: { type: TruePred }\n"
123+
" taskPredicate: { type: TruePred }\n"
122124
" debugKernel: false\n"
123125
" problemType:\n"
124126
" operationIdentifier: foo\n"

HostLibraryTests/sample_library.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ solutions:
2323

2424
hardwarePredicate: { type: TruePred }
2525
problemPredicate: { type: TruePred }
26+
taskPredicate: { type: TruePred }
2627
info: {}
2728
debugKernel: false
2829
index: 0
@@ -51,6 +52,7 @@ solutions:
5152

5253
hardwarePredicate: { type: TruePred }
5354
problemPredicate: { type: TruePred }
55+
taskPredicate: { type: TruePred }
5456
info: {}
5557
debugKernel: false
5658
index: 1
@@ -79,6 +81,7 @@ solutions:
7981

8082
hardwarePredicate: { type: TruePred }
8183
problemPredicate: { type: TruePred }
84+
taskPredicate: { type: TruePred }
8285
info: {}
8386
debugKernel: false
8487
index: 2
@@ -107,6 +110,7 @@ solutions:
107110

108111
hardwarePredicate: { type: TruePred }
109112
problemPredicate: { type: TruePred }
113+
taskPredicate: { type: TruePred }
110114
info: {}
111115
debugKernel: false
112116
index: 3

Tensile/Contractions.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved.
3+
# Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
44
#
55
# Permission is hereby granted, free of charge, to any person obtaining a copy
66
# of this software and associated documentation files (the "Software"), to deal
@@ -354,9 +354,6 @@ def FromOriginalKeyPair(cls, pair):
354354

355355
return cls(tag, index=index, value=value)
356356

357-
if key == "_WorkspaceSizePerElemC" and value > 0:
358-
return cls("WorkspaceCheck", index=0, value=value)
359-
360357
if key.startswith('Assert'):
361358
raise RuntimeError("Unknown assertion key: {}".format(key))
362359

@@ -446,6 +443,19 @@ def FromOriginalState(cls, d, problemType, morePreds=[]):
446443
predicates = [p for p in map(cls.FromOriginalKeyPair, d.items()) if p is not None] + extraPreds
447444
return cls.And(predicates)
448445

446+
class TaskPredicate(Properties.Predicate):
447+
@classmethod
448+
def FromOriginalKeyPair(cls, pair):
449+
(key, value) = pair
450+
if key == "_WorkspaceSizePerElemC" and value > 0:
451+
return cls("WorkspaceCheck")
452+
return None
453+
454+
@classmethod
455+
def FromOriginalState(cls, d, problemType, morePreds=[]):
456+
predicates = [p for p in map(cls.FromOriginalKeyPair, d.items()) if p is not None]
457+
return cls.And(predicates)
458+
449459
class SizeMapping:
450460
StateKeys = ['workGroup',
451461
'macroTile',
@@ -514,6 +524,7 @@ class Solution:
514524
'problemType',
515525
'hardwarePredicate',
516526
'problemPredicate',
527+
'taskPredicate',
517528
'sizeMapping',
518529
'debugKernel',
519530
'libraryLogicIndex',
@@ -537,6 +548,7 @@ def FromOriginalState(cls, d, deviceInfo=None):
537548
rv.problemType = ProblemType.FromOriginalState(d['ProblemType'])
538549

539550
rv.problemPredicate = ProblemPredicate.FromOriginalState(d, rv.problemType)
551+
rv.taskPredicate = TaskPredicate.FromOriginalState(d, rv.problemType)
540552

541553
if 'DebugKernel' in d:
542554
rv.debugKernel = d['DebugKernel']
@@ -579,6 +591,7 @@ def __init__(self, **kwargs):
579591
self.problemType = None
580592
self.hardwarePredicate = Hardware.HardwarePredicate('TruePred')
581593
self.problemPredicate = ProblemPredicate('TruePred')
594+
self.taskPredicate = TaskPredicate('TruePred')
582595
self.sizeMapping = None
583596
self.debugKernel = False
584597
self.libraryLogicIndex = {}

Tensile/Source/client/source/SolutionIterator.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
*
33
* MIT License
44
*
5-
* Copyright (C) 2020-2022 Advanced Micro Devices, Inc. All rights reserved.
5+
* Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved.
66
*
77
* Permission is hereby granted, free of charge, to any person obtaining a copy
88
* of this software and associated documentation files (the "Software"), to deal
@@ -88,15 +88,17 @@ namespace Tensile
8888

8989
// Test if the persistent kernel is eligible for the current hw and solution
9090
m_problem.checkPersistentKernelEligibility(solution, *m_hardware);
91-
m_problem.checkRequiredWorkspaceSize(solution, *m_hardware);
92-
if(!(*solution.problemPredicate)(m_problem))
91+
Task task(*m_hardware, m_problem, solution);
92+
if(!(*solution.problemPredicate)(m_problem) || !(*solution.taskPredicate)(task))
9393
{
9494
m_reporter->report(ResultKey::Validation, "DID_NOT_SATISFY_ASSERTS");
9595
if(m_reporter->logAtLevel(LogLevel::Verbose))
9696
{
9797
std::ostringstream msg;
9898
solution.problemPredicate->debugEval(m_problem, msg);
9999
msg << std::endl;
100+
solution.taskPredicate->debugEval(task, msg);
101+
msg << std::endl;
100102
m_reporter->log(LogLevel::Verbose, msg.str());
101103
}
102104

Tensile/Source/lib/include/Tensile/ContractionProblem.hpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
*
33
* MIT License
44
*
5-
* Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
5+
* Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
66
*
77
* Permission is hereby granted, free of charge, to any person obtaining a copy
88
* of this software and associated documentation files (the "Software"), to deal
@@ -820,19 +820,12 @@ namespace Tensile
820820

821821
void checkPersistentKernelEligibility(ContractionSolution const& solution,
822822
Hardware const& hardware);
823-
void checkRequiredWorkspaceSize(ContractionSolution const& solution,
824-
Hardware const& hardware);
825823

826824
bool getPersistentKernelEligibility() const
827825
{
828826
return m_eligibleForPK;
829827
}
830828

831-
size_t getRequiredWorkspaceSize() const
832-
{
833-
return m_requiredWorkspaceSize;
834-
}
835-
836829
private:
837830
TensorDescriptor m_a;
838831
TensorDescriptor m_b;
@@ -860,7 +853,6 @@ namespace Tensile
860853
bool m_fp16AltImpl = false;
861854
bool m_fp16AltImplRound = false;
862855
bool m_stochasticRounding = false;
863-
size_t m_requiredWorkspaceSize = 0;
864856
DataType m_f32XdlMathOp = DataType::Float;
865857
ArithmeticUnit m_arithmeticUnit = ArithmeticUnit::Any;
866858
KernelLanguage m_kernelLanguage = KernelLanguage::Any;

Tensile/Source/lib/include/Tensile/ContractionProblemPredicates.hpp

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
*
33
* MIT License
44
*
5-
* Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
5+
* Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
66
*
77
* Permission is hereby granted, free of charge, to any person obtaining a copy
88
* of this software and associated documentation files (the "Software"), to deal
@@ -1236,45 +1236,6 @@ namespace Tensile
12361236
}
12371237
};
12381238

1239-
struct WorkspaceCheck : public Predicate_CRTP<WorkspaceCheck, ContractionProblem>
1240-
{
1241-
enum
1242-
{
1243-
HasIndex = true,
1244-
HasValue = true
1245-
};
1246-
size_t index;
1247-
size_t value;
1248-
1249-
WorkspaceCheck() = default;
1250-
WorkspaceCheck(size_t index, size_t value)
1251-
: index(index)
1252-
, value(value)
1253-
{
1254-
}
1255-
1256-
static std::string Type()
1257-
{
1258-
return "WorkspaceCheck";
1259-
}
1260-
1261-
virtual bool operator()(ContractionProblem const& problem) const override
1262-
{
1263-
return problem.getRequiredWorkspaceSize() <= problem.workspaceSize();
1264-
}
1265-
1266-
virtual bool debugEval(ContractionProblem const& problem,
1267-
std::ostream& stream) const override
1268-
{
1269-
bool rv = (*this)(problem);
1270-
1271-
stream << *this << ": (" << problem.getRequiredWorkspaceSize()
1272-
<< " <= " << problem.workspaceSize() << ") == " << rv;
1273-
1274-
return rv;
1275-
}
1276-
};
1277-
12781239
struct PersistentKernelCheck
12791240
: public Predicate_CRTP<PersistentKernelCheck, ContractionProblem>
12801241
{

Tensile/Source/lib/include/Tensile/ContractionSolution.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
*
33
* MIT License
44
*
5-
* Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved.
5+
* Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
66
*
77
* Permission is hereby granted, free of charge, to any person obtaining a copy
88
* of this software and associated documentation files (the "Software"), to deal
@@ -37,6 +37,7 @@
3737
#include <Tensile/ContractionProblem_fwd.hpp>
3838
#include <Tensile/DataTypes.hpp>
3939
#include <Tensile/Predicates.hpp>
40+
#include <Tensile/Task.hpp>
4041
#include <Tensile/Utils.hpp>
4142

4243
namespace Tensile
@@ -324,6 +325,8 @@ namespace Tensile
324325
bool debugKernel = false;
325326
bool kernelArgsLog = false;
326327

328+
std::shared_ptr<Predicates::Predicate<Task>> taskPredicate
329+
= std::make_shared<Predicates::True<Task>>();
327330
std::shared_ptr<Predicates::Predicate<Problem>> problemPredicate
328331
= std::make_shared<Predicates::True<Problem>>();
329332
std::shared_ptr<Predicates::Predicate<Hardware>> hardwarePredicate

0 commit comments

Comments
 (0)