11import argparse
2- import io
3- import zipfile
42from pathlib import Path
53from time import perf_counter
64
7- import requests
8-
95import klay
106import numpy as np
117import torch
128import torch .nn as nn
139import torchvision .transforms as transforms
14- from klay .utils import torch_wmc_d4
1510from torch .utils .data import Dataset
1611
1712
18- def download_visudo_dataset (grid_size : int ):
19- data_path = Path (__file__ ).parent / Path ("tmp" )
20- if data_path .exists ():
21- return
22-
23- print ("-> Downloading Visual Sudoku Dataset..." )
24- r = requests .get (f"https://linqs-data.soe.ucsc.edu/public/datasets/ViSudo-PC/v01/"
25- f"ViSudo-PC_dimension::{ grid_size } _datasets::mnist_strategy::simple.zip" )
26- print ("-> Extracting..." )
27- z = zipfile .ZipFile (io .BytesIO (r .content ))
28- z .extractall (Path (__file__ ).parent )
29-
30-
3113class SudokuDataset (Dataset ):
3214 def __init__ (self , partition : str , grid_size : int = 4 , transform = None ):
3315 super ().__init__ ()
34- data_path = Path (__file__ ).parent / (f"tmp/ViSudo-PC/ViSudo-PC_dimension::4_datasets::"
35- f"mnist_strategy::simple/dimension::{ grid_size } /datasets:"
36- f":mnist/strategy::simple/strategy::simple/numTrain::00100/"
37- f"numTest::00100/numValid::00100/corruptChance::0.50/"
38- f"overlap::0.00/split::11" )
39- features_file = Path (data_path ) / f'{ partition } _puzzle_pixels.txt'
40- labels_file = Path (data_path ) / f'{ partition } _puzzle_labels.txt'
16+ data_path = Path (__file__ ).parent / f"visudo{ grid_size } "
17+ features_file = data_path / f'{ partition } _puzzle_pixels.txt'
18+ labels_file = data_path / f'{ partition } _puzzle_labels.txt'
4119 labels = np .loadtxt (labels_file , delimiter = "\t " , dtype = bool )
4220 features = np .loadtxt (features_file , delimiter = "\t " , dtype = np .float32 )
4321 self .images = torch .as_tensor (features )
@@ -54,7 +32,6 @@ def __getitem__(self, idx: int):
5432
5533
5634def get_dataloader (grid_size : int , partition : str , batch_size : int ):
57- download_visudo_dataset (grid_size )
5835 normalize = transforms .Normalize ((0.1307 ,), (0.3081 ,))
5936 train_dataset = SudokuDataset (partition , grid_size , transform = normalize )
6037 return torch .utils .data .DataLoader (
@@ -103,90 +80,83 @@ def __init__(self, grid_size: int):
10380
10481 def forward (self , images ):
10582 shape = images .shape
106- assert not torch .isnan (images ).any ()
10783 images = images .reshape (- 1 , 1 , 28 , 28 )
10884 image_probs = self .net (images )
10985 assert not torch .isnan (image_probs ).any ()
11086 image_probs = image_probs .reshape (shape [0 ], - 1 )
11187 return self .circuit_batched (image_probs , torch .zeros_like (image_probs ))
11288
11389
114- class VisualSudokuNaive (VisualSudokuModule ):
115- def __init__ (self , grid_size : int ):
116- super ().__init__ (grid_size )
117- self .net = LeNet (grid_size )
118- self .circuit = None
119- nnf_file = f"experiments/visual_sudoku/sudoku_{ grid_size } .nnf"
120- self .circuit_batched = lambda x , y : torch_wmc_d4 (nnf_file , x , y )
121- self .grid_size = grid_size
122-
123-
12490def get_circuit (grid_size : int ):
12591 circuit = klay .Circuit ()
126- const_lits = [] # [ -x for x in range(1, grid_size**3+ 1)]
127- circuit .add_d4_from_file (f"experiments/visual_sudoku/sudoku_{ grid_size } .nnf" , true_lits = const_lits )
92+ const_lits = [- x for x in range (1 , grid_size ** 3 + 1 )]
93+ circuit .add_d4_from_file (f"experiments/visual_sudoku/sudoku_{ grid_size } .nnf" , true_lits = const_lits )
12894 print ("Nb nodes" , circuit .nb_nodes ())
12995 return circuit .to_torch_module ()
13096
13197
13298def nll_loss (preds , targets ):
133- neg_preds = klay .backends . torch_backend . log1mexp (preds )
99+ neg_preds = klay .torch . log1mexp (preds , eps = 1e-7 )
134100 nll = - torch .where (targets , preds , neg_preds )
135101 return nll .mean ()
136102
137103
138- def main (grid_size : int , batch_size : int , nb_epochs : int , learning_rate : float , naive = False , device = "cuda" ):
104+ def train (model , optimizer , dataloader , device = "cuda" ):
105+ losses = []
106+ for xs , ys in dataloader :
107+ xs , ys = xs .to (device ), ys .to (device )
108+ preds = model (xs )
109+ loss = nll_loss (preds [0 ], ys )
110+ losses .append (loss .item ())
111+ assert not torch .isnan (loss ).any ()
112+ loss .backward ()
113+ torch .nn .utils .clip_grad_norm_ (model .parameters (), 2 )
114+ optimizer .step ()
115+ optimizer .zero_grad ()
116+ return losses
117+
118+
119+ def evaluate (model , dataloader , device = "cuda" ):
120+ model = model .eval ()
121+ accs = []
122+ for xs , ys in dataloader :
123+ xs , ys = xs .to (device ), ys .to (device )
124+ preds = model (xs ).exp ()
125+ acc = (preds [0 ] > 0.5 ) == ys
126+ accs += acc .tolist ()
127+ return accs
128+
129+
130+ def main (grid_size : int , batch_size : int , nb_epochs : int , learning_rate : float , device = "cuda" ):
139131 train_dataloader = get_dataloader (grid_size , "train" , batch_size )
140- if naive :
141- model = VisualSudokuNaive (grid_size ).to (device )
142- else :
143- model = VisualSudokuModule (grid_size ).to (device )
144- optimizer = torch .optim .AdamW (model .parameters (), lr = learning_rate , weight_decay = 0.00001 )
132+ model = VisualSudokuModule (grid_size ).to (device )
133+ optimizer = torch .optim .AdamW (model .parameters (), lr = learning_rate , weight_decay = 1e-7 )
145134 timings = []
146135
147136 for epoch in range (nb_epochs ):
148- losses = []
149137 t1 = perf_counter ()
150- for xs , ys in train_dataloader :
151- xs , ys = xs .to (device ), ys .to (device )
152- preds = model (xs )
153- loss = nll_loss (preds [0 ], ys )
154- losses .append (loss .item ())
155- assert not torch .isnan (loss ).any ()
156- loss .backward ()
157- torch .nn .utils .clip_grad_norm_ (model .parameters (), 2 )
158- optimizer .step ()
159- optimizer .zero_grad ()
138+ losses = train (model , optimizer , train_dataloader , device )
160139 timings .append (perf_counter () - t1 )
161140 print (f"Epoch { epoch } , Loss { np .mean (losses ):.5f} " )
162-
163141 print (f"Mean Epoch Time (s) { np .mean (timings ):.3f} ± { np .std (timings ):.3f} " )
164142
165- model = model .eval ()
166143 val_dataloader = get_dataloader (grid_size , "valid" , 1 )
167- accs = []
168- for xs , ys in val_dataloader :
169- xs , ys = xs .to (device ), ys .to (device )
170- preds = model (xs ).exp ()
171- acc = (preds [0 ] > 0.5 ) == ys
172- accs += acc .tolist ()
173- print (f"Validation Accuracy { np .mean (accs ):.5f} " )
144+ accs = evaluate (model , val_dataloader , device )
145+ print (f"Validation Accuracy { 100 * np .mean (accs ):.2f} %" )
174146
175147
176148if __name__ == "__main__" :
177149 parser = argparse .ArgumentParser ()
178- parser .add_argument ('-b' , '--batch_size' , type = int , default = 4 )
150+ parser .add_argument ('-b' , '--batch_size' , type = int , default = 1 )
179151 parser .add_argument ('-e' , '--nb_epochs' , type = int , default = 10 )
180152 parser .add_argument ('-d' , '--device' , default = 'cpu' )
181- parser .add_argument ('-lr' , '--learning_rate' , type = float , default = 0.0003 )
182- parser .add_argument ("-n" , '--naive' , action = argparse .BooleanOptionalAction , default = False )
153+ parser .add_argument ('-lr' , '--learning_rate' , type = float , default = 0.001 )
183154 args = parser .parse_args ()
184155
185156 main (
186157 grid_size = 4 ,
187158 batch_size = args .batch_size ,
188159 nb_epochs = args .nb_epochs ,
189160 learning_rate = args .learning_rate ,
190- naive = args .naive ,
191161 device = args .device
192162 )
0 commit comments