Skip to content

TestPredNet Loss Calculation is NOT in the Loop #5

@Henry-Louis

Description

@Henry-Louis

In several places where testpredNet optimization is involved, the training loop seems to be mis allocated:

    # below, currently only predict using the learned representation, consider including vaez too.
    baselinevaetestpred_trainX, baselinevaetestpred_trainy = baselinevaemlp(envs[1]['vaez'][::2])[0], envs[1]['labels'][::2]
    baselinevaetestpred_testX, baselinevaetestpred_testy = baselinevaemlp(envs[1]['vaez'][1::2])[0], envs[1]['labels'][1::2]
    baselinevaeprediction = baselinevaetestprednet(baselinevaetestpred_trainX)     # input x and predict based on x
    baselinevaeloss = loss_func(nn.Sigmoid()(baselinevaeprediction), baselinevaetestpred_trainy)     # must be (1. nn output, 2. target)


    for t in range(200):
        optimizer_baselinevaetestpred.zero_grad()   # clear gradients for next train
        baselinevaeloss.backward(retain_graph=True)         # backpropagation, compute gradients
        optimizer_baselinevaetestpred.step()        # apply gradients

It should be:

    # below, currently only predict using the learned representation, consider including vaez too.
    baselinevaetestpred_trainX, baselinevaetestpred_trainy = baselinevaemlp(envs[1]['vaez'][::2])[0], envs[1]['labels'][::2]
    baselinevaetestpred_testX, baselinevaetestpred_testy = baselinevaemlp(envs[1]['vaez'][1::2])[0], envs[1]['labels'][1::2]
    


    for t in range(200):
        baselinevaeprediction = baselinevaetestprednet(baselinevaetestpred_trainX)     # input x and predict based on x
        baselinevaeloss = loss_func(nn.Sigmoid()(baselinevaeprediction), baselinevaetestpred_trainy)     # must be (1. nn output, 2. target)
        optimizer_baselinevaetestpred.zero_grad()   # clear gradients for next train
        baselinevaeloss.backward(retain_graph=True)         # backpropagation, compute gradients
        optimizer_baselinevaetestpred.step()        # apply gradients

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions