-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
When I try to use GeometricSandwichProductDense like this:
from tfga import GeometricAlgebra
from tfga.layers import TensorToGeometric, GeometricProductConv1D, GeometricToTensor, GeometricSandwichProductDense, GeometricProductDense
ga = GeometricAlgebra(metric=[1, 1, 1, 1])
idx = ga.get_kind_blade_indices("even")
model = InceptionV3(classifier_activation = None, weights = "imagenet", input_tensor=Input(shape=(224, 224, 3)))
x2 = Dropout(0.3)(model.layers[-2].output)
x2 = Reshape((-1, 8))(x2)
x2 = TensorToGeometric(ga, blade_indices=idx)(x2)
x2 = GeometricSandwichProductDense(
ga, units=128, activation = "relu",
blade_indices_kernel=idx,
blade_indices_bias=idx)(x2)
x2 = GeometricSandwichProductDense(
ga, units=64, activation = "relu",
blade_indices_kernel=idx,
blade_indices_bias=idx)(x2)
outputs2 = GeometricSandwichProductDense(
ga, units=1, activation = "tanh",
blade_indices_kernel=idx,
blade_indices_bias=idx)(x2)
#x2 = GeometricToTensor(ga, blade_indices=idx)(x2)
#outputs2 = Flatten()(x2)
I receive the following error report:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[5], line 10
8 x2 = Reshape((-1, 8))(x2)
9 x2 = TensorToGeometric(ga, blade_indices=idx)(x2)
---> 10 x2 = GeometricSandwichProductDense(
11 ga, units=128, activation = "relu",
12 blade_indices_kernel=idx,
13 blade_indices_bias=idx)(x2)
14 x2 = GeometricSandwichProductDense(
15 ga, units=64, activation = "relu",
16 blade_indices_kernel=idx,
17 blade_indices_bias=idx)(x2)
18 outputs2 = GeometricSandwichProductDense(
19 ga, units=1, activation = "tanh",
20 blade_indices_kernel=idx,
21 blade_indices_bias=idx)(x2)
File ~/miniconda3/envs/tf/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.__traceback__)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
File ~/miniconda3/envs/tf/lib/python3.11/site-packages/tfga/layers.py:189, in GeometricProductDense.build(self, input_shape)
183 self.num_input_units = input_shape[-2]
184 shape_kernel = [
185 self.units,
186 self.num_input_units,
187 self.blade_indices_kernel.shape[0],
188 ]
--> 189 self.kernel = self.add_weight(
190 "kernel",
191 shape=shape_kernel,
192 initializer=self.kernel_initializer,
193 regularizer=self.kernel_regularizer,
194 constraint=self.kernel_constraint,
195 dtype=self.dtype,
196 trainable=True,
197 )
198 if self.use_bias:
199 shape_bias = [self.units, self.blade_indices_bias.shape[0]]
TypeError: Layer.add_weight() got multiple values for argument 'shape'
Is this issue caused by incorrect usage? Can I fix it simply by adjusting my code?
Metadata
Metadata
Assignees
Labels
No labels