Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 118 additions & 23 deletions src/onnx/parse_layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,43 +34,53 @@ struct parse_layernorm : op_parser<parse_layernorm>
{
std::vector<op_desc> operators() const { return {{"LayerNormalization"}}; }

std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
static int64_t handle_axis(const onnx_parser& parser, const onnx_parser::node_info& info)
{
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
}
return axis;
}

static float handle_epsilon(const onnx_parser& parser, const onnx_parser::node_info& info)
{
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
return epsilon;
}

static bool handle_stash_type(const onnx_parser& parser, const onnx_parser::node_info& info)
{
bool stash_type = true;
if(contains(info.attributes, "stash_type"))
{
std::cerr << "WARNING: LAYERNORM does not support stash_type, it will be ignored.\n";
stash_type = (1 == parser.parse_value(info.attributes.at("stash_type")).at<int64_t>());
}
return stash_type;
}

if(args.size() < 2 or args.size() > 3)
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid input count");
}
static void is_type_valid(const migraphx::shape::type_t& dtype, const std::string& var_name)
{
std::set<migraphx::shape::type_t> valid_types = {
migraphx::shape::float_type, migraphx::shape::bf16_type, migraphx::shape::half_type};

auto x = args.at(0);
auto scale = args.at(1);
bool skip_bias = args.size() == 2;
instruction_ref bias;
if(not skip_bias)
if(not(contains(valid_types, dtype)))
{
bias = args.at(2);
MIGRAPHX_THROW("PARSE_LAYERNORM: Invalid type for " + var_name);
}
}

static void check_x_input(const instruction_ref& x, const int64_t& axis)
{
auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
int64_t x_rank = x_shape.ndim();
is_type_valid(x_dtype, "input");

if(x_rank < 2)
{
Expand All @@ -83,31 +93,63 @@ struct parse_layernorm : op_parser<parse_layernorm>
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid axis");
}
}

static std::tuple<instruction_ref, instruction_ref, instruction_ref>
stage_one_calculation(const onnx_parser::node_info& info,
const instruction_ref& input,
const float& epsilon,
const int64_t& axis,
const int64_t& kdims,
bool stash_type)
{
// y = (x - mean) * rsqrt(variance + epsilon) * scale + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)

// axis can be negative
axis = axis < 0 ? axis + x_rank : axis;

auto kdims = x_rank - axis;
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), axis);
auto skipped_axes = x_rank - kdims;
auto x_shape = input->get_shape();
auto x_dtype = x_shape.type();

auto x = input;
if(stash_type and x_dtype != migraphx::shape::float_type)
{
x = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::float_type}}), input);
}

auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto x_sub_mean = info.add_common_op("sub", x, mean);
auto x_sqdiff_mean = info.add_common_op("sqdiff", x, mean);
auto variance =
info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqdiff_mean);
epsilon =
auto epsilon_val =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon_val}});
auto var_eps = info.add_common_op("add", variance, eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto result = info.add_common_op("mul", x_sub_mean, rsqrt);

if(stash_type and x_dtype != migraphx::shape::float_type)
{
result = info.add_instruction(make_op("convert", {{"target_type", x_dtype}}), result);
}

return {result, mean, rsqrt};
}

static instruction_ref stage_two_calculation(const onnx_parser::node_info& info,
const instruction_ref& x,
const instruction_ref& scale,
const instruction_ref& bias,
const instruction_ref& result,
const int64_t& kdims,
bool skip_bias)
{
auto x_shape = x->get_shape();
auto x_rank = x_shape.ndim();
auto skipped_axes = x_rank - kdims;
instruction_ref scale_bcast = scale;
instruction_ref bias_bcast = bias;
if(skipped_axes > 0)
Expand All @@ -129,7 +171,60 @@ struct parse_layernorm : op_parser<parse_layernorm>
}
}
auto scaled = info.add_common_op("mul", result, scale_bcast);
auto y = skip_bias ? scaled : info.add_common_op("add", scaled, bias_bcast);
return skip_bias ? scaled : info.add_common_op("add", scaled, bias_bcast);
}

std::tuple<instruction_ref, instruction_ref, instruction_ref, bool>
handle_inputs(std::vector<instruction_ref>& args, const int64_t& axis) const
{
if(args.size() < 2 or args.size() > 3)
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid input count");
}
auto x = args.at(0);
check_x_input(x, axis);

auto scale = args.at(1);
is_type_valid(scale->get_shape().type(), "scale");

bool skip_bias = args.size() == 2;
instruction_ref bias;
if(not skip_bias)
{
bias = args.at(2);
is_type_valid(bias->get_shape().type(), "bias");
}
return {x, scale, bias, skip_bias};
}

std::tuple<int64_t, float, bool> handle_attributes(const onnx_parser& parser,
const onnx_parser::node_info& info) const
{
auto axis = handle_axis(parser, info);
auto epsilon = handle_epsilon(parser, info);
auto stash_type = handle_stash_type(parser, info);

return {axis, epsilon, stash_type};
}

std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto [axis, epsilon, stash_type] = handle_attributes(parser, info);

auto [x, scale, bias, skip_bias] = handle_inputs(args, axis);

auto x_rank = x->get_shape().ndim();
// axis can be negative
axis = axis < 0 ? axis + x_rank : axis;
auto kdims = x_rank - axis;

