Skip to content

Commit d51df8d

Browse files
Test mnist training (#61)
Added MNIST training test case. This requires torch-vision, so added it to the requirement. Furthermore, the change in cpu pass to `torch-to-iree`, which needs the change in IREE pass to be landed and reflected in the IREE release, will fail all the tests with `empty_strided` involved. As soon as IREE's new release is available, the xfails can be removed.
1 parent ce2f5b0 commit d51df8d

File tree

4 files changed

+165
-4
lines changed

4 files changed

+165
-4
lines changed

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
-f https://openxla.github.io/iree/pip-release-links.html
66

77
-r pytorch-cpu-requirements.txt
8+
-r torchvision-requirements.txt
89

9-
iree-compiler==20230914.645
10-
iree-runtime==20230914.645
10+
iree-compiler==20230920.651
11+
iree-runtime==20230920.651

tests/dynamo/importer_basic_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,6 @@ def forward(self, x):
255255

256256
@unittest.expectedFailure
257257
def testImportAtenFull(self):
258-
"""Expected to fail until torch-mlir op: torch.aten.empty_strided is implemented"""
259-
260258
def foo(x):
261259
return torch.full(x.size(), fill_value=float("-inf"))
262260

tests/dynamo/mninst_test.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2023 Nod Labs, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import logging
8+
9+
import math
10+
import unittest
11+
from dataclasses import dataclass
12+
from typing import Any, Optional, Tuple
13+
14+
import torch
15+
import torch.nn.functional as F
16+
from torch import nn
17+
import torch.optim as optim
18+
19+
import torchvision.transforms as transforms
20+
import torchvision.datasets as datasets
21+
from torch.utils.data import DataLoader
22+
23+
24+
# MNIST Data Loader
25+
class MNISTDataLoader:
26+
def __init__(self, batch_size, shuffle=True):
27+
self.batch_size = batch_size
28+
self.shuffle = shuffle
29+
30+
# Data Transformations
31+
transform = transforms.Compose([
32+
transforms.ToTensor(),
33+
transforms.Normalize((0.5,), (0.5,))
34+
])
35+
36+
# Download MNIST dataset
37+
self.mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
38+
self.mnist_testset = datasets.MNIST(root='../data', train=False, download=True, transform=transform)
39+
40+
def get_train_loader(self):
41+
return DataLoader(
42+
dataset=self.mnist_trainset,
43+
batch_size=self.batch_size,
44+
shuffle=self.shuffle
45+
)
46+
47+
def get_test_loader(self):
48+
return DataLoader(
49+
dataset=self.mnist_testset,
50+
batch_size=self.batch_size,
51+
shuffle=False
52+
)
53+
54+
55+
# Simple CNN Model
56+
class CNN(nn.Module):
57+
def __init__(self):
58+
super(CNN, self).__init__()
59+
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
60+
self.relu = nn.ReLU()
61+
self.maxpool = nn.MaxPool2d(kernel_size=2)
62+
self.fc1 = nn.Linear(32 * 12 * 12, 10)
63+
64+
def forward(self, x):
65+
x = self.conv1(x)
66+
x = self.relu(x)
67+
x = self.maxpool(x)
68+
x = x.view(x.size(0), -1)
69+
x = self.fc1(x)
70+
return x
71+
72+
# Training
73+
def train(model, images, labels, optimizer, criterion):
74+
model.train()
75+
76+
total_loss = 0.0
77+
num_correct = 0.0
78+
79+
optimizer.zero_grad()
80+
# images, labels = images.to(device), labels.to(device)
81+
outputs = model(images)
82+
loss = criterion(outputs, labels)
83+
84+
num_correct += int((torch.argmax(outputs, dim=1) == labels).sum())
85+
total_loss += float(loss.item())
86+
87+
loss.backward()
88+
optimizer.step()
89+
total_loss += loss.item()
90+
91+
# TODO Implement inference func
92+
"""
93+
def test(model, images, labels, criterion):
94+
model.eval()
95+
num_correct = 0.0
96+
total_loss = 0.0
97+
with torch.no_grad():
98+
99+
# images, labels = images.to(device), labels.to(device)
100+
with torch.inference_mode():
101+
outputs = model(images)
102+
loss = criterion(outputs, labels)
103+
104+
num_correct += int((torch.argmax(outputs, dim=1) == labels).sum())
105+
total_loss += float(loss.item())
106+
107+
# acc = 100 * num_correct / (config['batch_size'] * len(test_loader))
108+
# total_loss = float(total_loss / len(test_loader))
109+
# return acc, total_loss
110+
"""
111+
112+
def main():
113+
# Example Hyperparameters
114+
config = {
115+
'batch_size': 64,
116+
'learning_rate': 0.001,
117+
# 'threshold' : 0.001,
118+
# 'factor' : 0.1,
119+
'num_epochs': 10,
120+
}
121+
122+
# Data Loader
123+
custom_data_loader = MNISTDataLoader(config['batch_size'])
124+
train_loader = custom_data_loader.get_train_loader()
125+
# test_loader = MNISTDataLoader.get_test_loader()
126+
127+
# Model, optimizer, loss
128+
model = CNN()
129+
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
130+
criterion = nn.CrossEntropyLoss()
131+
132+
# Training
133+
train_opt = torch.compile(train, backend="turbine_cpu")
134+
for i, (images, labels) in enumerate(train_loader):
135+
train_opt(model, images, labels, optimizer, criterion)
136+
137+
138+
# TODO: Inference
139+
"""
140+
test_opt = torch.compile(test, backend="turbine_cpu", mode="reduce-overhead")
141+
for i, (images, labels) in enumerate(test_loader):
142+
test(model, images, labels, criterion)
143+
"""
144+
145+
146+
147+
class ModelTests(unittest.TestCase):
148+
@unittest.expectedFailure
149+
def testMNIST(self):
150+
# TODO: Fix the below error
151+
"""
152+
failed to legalize operation 'arith.sitofp' that was explicitly marked illegal
153+
"""
154+
main()
155+
156+
157+
if __name__ == "__main__":
158+
logging.basicConfig(level=logging.DEBUG)
159+
unittest.main()

torchvision-requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
2+
--pre
3+
torchvision==0.16.0.dev20230901

0 commit comments

Comments
 (0)