Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/onnx/parse_layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ struct parse_layernorm : op_parser<parse_layernorm>
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
bool stash_type = true;
if(contains(info.attributes, "stash_type"))
{
std::cerr << "WARNING: LAYERNORM does not support stash_type, it will be ignored.\n";
stash_type = (1 == parser.parse_value(info.attributes.at("stash_type")).at<int64_t>());
}

if(args.size() < 2 or args.size() > 3)
Expand All @@ -70,6 +71,17 @@ struct parse_layernorm : op_parser<parse_layernorm>

auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();


std::set<migraphx::shape::type_t> supported_x_types = {migraphx::shape::float_type,
migraphx::shape::bf16_type,
migraphx::shape::half_type};

if(not(contains(supported_x_types, x_dtype)))
{
MIGRAPHX_THROW("PARSE_LAYERNORM: Invalid type for input");
}

int64_t x_rank = x_shape.ndim();

if(x_rank < 2)
Expand All @@ -96,6 +108,12 @@ struct parse_layernorm : op_parser<parse_layernorm>
std::iota(axes.begin(), axes.end(), axis);
auto skipped_axes = x_rank - kdims;

if(stash_type and x_dtype != migraphx::shape::float_type)
{
x = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::float_type}}), x);
}

auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto x_sub_mean = info.add_common_op("sub", x, mean);
auto x_sqdiff_mean = info.add_common_op("sqdiff", x, mean);
Expand All @@ -108,6 +126,17 @@ struct parse_layernorm : op_parser<parse_layernorm>
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto result = info.add_common_op("mul", x_sub_mean, rsqrt);

if(stash_type and x_dtype != migraphx::shape::float_type)
{
result = info.add_instruction(make_op("convert", {{"target_type", x_dtype}}), result);
}

if(stash_type and x_dtype == migraphx::shape::bf16_type)
{
mean = info.add_instruction(make_op("convert", {{"target_type", x_dtype}}), mean);
rsqrt = info.add_instruction(make_op("convert", {{"target_type", x_dtype}}), rsqrt);
}

instruction_ref scale_bcast = scale;
instruction_ref bias_bcast = bias;
if(skipped_axes > 0)
Expand All @@ -130,6 +159,7 @@ struct parse_layernorm : op_parser<parse_layernorm>
}
auto scaled = info.add_common_op("mul", result, scale_bcast);
auto y = skip_bias ? scaled : info.add_common_op("add", scaled, bias_bcast);

return {y, mean, rsqrt};
}
};
Expand Down
17 changes: 17 additions & 0 deletions test/onnx/include/onnx_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ make_layer_norm(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& reduce_axes,
size_t skipped_axis,
bool skip_bias = false,
const bool stash_type = true,
const float eps_value = 1e-5f,
const migraphx::shape::type_t dtype = migraphx::shape::float_type)
{
Expand All @@ -354,6 +355,13 @@ make_layer_norm(const std::vector<int64_t>& input_shape,
{
bias = mm->add_parameter("bias", {dtype, scale_bias_shape});
}

if(stash_type and dtype != migraphx::shape::float_type)
{
x = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x);
}

auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x);
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
Expand All @@ -363,6 +371,15 @@ make_layer_norm(const std::vector<int64_t>& input_shape,
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps});
auto result = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt});

if(stash_type and dtype != migraphx::shape::float_type)
{
result =
mm->add_instruction(migraphx::make_op("convert", {{"target_type", dtype}}), result);
mean = mm->add_instruction(migraphx::make_op("convert", {{"target_type", dtype}}), mean);
rsqrt = mm->add_instruction(migraphx::make_op("convert", {{"target_type", dtype}}), rsqrt);
}

migraphx::instruction_ref scale_bcast = scale;
migraphx::instruction_ref bias_bcast = bias;
if(skipped_axis > 0)
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/parse/layer_norm_3d_bf16_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
TEST_CASE(layer_norm_3d_bf16_test)
{
migraphx::program p =
make_layer_norm({1, 4, 2}, {2}, {2}, 2, false, 1e-5f, migraphx::shape::bf16_type);
make_layer_norm({1, 4, 2}, {2}, {2}, 2, false, true, 1e-5f, migraphx::shape::bf16_type);

auto prog = optimize_onnx("layer_norm_3d_bf16_test.onnx");
EXPECT(p == prog);
Expand Down
4 changes: 2 additions & 2 deletions test/onnx/parse/layer_norm_3d_half_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,7 +28,7 @@
TEST_CASE(layer_norm_3d_half_test)
{
migraphx::program p =
make_layer_norm({1, 4, 2}, {2}, {2}, 2, false, 1e-5f, migraphx::shape::half_type);
make_layer_norm({1, 4, 2}, {2}, {2}, 2, false, true, 1e-5f, migraphx::shape::half_type);

auto prog = optimize_onnx("layer_norm_3d_half_test.onnx");
EXPECT(p == prog);
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/parse/layer_norm_4d_bf16_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
TEST_CASE(layer_norm_4d_bf16_test)
{
migraphx::program p =
make_layer_norm({3, 3, 3, 3}, {3}, {3}, 3, false, 1e-5f, migraphx::shape::bf16_type);
make_layer_norm({3, 3, 3, 3}, {3}, {3}, 3, false, true, 1e-5f, migraphx::shape::bf16_type);

auto prog = optimize_onnx("layer_norm_4d_bf16_test.onnx");
EXPECT(p == prog);
Expand Down
4 changes: 2 additions & 2 deletions test/onnx/parse/layer_norm_4d_half_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,7 +28,7 @@
TEST_CASE(layer_norm_4d_half_test)
{
migraphx::program p =
make_layer_norm({3, 3, 3, 3}, {3}, {3}, 3, false, 1e-5f, migraphx::shape::half_type);
make_layer_norm({3, 3, 3, 3}, {3}, {3}, 3, false, true, 1e-5f, migraphx::shape::half_type);

auto prog = optimize_onnx("layer_norm_4d_half_test.onnx");
EXPECT(p == prog);
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/parse/layer_norm_small_eps_bf16_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
TEST_CASE(layer_norm_small_eps_bf16_test)
{
migraphx::program p =
make_layer_norm({1, 2}, {2}, {1}, 1, true, 1e-7, migraphx::shape::bf16_type);
make_layer_norm({1, 2}, {2}, {1}, 1, true, true, 1e-7, migraphx::shape::bf16_type);

auto prog = optimize_onnx("layer_norm_small_eps_bf16_test.onnx");
EXPECT(p == prog);
Expand Down
4 changes: 2 additions & 2 deletions test/onnx/parse/layer_norm_small_eps_half_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,7 +28,7 @@
TEST_CASE(layer_norm_small_eps_half_test)
{
migraphx::program p =
make_layer_norm({1, 2}, {2}, {1}, 1, true, 1e-7, migraphx::shape::half_type);
make_layer_norm({1, 2}, {2}, {1}, 1, true, true, 1e-7, migraphx::shape::half_type);

auto prog = optimize_onnx("layer_norm_small_eps_half_test.onnx");
EXPECT(p == prog);
Expand Down