Skip to content

Commit e3f4822

Browse files
authored
Merge pull request #791 from aarongreig/aaron/expandBufferFillTests
Add some parameterization to MemBufferFill tests.
2 parents 404afa8 + 4a02d7b commit e3f4822

File tree

4 files changed

+144
-70
lines changed

4 files changed

+144
-70
lines changed

test/conformance/enqueue/urEnqueueMemBufferFill.cpp

Lines changed: 122 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,90 +4,172 @@
44
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55
#include <uur/fixtures.h>
66

7-
using urEnqueueMemBufferFillTest = uur::urMemBufferQueueTest;
8-
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urEnqueueMemBufferFillTest);
7+
struct testParametersFill {
8+
size_t size;
9+
size_t pattern_size;
10+
};
11+
12+
template <typename T>
13+
inline std::string
14+
printFillTestString(const testing::TestParamInfo<typename T::ParamType> &info) {
15+
const auto device_handle = std::get<0>(info.param);
16+
const auto platform_device_name =
17+
uur::GetPlatformAndDeviceName(device_handle);
18+
std::stringstream test_name;
19+
test_name << platform_device_name << "__size__"
20+
<< std::get<1>(info.param).size << "__patternSize__"
21+
<< std::get<1>(info.param).pattern_size;
22+
return test_name.str();
23+
}
24+
25+
struct urEnqueueMemBufferFillTest
26+
: uur::urQueueTestWithParam<testParametersFill> {
27+
void SetUp() override {
28+
UUR_RETURN_ON_FATAL_FAILURE(
29+
urQueueTestWithParam<testParametersFill>::SetUp());
30+
size = std::get<1>(GetParam()).size;
31+
pattern_size = std::get<1>(GetParam()).pattern_size;
32+
pattern = std::vector<uint8_t>(pattern_size);
33+
uur::generateMemFillPattern(pattern);
34+
ASSERT_SUCCESS(urMemBufferCreate(this->context, UR_MEM_FLAG_READ_WRITE,
35+
size, nullptr, &buffer));
36+
}
37+
38+
void TearDown() override {
39+
if (buffer) {
40+
EXPECT_SUCCESS(urMemRelease(buffer));
41+
}
42+
UUR_RETURN_ON_FATAL_FAILURE(
43+
urQueueTestWithParam<testParametersFill>::TearDown());
44+
}
45+
46+
void verifyData(std::vector<uint8_t> &output, size_t verify_size) {
47+
size_t pattern_index = 0;
48+
for (size_t i = 0; i < verify_size; ++i) {
49+
ASSERT_EQ(output[i], pattern[pattern_index])
50+
<< "Result mismatch at index: " << i;
51+
52+
++pattern_index;
53+
if (pattern_index % pattern_size == 0) {
54+
pattern_index = 0;
55+
}
56+
}
57+
}
58+
59+
ur_mem_handle_t buffer = nullptr;
60+
std::vector<uint8_t> pattern;
61+
size_t size;
62+
size_t pattern_size;
63+
};
64+
65+
static std::vector<testParametersFill> test_cases{
66+
/* Everything set to 1 */
67+
{1, 1},
68+
/* pattern_size == size */
69+
{256, 256},
70+
/* pattern_size < size */
71+
{1024, 256},
72+
/* pattern sizes corresponding to some common scalar and vector types */
73+
{256, 4},
74+
{256, 8},
75+
{256, 16},
76+
{256, 32}};
77+
78+
UUR_TEST_SUITE_P(urEnqueueMemBufferFillTest, testing::ValuesIn(test_cases),
79+
printFillTestString<urEnqueueMemBufferFillTest>);
980

