Skip to content

这段计算eval performance的代码是不是有错误 #7

@PeterPanUnderhill

Description

@PeterPanUnderhill

我看了一下代码逻辑,由split后的probabilities和labels来计算accuracy, 下面这个代码块中加粗部分是否应该改成label_ids_split=tf.split(label_ids,FLAGS.num_aspects,axis=-1)? 这个是否与其他人po的eval_accuracy出错有关?

`def metric_fn(per_example_loss, label_ids, logits):
#print("###metric_fn.logits:",logits.shape) # (?,80)
#predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
#print("###metric_fn.label_ids:",label_ids.shape,";predictions:",predictions.shape) # label_ids: (?,80);predictions:(?,)
logits_split=tf.split(logits,FLAGS.num_aspects,axis=-1) # a list. length is num_aspects
label_ids_split=tf.split(logits,FLAGS.num_aspects,axis=-1) # a list. length is num_aspects
accuracy=tf.constant(0.0,dtype=tf.float64)

    for j,logits in enumerate(logits_split): #
        #  accuracy = tf.metrics.accuracy(label_ids, predictions)

        predictions=tf.argmax(logits, axis=-1, output_type=tf.int32) # should be [batch_size,]
        label_id_=tf.cast(tf.argmax(label_ids_split[j],axis=-1),dtype=tf.int32)
        print("label_ids_split[j]:",label_ids_split[j],";predictions:",predictions,";label_id_:",label_id_)
        current_accuracy,update_op_accuracy=tf.metrics.accuracy(label_id_,predictions)
        accuracy+=tf.cast(current_accuracy,dtype=tf.float64)
    accuracy=accuracy/tf.constant(FLAGS.num_aspects,dtype=tf.float64)
    loss = tf.metrics.mean(per_example_loss)
    return {
        "eval_accuracy": (accuracy,update_op_accuracy),
        "eval_loss": loss,
    }`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions