6
6
from pathlib import Path
7
7
from typing import TYPE_CHECKING , Callable , cast
8
8
9
+ import platformdirs
10
+
9
11
from quackosm ._constants import WGS84_CRS
10
12
11
13
if TYPE_CHECKING : # pragma: no cover
@@ -56,14 +58,21 @@ def load_index_decorator(
56
58
57
59
def inner (function : Callable [[], "GeoDataFrame" ]) -> Callable [[], "GeoDataFrame" ]:
58
60
def wrapper () -> "GeoDataFrame" :
59
- cache_file_path = _get_cache_file_path (extract_source )
61
+ global_cache_file_path = _get_global_cache_file_path (extract_source )
60
62
expected_columns = ["id" , "name" , "file_name" , "parent" , "geometry" , "area" , "url" ]
61
63
62
64
# Check if index exists in cache
63
- if cache_file_path .exists ():
65
+ if global_cache_file_path .exists ():
66
+ import geopandas as gpd
67
+
68
+ index_gdf = gpd .read_file (global_cache_file_path )
69
+ elif (local_cache_file_path := _get_local_cache_file_path (extract_source )).exists ():
70
+ import shutil
71
+
64
72
import geopandas as gpd
65
73
66
- index_gdf = gpd .read_file (cache_file_path )
74
+ shutil .copy (local_cache_file_path , global_cache_file_path )
75
+ index_gdf = gpd .read_file (global_cache_file_path )
67
76
# Download index
68
77
else : # pragma: no cover
69
78
index_gdf = function ()
@@ -87,14 +96,14 @@ def wrapper() -> "GeoDataFrame":
87
96
stacklevel = 0 ,
88
97
)
89
98
# Invalidate previous cached index
90
- cache_file_path .replace (cache_file_path .with_suffix (".geojson.old" ))
99
+ global_cache_file_path .replace (global_cache_file_path .with_suffix (".geojson.old" ))
91
100
# Download index again
92
101
index_gdf = wrapper ()
93
102
94
103
# Save index to cache
95
- if not cache_file_path .exists ():
96
- cache_file_path .parent .mkdir (parents = True , exist_ok = True )
97
- index_gdf [expected_columns ].to_file (cache_file_path , driver = "GeoJSON" )
104
+ if not global_cache_file_path .exists ():
105
+ global_cache_file_path .parent .mkdir (parents = True , exist_ok = True )
106
+ index_gdf [expected_columns ].to_file (global_cache_file_path , driver = "GeoJSON" )
98
107
99
108
return index_gdf
100
109
@@ -112,7 +121,14 @@ def extracts_to_geodataframe(extracts: list[OpenStreetMapExtract]) -> "GeoDataFr
112
121
).set_crs (WGS84_CRS )
113
122
114
123
115
- def _get_cache_file_path (extract_source : OsmExtractSource ) -> Path :
124
+ def _get_global_cache_file_path (extract_source : OsmExtractSource ) -> Path :
125
+ return (
126
+ Path (platformdirs .user_cache_dir ("QuackOSM" ))
127
+ / f"{ extract_source .value .lower ()} _index.geojson"
128
+ )
129
+
130
+
131
+ def _get_local_cache_file_path (extract_source : OsmExtractSource ) -> Path :
116
132
return Path (f"cache/{ extract_source .value .lower ()} _index.geojson" )
117
133
118
134
0 commit comments