auto [result, mean, rsqrt] =
stage_one_calculation(info, x, epsilon, axis, kdims, stash_type);
auto y = stage_two_calculation(info, x, scale, bias, result, kdims, skip_bias);

return {y, mean, rsqrt};
}
};
Expand Down
53 changes: 48 additions & 5 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7495,21 +7495,40 @@ def make_layer_norm(shape,
axis=-1,
dtype=TensorProto.FLOAT,
scale_shape=None,
bias_shape=None):
bias_shape=None,
stash_type=None,
epsilon=None,
scale_type=None,
bias_type=None):

if scale_type is None:
scale_type = dtype

if bias_type is None:
bias_type = dtype

norm_axis = axis + len(shape) if axis < 0 else axis
x = helper.make_tensor_value_info('x', dtype, shape)
if scale_shape is None:
scale_shape = shape[norm_axis:]
if bias_shape is None:
bias_shape = shape[norm_axis:]
scale = helper.make_tensor_value_info('scale', dtype, scale_shape)
bias = helper.make_tensor_value_info('bias', dtype, bias_shape)
scale = helper.make_tensor_value_info('scale', scale_type, scale_shape)
bias = helper.make_tensor_value_info('bias', bias_type, bias_shape)
y = helper.make_tensor_value_info('y', dtype, shape)

node = onnx.helper.make_node('LayerNormalization',
inputs=['x', 'scale', 'bias'],
outputs=['y'],
axis=axis)
outputs=['y'])

# Attributes
node.attribute.append(onnx.helper.make_attribute("axis", axis))

if stash_type is not None:
node.attribute.append(onnx.helper.make_attribute("stash_type", stash_type))

if epsilon is not None:
node.attribute.append(onnx.helper.make_attribute("epsilon", epsilon))

return ([node], [x, scale, bias], [y])

Expand Down Expand Up @@ -7545,12 +7564,36 @@ def layer_norm_3d_scale_bias_test():
scale_shape=[2, 1, 7],
bias_shape=[2, 1, 7])

@onnx_test()
def layer_norm_3d_invalid_int8_test():
return make_layer_norm([1, 4, 2], -1, TensorProto.INT8)


@onnx_test()
def layer_norm_3d_invalid_scale_test():
return make_layer_norm([1, 4, 2], -1, scale_type=TensorProto.INT8)


@onnx_test()
def layer_norm_3d_invalid_bias_test():
return make_layer_norm([1, 4, 2], -1, bias_type=TensorProto.INT8)


@onnx_test()
def layer_norm_3d_half_test():
return make_layer_norm([1, 4, 2], -1, TensorProto.FLOAT16)


@onnx_test()
def layer_norm_3d_half_stash_off_test():
return make_layer_norm([1, 4, 2], -1, TensorProto.FLOAT16, stash_type=int(0))


@onnx_test()
def layer_norm_3d_half_stash_off_epsilon_test():
return make_layer_norm([1, 4, 2], -1, TensorProto.FLOAT16, stash_type=int(0), epsilon=float(1e-4))


@onnx_test()
def layer_norm_3d_bf16_test():
return make_layer_norm([1, 4, 2], -1, TensorProto.BFLOAT16)
Expand Down
15 changes: 15 additions & 0 deletions test/onnx/include/onnx_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ make_layer_norm(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& reduce_axes,
size_t skipped_axis,
bool skip_bias = false,
const bool stash_type = true,
const float eps_value = 1e-5f,
const migraphx::shape::type_t dtype = migraphx::shape::float_type)
{
Expand All @@ -354,6 +355,13 @@ make_layer_norm(const std::vector<int64_t>& input_shape,
{
bias = mm->add_parameter("bias", {dtype, scale_bias_shape});
}

if(stash_type and dtype != migraphx::shape::float_type)
{
x = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x);
}

auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x);
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
Expand All @@ -363,6 +371,13 @@ make_layer_norm(const std::vector<int64_t>& input_shape,
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps});
auto result = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt});

if(stash_type and dtype != migraphx::shape::float_type)
{
result =
mm->add_instruction(migraphx::make_op("convert", {{"target_type", dtype}}), result);
}

migraphx::instruction_ref scale_bcast = scale;
migraphx::instruction_ref bias_bcast = bias;
if(skipped_axis > 0)
Expand Down
Binary file not shown.
Binary file added test/onnx/layer_norm_3d_half_stash_off_test.onnx
Binary file not shown.
24 changes: 24 additions & 0 deletions test/onnx/layer_norm_3d_invalid_bias_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
 layer_norm_3d_invalid_bias_test:»
=
x
scale
biasy"LayerNormalization*
axisÿÿÿÿÿÿÿÿÿ layer_norm_3d_invalid_bias_testZ
x



Z
scale


Z
bias


b
y



B
24 changes: 24 additions & 0 deletions test/onnx/layer_norm_3d_invalid_int8_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
 layer_norm_3d_invalid_int8_test:»
=
x
scale
biasy"LayerNormalization*
axisÿÿÿÿÿÿÿÿÿ layer_norm_3d_invalid_int8_testZ
x



Z
scale


Z
bias


b
y



B
24 changes: 24 additions & 0 deletions test/onnx/layer_norm_3d_invalid_scale_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
  layer_norm_3d_invalid_scale_test:¼
=
x
scale
biasy"LayerNormalization*
axisÿÿÿÿÿÿÿÿÿ  layer_norm_3d_invalid_scale_testZ
x



Z
scale


Z
bias


b
y



B
Loading
Loading