|
329 | 329 | "create_model = build_ts_model"
|
330 | 330 | ]
|
331 | 331 | },
|
| 332 | + { |
| 333 | + "cell_type": "code", |
| 334 | + "execution_count": null, |
| 335 | + "metadata": {}, |
| 336 | + "outputs": [], |
| 337 | + "source": [ |
| 338 | + "from tsai.data.core import get_ts_dls, TSClassification\n", |
| 339 | + "from tsai.models.TSiTPlus import TSiTPlus\n", |
| 340 | + "from fastai.losses import CrossEntropyLossFlat" |
| 341 | + ] |
| 342 | + }, |
| 343 | + { |
| 344 | + "cell_type": "code", |
| 345 | + "execution_count": null, |
| 346 | + "metadata": {}, |
| 347 | + "outputs": [ |
| 348 | + { |
| 349 | + "name": "stdout", |
| 350 | + "output_type": "stream", |
| 351 | + "text": [ |
| 352 | + "arch: TSiTPlus(c_in=3 c_out=2 seq_len=128 arch_config={} kwargs={'custom_head': functools.partial(<class 'tsai.models.layers.lin_nd_head'>, d=3)})\n", |
| 353 | + "torch.Size([13, 3, 2])\n", |
| 354 | + "TensorBase(0.8796, grad_fn=<AliasBackward0>)\n" |
| 355 | + ] |
| 356 | + } |
| 357 | + ], |
| 358 | + "source": [ |
| 359 | + "X = np.random.rand(16, 3, 128)\n", |
| 360 | + "y = np.random.randint(0, 2, (16, 3))\n", |
| 361 | + "tfms = [None, [TSClassification()]]\n", |
| 362 | + "dls = get_ts_dls(X, y, splits=RandomSplitter()(range_of(X)), tfms=tfms)\n", |
| 363 | + "model = build_ts_model(TSiTPlus, dls=dls, pretrained=False, verbose=True)\n", |
| 364 | + "xb, yb = dls.one_batch()\n", |
| 365 | + "output = model(xb)\n", |
| 366 | + "print(output.shape)\n", |
| 367 | + "loss = CrossEntropyLossFlat()(output, yb)\n", |
| 368 | + "print(loss)\n", |
| 369 | + "assert output.shape == (dls.bs, dls.d, dls.c)" |
| 370 | + ] |
| 371 | + }, |
332 | 372 | {
|
333 | 373 | "cell_type": "code",
|
334 | 374 | "execution_count": null,
|
|
495 | 535 | "cell_type": "code",
|
496 | 536 | "execution_count": null,
|
497 | 537 | "metadata": {},
|
498 |
| - "outputs": [ |
499 |
| - { |
500 |
| - "name": "stderr", |
501 |
| - "output_type": "stream", |
502 |
| - "text": [ |
503 |
| - "[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.\n" |
504 |
| - ] |
505 |
| - } |
506 |
| - ], |
| 538 | + "outputs": [], |
507 | 539 | "source": [
|
508 | 540 | "c_in = 3\n",
|
509 | 541 | "seq_len = 30\n",
|
|
555 | 587 | {
|
556 | 588 | "data": {
|
557 | 589 | "text/plain": [
|
558 |
| - "(array([ 0.74775537, 1.41245663, 2.12445924, 2.8943163 , 3.56384351,\n", |
559 |
| - " 4.23789602, 4.83134182, 5.18560431, 5.30551186, 6.29076506,\n", |
560 |
| - " 6.58873471, 7.03661275, 7.0884361 , 7.57927022, 8.21911791,\n", |
561 |
| - " 8.59726773, 9.37382718, 10.17298849, 10.40118308, 10.82265631]),\n", |
562 |
| - " array([ 6.29076506, 6.58873471, 7.03661275, 7.0884361 , 7.57927022,\n", |
563 |
| - " 8.21911791, 8.59726773, 9.37382718, 10.17298849, 10.40118308]),\n", |
564 |
| - " array([ 6.58873471, 7.03661275, 7.0884361 , 7.57927022, 8.21911791,\n", |
565 |
| - " 8.59726773, 9.37382718, 10.17298849, 10.40118308, 10.82265631]))" |
| 590 | + "(array([0.99029138, 1.68463991, 2.21744589, 2.65448222, 2.85159354,\n", |
| 591 | + " 3.26171729, 3.67986707, 4.04343956, 4.3077458 , 4.44585435,\n", |
| 592 | + " 4.76876866, 4.85844441, 4.93256093, 5.52415845, 6.10704489,\n", |
| 593 | + " 6.74848957, 7.31920741, 8.20198208, 8.78954283, 9.0402 ]),\n", |
| 594 | + " array([4.44585435, 4.76876866, 4.85844441, 4.93256093, 5.52415845,\n", |
| 595 | + " 6.10704489, 6.74848957, 7.31920741, 8.20198208, 8.78954283]),\n", |
| 596 | + " array([4.76876866, 4.85844441, 4.93256093, 5.52415845, 6.10704489,\n", |
| 597 | + " 6.74848957, 7.31920741, 8.20198208, 8.78954283, 9.0402 ]))" |
566 | 598 | ]
|
567 | 599 | },
|
568 | 600 | "execution_count": null,
|
|
595 | 627 | "name": "stdout",
|
596 | 628 | "output_type": "stream",
|
597 | 629 | "text": [
|
598 |
| - "/Users/nacho/notebooks/tsai/nbs/030_models.utils.ipynb saved at 2024-01-31 13:03:06\n", |
| 630 | + "/Users/nacho/notebooks/tsai/nbs/030_models.utils.ipynb saved at 2025-01-22 18:23:18\n", |
599 | 631 | "Correct notebook to script conversion! 😃\n",
|
600 |
| - "Wednesday 31/01/24 13:03:08 CET\n" |
| 632 | + "Wednesday 22/01/25 18:23:21 CET\n" |
601 | 633 | ]
|
602 | 634 | },
|
603 | 635 | {
|
|
0 commit comments