Skip to content

Commit 2c094fc

Browse files
committed
updated visudo experiment
1 parent 2a417a0 commit 2c094fc

File tree

9 files changed

+2088
-70
lines changed

9 files changed

+2088
-70
lines changed

experiments/visual_sudoku/run.py

Lines changed: 40 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,21 @@
11
import argparse
2-
import io
3-
import zipfile
42
from pathlib import Path
53
from time import perf_counter
64

7-
import requests
8-
95
import klay
106
import numpy as np
117
import torch
128
import torch.nn as nn
139
import torchvision.transforms as transforms
14-
from klay.utils import torch_wmc_d4
1510
from torch.utils.data import Dataset
1611

1712

18-
def download_visudo_dataset(grid_size: int):
19-
data_path = Path(__file__).parent / Path("tmp")
20-
if data_path.exists():
21-
return
22-
23-
print("-> Downloading Visual Sudoku Dataset...")
24-
r = requests.get(f"https://linqs-data.soe.ucsc.edu/public/datasets/ViSudo-PC/v01/"
25-
f"ViSudo-PC_dimension::{grid_size}_datasets::mnist_strategy::simple.zip")
26-
print("-> Extracting...")
27-
z = zipfile.ZipFile(io.BytesIO(r.content))
28-
z.extractall(Path(__file__).parent)
29-
30-
3113
class SudokuDataset(Dataset):
3214
def __init__(self, partition: str, grid_size: int = 4, transform=None):
3315
super().__init__()
34-
data_path = Path(__file__).parent / (f"tmp/ViSudo-PC/ViSudo-PC_dimension::4_datasets::"
35-
f"mnist_strategy::simple/dimension::{grid_size}/datasets:"
36-
f":mnist/strategy::simple/strategy::simple/numTrain::00100/"
37-
f"numTest::00100/numValid::00100/corruptChance::0.50/"
38-
f"overlap::0.00/split::11")
39-
features_file = Path(data_path) / f'{partition}_puzzle_pixels.txt'
40-
labels_file = Path(data_path) / f'{partition}_puzzle_labels.txt'
16+
data_path = Path(__file__).parent / f"visudo{grid_size}"
17+
features_file = data_path / f'{partition}_puzzle_pixels.txt'
18+
labels_file = data_path / f'{partition}_puzzle_labels.txt'
4119
labels = np.loadtxt(labels_file, delimiter="\t", dtype=bool)
4220
features = np.loadtxt(features_file, delimiter="\t", dtype=np.float32)
4321
self.images = torch.as_tensor(features)
@@ -54,7 +32,6 @@ def __getitem__(self, idx: int):
5432

5533

