-
Notifications
You must be signed in to change notification settings - Fork 131
Open
Description
Is it right for me to do this? The code works, but I don't know if this is the right process...
# Set up the regressor.
device = chainer.get_device(args.device)
model_path = os.path.join(args.in_dir, args.model_filename)
metrics_fun = {'mae': F.mean_absolute_error, 'rmse': rmse}
regressor = Regressor.load_pickle('result/pretrain_qm9.pkl', device=device)
mlp = MLP(out_dim=class_num, hidden_dim=args.unit_num)
predictor = regressor.predictor
new_predictor = GraphConvPredictor(predictor, mlp=mlp)
new_regressor = Regressor(new_predictor,lossfun=F.mean_squared_error,
metrics_fun=metrics_fun, device=device)
print('Training...')
run_train(new_regressor, dataset, valid=None, batch_size=args.batchsize, epoch=args.epoch, out=args.out,
device=device, converter=megnet_converter, resume_path=None,
extensions_list=[extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch',
filename='-trans-megnet-full-loss.svg', marker='None'),
extensions.PlotReport(['main/rmse', 'validation/main/rmse'], 'epoch',
filename='trans-megnet-full-rmse.svg', marker='None'),
extensions.PlotReport(['main/mae', 'validation/main/mae'], 'epoch',
filename='trans-megnet-full-mae.svg', marker='None')])
Metadata
Metadata
Assignees
Labels
No labels