Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions src/generate.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -63,6 +63,13 @@ argument generate_argument(shape s, unsigned long seed, random_mode m)

result = argument(sub_args);
}
// special processing for non-computable type
else if(not s.computable())
{
// NOTE: these values can be wrong (ex. not valid fp4x2)
auto v = generate_tensor_data<uint8_t>(s, seed, m);
result = {s, v};
}
else
{
s.visit_type([&](auto as) {
Expand All @@ -88,11 +95,19 @@ argument generate_argument(shape s, unsigned long seed, random_mode m)
literal generate_literal(shape s, unsigned long seed)
{
literal result;
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
auto v = generate_tensor_data<type>(s, seed);
result = {s, reinterpret_cast<char*>(v.get())};
});
if(not s.computable())
{
auto v = generate_tensor_data<uint8_t>(s, seed);
result = {s, reinterpret_cast<char*>(v.get())};
}
else
{
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
auto v = generate_tensor_data<type>(s, seed);
result = {s, reinterpret_cast<char*>(v.get())};
});
}
return result;
}

Expand Down
127 changes: 127 additions & 0 deletions src/include/migraphx/byte.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#ifndef MIGRAPHX_GUARD_BYTE_HPP
#define MIGRAPHX_GUARD_BYTE_HPP

#include <migraphx/config.hpp>
#include <migraphx/requires.hpp>
#include <cstdint>
#include <type_traits>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

/**
* Implementation of std::byte for MIGraphX.
* Created to have a custom stream operator so that it prints as an unsigned int.
* This type is essentially a limited unsigned_char to prevent things like trying to add two bytes.
*/
enum class byte : unsigned char
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All your casts are of type uint8_t. However, this enum takes (its base) after unsigned char.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching the standard library proposed implementation.

{
};

template <class IntType,
MIGRAPHX_REQUIRES(std::is_integral<IntType>{} and std::is_unsigned<IntType>{})>
constexpr byte operator<<(byte b, IntType shift) noexcept
{
return static_cast<byte>(static_cast<unsigned char>(b) << shift);
};

template <class IntType,
MIGRAPHX_REQUIRES(std::is_integral<IntType>{} and std::is_unsigned<IntType>{})>
constexpr byte operator>>(byte b, IntType shift) noexcept
{
return static_cast<byte>(static_cast<uint8_t>(b) >> shift);
};

template <class IntType,
MIGRAPHX_REQUIRES(std::is_integral<IntType>{} and std::is_unsigned<IntType>{})>
constexpr byte& operator>>=(byte& b, IntType shift) noexcept
{
b = b >> shift;
return b;
};

template <class IntType,
MIGRAPHX_REQUIRES(std::is_integral<IntType>{} and std::is_unsigned<IntType>{})>
constexpr byte& operator<<=(byte& b, IntType shift) noexcept
{
b = b << shift;
return b;
};

constexpr byte operator|(byte l, byte r) noexcept
{
return static_cast<byte>(static_cast<uint8_t>(l) | static_cast<uint8_t>(r));
}

constexpr byte& operator|=(byte& l, byte r) noexcept
{
l = l | r;
return l;
}

constexpr byte operator&(byte l, byte r) noexcept
{
return static_cast<byte>(static_cast<uint8_t>(l) & static_cast<uint8_t>(r));
}

constexpr byte& operator&=(byte& l, byte r) noexcept
{
l = l & r;
return l;
}

constexpr byte operator^(byte l, byte r) noexcept
{
return static_cast<byte>(static_cast<uint8_t>(l) ^ static_cast<uint8_t>(r));
}

constexpr byte& operator^=(byte& l, byte r) noexcept
{
l = l ^ r;
return l;
}

constexpr byte operator~(byte b) noexcept { return static_cast<byte>(~static_cast<uint8_t>(b)); }

template <class IntType,
MIGRAPHX_REQUIRES(std::is_integral<IntType>{} and std::is_unsigned<IntType>{})>
constexpr IntType to_integer(byte b) noexcept
{
return static_cast<IntType>(b);
}

template <class Stream>
Stream& operator<<(Stream& os, const byte& b)
{
os << static_cast<unsigned>(b);
return os;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
64 changes: 55 additions & 9 deletions src/include/migraphx/raw_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <migraphx/tensor_view.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/byte.hpp>
#include <migraphx/config.hpp>
#include <sstream>

Expand Down Expand Up @@ -104,6 +105,36 @@ struct raw_data : raw_data_base
visit(v, [&](const auto&) { MIGRAPHX_THROW("Invalid tuple type"); });
}

/**
* Visit the data using the normal visit function for computable types.
* For non-computable types, use a tensor_view<byte> with shape = {type = uint8_type, lens =
* {num bytes}};
*/
template <class Visitor, class TupleVisitor>
void fallback_visit(Visitor v, TupleVisitor tv) const
{
auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
if(s.computable())
{
visit(v, tv);
}
else
{
auto&& buffer = static_cast<const Derived&>(*this).data();
shape view_shape = {shape::uint8_type, {s.bytes()}};
v(make_view(view_shape, reinterpret_cast<byte*>(buffer)));
}
}

