|
11 | 11 | }, |
12 | 12 | { |
13 | 13 | "cell_type": "code", |
14 | | - "execution_count": 3, |
| 14 | + "execution_count": null, |
15 | 15 | "id": "904f6e6c", |
16 | 16 | "metadata": {}, |
17 | 17 | "outputs": [], |
|
21 | 21 | "from timm import create_model\n", |
22 | 22 | "\n", |
23 | 23 | "from wildlife_tools.features import DeepFeatures\n", |
24 | | - "from wildlife_tools.data import WildlifeDataset, SplitMetadata\n", |
| 24 | + "from wildlife_tools.data import WildlifeDataset\n", |
25 | 25 | "from wildlife_tools.similarity import CosineSimilarity\n", |
26 | 26 | "from wildlife_tools.inference import KnnClassifier\n", |
27 | 27 | "\n", |
|
67 | 67 | }, |
68 | 68 | { |
69 | 69 | "cell_type": "code", |
70 | | - "execution_count": 2, |
| 70 | + "execution_count": null, |
71 | 71 | "id": "e4381543", |
72 | 72 | "metadata": {}, |
73 | 73 | "outputs": [ |
|
519 | 519 | " ])\n", |
520 | 520 | "\n", |
521 | 521 | " database = WildlifeDataset(\n", |
522 | | - " metadata=metadata,\n", |
| 522 | + " metadata=metadata.query('split == \"train\"'),\n", |
523 | 523 | " root=f'{root_images}/{name}/',\n", |
524 | 524 | " transform=transform,\n", |
525 | | - " split=SplitMetadata('split', 'train'),\n", |
526 | 525 | " )\n", |
527 | 526 | "\n", |
528 | 527 | " query = WildlifeDataset(\n", |
529 | | - " metadata=metadata,\n", |
| 528 | + " metadata=metadata.query('split == \"test\"'),\n", |
530 | 529 | " root=f'{root_images}/{name}/',\n", |
531 | 530 | " transform=transform,\n", |
532 | | - " split=SplitMetadata('split', 'test'),\n", |
533 | 531 | " )\n", |
534 | 532 | "\n", |
| 533 | + "\n", |
535 | 534 | " matcher = CosineSimilarity()\n", |
536 | 535 | " similarity = matcher(query=extractor(query), database=extractor(database))\n", |
537 | 536 | " preds = KnnClassifier(k=1, database_labels=database.labels_string)(similarity['cosine'])\n", |
|
0 commit comments