Skip to content

Commit a8e3d64

Browse files
committed
refactor: use Path objects instead of strings
Don't convert Paths to Paths. Simplify code and take advantage of Path object methods. Signed-off-by: Rafał Ilnicki <r.ilnicki@welotec.com>
1 parent 13c5b53 commit a8e3d64

18 files changed

+58
-76
lines changed

cve_bin_tool/cli.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,8 +830,7 @@ def main(argv=None):
830830
error_mode=error_mode,
831831
)
832832

833-
# if OLD_CACHE_DIR (from cvedb.py) exists, print warning
834-
if Path(OLD_CACHE_DIR).exists():
833+
if OLD_CACHE_DIR.exists():
835834
LOGGER.warning(
836835
f"Obsolete cache dir {OLD_CACHE_DIR} is no longer needed and can be removed."
837836
)

cve_bin_tool/cve_scanner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import sys
66
from collections import defaultdict
77
from logging import Logger
8-
from pathlib import Path
98
from string import ascii_lowercase
109
from typing import DefaultDict, Dict, List
1110

@@ -31,7 +30,7 @@ class CVEScanner:
3130
all_cve_version_info: Dict[str, VersionInfo]
3231

3332
RANGE_UNSET: str = ""
34-
dbname: str = str(Path(DISK_LOCATION_DEFAULT) / DBNAME)
33+
dbname: str = str(DISK_LOCATION_DEFAULT / DBNAME)
3534
CONSOLE: Console = Console(file=sys.stderr, theme=cve_theme)
3635
ALPHA_TO_NUM: Dict[str, int] = dict(zip(ascii_lowercase, range(26)))
3736

cve_bin_tool/data_sources/curl_source.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import json
77
import logging
8-
from pathlib import Path
98

109
import aiohttp
1110

@@ -66,7 +65,7 @@ async def download_curl_vulnerabilities(self, session: RateLimiter) -> None:
6665
async with await session.get(self.DATA_SOURCE_LINK) as response:
6766
response.raise_for_status()
6867
self.vulnerability_data = await response.json()
69-
path = Path(str(Path(self.cachedir) / "vuln.json"))
68+
path = self.cachedir / "vuln.json"
7069
filepath = path.resolve()
7170
async with FileIO(filepath, "w") as f:
7271
await f.write(json.dumps(self.vulnerability_data, indent=4))

cve_bin_tool/data_sources/epss_source.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import csv
77
import gzip
88
import logging
9-
import os
109
from datetime import datetime, timedelta
1110
from io import StringIO
1211
from pathlib import Path
@@ -34,8 +33,8 @@ def __init__(self, error_mode=ErrorMode.TruncTrace):
3433
self.error_mode = error_mode
3534
self.cachedir = self.CACHEDIR
3635
self.backup_cachedir = self.BACKUPCACHEDIR
37-
self.epss_path = str(Path(self.cachedir) / "epss")
38-
self.file_name = os.path.join(self.epss_path, "epss_scores-current.csv")
36+
self.epss_path = self.cachedir / "epss"
37+
self.file_name = self.epss_path / "epss_scores-current.csv"
3938
self.source_name = self.SOURCE
4039

4140
async def update_epss(self):
@@ -58,11 +57,11 @@ async def download_epss_data(self):
5857
"""Downloads the EPSS CSV file and saves it to the local filesystem.
5958
The download is only performed if the file is older than 24 hours.
6059
"""
61-
os.makedirs(self.epss_path, exist_ok=True)
60+
self.epss_path.mkdir(parents=True, exist_ok=True)
6261
# Check if the file exists
63-
if os.path.exists(self.file_name):
62+
if self.file_name.exists():
6463
# Get the modification time of the file
65-
modified_time = os.path.getmtime(self.file_name)
64+
modified_time = self.file_name.stat().st_mtime
6665
last_modified = datetime.fromtimestamp(modified_time)
6766

6867
# Calculate the time difference between now and the last modified time
@@ -80,8 +79,7 @@ async def download_epss_data(self):
8079
decompressed_data = gzip.decompress(await response.read())
8180

