Skip to content

Commit d23be0a

Browse files
committed
Use function instead of class for temp_dir, fix type issue
1 parent 1ed631f commit d23be0a

File tree

8 files changed

+47
-48
lines changed

8 files changed

+47
-48
lines changed

tests/data-files/get_examples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import pystac
1313
from pystac.serialization import identify_stac_object
14-
from tests.utils import TemporaryDirectory
14+
from tests.utils import get_temp_dir
1515

1616

1717
def remove_bad_collection(js: Dict[str, Any]) -> Dict[str, Any]:
@@ -50,7 +50,7 @@ def remove_bad_collection(js: Dict[str, Any]) -> Dict[str, Any]:
5050

5151
examples_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "examples"))
5252

53-
with TemporaryDirectory() as tmp_dir:
53+
with get_temp_dir() as tmp_dir:
5454
call(
5555
[
5656
"git",

tests/extensions/test_label.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
import pystac.validation
1717
from pystac.utils import get_opt
18-
from tests.utils import TestCases, assert_to_from_dict, TemporaryDirectory
18+
from tests.utils import TestCases, assert_to_from_dict, get_temp_dir
1919

2020

2121
class LabelTest(unittest.TestCase):
@@ -85,7 +85,7 @@ def test_validate_label(self) -> None:
8585
label_example_1_dict, pystac.STACObjectType.ITEM
8686
)
8787

88-
with TemporaryDirectory() as tmp_dir:
88+
with get_temp_dir() as tmp_dir:
8989
cat_dir = os.path.join(tmp_dir, "catalog")
9090
catalog = TestCases.test_case_1()
9191
catalog.normalize_and_save(cat_dir, catalog_type=CatalogType.SELF_CONTAINED)

tests/test_catalog.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
ARBITRARY_GEOM,
2323
ARBITRARY_BBOX,
2424
MockStacIO,
25-
TemporaryDirectory,
25+
get_temp_dir,
2626
)
2727

2828

