Skip to content

Commit 2297d48

Browse files
committed
fix fastai tests
1 parent 4f3a31e commit 2297d48

File tree

1 file changed

+11
-26
lines changed

1 file changed

+11
-26
lines changed

tests/test_fastai.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,12 @@
11
import unittest
22

33
import fastai
4-
import pandas as pd
5-
import torch
64

7-
from fastai.tabular import *
8-
from fastai.core import partition
9-
from fastai.torch_core import tensor
5+
from fastai.tabular.all import *
106

117
class TestFastAI(unittest.TestCase):
12-
def test_partition(self):
13-
result = partition([1,2,3,4,5], 2)
14-
15-
self.assertEqual(3, len(result))
16-
178
def test_has_version(self):
18-
self.assertGreater(len(fastai.__version__), 1)
9+
self.assertGreater(len(fastai.__version__), 2)
1910

2011
# based on https://github.com/fastai/fastai/blob/master/tests/test_torch_core.py#L17
2112
def test_torch_tensor(self):
@@ -25,18 +16,12 @@ def test_torch_tensor(self):
2516
self.assertTrue(torch.all(a == b))
2617

2718
def test_tabular(self):
28-
df = pd.read_csv("/input/tests/data/train.csv")
29-
procs = [FillMissing, Categorify, Normalize]
30-
31-
valid_idx = range(len(df)-5, len(df))
32-
dep_var = "label"
33-
cont_names = []
34-
for i in range(784):
35-
cont_names.append("pixel" + str(i))
36-
37-
data = (TabularList.from_df(df, path="", cont_names=cont_names, cat_names=[], procs=procs)
38-
.split_by_idx(valid_idx)
39-
.label_from_df(cols=dep_var)
40-
.databunch())
41-
learn = tabular_learner(data, layers=[200, 100])
42-
learn.fit(epochs=1)
19+
dls = TabularDataLoaders.from_csv(
20+
"/input/tests/data/train.csv",
21+
cont_names=["pixel"+str(i) for i in range(784)],
22+
y_names='label',
23+
procs=[FillMissing, Categorify, Normalize])
24+
learn = tabular_learner(dls, layers=[200, 100])
25+
learn.fit_one_cycle(n_epoch=1)
26+
27+
self.assertGreater(learn.smooth_loss, 0)

0 commit comments

Comments
 (0)