diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 6ee66c096..0060bcce7 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -20,7 +20,7 @@ from tf2onnx import constants, logging, utils, optimizer from tf2onnx import tf_loader from tf2onnx.graph import ExternalTensorStorage -from tf2onnx.tf_utils import compress_graph_def, get_tf_version +from tf2onnx.tf_utils import compress_graph_def, get_tf_version, get_keras_version @@ -408,6 +408,106 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None, return model_proto, external_tensor_storage +def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None, + custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None, + target=None, large_model=False, output_path=None, optimizers=None): + """ + Convert a Keras 3 model to ONNX using tf2onnx. + + Args: + model: Keras 3 Functional or Sequential model + name: Name for the converted model + input_signature: Optional list of tf.TensorSpec + opset: ONNX opset version + custom_ops: Dictionary of custom ops + custom_op_handlers: Dictionary of custom op handlers + custom_rewriter: List of graph rewriters + inputs_as_nchw: List of input names to convert to NCHW + extra_opset: Additional opset imports + shape_override: Dictionary to override input shapes + target: Target platforms (for workarounds) + large_model: Whether to use external tensor storage + output_path: Optional path to write ONNX model to file + + Returns: + A tuple (model_proto, external_tensor_storage_dict) + """ + + + if not input_signature: + + input_signature = [ + tf.TensorSpec(tensor.shape, tensor.dtype, name=tensor.name.split(":")[0]) + for tensor in model.inputs + ] + + # Trace model + function = tf.function(model) + concrete_func = function.get_concrete_function(*input_signature) + + # These inputs will be removed during freezing (includes resources, etc.) + if hasattr(concrete_func.graph, '_captures'): + graph_captures = concrete_func.graph._captures # pylint: disable=protected-access + captured_inputs = [t_name.name for _, t_name in graph_captures.values()] + else: + graph_captures = concrete_func.graph.function_captures.by_val_internal + captured_inputs = [t.name for t in graph_captures.values()] + input_names = [input_tensor.name for input_tensor in concrete_func.inputs + if input_tensor.name not in captured_inputs] + output_names = [output_tensor.name for output_tensor in concrete_func.outputs + if output_tensor.dtype != tf.dtypes.resource] + + + tensors_to_rename = tensor_names_from_structed(concrete_func, input_names, output_names) + reverse_lookup = {v: k for k, v in tensors_to_rename.items()} + + + + valid_names = [] + for out in model.output_names: + if out in reverse_lookup: + valid_names.append(reverse_lookup[out]) + else: + print(f"Warning: Output name '{out}' not found in reverse_lookup.") + # Fallback: verwende TensorFlow-Ausgangsnamen direkt + valid_names = [t.name for t in concrete_func.outputs if t.dtype != tf.dtypes.resource] + break + output_names = valid_names + + + #if old_out_names is not None: + #model.output_names = old_out_names + + with tf.device("/cpu:0"): + frozen_graph, initialized_tables = \ + tf_loader.from_trackable(model, concrete_func, input_names, output_names, large_model) + + for node in frozen_graph.node: + print(node.name, node.op) + model_proto, external_tensor_storage = _convert_common( + frozen_graph, + name=model.name, + continue_on_error=True, + target=target, + opset=opset, + custom_ops=custom_ops, + custom_op_handlers=custom_op_handlers, + optimizers=optimizers, + custom_rewriter=custom_rewriter, + extra_opset=extra_opset, + shape_override=shape_override, + input_names=input_names, + output_names=output_names, + inputs_as_nchw=inputs_as_nchw, + outputs_as_nchw=outputs_as_nchw, + large_model=large_model, + tensors_to_rename=tensors_to_rename, + initialized_tables=initialized_tables, + output_path=output_path) + + #print(model_proto) + + return model_proto, external_tensor_storage def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None, @@ -438,6 +538,10 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_ if get_tf_version() < Version("2.0"): return _from_keras_tf1(model, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw, outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path) + if get_keras_version() > Version("3.0"): + return from_keras3(model, input_signature, opset, custom_ops, custom_op_handlers, + custom_rewriter, inputs_as_nchw, outputs_as_nchw, extra_opset, shape_override, + target, large_model, output_path, optimizers) old_out_names = _rename_duplicate_keras_model_names(model) from tensorflow.python.keras.saving import saving_utils as _saving_utils # pylint: disable=import-outside-toplevel diff --git a/tf2onnx/tf_utils.py b/tf2onnx/tf_utils.py index 16cb76344..0dd157fb0 100644 --- a/tf2onnx/tf_utils.py +++ b/tf2onnx/tf_utils.py @@ -10,6 +10,7 @@ import numpy as np import tensorflow as tf +import keras from tensorflow.core.framework import types_pb2, tensor_pb2, graph_pb2 from tensorflow.python.framework import tensor_util @@ -124,6 +125,9 @@ def get_tf_node_attr(node, name): def get_tf_version(): return Version(tf.__version__) +def get_keras_version(): + return Version(keras.__version__) + def compress_graph_def(graph_def): """ Remove large const values from graph. This lets us import the graph and run shape inference without TF crashing.