11from __future__ import annotations
22
3+ import codecs
34import io
45from contextlib import contextmanager , nullcontext
5- from io import TextIOBase
66from pathlib import Path
77from typing import IO , Any , ContextManager , Dict , List
88
@@ -20,21 +20,20 @@ def detect_encoding(source: str | Path | IO[str] | IO[bytes] | bytes) -> str:
2020
2121 Args:
2222 source (str | Path | IO[str] | IO[bytes] | bytes): The source to detect encoding from.
23- - str: Treated as a file path if it exists, otherwise as text (returns `DEFAULT_ENCODING`)
24- - Path: File path to read and detect encoding
23+ - str or Path: File path.
2524 - IO[str]: Already decoded text stream (returns `DEFAULT_ENCODING`)
2625 - IO[bytes]: Binary stream to detect encoding from
2726 - bytes: Binary data to detect encoding from
2827
2928 Returns:
30- str: The detected encoding name (e.g., 'utf-8', 'iso-8859-1', etc.)
29+ str: The detected encoding name (e.g., 'utf-8', 'iso-8859-1', 'ascii', etc.)
3130
3231 Raises:
3332 TypeError: If the source type is unsupported
3433 FileNotFoundError: If a file path doesn't exist
3534 """
3635 # Set number of bytes to read for detection and required confidence
37- SAMPLE_SIZE = 10000
36+ SAMPLE_SIZE = 1_000_000
3837 REQUIRED_CONFIDENCE = 0.7
3938
4039 def _detect_from_bytes (data : bytes ) -> str :
@@ -47,6 +46,9 @@ def _detect_from_bytes(data: bytes) -> str:
4746 if not encoding or confidence < REQUIRED_CONFIDENCE :
4847 return DEFAULT_ENCODING
4948
49+ if encoding .lower () == "ascii" :
50+ return "utf-8"
51+
5052 return encoding
5153
5254 def _read_from_path (path : Path ):
@@ -79,18 +81,17 @@ def _read_from_path(path: Path):
7981 # IO[str] object
8082 if hasattr (source , "encoding" ):
8183 if source .encoding :
82- # Could be `None`
84+ # Could be `None`, e.g. io.StringIO has an encoding attribute which is None.
8385 return source .encoding
8486 else :
8587 return DEFAULT_ENCODING
8688
8789 # IO[bytes]
88- if isinstance (source , io .BytesIO ):
89- original_position = source .tell ()
90+ if isinstance (source , io .BufferedIOBase ):
9091 try :
92+ original_position = source .tell ()
9193 source .seek (0 )
9294 sample = source .read (SAMPLE_SIZE )
93- encoding = _detect_from_bytes (sample )
9495 if isinstance (sample , bytes ):
9596 encoding = _detect_from_bytes (sample )
9697 else :
@@ -105,9 +106,9 @@ def _read_from_path(path: Path):
105106 raise TypeError (f"Unsupported input type for encoding detection: { type (source )} " )
106107
107108
108- def read_ags_source (
109+ def open_ags_source (
109110 source : str | Path | IO [str ] | IO [bytes ] | bytes , encoding = None
110- ) -> ContextManager [TextIOBase ]:
111+ ) -> ContextManager [io . TextIOBase ]:
111112 """Opens or wraps a given source for reading AGS (text-based) data.
112113
113114 Args:
@@ -124,41 +125,42 @@ def read_ags_source(
124125 Raises:
125126 TypeError: If the source type is unsupported or binary streams are not decoded.
126127 """
128+ try :
129+ codecs .lookup (encoding )
130+ except LookupError :
131+ raise ValueError (f"Unsupported encoding: { encoding } " )
127132
128133 @contextmanager
129- def string_source ( content : str ):
130- string_io = io .StringIO (content )
134+ def _bytes_source ( bytes_content : bytes ):
135+ string_io = io .StringIO (bytes_content . decode ( encoding ) )
131136 try :
132137 yield string_io
133138 finally :
134139 string_io .close ()
135140
136- if isinstance (source , str ):
141+ if isinstance (source , ( str , Path ) ):
137142 path = Path (source )
138143 if path .exists () and path .is_file ():
139144 return open (path , "r" , encoding = encoding )
140145 raise FileNotFoundError (f"Path does not exist or is not a file: { source } " )
141146
142- elif isinstance (source , Path ):
143- if source .exists () and source .is_file ():
144- return open (source , "r" , encoding = encoding )
145- raise FileNotFoundError (f"Path does not exist or is not a file: { source } " )
146-
147- elif isinstance (source , bytes ):
148- return string_source (source .decode (encoding ))
147+ elif isinstance (source , io .TextIOBase ):
148+ source .seek (0 )
149+ return nullcontext (source )
149150
150- elif isinstance (source , io .BytesIO ):
151- return string_source (source .getvalue ().decode (encoding ))
151+ elif isinstance (source , io .BufferedIOBase ):
152+ text_stream = io .TextIOWrapper (source , encoding = encoding )
153+ text_stream .seek (0 )
154+ return nullcontext (text_stream )
152155
153- elif hasattr (source , "read" ):
154- # reset the cursor to the beginning
155- try :
156- source .seek (0 )
157- except (AttributeError , io .UnsupportedOperation ):
158- pass
159- return nullcontext (source )
156+ elif isinstance (source , bytes ):
157+ return _bytes_source (source )
160158
161- raise TypeError (f"Unsupported input type: { type (source )} " )
159+ else :
160+ raise TypeError (
161+ f"Unsupported source type: { type (source )} . "
162+ "Expected str, Path, IO[str], IO[bytes], or bytes."
163+ )
162164
163165
164166def ags_to_dfs (
@@ -179,15 +181,11 @@ def ags_to_dfs(
179181 Dict[str, pd.DataFrame]]: A dictionary where keys represent AGS group
180182 names with corresponding DataFrames for the corresponding group data.
181183 """
182- # if bytes are provided, convert to IO[bytes] to be file-like
183- if isinstance (source , bytes ):
184- source = io .BytesIO (source )
185-
186184 if not encoding :
187185 encoding = detect_encoding (source )
188186
189187 # Get first non-blank line, `None` if all lines are blank
190- with read_ags_source (source , encoding = encoding ) as f :
188+ with open_ags_source (source , encoding = encoding ) as f :
191189 first_line = next ((line .strip () for line in f if line .strip ()), None )
192190
193191 if first_line :
@@ -239,7 +237,7 @@ def ags3_to_dfs(
239237 headers : List [str ] = ["" , "" , "" ]
240238 group_data : List [List [Any ]] = [[], [], []]
241239
242- with read_ags_source (source , encoding = encoding ) as file :
240+ with open_ags_source (source , encoding = encoding ) as file :
243241 for i , line in enumerate (file ):
244242 line = line .strip ()
245243 last_line_type = line_type
@@ -333,7 +331,7 @@ def ags4_to_dfs(
333331 object that represents and AGS4 file.
334332
335333 Returns:
336- Dict[str, pd.DataFrame]: A dictionary of pandas DataFrames, where each key
334+ Dict[str, pd.DataFrame]: A dictionary of pandas DataFrames, where each key
337335 represents a group name from AGS 4 data, and the corresponding value is a
338336 pandas DataFrame containing the data for that group.
339337 """
0 commit comments