1081
TEST_P(urEnqueueMemBufferFillTest, Success) {
11-
const uint32_t pattern = 0xdeadbeef;
12-
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, &pattern,
13-
sizeof(pattern), 0, size, 0, nullptr,
82+
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, pattern.data(),
83+
pattern_size, 0, size, 0, nullptr,
1484
nullptr));
15-
std::vector<uint32_t> output(count, 1);
85+
std::vector<uint8_t> output(size, 1);
1686
ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, buffer, true, 0, size,
1787
output.data(), 0, nullptr, nullptr));
18-
for (unsigned i = 0; i < count; ++i) {
19-
ASSERT_EQ(output[i], pattern) << "Result mismatch at index: " << i;
20-
}
88+
verifyData(output, size);
2189
}
22-
2390
TEST_P(urEnqueueMemBufferFillTest, SuccessPartialFill) {
24-
const std::vector<uint32_t> input(count, 42);
91+
if (size == 1) {
92+
// Can't partially fill one byte
93+
GTEST_SKIP();
94+
}
95+
const std::vector<uint8_t> input(size, 0);
2596
ASSERT_SUCCESS(urEnqueueMemBufferWrite(queue, buffer, true, 0, size,
2697
input.data(), 0, nullptr, nullptr));
27-
const uint32_t pattern = 0xdeadbeef;
2898
const size_t partial_fill_size = size / 2;
29-
const size_t fill_count = count / 2;
30-
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, &pattern,
31-
sizeof(pattern), 0, partial_fill_size,
32-
0, nullptr, nullptr));
33-
std::vector<uint32_t> output(count, 1);
99+
// Make sure we don't end up with pattern_size > size
100+
pattern_size = pattern_size / 2;
101+
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, pattern.data(),
102+
pattern_size, 0, partial_fill_size, 0,
103+
nullptr, nullptr));
104+
std::vector<uint8_t> output(size, 1);
34105
ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, buffer, true, 0, size,
35106
output.data(), 0, nullptr, nullptr));
36-
for (size_t i = 0; i < count - fill_count; ++i) {
37-
ASSERT_EQ(output[i], pattern) << "Result mismatch at index: " << i;
38-
}
107+
// Check the first half matches the pattern and the second half remains untouched.
108+
verifyData(output, partial_fill_size);
39109

40-
for (size_t i = fill_count; i < count; ++i) {
41-
ASSERT_EQ(output[i], 42) << "Result mismatch at index: " << i;
110+
for (size_t i = partial_fill_size; i < size; ++i) {
111+
ASSERT_EQ(output[i], input[i]) << "Result mismatch at index: " << i;
42112
}
43113
}
44114

45115
TEST_P(urEnqueueMemBufferFillTest, SuccessOffset) {
46-
const std::vector<uint32_t> input(count, 42);
116+
if (size == 1) {
117+
// No room for an offset
118+
GTEST_SKIP();
119+
}
120+
const std::vector<uint8_t> input(size, 0);
47121
ASSERT_SUCCESS(urEnqueueMemBufferWrite(queue, buffer, true, 0, size,
48122
input.data(), 0, nullptr, nullptr));
49-
const uint32_t pattern = 0xdeadbeef;
123+
50124
const size_t offset_size = size / 2;
51-
const size_t offset_count = count / 2;
52-
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, &pattern,
53-
sizeof(pattern), offset_size,
125+
// Make sure we don't end up with pattern_size > size
126+
pattern_size = pattern_size / 2;
127+
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, pattern.data(),
128+
pattern_size, offset_size,
54129
offset_size, 0, nullptr, nullptr));
55-
std::vector<uint32_t> output(count, 1);
56-
ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, buffer, true, 0, size,
57-
output.data(), 0, nullptr, nullptr));
58-
for (size_t i = 0; i < offset_count; ++i) {
59-
ASSERT_EQ(output[i], 42) << "Result mismatch at index: " << i;
60-
}
61130

