Skip to content

Commit 1b7af54

Browse files
authored
Fix conversion issue in layernorm fusion (#1483) (#1493)
1 parent fe19455 commit 1b7af54

File tree

5 files changed

+41
-37
lines changed

5 files changed

+41
-37
lines changed

src/targets/gpu/include/migraphx/gpu/hip.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct hip_copy_to_gpu
105105
std::string name() const { return "hip::copy_to_gpu"; }
106106
shape compute_shape(std::vector<shape> inputs) const
107107
{
108-
check_shapes{inputs, *this}.has(1, 2);
108+
check_shapes{inputs, *this}.has(1, 2).same_type();
109109
return inputs.at(0);
110110
}
111111
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
@@ -131,7 +131,7 @@ struct hip_copy_from_gpu
131131
std::string name() const { return "hip::copy_from_gpu"; }
132132
shape compute_shape(std::vector<shape> inputs) const
133133
{
134-
check_shapes{inputs, *this}.has(1, 2);
134+
check_shapes{inputs, *this}.has(1, 2).same_type();
135135
return inputs.at(0);
136136
}
137137
argument
@@ -159,7 +159,7 @@ struct hip_copy
159159
std::string name() const { return "hip::copy"; }
160160
shape compute_shape(std::vector<shape> inputs) const
161161
{
162-
check_shapes{inputs, *this}.has(2);
162+
check_shapes{inputs, *this}.has(2).same_type();
163163
return inputs.at(1);
164164
}
165165
argument compute(context& ctx, const shape&, std::vector<argument> args) const

src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
2626
#include <migraphx/kernels/reduce.hpp>
2727
#include <migraphx/kernels/ops.hpp>
28+
#include <migraphx/kernels/vec.hpp>
2829
#include <migraphx/kernels/print.hpp>
2930

3031
namespace migraphx {

src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,38 +33,6 @@
3333

3434
namespace migraphx {
3535

36-
template <class T>
37-
struct implicit_conversion_op
38-
{
39-
T x;
40-
41-
template <index_int N, class U>
42-
constexpr operator vec<U, N>() const
43-
{
44-
if constexpr(vec_size<T>() == 0)
45-
{
46-
return x;
47-
}
48-
else
49-
{
50-
static_assert(vec_size<T>() == N, "Vector mismatch size");
51-
return __builtin_convertvector(x, vec<U, N>);
52-
}
53-
}
54-
55-
template <class U>
56-
constexpr operator U() const
57-
{
58-
return x;
59-
}
60-
};
61-
62-
template <class T>
63-
constexpr implicit_conversion_op<T> implicit_conversion(T x)
64-
{
65-
return {x};
66-
}
67-
6836
template <class F, class T, class... Ts>
6937
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
7038
{

src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,5 +185,37 @@ constexpr auto vec_reduce(T x, Op op)
185185
}
186186
}
187187

188+
template <class T>
189+
struct implicit_conversion_op
190+
{
191+
T x;
192+
193+
template <index_int N, class U>
194+
constexpr operator vec<U, N>() const
195+
{
196+
if constexpr(vec_size<T>() == 0)
197+
{
198+
return x;
199+
}
200+
else
201+
{
202+
static_assert(vec_size<T>() == N, "Vector mismatch size");
203+
return __builtin_convertvector(x, vec<U, N>);
204+
}
205+
}
206+
207+
template <class U>
208+
constexpr operator U() const
209+
{
210+
return x;
211+
}
212+
};
213+
214+
template <class T>
215+
constexpr implicit_conversion_op<T> implicit_conversion(T x)
216+
{
217+
return {x};
218+
}
219+
188220
} // namespace migraphx
189221
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP

src/targets/gpu/prefuse_ops.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,20 @@ struct layernorm_base
5151
}
5252
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
5353
auto s = inputs.at(0);
54+
auto t = s.type();
55+
if(not mods.empty())
56+
t = mods.front()->get_output_shapes().front().type();
5457
if(s.scalar())
5558
{
5659
return s;
5760
}
5861
else if(s.broadcasted())
5962
{
60-
return {s.type(), s.lens()};
63+
return {t, s.lens()};
6164
}
6265
else
6366
{
64-
return s.with_lens(s.lens());
67+
return s.with_lens(t, s.lens());
6568
}
6669
}
6770
};

0 commit comments

Comments
 (0)