Skip to content

Commit 0490421

Browse files
authored
Merge pull request #934 from timeseriesAI/tsai_v0.4.0
Tsai v0.4.0
2 parents 3f461e5 + 07ab9f5 commit 0490421

17 files changed

+698
-633
lines changed

nbs/003_data.validation.ipynb

Lines changed: 113 additions & 113 deletions
Large diffs are not rendered by default.

nbs/021_calibration.ipynb

Lines changed: 17 additions & 25 deletions
Large diffs are not rendered by default.

nbs/022_tslearner.ipynb

Lines changed: 102 additions & 18 deletions
Large diffs are not rendered by default.

nbs/029_models.layers.ipynb

Lines changed: 15 additions & 15 deletions
Large diffs are not rendered by default.

nbs/030_models.utils.ipynb

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,46 @@
329329
"create_model = build_ts_model"
330330
]
331331
},
332+
{
333+
"cell_type": "code",
334+
"execution_count": null,
335+
"metadata": {},
336+
"outputs": [],
337+
"source": [
338+
"from tsai.data.core import get_ts_dls, TSClassification\n",
339+
"from tsai.models.TSiTPlus import TSiTPlus\n",
340+
"from fastai.losses import CrossEntropyLossFlat"
341+
]
342+
},
343+
{
344+
"cell_type": "code",
345+
"execution_count": null,
346+
"metadata": {},
347+
"outputs": [
348+
{
349+
"name": "stdout",
350+
"output_type": "stream",
351+
"text": [
352+
"arch: TSiTPlus(c_in=3 c_out=2 seq_len=128 arch_config={} kwargs={'custom_head': functools.partial(<class 'tsai.models.layers.lin_nd_head'>, d=3)})\n",
353+
"torch.Size([13, 3, 2])\n",
354+
"TensorBase(0.8796, grad_fn=<AliasBackward0>)\n"
355+
]
356+
}
357+
],
358+
"source": [
359+
"X = np.random.rand(16, 3, 128)\n",
360+
"y = np.random.randint(0, 2, (16, 3))\n",
361+
"tfms = [None, [TSClassification()]]\n",
362+
"dls = get_ts_dls(X, y, splits=RandomSplitter()(range_of(X)), tfms=tfms)\n",
363+
"model = build_ts_model(TSiTPlus, dls=dls, pretrained=False, verbose=True)\n",
364+
"xb, yb = dls.one_batch()\n",
365+
"output = model(xb)\n",
366+
"print(output.shape)\n",
367+
"loss = CrossEntropyLossFlat()(output, yb)\n",
368+
"print(loss)\n",
369+
"assert output.shape == (dls.bs, dls.d, dls.c)"
370+
]
371+
},
332372
{
333373
"cell_type": "code",
334374
"execution_count": null,
@@ -495,15 +535,7 @@
495535
"cell_type": "code",
496536
"execution_count": null,
497537
"metadata": {},
498-
"outputs": [
499-
{
500-
"name": "stderr",
501-
"output_type": "stream",
502-
"text": [
503-
"[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.\n"
504-
]
505-
}
506-
],
538+
"outputs": [],
507539
"source": [
508540
"c_in = 3\n",
509541
"seq_len = 30\n",
@@ -555,14 +587,14 @@
555587
{
556588
"data": {
557589
"text/plain": [
558-
"(array([ 0.74775537, 1.41245663, 2.12445924, 2.8943163 , 3.56384351,\n",
559-
" 4.23789602, 4.83134182, 5.18560431, 5.30551186, 6.29076506,\n",
560-
" 6.58873471, 7.03661275, 7.0884361 , 7.57927022, 8.21911791,\n",
561-
" 8.59726773, 9.37382718, 10.17298849, 10.40118308, 10.82265631]),\n",
562-
" array([ 6.29076506, 6.58873471, 7.03661275, 7.0884361 , 7.57927022,\n",
563-
" 8.21911791, 8.59726773, 9.37382718, 10.17298849, 10.40118308]),\n",
564-
" array([ 6.58873471, 7.03661275, 7.0884361 , 7.57927022, 8.21911791,\n",
565-
" 8.59726773, 9.37382718, 10.17298849, 10.40118308, 10.82265631]))"
590+
"(array([0.99029138, 1.68463991, 2.21744589, 2.65448222, 2.85159354,\n",
591+
" 3.26171729, 3.67986707, 4.04343956, 4.3077458 , 4.44585435,\n",
592+
" 4.76876866, 4.85844441, 4.93256093, 5.52415845, 6.10704489,\n",
593+
" 6.74848957, 7.31920741, 8.20198208, 8.78954283, 9.0402 ]),\n",
594+
" array([4.44585435, 4.76876866, 4.85844441, 4.93256093, 5.52415845,\n",
595+
" 6.10704489, 6.74848957, 7.31920741, 8.20198208, 8.78954283]),\n",
596+
" array([4.76876866, 4.85844441, 4.93256093, 5.52415845, 6.10704489,\n",
597+
" 6.74848957, 7.31920741, 8.20198208, 8.78954283, 9.0402 ]))"
566598
]
567599
},
568600
"execution_count": null,
@@ -595,9 +627,9 @@
595627
"name": "stdout",
596628
"output_type": "stream",
597629
"text": [
598-
"/Users/nacho/notebooks/tsai/nbs/030_models.utils.ipynb saved at 2024-01-31 13:03:06\n",
630+
"/Users/nacho/notebooks/tsai/nbs/030_models.utils.ipynb saved at 2025-01-22 18:23:18\n",
599631
"Correct notebook to script conversion! 😃\n",
600-
"Wednesday 31/01/24 13:03:08 CET\n"
632+
"Wednesday 22/01/25 18:23:21 CET\n"
601633
]
602634
},
603635
{

nbs/068_models.TSiTPlus.ipynb

Lines changed: 154 additions & 57 deletions
Large diffs are not rendered by default.

nbs/069_models.TSSequencerPlus.ipynb

Lines changed: 7 additions & 7 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)