From be9c4fd38686f9453de1c993e69ead68c3e95f45 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Sun, 8 Jun 2025 22:53:29 +0200 Subject: [PATCH] Support for floating point types --- hls4ml/backends/fpga/fpga_backend.py | 53 ++++++++- hls4ml/backends/fpga/fpga_types.py | 55 ++++++++- hls4ml/backends/oneapi/oneapi_types.py | 10 +- hls4ml/model/types.py | 123 +++++++++++++++++++- hls4ml/templates/oneapi/firmware/defines.h | 1 + hls4ml/templates/quartus/firmware/defines.h | 2 + test/pytest/test_types.py | 103 +++++++++++++++- 7 files changed, 338 insertions(+), 9 deletions(-) diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index 4896c25f9f..21dfee082a 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -38,10 +38,12 @@ from hls4ml.model.types import ( ExponentPrecisionType, FixedPrecisionType, + FloatPrecisionType, IntegerPrecisionType, PrecisionType, RoundingMode, SaturationMode, + StandardFloatPrecisionType, UnspecifiedPrecisionType, XnorPrecisionType, ) @@ -343,11 +345,22 @@ def convert_precision_string(cls, precision): if precision.lower() == 'auto': return cls._convert_auto_type(precision) + if precision in ['float', 'double', 'half', 'bfloat16'] or precision.startswith( + ('ap_float', 'ac_std_float', 'std_float') + ): + return cls._convert_standard_float_type(precision) + + if precision.startswith('ac_float'): + return cls._convert_ac_float_type(precision) + if precision.startswith('ac_'): return cls._convert_ac_type(precision) - else: + + if precision.startswith(('ap_', 'fixed', 'ufixed', 'int', 'uint')): # We parse AP notation even without 'ap_' prefix return cls._convert_ap_type(precision) + raise ValueError(f'Unsupported precision type: {precision}') + @classmethod def _convert_ap_type(cls, precision): ''' @@ -416,6 +429,44 @@ def _convert_ac_type(cls, precision): elif 'int' in precision: return IntegerPrecisionType(width, signed) + @classmethod + def _convert_standard_float_type(cls, precision): + # Some default values + if precision == 'float': + return StandardFloatPrecisionType(width=32, exponent=8, use_cpp_type=True) + if precision == 'double': + return StandardFloatPrecisionType(width=64, exponent=11, use_cpp_type=True) + if precision == 'half': + return StandardFloatPrecisionType(width=16, exponent=5, use_cpp_type=True) + if precision == 'bfloat16': + return StandardFloatPrecisionType(width=16, exponent=8, use_cpp_type=True) + + # If it is a float type, parse the width and exponent + bits = re.search('.+<(.+?)>', precision).group(1).split(',') + if len(bits) == 2: + width = int(bits[0]) + exponent = int(bits[1]) + return StandardFloatPrecisionType(width=width, exponent=exponent, use_cpp_type=False) + else: + raise ValueError(f'Invalid standard float precision format: {precision}') + + @classmethod + def _convert_ac_float_type(cls, precision): + # If it is a float type, parse the width and exponent + bits = re.search('.+<(.+?)>', precision).group(1).split(',') + if len(bits) == 3 or len(bits) == 4: + mantissa = int(bits[0]) + integer = int(bits[1]) + exponent = int(bits[2]) + width = mantissa + exponent + if len(bits) == 4: + round_mode = RoundingMode.from_string(bits[3]) + else: + round_mode = None + return FloatPrecisionType(width=width, integer=integer, exponent=exponent, rounding_mode=round_mode) + else: + raise ValueError(f'Invalid ac_float precision format: {precision}') + @classmethod def _convert_auto_type(cls, precision): ''' diff --git a/hls4ml/backends/fpga/fpga_types.py b/hls4ml/backends/fpga/fpga_types.py index 37abffc4b8..b2be271dfe 100644 --- a/hls4ml/backends/fpga/fpga_types.py +++ b/hls4ml/backends/fpga/fpga_types.py @@ -5,9 +5,11 @@ ExponentPrecisionType, ExponentType, FixedPrecisionType, + FloatPrecisionType, IntegerPrecisionType, NamedType, PackedType, + StandardFloatPrecisionType, XnorPrecisionType, ) @@ -51,6 +53,21 @@ def definition_cpp(self): return typestring +class APFloatPrecisionDefinition(PrecisionDefinition): + def definition_cpp(self): + raise NotImplementedError( + 'FloatPrecisionType is not supported in AP type precision definitions. Use StandardFloatPrecisionType instead.' + ) + + +class APStandardFloatPrecisionDefinition(PrecisionDefinition): + def definition_cpp(self): + typestring = str(self) + if typestring.startswith('std_float'): + typestring = typestring.replace('std_float', 'ap_float') + return typestring + + class ACIntegerPrecisionDefinition(PrecisionDefinition): def definition_cpp(self): typestring = f'ac_int<{self.width}, {str(self.signed).lower()}>' @@ -90,12 +107,40 @@ def definition_cpp(self): return typestring +class ACFloatPrecisionDefinition(PrecisionDefinition): + def _rounding_mode_cpp(self, mode): + if mode is not None: + return 'AC_' + str(mode) + + def definition_cpp(self): + args = [ + self.width, + self.integer, + self.exponent, + self._rounding_mode_cpp(self.rounding_mode), + ] + if args[3] == 'AC_TRN': + # This is the default, so we won't write the full definition for brevity + args[3] = None + args = ','.join([str(arg) for arg in args[:5] if arg is not None]) + typestring = f'ac_float<{args}>' + return typestring + + +class ACStandardFloatPrecisionDefinition(PrecisionDefinition): + def definition_cpp(self): + typestring = str(self) + if typestring.startswith('std_float'): + typestring = 'ac_' + typestring + return typestring + + class PrecisionConverter: def convert(self, precision_type): raise NotImplementedError -class FixedPrecisionConverter(PrecisionConverter): +class FPGAPrecisionConverter(PrecisionConverter): def __init__(self, type_map, prefix): self.type_map = type_map self.prefix = prefix @@ -120,12 +165,14 @@ def convert(self, precision_type): raise Exception(f'Cannot convert precision type to {self.prefix}: {precision_type.__class__.__name__}') -class APTypeConverter(FixedPrecisionConverter): +class APTypeConverter(FPGAPrecisionConverter): def __init__(self): super().__init__( type_map={ FixedPrecisionType: APFixedPrecisionDefinition, IntegerPrecisionType: APIntegerPrecisionDefinition, + FloatPrecisionType: APFloatPrecisionDefinition, + StandardFloatPrecisionType: APStandardFloatPrecisionDefinition, ExponentPrecisionType: APIntegerPrecisionDefinition, XnorPrecisionType: APIntegerPrecisionDefinition, }, @@ -133,12 +180,14 @@ def __init__(self): ) -class ACTypeConverter(FixedPrecisionConverter): +class ACTypeConverter(FPGAPrecisionConverter): def __init__(self): super().__init__( type_map={ FixedPrecisionType: ACFixedPrecisionDefinition, IntegerPrecisionType: ACIntegerPrecisionDefinition, + FloatPrecisionType: ACFloatPrecisionDefinition, + StandardFloatPrecisionType: ACStandardFloatPrecisionDefinition, ExponentPrecisionType: ACIntegerPrecisionDefinition, XnorPrecisionType: ACIntegerPrecisionDefinition, }, diff --git a/hls4ml/backends/oneapi/oneapi_types.py b/hls4ml/backends/oneapi/oneapi_types.py index 261bd9bdfa..a18351ad6d 100644 --- a/hls4ml/backends/oneapi/oneapi_types.py +++ b/hls4ml/backends/oneapi/oneapi_types.py @@ -6,11 +6,15 @@ from hls4ml.backends.fpga.fpga_types import ( ACFixedPrecisionDefinition, + ACFloatPrecisionDefinition, ACIntegerPrecisionDefinition, - FixedPrecisionConverter, + ACStandardFloatPrecisionDefinition, + FloatPrecisionType, + FPGAPrecisionConverter, HLSTypeConverter, NamedTypeConverter, PrecisionDefinition, + StandardFloatPrecisionType, TypeDefinition, TypePrecisionConverter, VariableDefinition, @@ -35,12 +39,14 @@ def definition_cpp(self): return typestring -class OneAPIACTypeConverter(FixedPrecisionConverter): +class OneAPIACTypeConverter(FPGAPrecisionConverter): def __init__(self): super().__init__( type_map={ FixedPrecisionType: ACFixedPrecisionDefinition, IntegerPrecisionType: ACIntegerPrecisionDefinition, + FloatPrecisionType: ACFloatPrecisionDefinition, + StandardFloatPrecisionType: ACStandardFloatPrecisionDefinition, ExponentPrecisionType: ACExponentPrecisionDefinition, XnorPrecisionType: ACIntegerPrecisionDefinition, }, diff --git a/hls4ml/model/types.py b/hls4ml/model/types.py index 874effaadc..6503476e78 100644 --- a/hls4ml/model/types.py +++ b/hls4ml/model/types.py @@ -291,6 +291,126 @@ def __str__(self): return typestring +class FloatPrecisionType(PrecisionType): + """ + Class representing a floating-point precision type. + + This type is equivalent to ac_float HLS types. If the use of C++ equivalent types is required, see + ``StandardFloatPrecisionType``. + + Args: + width (int, optional): Total number of bits used. Defaults to 33. + integer (int, optional): Number of bits used for the integer part. Defaults to 2. + exponent (int, optional): Number of bits used for the exponent. Defaults to 8. + """ + + def __init__(self, width=33, integer=2, exponent=8, rounding_mode=None): + super().__init__(width=width, signed=True) + self.exponent = exponent + self.integer = integer # If None, will be set to width - exponent - 1 + self.rounding_mode = rounding_mode + + @property + def rounding_mode(self): + return self._rounding_mode + + @rounding_mode.setter + def rounding_mode(self, mode): + if mode is None: + self._rounding_mode = RoundingMode.TRN + elif isinstance(mode, str): + self._rounding_mode = RoundingMode.from_string(mode) + else: + self._rounding_mode = mode + + def __str__(self): + args = [self.width - self.exponent, self.integer, self.exponent, self.rounding_mode] + args = ','.join([str(arg) for arg in args]) + typestring = f'float<{args}>' + return typestring + + def __eq__(self, other: object) -> bool: + if isinstance(other, FloatPrecisionType): + eq = super().__eq__(other) + eq = eq and self.integer == other.integer + eq = eq and self.exponent == other.exponent + eq = eq and self.rounding_mode == other.rounding_mode + return eq + + return False + + def __hash__(self) -> int: + return super().__hash__() ^ hash((self.integer, self.exponent, self.rounding_mode)) + + def serialize_state(self): + state = super().serialize_state() + state.update( + { + 'integer': self.integer, + 'exponent': self.exponent, + 'rounding_mode': str(self.rounding_mode), + } + ) + return state + + +class StandardFloatPrecisionType(PrecisionType): + """ + Class representing a floating-point precision type. + + This type is equivalent to ap_float and ac_std_float HLS types. <32,8> corresponds to a 'float' type in C/C++. <64,11> + corresponds to a 'double' type in C/C++. <16,5> corresponds to a 'half' type in C/C++. <16,8> corresponds to a + 'bfloat16' type in C/C++. + + Args: + width (int, optional): Total number of bits used. Defaults to 32. + exponent (int, optional): Number of bits used for the exponent. Defaults to 8. + use_cpp_type (bool, optional): Use C++ equivalent types if available. Defaults to ``True``. + """ + + def __init__(self, width=32, exponent=8, use_cpp_type=True): + super().__init__(width=width, signed=True) + self.exponent = exponent + self.use_cpp_type = use_cpp_type + + def __str__(self): + if self._check_cpp_type(32, 8): + typestring = 'float' + elif self._check_cpp_type(64, 11): + typestring = 'double' + elif self._check_cpp_type(16, 5): + typestring = 'half' + elif self._check_cpp_type(16, 8): + typestring = 'bfloat16' + else: + typestring = f'std_float<{self.width},{self.exponent}>' + return typestring + + def _check_cpp_type(self, width, exponent): + return self.use_cpp_type and self.width == width and self.exponent == exponent + + def __eq__(self, other: object) -> bool: + if isinstance(other, FloatPrecisionType): + eq = super().__eq__(other) + eq = eq and self.exponent == other.exponent + return eq + + return False + + def __hash__(self) -> int: + return super().__hash__() ^ hash(self.exponent) + + def serialize_state(self): + state = super().serialize_state() + state.update( + { + 'exponent': self.exponent, + 'use_cpp_type': self.use_cpp_type, + } + ) + return state + + class UnspecifiedPrecisionType(PrecisionType): """ Class representing an unspecified precision type. @@ -592,7 +712,8 @@ def update_precision(self, new_precision): elif isinstance(new_precision, FixedPrecisionType): decimal_spaces = max(0, new_precision.fractional) self.precision_fmt = f'{{:.{decimal_spaces}f}}' - + elif isinstance(new_precision, (FloatPrecisionType, StandardFloatPrecisionType)): + self.precision_fmt = '{:.16f}' # Not ideal, but should be enough for most cases else: raise RuntimeError(f"Unexpected new precision type: {new_precision}") diff --git a/hls4ml/templates/oneapi/firmware/defines.h b/hls4ml/templates/oneapi/firmware/defines.h index 05de507dcd..b2fc5bdd9a 100644 --- a/hls4ml/templates/oneapi/firmware/defines.h +++ b/hls4ml/templates/oneapi/firmware/defines.h @@ -2,6 +2,7 @@ #define DEFINES_H_ #include +#include #include #include #include diff --git a/hls4ml/templates/quartus/firmware/defines.h b/hls4ml/templates/quartus/firmware/defines.h index c3fe4ec402..e74fde83e9 100644 --- a/hls4ml/templates/quartus/firmware/defines.h +++ b/hls4ml/templates/quartus/firmware/defines.h @@ -13,6 +13,7 @@ #ifndef __INTELFPGA_COMPILER__ #include "ac_fixed.h" +#include "ac_float.h" #include "ac_int.h" #define hls_register @@ -24,6 +25,7 @@ template using stream_out = nnet::stream; #else #include "HLS/ac_fixed.h" +#include "HLS/ac_float.h" #include "HLS/ac_int.h" #include "HLS/hls.h" diff --git a/test/pytest/test_types.py b/test/pytest/test_types.py index 8f4857fec9..ad6aa97a12 100644 --- a/test/pytest/test_types.py +++ b/test/pytest/test_types.py @@ -5,9 +5,11 @@ from hls4ml.model.types import ( ExponentPrecisionType, FixedPrecisionType, + FloatPrecisionType, IntegerPrecisionType, RoundingMode, SaturationMode, + StandardFloatPrecisionType, XnorPrecisionType, ) @@ -80,8 +82,105 @@ def test_precision_type_creation(capsys): ) def test_sign_parsing(prec_pair): '''Test that convert_precisions_string determines the signedness correctly''' - strprec = prec_pair[0] - signed = prec_pair[1] + strprec, signed = prec_pair evalprec = FPGABackend.convert_precision_string(strprec) assert evalprec.signed == signed + + +@pytest.mark.parametrize( + 'prec_tuple', + [ + # Notation without the prefix + ('fixed<16,6>', 16, 6, True, RoundingMode.TRN, SaturationMode.WRAP), + ('ufixed<18,7>', 18, 7, False, RoundingMode.TRN, SaturationMode.WRAP), + ('fixed<14, 5, RND>', 14, 5, True, RoundingMode.RND, SaturationMode.WRAP), + ('ufixed<13, 8, RND, SAT>', 13, 8, False, RoundingMode.RND, SaturationMode.SAT), + # Prefixed notation + ('ap_ufixed<17,6>', 17, 6, False, RoundingMode.TRN, SaturationMode.WRAP), + ('ac_fixed<15, 4, false, RND, SAT>', 15, 4, False, RoundingMode.RND, SaturationMode.SAT), + ], +) +def test_fixed_type_parsing(prec_tuple): + '''Test that convert_precision_string correctly parses specified fixed-point types''' + prec_str, width, integer, signed, round_mode, saturation_mode = prec_tuple + + evalprec = FPGABackend.convert_precision_string(prec_str) + + assert isinstance(evalprec, FixedPrecisionType) + assert evalprec.width == width + assert evalprec.integer == integer + assert evalprec.signed == signed + assert evalprec.rounding_mode == round_mode + assert evalprec.saturation_mode == saturation_mode + + +@pytest.mark.parametrize( + 'prec_tuple', + [ + # Notation without the prefix + ('int<16>', 16, True), + ('uint<18>', 18, False), + # Prefixed notation + ('ap_uint<8>', 8, False), + ('ac_int<14, false>', 14, False), + ], +) +def test_int_type_parsing(prec_tuple): + '''Test that convert_precision_string correctly parses specified fixed-point types''' + prec_str, width, signed = prec_tuple + + evalprec = FPGABackend.convert_precision_string(prec_str) + + assert isinstance(evalprec, IntegerPrecisionType) + assert evalprec.width == width + assert evalprec.signed == signed + + +@pytest.mark.parametrize( + 'prec_pair', + [ + # Standard floating-point types, should be parsed as C++ types + ('float', True), + ('double', True), + ('half', True), + ('bfloat16', True), + # Standard bitwidths, but should result in ap_float or ac_std_float, not standard C++ types + ('std_float<32,8>', False), + ('std_float<64,11>', False), + ('std_float<16,5>', False), + ('std_float<16,8>', False), + # Non-standard bitwidths, should not be parsed as C++ types + ('std_float<16,6>', False), + ('std_float<64,10>', False), + ], +) +def test_float_cpp_parsing(prec_pair): + '''Test that convert_precision_string correctly parses C++ types''' + prec_str, is_cpp = prec_pair + + evalprec = FPGABackend.convert_precision_string(prec_str) + assert isinstance(evalprec, StandardFloatPrecisionType) + assert evalprec.use_cpp_type == is_cpp and prec_str in str(evalprec) + + +@pytest.mark.parametrize( + 'prec_tuple', + [ + # Should result in ap_float + ('ac_float<25,2, 8>', 33, 2, 8, RoundingMode.TRN), + ('ac_float<54,2,11, AC_RND>', 65, 2, 11, RoundingMode.RND), + ('ac_float<25,4, 8>', 33, 4, 8, RoundingMode.TRN), + ], +) +def test_ac_float_parsing(prec_tuple): + '''Test that convert_precision_string correctly parses ac_float types''' + prec_str, width, integer, exponent, round_mode = prec_tuple + + evalprec = FPGABackend.convert_precision_string(prec_str) + + assert isinstance(evalprec, FloatPrecisionType) + assert evalprec.width == width + assert evalprec.integer == integer + assert evalprec.exponent == exponent + assert evalprec.rounding_mode == round_mode