8281
# Save the downloaded data to the file
83-
with open(self.file_name, "wb") as file:
84-
file.write(decompressed_data)
82+
self.file_name.write_bytes(decompressed_data)
8583

8684
except aiohttp.ClientError as e:
8785
self.LOGGER.error(f"An error occurred during updating epss {e}")
@@ -102,8 +100,7 @@ async def download_epss_data(self):
102100
decompressed_data = gzip.decompress(await response.read())
103101

104102
# Save the downloaded data to the file
105-
with open(self.file_name, "wb") as file:
106-
file.write(decompressed_data)
103+
self.file_name.write_bytes(decompressed_data)
107104

108105
except aiohttp.ClientError as e:
109106
self.LOGGER.error(f"An error occurred during downloading epss {e}")
@@ -114,9 +111,8 @@ def parse_epss_data(self, file_path=None):
114111
if file_path is None:
115112
file_path = self.file_name
116113

117-
with open(file_path) as file:
118-
# Read the content of the CSV file
119-
decoded_data = file.read()
114+
# Read the content of the CSV file
115+
decoded_data = Path(file_path).read_text()
120116

121117
# Create a CSV reader to read the data from the decoded CSV content
122118
reader = csv.reader(StringIO(decoded_data), delimiter=",")

cve_bin_tool/data_sources/gad_source.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import io
99
import re
1010
import zipfile
11-
from pathlib import Path
1211

1312
import aiohttp
1413
import yaml
@@ -39,7 +38,7 @@ def __init__(
3938
):
4039
self.cachedir = self.CACHEDIR
4140
self.slugs = None
42-
self.gad_path = str(Path(self.cachedir) / "gad")
41+
self.gad_path = self.cachedir / "gad"
4342
self.source_name = self.SOURCE
4443

4544
self.error_mode = error_mode
@@ -90,8 +89,8 @@ async def fetch_cves(self):
9089

9190
self.db = cvedb.CVEDB()
9291

93-
if not Path(self.gad_path).exists():
94-
Path(self.gad_path).mkdir()
92+
if not self.gad_path.exists():
93+
self.gad_path.mkdir()
9594
# As no data, force full update
9695
self.incremental_update = False
9796

@@ -155,7 +154,7 @@ async def fetch_cves(self):
155154
async def update_cve_entries(self):
156155
"""Updates CVE entries from CVEs in cache."""
157156