2929
class CatalogTypeTest(unittest.TestCase):
3030
def test_determine_type_for_absolute_published(self) -> None:
3131
cat = TestCases.test_case_1()
32-
with TemporaryDirectory() as tmp_dir:
32+
with get_temp_dir() as tmp_dir:
3333
cat.normalize_and_save(tmp_dir, catalog_type=CatalogType.ABSOLUTE_PUBLISHED)
3434
cat_json = pystac.StacIO.default().read_json(
3535
os.path.join(tmp_dir, "catalog.json")
@@ -40,7 +40,7 @@ def test_determine_type_for_absolute_published(self) -> None:
4040

4141
def test_determine_type_for_relative_published(self) -> None:
4242
cat = TestCases.test_case_2()
43-
with TemporaryDirectory() as tmp_dir:
43+
with get_temp_dir() as tmp_dir:
4444
cat.normalize_and_save(tmp_dir, catalog_type=CatalogType.RELATIVE_PUBLISHED)
4545
cat_json = pystac.StacIO.default().read_json(
4646
os.path.join(tmp_dir, "catalog.json")
@@ -68,7 +68,7 @@ def test_determine_type_for_unknown(self) -> None:
6868

6969
class CatalogTest(unittest.TestCase):
7070
def test_create_and_read(self) -> None:
71-
with TemporaryDirectory() as tmp_dir:
71+
with get_temp_dir() as tmp_dir:
7272
cat_dir = os.path.join(tmp_dir, "catalog")
7373
catalog = TestCases.test_case_1()
7474

@@ -288,7 +288,7 @@ def test_clone_generates_correct_links(self) -> None:
288288
def test_save_uses_previous_catalog_type(self) -> None:
289289
catalog = TestCases.test_case_1()
290290
assert catalog.catalog_type == CatalogType.SELF_CONTAINED
291-
with TemporaryDirectory() as tmp_dir:
291+
with get_temp_dir() as tmp_dir:
292292
catalog.normalize_hrefs(tmp_dir)
293293
href = catalog.self_href
294294
catalog.save()
@@ -365,7 +365,7 @@ def test_generate_subcatalogs_does_not_change_item_count(self) -> None:
365365

366366
catalog.generate_subcatalogs("${year}/${day}")
367367

368-
with TemporaryDirectory() as tmp_dir:
368+
with get_temp_dir() as tmp_dir:
369369
catalog.normalize_hrefs(tmp_dir)
370370
catalog.save(pystac.CatalogType.SELF_CONTAINED)
371371

@@ -494,7 +494,7 @@ def item_mapper(item: pystac.Item) -> pystac.Item:
494494
item.properties["ITEM_MAPPER"] = "YEP"
495495
return item
496496

497-
with TemporaryDirectory() as tmp_dir:
497+
with get_temp_dir() as tmp_dir:
498498
catalog = TestCases.test_case_1()
499499

500500
new_cat = catalog.map_items(item_mapper)
@@ -518,7 +518,7 @@ def item_mapper(item: pystac.Item) -> List[pystac.Item]:
518518
item2.properties["ITEM_MAPPER_2"] = "YEP"
519519
return [item, item2]
520520

521-
with TemporaryDirectory() as tmp_dir:
521+
with get_temp_dir() as tmp_dir:
522522
catalog = TestCases.test_case_1()
523523
catalog_items = catalog.get_all_items()
524524

@@ -623,7 +623,7 @@ def asset_mapper(key: str, asset: pystac.Asset) -> pystac.Asset:
623623

624624
return asset
625625

626-
with TemporaryDirectory() as tmp_dir:
626+
with get_temp_dir() as tmp_dir:
627627
catalog = TestCases.test_case_2()
628628

629629
new_cat = catalog.map_assets(asset_mapper)
@@ -656,7 +656,7 @@ def asset_mapper(
656656
else:
657657
return asset
658658

659-
with TemporaryDirectory() as tmp_dir:
659+
with get_temp_dir() as tmp_dir:
660660
catalog = TestCases.test_case_2()
661661

662662
new_cat = catalog.map_assets(asset_mapper)
@@ -696,7 +696,7 @@ def asset_mapper(
696696
else:
697697
return asset
698698

699-
with TemporaryDirectory() as tmp_dir:
699+
with get_temp_dir() as tmp_dir:
700700
catalog = TestCases.test_case_2()
701701

702702
new_cat = catalog.map_assets(asset_mapper)
@@ -771,7 +771,7 @@ def check_all_absolute(cat: Catalog) -> None:
771771
test_cases = TestCases.all_test_catalogs()
772772

773773
for catalog in test_cases:
774-
with TemporaryDirectory() as tmp_dir:
774+
with get_temp_dir() as tmp_dir:
775775
c2 = catalog.full_copy()
776776
c2.normalize_hrefs(tmp_dir)
777777
c2.catalog_type = CatalogType.RELATIVE_PUBLISHED
@@ -797,7 +797,7 @@ def test_extra_fields(self) -> None:
797797

798798
catalog.extra_fields["type"] = "FeatureCollection"
799799

800-
with TemporaryDirectory() as tmp_dir:
800+
with get_temp_dir() as tmp_dir:
801801
p = os.path.join(tmp_dir, "catalog.json")
802802
catalog.save_object(include_self_link=False, dest_href=p)
803803
with open(p) as f:
@@ -822,7 +822,7 @@ def test_validate_all(self) -> None:
822822
item = cat.get_item("area-1-1-labels", recursive=True)
823823
assert item is not None
824824
item.geometry = {"type": "INVALID", "coordinates": "NONE"}
825-
with TemporaryDirectory() as tmp_dir:
825+
with get_temp_dir() as tmp_dir:
826826
cat.normalize_hrefs(tmp_dir)
827827
cat.save(catalog_type=pystac.CatalogType.SELF_CONTAINED)
828828

@@ -843,7 +843,7 @@ def test_set_hrefs_manually(self) -> None:
843843
year += 1
844844
month += 1
845845

846-
with TemporaryDirectory() as tmp_dir:
846+
with get_temp_dir() as tmp_dir:
847847
for root, _, items in catalog.walk():
848848

849849
# Set root's HREF based off the parent
@@ -933,7 +933,7 @@ def test_reading_iterating_and_writing_works_as_expected(self) -> None:
933933
for item in cat.get_all_items():
934934
pass
935935

936-
with TemporaryDirectory() as tmp_dir:
936+
with get_temp_dir() as tmp_dir:
937937
new_stac_uri = os.path.join(tmp_dir, "test-case-6")
938938
cat.normalize_hrefs(new_stac_uri)
939939
cat.save(catalog_type=CatalogType.SELF_CONTAINED)
@@ -1003,7 +1003,7 @@ def check_catalog(self, c: Catalog, tag: str) -> None:
10031003
self.check_item(item, tag)
10041004

10051005
def test_full_copy_1(self) -> None:
1006-
with TemporaryDirectory() as tmp_dir:
1006+
with get_temp_dir() as tmp_dir:
10071007
cat = Catalog(id="test", description="test catalog")
10081008

10091009
item = Item(
@@ -1024,7 +1024,7 @@ def test_full_copy_1(self) -> None:
10241024
self.check_catalog(cat2, "dest")
10251025

10261026
def test_full_copy_2(self) -> None:
1027-
with TemporaryDirectory() as tmp_dir:
1027+
with get_temp_dir() as tmp_dir:
10281028
cat = Catalog(id="test", description="test catalog")
10291029
image_item = Item(
10301030
id="Imagery",
@@ -1071,7 +1071,7 @@ def test_full_copy_2(self) -> None:
10711071
self.check_catalog(cat2, "dest")
10721072

10731073
def test_full_copy_3(self) -> None:
1074-
with TemporaryDirectory() as tmp_dir:
1074+
with get_temp_dir() as tmp_dir:
10751075
root_cat = TestCases.test_case_1()
10761076
root_cat.normalize_hrefs(
10771077
os.path.join(tmp_dir, "catalog-full-copy-3-source")
@@ -1085,7 +1085,7 @@ def test_full_copy_3(self) -> None:
10851085
self.check_catalog(cat2, "dest")
10861086

10871087
def test_full_copy_4(self) -> None:
1088-
with TemporaryDirectory() as tmp_dir:
1088+
with get_temp_dir() as tmp_dir:
10891089
root_cat = TestCases.test_case_2()
10901090
root_cat.normalize_hrefs(
10911091
os.path.join(tmp_dir, "catalog-full-copy-4-source")

tests/test_collection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pystac.validation import validate_dict
1010
from pystac import Collection, Item, Extent, SpatialExtent, TemporalExtent, CatalogType
1111
from pystac.utils import datetime_to_str
12-
from tests.utils import TestCases, ARBITRARY_GEOM, ARBITRARY_BBOX, TemporaryDirectory
12+
from tests.utils import TestCases, ARBITRARY_GEOM, ARBITRARY_BBOX, get_temp_dir
1313

1414
TEST_DATETIME = datetime(2020, 3, 14, 16, 32)
1515

@@ -34,7 +34,7 @@ def test_save_uses_previous_catalog_type(self) -> None:
3434
collection = TestCases.test_case_8()
3535
assert collection.STAC_OBJECT_TYPE == pystac.STACObjectType.COLLECTION
3636
self.assertEqual(collection.catalog_type, CatalogType.SELF_CONTAINED)
37-
with TemporaryDirectory() as tmp_dir:
37+
with get_temp_dir() as tmp_dir:
3838
collection.normalize_hrefs(tmp_dir)
3939
href = collection.self_href
4040
collection.save()
@@ -83,7 +83,7 @@ def test_extra_fields(self) -> None:
8383

8484
collection.extra_fields["test"] = "extra"
8585

86-
with TemporaryDirectory() as tmp_dir:
86+
with get_temp_dir() as tmp_dir:
8787
p = os.path.join(tmp_dir, "collection.json")
8888
collection.save_object(include_self_link=False, dest_href=p)
8989
with open(p) as f:

tests/test_item.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pystac.serialization.common_properties
1111
from pystac.item import CommonMetadata
1212
from pystac.utils import datetime_to_str, get_opt, str_to_datetime, is_absolute_href
13-
from tests.utils import TestCases, assert_to_from_dict, TemporaryDirectory
13+
from tests.utils import TestCases, assert_to_from_dict, get_temp_dir
1414

1515

1616
class ItemTest(unittest.TestCase):
@@ -72,7 +72,7 @@ def test_extra_fields(self) -> None:
7272

7373
item.extra_fields["test"] = "extra"
7474

75-
with TemporaryDirectory() as tmp_dir:
75+
with get_temp_dir() as tmp_dir:
7676
p = os.path.join(tmp_dir, "item.json")
7777
item.save_object(include_self_link=False, dest_href=p)
7878
with open(p) as f:

tests/test_writing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pystac.utils import is_absolute_href, make_absolute_href, make_relative_href
77
from pystac.validation import validate_dict
88

9-
from tests.utils import TestCases, TemporaryDirectory
9+
from tests.utils import TestCases, get_temp_dir
1010

1111

1212
class STACWritingTest(unittest.TestCase):
@@ -104,7 +104,7 @@ def validate_catalog_link_type(
104104
def do_test(
105105
self, catalog: pystac.Catalog, catalog_type: pystac.CatalogType
106106
) -> None:
107-
with TemporaryDirectory() as tmp_dir:
107+
with get_temp_dir() as tmp_dir:
108108
catalog.normalize_hrefs(tmp_dir)
109109
self.validate_catalog(catalog)
110110

tests/utils/__init__.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# flake8: noqa
22

33
import os
4-
import tempfile
5-
from typing import Any, Dict, Type, Optional
4+
from tempfile import TemporaryDirectory
5+
from typing import Any, AnyStr, Dict, TYPE_CHECKING, Type
66
import unittest
77
from tests.utils.test_cases import (
88
TestCases,
@@ -18,6 +18,9 @@
1818
import pystac
1919
from tests.utils.stac_io_mock import MockStacIO
2020

21+
if TYPE_CHECKING:
22+
from tempfile import TemporaryDirectory as TemporaryDirectory_Type
23+
2124

2225
def assert_to_from_dict(
2326
test_class: unittest.TestCase,
@@ -45,16 +48,12 @@ def _parse_times(a_dict: Dict[str, Any]) -> None:
4548
test_class.assertDictEqual(d1, d2)
4649

4750

48-
# Mypy raises an error for this class:
49-
# error: Missing type parameters for generic type "TemporaryDirectory" [type-arg]
50-
# Trying to add a concrete type (e.g. TemporaryDirectory[str]) satisfies mypy, but
51-
# raises a runtime exception:
52-
# TypeError: 'type' object is not subscriptable
53-
class TemporaryDirectory(tempfile.TemporaryDirectory): # type: ignore [type-arg]
54-
def __init__(self, suffix: Optional[str] = None, prefix: Optional[str] = None):
55-
"""In the GitHub Actions Windows runner the default TMPDIR directory is on a
56-
different drive (C:\\) than the code and test data files (D:\\). This was causing
57-
failures in os.path.relpath on Windows, so we put the temp directories in the
58-
current working directory instead. There os a "tmp*" line in the .gitignore file
59-
that ignores these directories."""
60-
super().__init__(suffix=suffix, prefix=prefix, dir=os.getcwd())
51+
# Use suggestion from https://github.com/python/mypy/issues/5264#issuecomment-399407428
52+
# to solve type errors.
53+
def get_temp_dir() -> "TemporaryDirectory_Type[str]":
54+
"""In the GitHub Actions Windows runner the default TMPDIR directory is on a
55+
different drive (C:\\) than the code and test data files (D:\\). This was causing
56+
failures in os.path.relpath on Windows, so we put the temp directories in the
57+
current working directory instead. There os a "tmp*" line in the .gitignore file
58+
that ignores these directories."""
59+
return TemporaryDirectory(dir=os.getcwd())

tests/validation/test_validate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pystac.validation
1313
from pystac.cache import CollectionCache
1414
from pystac.serialization.common_properties import merge_common_properties
15-
from tests.utils import TestCases, TemporaryDirectory
15+
from tests.utils import TestCases, get_temp_dir
1616

1717

1818
class ValidateTest(unittest.TestCase):
@@ -98,7 +98,7 @@ def test_validate_all(self) -> None:
9898
# Modify a 0.8.1 collection in a catalog to be invalid with a
9999
# since-renamed extension and make sure it catches the validation error.
100100

101-
with TemporaryDirectory() as tmp_dir:
101+
with get_temp_dir() as tmp_dir:
102102
dst_dir = os.path.join(tmp_dir, "catalog")
103103
# Copy test case 7 to the temporary directory
104104
catalog_href = get_opt(TestCases.test_case_7().get_self_href())

0 commit comments

Comments
 (0)