Skip to content

Commit 19f3cc3

Browse files
reed-lauhwu36
andauthored
Fix uint128 operator add (#1400)
* fix uint128 operator add for 64-bit hilo implemenation * add uint128 test for operator add * make clang happy --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
1 parent f9ece1b commit 19f3cc3

File tree

3 files changed

+123
-1
lines changed

3 files changed

+123
-1
lines changed

include/cutlass/uint128.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ struct alignas(16) uint128_t
138138
y.native = native + rhs.native;
139139
#else
140140
y.hilo_.lo = hilo_.lo + rhs.hilo_.lo;
141-
y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (!y.hilo_.lo && (rhs.hilo_.lo));
141+
y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (y.hilo_.lo < hilo_.lo);
142142
#endif
143143
return y;
144144
}

test/unit/core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ cutlass_test_unit_add_executable(
3434
float8.cu
3535
tfloat32.cu
3636
complex.cu
37+
uint128.cu
3738
quaternion.cu
3839
matrix.cu
3940
predicate_vector.cu

test/unit/core/uint128.cu

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
/*! \file
32+
\brief Tests for basic uint128 functionality
33+
*/
34+
35+
#include "../common/cutlass_unit_test.h"
36+
37+
#include "cutlass/array.h"
38+
#include "cutlass/layout/matrix.h"
39+
#include "cutlass/numeric_types.h"
40+
#include "cutlass/numeric_conversion.h"
41+
#include "cutlass/util/device_memory.h"
42+
#include "cutlass/util/host_tensor.h"
43+
44+
45+
/////////////////////////////////////////////////////////////////////////////////////////////////
46+
//
47+
// Host
48+
//
49+
/////////////////////////////////////////////////////////////////////////////////////////////////
50+
51+
TEST(uint128_t, host_arithmetic) {
52+
using T = cutlass::uint128_t;
53+
54+
// only low 64bit
55+
for (uint64_t i = 0; i < 1024; ++i) {
56+
for (uint64_t j = 0; j < 1024; ++j) {
57+
T x = i;
58+
T y = j;
59+
60+
EXPECT_TRUE(static_cast<uint64_t>(x + y) == (i + j));
61+
}
62+
}
63+
64+
// carry overflow for low uint64_t
65+
{
66+
for (uint64_t i = 0; i < 1024; ++i) {
67+
T x = static_cast<uint64_t>(0xFFFFFFFFFFFFFFFF);
68+
T y = i + 1;
69+
70+
T z = x + y;
71+
72+
EXPECT_EQ(z.hilo_.hi, static_cast<uint64_t>(0x1));
73+
EXPECT_EQ(z.hilo_.lo, i);
74+
}
75+
}
76+
}
77+
78+
/////////////////////////////////////////////////////////////////////////////////////////////////
79+
//
80+
// Device
81+
//
82+
/////////////////////////////////////////////////////////////////////////////////////////////////
83+
84+
__global__ void uint128_add_operator(cutlass::uint128_t *output, cutlass::uint128_t const *input, cutlass::uint128_t base, int N) {
85+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
86+
if (tid < N) {
87+
output[tid] = input[tid] + base;
88+
}
89+
}
90+
91+
TEST(uint128_t, device_arithmetic) {
92+
using T = cutlass::uint128_t;
93+
94+
int const N = 1024;
95+
96+
cutlass::HostTensor<T, cutlass::layout::RowMajor> input({N, 1});
97+
cutlass::HostTensor<T, cutlass::layout::RowMajor> sum({N, 1});
98+
99+
for (int i = 0; i < N; ++i) {
100+
input.at({i, 0}) = static_cast<uint64_t>(i + 1);
101+
}
102+
103+
T b = static_cast<uint64_t>(0xFFFFFFFFFFFFFFFF);
104+
105+
input.sync_device();
106+
107+
uint128_add_operator<<< dim3(1,1), dim3(N, 1) >>>(sum.device_data(), input.device_data(), b, N);
108+
109+
ASSERT_EQ(cudaGetLastError(), cudaSuccess) << "Kernel launch error.";
110+
111+
sum.sync_host();
112+
113+
for (int i = 0; i < N; ++i) {
114+
T got = sum.at({i, 0});
115+
uint64_t expected_hi = static_cast<uint64_t>(0x1);
116+
uint64_t expected_lo = static_cast<uint64_t>(i);
117+
118+
EXPECT_EQ(got.hilo_.hi, expected_hi);
119+
EXPECT_EQ(got.hilo_.lo, expected_lo);
120+
}
121+
}

0 commit comments

Comments
 (0)