5634
def get_dataloader(grid_size: int, partition: str, batch_size: int):
57-
download_visudo_dataset(grid_size)
5835
normalize = transforms.Normalize((0.1307,), (0.3081,))
5936
train_dataset = SudokuDataset(partition, grid_size, transform=normalize)
6037
return torch.utils.data.DataLoader(
@@ -103,90 +80,83 @@ def __init__(self, grid_size: int):
10380

10481
def forward(self, images):
10582
shape = images.shape
106-
assert not torch.isnan(images).any()
10783
images = images.reshape(-1, 1, 28, 28)
10884
image_probs = self.net(images)
10985
assert not torch.isnan(image_probs).any()
11086
image_probs = image_probs.reshape(shape[0], -1)
11187
return self.circuit_batched(image_probs, torch.zeros_like(image_probs))
11288

11389

114-
class VisualSudokuNaive(VisualSudokuModule):
115-
def __init__(self, grid_size: int):
116-
super().__init__(grid_size)
117-
self.net = LeNet(grid_size)
118-
self.circuit = None
119-
nnf_file = f"experiments/visual_sudoku/sudoku_{grid_size}.nnf"
120-
self.circuit_batched = lambda x, y: torch_wmc_d4(nnf_file, x, y)
121-
self.grid_size = grid_size
122-
123-
12490
def get_circuit(grid_size: int):
12591
circuit = klay.Circuit()
126-
const_lits = [] # [-x for x in range(1, grid_size**3+1)]
127-
circuit.add_d4_from_file(f"experiments/visual_sudoku/sudoku_{grid_size}.nnf", true_lits = const_lits)
92+
const_lits = [-x for x in range(1, grid_size ** 3 + 1)]
93+
circuit.add_d4_from_file(f"experiments/visual_sudoku/sudoku_{grid_size}.nnf", true_lits=const_lits)
12894
print("Nb nodes", circuit.nb_nodes())
12995
return circuit.to_torch_module()
13096

13197

13298
def nll_loss(preds, targets):
133-
neg_preds = klay.backends.torch_backend.log1mexp(preds)
99+
neg_preds = klay.torch.log1mexp(preds, eps=1e-7)
134100
nll = -torch.where(targets, preds, neg_preds)
135101
return nll.mean()
136102

137103

138-
def main(grid_size: int, batch_size: int, nb_epochs: int, learning_rate: float, naive=False, device="cuda"):
104+
def train(model, optimizer, dataloader, device="cuda"):
105+
losses = []
106+
for xs, ys in dataloader:
107+
xs, ys = xs.to(device), ys.to(device)
108+
preds = model(xs)
109+
loss = nll_loss(preds[0], ys)
110+
losses.append(loss.item())
111+
assert not torch.isnan(loss).any()
112+
loss.backward()
113+
torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
114+
optimizer.step()
115+
optimizer.zero_grad()
116+
return losses
117+
118+
119+
def evaluate(model, dataloader, device="cuda"):
120+
model = model.eval()
121+
accs = []
122+
for xs, ys in dataloader:
123+
xs, ys = xs.to(device), ys.to(device)
124+
preds = model(xs).exp()
125+
acc = (preds[0] > 0.5) == ys
126+
accs += acc.tolist()
127+
return accs
128+
129+
130+
def main(grid_size: int, batch_size: int, nb_epochs: int, learning_rate: float, device="cuda"):
139131
train_dataloader = get_dataloader(grid_size, "train", batch_size)
140-
if naive:
141-
model = VisualSudokuNaive(grid_size).to(device)
142-
else:
143-
model = VisualSudokuModule(grid_size).to(device)
144-
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.00001)
132+
model = VisualSudokuModule(grid_size).to(device)
133+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-7)
145134
timings = []
146135

147136
for epoch in range(nb_epochs):
148-
losses = []
149137
t1 = perf_counter()
150-
for xs, ys in train_dataloader:
151-
xs, ys = xs.to(device), ys.to(device)
152-
preds = model(xs)
153-
loss = nll_loss(preds[0], ys)
154-
losses.append(loss.item())
155-
assert not torch.isnan(loss).any()
156-
loss.backward()
157-
torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
158-
optimizer.step()
159-
optimizer.zero_grad()
138+
losses = train(model, optimizer, train_dataloader, device)
160139
timings.append(perf_counter() - t1)
161140
print(f"Epoch {epoch}, Loss {np.mean(losses):.5f}")
162-
163141
print(f"Mean Epoch Time (s) {np.mean(timings):.3f} ± {np.std(timings):.3f}")
164142

165-
model = model.eval()
166143
val_dataloader = get_dataloader(grid_size, "valid", 1)
167-
accs = []
168-
for xs, ys in val_dataloader:
169-
xs, ys = xs.to(device), ys.to(device)
170-
preds = model(xs).exp()
171-
acc = (preds[0] > 0.5) == ys
172-
accs += acc.tolist()
173-
print(f"Validation Accuracy {np.mean(accs):.5f}")
144+
accs = evaluate(model, val_dataloader, device)
145+
print(f"Validation Accuracy {100*np.mean(accs):.2f}%")
174146

175147

176148
if __name__ == "__main__":
177149
parser = argparse.ArgumentParser()
178-
parser.add_argument('-b', '--batch_size', type=int, default=4)
150+
parser.add_argument('-b', '--batch_size', type=int, default=1)
179151
parser.add_argument('-e', '--nb_epochs', type=int, default=10)
180152
parser.add_argument('-d', '--device', default='cpu')
181-
parser.add_argument('-lr', '--learning_rate', type=float, default=0.0003)
182-
parser.add_argument("-n", '--naive', action=argparse.BooleanOptionalAction, default=False)
153+
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001)
183154
args = parser.parse_args()
184155

185156
main(
186157
grid_size=4,
187158
batch_size=args.batch_size,
188159
nb_epochs=args.nb_epochs,
189160
learning_rate=args.learning_rate,
190-
naive=args.naive,
191161
device=args.device
192162
)

0 commit comments

Comments
 (0)