template <class Visitor>
void fallback_visit(Visitor v) const
{
fallback_visit(v, [&](const auto&) { MIGRAPHX_THROW("Invalid tuple type"); });
}

/// Returns true if the raw data is only one element
bool single() const
{
Expand Down Expand Up @@ -163,18 +194,26 @@ struct raw_data : raw_data_base
/// Implicit conversion of raw data pointer
auto_cast implicit() const { return {static_cast<const Derived*>(this)}; }

/// Get a tensor_view to the data
/// Get a tensor_view to the data.
/// For get<byte>() returns a 1D tensor_view<const byte*>.
template <class T>
tensor_view<T> get() const
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
if(s.type() != migraphx::shape::get_type<T>{})
MIGRAPHX_THROW("Incorrect data type for raw data");
return make_view(s, reinterpret_cast<T*>(buffer));
if constexpr(std::is_same<std::remove_cv_t<T>, migraphx::byte>{})
{
shape view_shape = {shape::uint8_type, {s.bytes()}};
return make_view(view_shape, reinterpret_cast<const migraphx::byte*>(buffer));
}
else
{
if(s.computable() and s.type() != migraphx::shape::get_type<T>{})
MIGRAPHX_THROW("Incorrect data type for raw data");
return make_view(s, reinterpret_cast<T*>(buffer));
}
}

/// Cast the data pointer
template <class T>
T* cast() const
{
Expand Down Expand Up @@ -284,10 +323,17 @@ bool operator==(const T& x, const U& y)
bool result = x.empty() and y.empty();
if(not result and xshape == yshape)
{
visit_all(x, y)([&](auto xview, auto yview) { result = xview == yview; },
[&](auto&& xs, auto&& ys) {
result = std::equal(xs.begin(), xs.end(), ys.begin(), ys.end());
});
if(xshape.computable())
{
visit_all(x, y)([&](auto xview, auto yview) { result = xview == yview; },
[&](auto&& xs, auto&& ys) {
result = std::equal(xs.begin(), xs.end(), ys.begin(), ys.end());
});
}
else
{
result = x.template get<const byte>() == y.template get<const byte>();
}
}
return result;
}
Expand Down
4 changes: 4 additions & 0 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ struct MIGRAPHX_EXPORT shape
static bool is_compatible(const shape& actual, const shape& expected);

static bool is_unsigned(type_t t);
static bool is_computable(type_t t);

shape();
shape(type_t t);
Expand Down Expand Up @@ -335,6 +336,9 @@ struct MIGRAPHX_EXPORT shape
/// Return true if this shape or any of the sub_shapes are dynamic
bool any_of_dynamic() const;

/// If type is computable (can do math ops like add or divide) and has a visitor function
bool computable() const;

shape normalize_standard() const;

shape as_standard() const;
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/tensor_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <migraphx/requires.hpp>
#include <migraphx/iota_iterator.hpp>
#include <migraphx/as_number.hpp>
#include <migraphx/byte.hpp>

#include <iostream>
#include <utility>
Expand Down
28 changes: 24 additions & 4 deletions src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <algorithm>
#include <functional>
#include <unordered_map>
#include <unordered_set>
#include <iostream>

namespace migraphx {
Expand Down Expand Up @@ -242,8 +243,11 @@ std::string shape::to_sizes_string(const std::vector<shape>& shapes)
const std::vector<shape::type_t>& shape::types()
{
static const std::vector<shape::type_t> result = {
// clang-format off
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR)
tuple_type};
// clang-format on
return result;
}

Expand Down Expand Up @@ -313,6 +317,8 @@ bool shape::is_unsigned(shape::type_t t)
return result;
}

bool shape::is_computable(shape::type_t t) { return t != shape::fp4x2_type; }

shape::shape() : impl(shape_impl::default_shape()) {}

shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
Expand Down Expand Up @@ -419,7 +425,16 @@ std::size_t shape::type_size() const
{
std::size_t n = 0;
if(this->sub_shapes().empty())
this->visit_type([&](auto as) { n = as.size(); });
{
if(this->computable())
{
this->visit_type([&](auto as) { n = as.size(); });
}
else
{
n = sizeof(uint8_t);
}
}
return n;
}

Expand Down Expand Up @@ -651,6 +666,8 @@ bool shape::any_of_dynamic() const
});
}

bool shape::computable() const { return is_computable(this->type()); }

const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const
{
if(not this->dynamic())
Expand Down Expand Up @@ -814,10 +831,13 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
shape::type_t shape::parse_type(const std::string& s)
{
static const std::unordered_map<std::string, shape::type_t> m = {
// clang-format off
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
tuple_type},
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP)
{"fp4x2_type", fp4x2_type},
{"tuple_type", tuple_type},
{"tuple", tuple_type}};
// clang-format on
return m.at(s);
}

Expand Down
Loading
Loading