62-
for (size_t i = offset_count; i < count; ++i) {
63-
ASSERT_EQ(output[i], pattern) << "Result mismatch at index: " << i;
131+
// Check the second half matches the pattern and the first half remains untouched.
132+
std::vector<uint8_t> output(offset_size);
133+
ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, buffer, true, offset_size,
134+
offset_size, output.data(), 0,
135+
nullptr, nullptr));
136+
verifyData(output, offset_size);
137+
138+
ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, buffer, true, 0, offset_size,
139+
output.data(), 0, nullptr, nullptr));
140+
for (size_t i = 0; i < offset_size; ++i) {
141+
ASSERT_EQ(output[i], input[i]) << "Result mismatch at index: " << i;
64142
}
65143
}
66144

67-
TEST_P(urEnqueueMemBufferFillTest, InvalidNullHandleQueue) {
145+
using urEnqueueMemBufferFillNegativeTest = uur::urMemBufferQueueTest;
146+
147+
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urEnqueueMemBufferFillNegativeTest);
148+
149+
TEST_P(urEnqueueMemBufferFillNegativeTest, InvalidNullHandleQueue) {
68150
const uint32_t pattern = 0xdeadbeef;
69151
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE,
70152
urEnqueueMemBufferFill(nullptr, buffer, &pattern,
71153
sizeof(pattern), 0, size, 0,
72154
nullptr, nullptr));
73155
}
74156

75-
TEST_P(urEnqueueMemBufferFillTest, InvalidNullHandleBuffer) {
157+
TEST_P(urEnqueueMemBufferFillNegativeTest, InvalidNullHandleBuffer) {
76158
const uint32_t pattern = 0xdeadbeef;
77159
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE,
78160
urEnqueueMemBufferFill(queue, nullptr, &pattern,
79161
sizeof(pattern), 0, size, 0,
80162
nullptr, nullptr));
81163
}
82164

83-
TEST_P(urEnqueueMemBufferFillTest, InvalidNullHandlePointerPattern) {
165+
TEST_P(urEnqueueMemBufferFillNegativeTest, InvalidNullHandlePointerPattern) {
84166
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_POINTER,
85167
urEnqueueMemBufferFill(queue, buffer, nullptr,
86168
sizeof(uint32_t), 0, size, 0,
87169
nullptr, nullptr));
88170
}
89171