158-
p = Path(self.gad_path).glob("**/*")
157+
p = self.gad_path.glob("**/*")
159158
# Need to find files which are new to the cache
160159
last_update_timestamp = (
161160
self.time_of_last_update.timestamp()

cve_bin_tool/data_sources/nvd_source.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import logging
1414
import re
1515
import sqlite3
16-
from pathlib import Path
1716

1817
import aiohttp
1918
from rich.progress import track
@@ -27,7 +26,6 @@
2726
NVD_FILENAME_TEMPLATE,
2827
)
2928
from cve_bin_tool.error_handler import (
30-
AttemptedToWriteOutsideCachedir,
3129
CVEDataForYearNotInCache,
3230
ErrorHandler,
3331
ErrorMode,
@@ -78,7 +76,7 @@ def __init__(
7876
self.source_name = self.SOURCE
7977

8078
# set up the db if needed
81-
self.dbpath = str(Path(self.cachedir) / DBNAME)
79+
self.dbpath = self.cachedir / DBNAME
8280
self.connection: sqlite3.Connection | None = None
8381
self.session = session
8482
self.cve_count = -1
@@ -544,12 +542,9 @@ async def cache_update(
544542
Update the cache for a single year of NVD data.
545543
"""
546544
filename = url.split("/")[-1]
547-
# Ensure we only write to files within the cachedir
548-
cache_path = Path(self.cachedir)
549-
filepath = Path(str(cache_path / filename)).resolve()
550-
if not str(filepath).startswith(str(cache_path.resolve())):
551-
with ErrorHandler(mode=self.error_mode, logger=self.LOGGER):
552-
raise AttemptedToWriteOutsideCachedir(filepath)
545+
cache_path = self.cachedir
546+
filepath = cache_path / filename
547+
553548
# Validate the contents of the cached file
554549
if filepath.is_file():
555550
# Validate the sha and write out
@@ -604,7 +599,7 @@ def load_nvd_year(self, year: int) -> dict[str, str | object]:
604599
Return the dict of CVE data for the given year.
605600
"""
606601

607-
filename = Path(self.cachedir) / self.NVDCVE_FILENAME_TEMPLATE.format(year)
602+
filename = self.cachedir / self.NVDCVE_FILENAME_TEMPLATE.format(year)
608603
# Check if file exists
609604
if not filename.is_file():
610605
with ErrorHandler(mode=self.error_mode, logger=self.LOGGER):

cve_bin_tool/data_sources/osv_source.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import datetime
88
import io
99
import json
10-
import os
1110
import shutil
1211
import zipfile
1312
from pathlib import Path
@@ -25,7 +24,7 @@
2524

2625
def find_gsutil():
2726
gsutil_path = shutil.which("gsutil")
28-
if not os.path.exists(gsutil_path):
27+
if not Path(gsutil_path).exists():
2928
raise FileNotFoundError(
3029
"gsutil not found. Did you need to install requirements or activate a venv where gsutil is installed?"
3130
)
@@ -46,7 +45,7 @@ def __init__(
4645
):
4746
self.cachedir = self.CACHEDIR
4847
self.ecosystems = None
49-
self.osv_path = str(Path(self.cachedir) / "osv")
48+
self.osv_path = self.cachedir / "osv"
5049
self.source_name = self.SOURCE
5150

5251
self.error_mode = error_mode
@@ -104,7 +103,7 @@ async def get_ecosystem_incremental(self, ecosystem, time_of_last_update, sessio
104103
tasks.append(task)
105104

106105
for r in await asyncio.gather(*tasks):
107-
filepath = Path(self.osv_path) / (r.get("id") + ".json")
106+
filepath = self.osv_path / (r.get("id") + ".json")
108107
r = json.dumps(r)
109108

110109
async with FileIO(filepath, "w") as f:
@@ -149,9 +148,9 @@ async def get_totalfiles(self, ecosystem):
149148

150149
gsutil_path = find_gsutil() # use helper function
151150
gs_file = self.gs_url + ecosystem + "/all.zip"
152-
await aio_run_command([gsutil_path, "cp", gs_file, self.osv_path])
151+
await aio_run_command([gsutil_path, "cp", gs_file, str(self.osv_path)])
153152

154-
zip_path = Path(self.osv_path) / "all.zip"
153+
zip_path = self.osv_path / "all.zip"
155154
totalfiles = 0
156155

157156
with zipfile.ZipFile(zip_path, "r") as z:
@@ -170,8 +169,8 @@ async def fetch_cves(self):
170169

171170
self.db = cvedb.CVEDB()
172171

173-
if not Path(self.osv_path).exists():
174-
Path(self.osv_path).mkdir()
172+
if not self.osv_path.exists():
173+
self.osv_path.mkdir()
175174
# As no data, force full update
176175
self.incremental_update = False
177176

@@ -230,7 +229,7 @@ async def fetch_cves(self):
230229
async def update_cve_entries(self):
231230
"""Updates CVE entries from CVEs in cache"""
232231

233-
p = Path(self.osv_path).glob("**/*")
232+
p = self.osv_path.glob("**/*")
234233
# Need to find files which are new to the cache
235234
last_update_timestamp = (
236235
self.time_of_last_update.timestamp()

cve_bin_tool/data_sources/purl2cpe_source.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import zipfile
44
from io import BytesIO
5-
from pathlib import Path
65

76
import aiohttp
87

@@ -25,7 +24,7 @@ def __init__(
2524
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
2625
):
2726
self.cachedir = self.CACHEDIR
28-
self.purl2cpe_path = str(Path(self.cachedir) / "purl2cpe")
27+
self.purl2cpe_path = self.cachedir / "purl2cpe"
2928
self.source_name = self.SOURCE
3029
self.error_mode = error_mode
3130
self.incremental_update = incremental_update
@@ -36,8 +35,8 @@ async def fetch_cves(self):
3635
"""Fetches PURL2CPE database and places it in purl2cpe_path."""
3736
LOGGER.info("Getting PURL2CPE data...")
3837

39-
if not Path(self.purl2cpe_path).exists():
40-
Path(self.purl2cpe_path).mkdir()
38+
if not self.purl2cpe_path.exists():
39+
self.purl2cpe_path.mkdir()
4140

4241
if not self.session:
4342
connector = aiohttp.TCPConnector(limit_per_host=10)

cve_bin_tool/data_sources/redhat_source.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import datetime
55
import json
6-
from pathlib import Path
76

87
import aiohttp
98

@@ -28,7 +27,7 @@ def __init__(
2827
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
2928
):
3029
self.cachedir = self.CACHEDIR
31-
self.redhat_path = str(Path(self.cachedir) / "redhat")
30+
self.redhat_path = self.cachedir / "redhat"
3231
self.source_name = self.SOURCE
3332

3433
self.error_mode = error_mode
@@ -57,7 +56,7 @@ async def store_data(self, content):
5756
"""Asynchronously stores CVE data in separate JSON files, excluding entries without a CVE ID."""
5857
for c in content:
5958
if c["CVE"] != "":
60-
filepath = Path(self.redhat_path) / (str(c["CVE"]) + ".json")
59+
filepath = self.redhat_path / (str(c["CVE"]) + ".json")
6160
r = json.dumps(c)
6261
async with FileIO(filepath, "w") as f:
6362
await f.write(r)
@@ -73,8 +72,8 @@ async def fetch_cves(self):
7372

7473
self.db = cvedb.CVEDB()
7574

76-
if not Path(self.redhat_path).exists():
77-
Path(self.redhat_path).mkdir()
75+
if not self.redhat_path.exists():
76+
self.redhat_path.mkdir()
7877
# As no data, force full update
7978
self.incremental_update = False
8079

@@ -121,7 +120,7 @@ async def fetch_cves(self):
121120
async def update_cve_entries(self):
122121
"""Updates CVE entries from CVEs in cache."""
123122

124-
p = Path(self.redhat_path).glob("**/*")
123+
p = self.redhat_path.glob("**/*")
125124
# Need to find files which are new to the cache
126125
last_update_timestamp = (
127126
self.time_of_last_update.timestamp()

cve_bin_tool/data_sources/rsd_source.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import io
99
import json
1010
import zipfile
11-
from pathlib import Path
1211

1312
import aiohttp
1413
from cvss import CVSS2, CVSS3
@@ -36,7 +35,7 @@ def __init__(
3635
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
3736
):
3837
self.cachedir = self.CACHEDIR
39-
self.rsd_path = str(Path(self.cachedir) / "rsd")
38+
self.rsd_path = self.cachedir / "rsd"
4039
self.source_name = self.SOURCE
4140

4241
self.error_mode = error_mode
@@ -71,8 +70,8 @@ async def fetch_cves(self):
7170

7271
self.db = cvedb.CVEDB()
7372

74-
if not Path(self.rsd_path).exists():
75-
Path(self.rsd_path).mkdir()
73+
if not self.rsd_path.exists():
74+
self.rsd_path.mkdir()
7675

7776
if not self.session:
7877
connector = aiohttp.TCPConnector(limit_per_host=19)
@@ -133,7 +132,7 @@ async def fetch_cves(self):
133132
async def update_cve_entries(self):
134133
"""Updates CVE entries from CVEs in cache."""
135134

136-
p = Path(self.rsd_path).glob("**/*")
135+
p = self.rsd_path.glob("**/*")
137136
# Need to find files which are new to the cache
138137
last_update_timestamp = (
139138
self.time_of_last_update.timestamp()

0 commit comments

Comments
 (0)