Replies: 1 comment
-
Got it! It was just a problem of resetting the computer. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I'm trying to solve a binary classification problem and so it seems to me that binary cross entropy loss function in tensorflow is the one I must consider. However, this loss isn't included in the file losses.py of deepxde.
I've included it myself in losses.py. Precisely, here is the piece of code:
def binary_cross_entropy(y_true, y_pred):
# TODO: pytorch
return tf.keras.losses.BinaryCrossentropy(from_logits=True)(y_true, y_pred)
and in the LOSS_DICT
Unfortunately, it doesn't work. I get an error message.
May you please help me? thanks in advance.
Here is the error message:
KeyError Traceback (most recent call last)
Cell In[74], line 18
9 net = dde.nn.DeepONet(
10 [m, 40, p], # dimensions of the fully connected branch net
11 [n, 40, p], # dimensions of the fully connected trunk net
12 "sigmoid",
13 "Glorot normal", # initialization of parameters
14 )
16 model = dde.Model(data, net)
---> 18 model.compile("adam", lr=0.001, loss="binary cross entropy", metrics=['accuracy']) # accuracy is the mean of matches between predictions and labels
19 model.train(iterations=ITERATIONS)
20 model.compile("L-BFGS-B", metrics=['accuracy'])
File c:\Users\Paco\anaconda3\envs\deeponetcontrol\lib\site-packages\deepxde\utils\internal.py:22, in timing..wrapper(*args, **kwargs)
19 @wraps(f)
20 def wrapper(*args, **kwargs):
21 ts = timeit.default_timer()
---> 22 result = f(*args, **kwargs)
23 te = timeit.default_timer()
24 if config.rank == 0:
File c:\Users\Paco\anaconda3\envs\deeponetcontrol\lib\site-packages\deepxde\model.py:121, in Model.compile(self, optimizer, lr, loss, metrics, decay, loss_weights, external_trainable_variables)
119 print("Compiling model...")
120 self.opt_name = optimizer
--> 121 loss_fn = losses_module.get(loss)
122 self.loss_weights = loss_weights
123 if external_trainable_variables is None:
File c:\Users\Paco\anaconda3\envs\deeponetcontrol\lib\site-packages\deepxde\losses.py:69, in get(identifier)
66 return list(map(get, identifier))
68 if isinstance(identifier, str):
---> 69 return LOSS_DICT[identifier]
70 if callable(identifier):
71 return identifier
KeyError: 'binary cross entropy'
Beta Was this translation helpful? Give feedback.
All reactions