-
Notifications
You must be signed in to change notification settings - Fork 163
Open
Description
我看了一下代码逻辑,由split后的probabilities和labels来计算accuracy, 下面这个代码块中加粗部分是否应该改成label_ids_split=tf.split(label_ids,FLAGS.num_aspects,axis=-1)? 这个是否与其他人po的eval_accuracy出错有关?
`def metric_fn(per_example_loss, label_ids, logits):
#print("###metric_fn.logits:",logits.shape) # (?,80)
#predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
#print("###metric_fn.label_ids:",label_ids.shape,";predictions:",predictions.shape) # label_ids: (?,80);predictions:(?,)
logits_split=tf.split(logits,FLAGS.num_aspects,axis=-1) # a list. length is num_aspects
label_ids_split=tf.split(logits,FLAGS.num_aspects,axis=-1) # a list. length is num_aspects
accuracy=tf.constant(0.0,dtype=tf.float64)
for j,logits in enumerate(logits_split): #
# accuracy = tf.metrics.accuracy(label_ids, predictions)
predictions=tf.argmax(logits, axis=-1, output_type=tf.int32) # should be [batch_size,]
label_id_=tf.cast(tf.argmax(label_ids_split[j],axis=-1),dtype=tf.int32)
print("label_ids_split[j]:",label_ids_split[j],";predictions:",predictions,";label_id_:",label_id_)
current_accuracy,update_op_accuracy=tf.metrics.accuracy(label_id_,predictions)
accuracy+=tf.cast(current_accuracy,dtype=tf.float64)
accuracy=accuracy/tf.constant(FLAGS.num_aspects,dtype=tf.float64)
loss = tf.metrics.mean(per_example_loss)
return {
"eval_accuracy": (accuracy,update_op_accuracy),
"eval_loss": loss,
}`
Metadata
Metadata
Assignees
Labels
No labels