diff --git a/examples/trompt.py b/examples/trompt.py index 342ab653e..41d53d379 100644 --- a/examples/trompt.py +++ b/examples/trompt.py @@ -14,7 +14,6 @@ helena : 37.90 jannis : 72.98 """ - import argparse import os.path as osp @@ -27,6 +26,10 @@ from torch_frame.datasets import TabularBenchmark from torch_frame.nn import Trompt +# Use TF32 for faster matrix multiplication on Ampere GPUs. +# https://dev-discuss.pytorch.org/t/pytorch-and-tensorfloat32/504 +torch.set_float32_matmul_precision('high') + parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="california") parser.add_argument("--channels", type=int, default=128) @@ -64,12 +67,23 @@ train_tensor_frame = train_dataset.tensor_frame val_tensor_frame = val_dataset.tensor_frame test_tensor_frame = test_dataset.tensor_frame -train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size, - shuffle=True) -val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size) -test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size) +train_loader = DataLoader( + train_tensor_frame, + batch_size=args.batch_size, + shuffle=True, + pin_memory=True, +) +val_loader = DataLoader( + val_tensor_frame, + batch_size=args.batch_size, + pin_memory=True, +) +test_loader = DataLoader( + test_tensor_frame, + batch_size=args.batch_size, + pin_memory=True, +) -# Set up model and optimizer model = Trompt( channels=args.channels, out_channels=dataset.num_classes, @@ -79,59 +93,69 @@ col_names_dict=train_tensor_frame.col_names_dict, ).to(device) model = torch.compile(model, dynamic=True) if args.compile else model -optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, fused=True) lr_scheduler = ExponentialLR(optimizer, gamma=0.95) -def train(epoch: int) -> float: +def train(epoch: int) -> torch.Tensor: model.train() - loss_accum = total_count = 0 + loss_accum = torch.zeros(1, device=device, dtype=torch.float32).squeeze_() + total_count = 0 - for tf in tqdm(train_loader, desc=f"Epoch: {epoch}"): - tf = tf.to(device) + for tf in tqdm(train_loader, desc=f"Epoch {epoch:3d}"): + tf = tf.to(device, non_blocking=True) # [batch_size, num_layers, num_classes] out = model(tf) - num_layers = out.size(1) + batch_size, num_layers, num_classes = out.size() # [batch_size * num_layers, num_classes] - pred = out.view(-1, dataset.num_classes) - y = tf.y.repeat_interleave(num_layers) + pred = out.view(-1, num_classes) + y = tf.y.repeat_interleave( + num_layers, + output_size=num_layers * batch_size, + ) # Layer-wise logit loss loss = F.cross_entropy(pred, y) - optimizer.zero_grad() loss.backward() - loss_accum += float(loss) * len(tf.y) - total_count += len(tf.y) optimizer.step() + optimizer.zero_grad() + + total_count += len(tf.y) + loss *= len(tf.y) + loss_accum += loss + + lr_scheduler.step() return loss_accum / total_count @torch.no_grad() -def test(loader: DataLoader) -> float: +def test(loader: DataLoader, desc: str) -> torch.Tensor: model.eval() - accum = total_count = 0 + accum = torch.zeros(1, device=device, dtype=torch.long).squeeze_() + total_count = 0 - for tf in loader: - tf = tf.to(device) + for tf in tqdm(loader, desc=desc): + tf = tf.to(device, non_blocking=True) pred = model(tf).mean(dim=1) pred_class = pred.argmax(dim=-1) - accum += float((tf.y == pred_class).sum()) + accum += (tf.y == pred_class).sum() total_count += len(tf.y) return accum / total_count -best_val_acc = 0 -best_test_acc = 0 +best_val_acc = 0.0 +best_test_acc = 0.0 for epoch in range(1, args.epochs + 1): train_loss = train(epoch) - train_acc = test(train_loader) - val_acc = test(val_loader) - test_acc = test(test_loader) + train_acc = test(train_loader, "Eval (train)") + val_acc = test(val_loader, "Eval (val)") if best_val_acc < val_acc: best_val_acc = val_acc - best_test_acc = test_acc - print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " - f"Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}") - lr_scheduler.step() + best_test_acc = test(test_loader, "Eval (test)") + + print(f"Train Loss: {train_loss:.4f}, " + f"Train Acc: {train_acc:.4f}, " + f"Val Acc: {val_acc:.4f}, " + f"Test Acc: {best_test_acc:.4f}") print(f"Best Val Acc: {best_val_acc:.4f}, Best Test Acc: {best_test_acc:.4f}") diff --git a/torch_frame/nn/conv/trompt_conv.py b/torch_frame/nn/conv/trompt_conv.py index 1627ccc31..a41f1136a 100644 --- a/torch_frame/nn/conv/trompt_conv.py +++ b/torch_frame/nn/conv/trompt_conv.py @@ -92,8 +92,7 @@ def forward(self, x: Tensor, x_prompt: Tensor) -> Tensor: # M_importance # [batch_size, num_prompts, channels], [batch_size, num_cols, channels] # -> [batch_size, num_prompts, num_cols] - m_importance = torch.einsum('ijl,ikl->ijk', stacked_e_prompt, - stacked_e_column) + m_importance = stacked_e_prompt @ stacked_e_column.transpose(1, 2) m_importance = F.softmax(m_importance, dim=-1) # [batch_size, num_prompts, num_cols, 1] m_importance = m_importance.unsqueeze(dim=-1)