90-
TEST_P(urEnqueueMemBufferFillTest, InvalidNullPtrEventWaitList) {
172+
TEST_P(urEnqueueMemBufferFillNegativeTest, InvalidNullPtrEventWaitList) {
91173
const uint32_t pattern = 0xdeadbeef;
92174
ASSERT_EQ_RESULT(urEnqueueMemBufferFill(queue, buffer, &pattern,
93175
sizeof(uint32_t), 0, size, 1,
@@ -103,7 +185,7 @@ TEST_P(urEnqueueMemBufferFillTest, InvalidNullPtrEventWaitList) {
103185
UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST);
104186
}
105187

106-
TEST_P(urEnqueueMemBufferFillTest, InvalidSize) {
188+
TEST_P(urEnqueueMemBufferFillNegativeTest, InvalidSize) {
107189
const uint32_t pattern = 0xdeadbeef;
108190
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_SIZE,
109191
urEnqueueMemBufferFill(queue, buffer, &pattern,

test/conformance/enqueue/urEnqueueUSMFill.cpp

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// See LICENSE.TXT
44
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55

6-
#include <random>
76
#include <uur/fixtures.h>
87

98
struct testParametersFill {
@@ -34,7 +33,7 @@ struct urEnqueueUSMFillTestWithParam
3433
host_mem = std::vector<uint8_t>(size);
3534
pattern_size = std::get<1>(GetParam()).pattern_size;
3635
pattern = std::vector<uint8_t>(pattern_size);
37-
generatePattern();
36+
uur::generateMemFillPattern(pattern);
3837

3938
ur_device_usm_access_capability_flags_t device_usm = 0;
4039
ASSERT_SUCCESS(uur::GetDeviceUSMDeviceSupport(device, device_usm));
@@ -54,19 +53,6 @@ struct urEnqueueUSMFillTestWithParam
5453
UUR_RETURN_ON_FATAL_FAILURE(urQueueTestWithParam::TearDown());
5554
}
5655

57-
void generatePattern() {
58-
59-
const size_t seed = 1;
60-
std::mt19937 mersenne_engine{seed};
61-
std::uniform_int_distribution<int> dist{0, 255};
62-
63-
auto gen = [&dist, &mersenne_engine]() {
64-
return static_cast<uint8_t>(dist(mersenne_engine));
65-
};
66-
67-
std::generate(begin(pattern), end(pattern), gen);
68-
}
69-
7056
void verifyData() {
7157
ASSERT_SUCCESS(urEnqueueUSMMemcpy(queue, true, host_mem.data(), ptr,
7258
size, 0, nullptr, nullptr));
@@ -98,7 +84,11 @@ static std::vector<testParametersFill> test_cases{
9884
{256, 256},
9985
/* pattern_size < size */
10086
{1024, 256},
101-
};
87+
/* pattern sizes corresponding to some common scalar and vector types */
88+
{256, 4},
89+
{256, 8},
90+
{256, 16},
91+
{256, 32}};
10292

10393
UUR_TEST_SUITE_P(urEnqueueUSMFillTestWithParam, testing::ValuesIn(test_cases),
10494
printFillTestString<urEnqueueUSMFillTestWithParam>);

test/conformance/enqueue/urEnqueueUSMFill2D.cpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct urEnqueueUSMFill2DTestWithParam
3838
height = std::get<1>(GetParam()).height;
3939
pattern_size = std::get<1>(GetParam()).pattern_size;
4040
pattern = std::vector<uint8_t>(pattern_size);
41-
generatePattern();
41+
uur::generateMemFillPattern(pattern);
4242
allocation_size = pitch * height;
4343
host_mem = std::vector<uint8_t>(allocation_size);
4444

@@ -60,19 +60,6 @@ struct urEnqueueUSMFill2DTestWithParam
6060
UUR_RETURN_ON_FATAL_FAILURE(urQueueTestWithParam::TearDown());
6161
}
6262

63-
void generatePattern() {
64-
65-
const size_t seed = 1;
66-
std::mt19937 mersenne_engine{seed};
67-
std::uniform_int_distribution<int> dist{0, 255};
68-
69-
auto gen = [&dist, &mersenne_engine]() {
70-
return static_cast<uint8_t>(dist(mersenne_engine));
71-
};
72-
73-
std::generate(begin(pattern), end(pattern), gen);
74-
}
75-
7663
void verifyData() {
7764
ASSERT_SUCCESS(urEnqueueUSMMemcpy2D(queue, true, host_mem.data(), pitch,
7865
ptr, pitch, width, height, 0,

test/conformance/testing/include/uur/fixtures.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <uur/environment.h>
1212
#include <uur/utils.h>
1313

14+
#include <random>
15+
1416
#define UUR_RETURN_ON_FATAL_FAILURE(...) \
1517
__VA_ARGS__; \
1618
if (this->HasFatalFailure() || this->IsSkipped()) { \
@@ -873,6 +875,19 @@ struct urUSMDeviceAllocTestWithParam : urQueueTestWithParam<T> {
873875
ur_usm_pool_handle_t pool = nullptr;
874876
};
875877

878+
// Generates a random byte pattern for MemFill type entry-points.
879+
inline void generateMemFillPattern(std::vector<uint8_t> &pattern) {
880+
const size_t seed = 1;
881+
std::mt19937 mersenne_engine{seed};
882+
std::uniform_int_distribution<int> dist{0, 255};
883+
884+
auto gen = [&dist, &mersenne_engine]() {
885+
return static_cast<uint8_t>(dist(mersenne_engine));
886+
};
887+
888+
std::generate(begin(pattern), end(pattern), gen);
889+
}
890+
876891
/// @brief
877892
/// @tparam T
878893
/// @param info

0 commit comments

Comments
 (0)