Skip to content

Commit 5f84668

Browse files
authored
parse_csv also takes file-likes
parse_csv also takes file-likes
2 parents 058ad35 + 1b6f9c7 commit 5f84668

File tree

2 files changed

+58
-24
lines changed

2 files changed

+58
-24
lines changed

cpm/parse.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import TextIO, Union
12
from cpm.exceptions import *
23
from cpm.models import DSM
34
from os import listdir
@@ -31,54 +32,81 @@ def parse_csv_dir(dir_path: str, pattern: str = None, delimiter: str = 'auto',
3132
return dsm_array
3233

3334

34-
def parse_csv(filepath: str, delimiter: str = 'auto', encoding: str = 'utf-8', instigator: str = 'column'):
35+
def parse_csv(file: Union[str, TextIO], delimiter: str = 'auto', encoding: str = 'utf-8', instigator: str = 'column'):
3536
"""
3637
Parse CSV to DSM
37-
:param filepath: Targeted CSV file
38+
:param file: Targeted CSV file or file-like object
3839
:param delimiter: CSV delimiter. Defaults to auto-detection.
3940
:param encoding: text-encoding. Defaults to utf-8
4041
:param instigator: Determines directionality of DSM. Defaults to columns instigating rows.
4142
:return: DSM
4243
"""
43-
44+
45+
content = _read_file(file, encoding)
46+
4447
if delimiter == 'auto':
45-
with open(filepath, 'r', encoding=encoding) as file:
46-
delimiter = detect_delimiter(file.read())
48+
delimiter = detect_delimiter(content)
4749

4850
# Identify number of rows, and separate header row
4951
num_cols = 0
5052
column_names = []
51-
with open(filepath, 'r') as file:
52-
for line in file:
53-
column_names.append(line.split(delimiter)[0])
54-
num_cols += 1
53+
lines = _get_file_lines(file, encoding)
54+
for line in lines:
55+
column_names.append(line.split(delimiter)[0])
56+
num_cols += 1
5557

5658
# We do not want the first column in the header
5759
column_names.pop(0)
5860

5961
data = []
6062

61-
with open(filepath, 'r') as file:
62-
for i, line in enumerate(file):
63-
if i == 0:
63+
for i, line in enumerate(lines):
64+
if i == 0:
65+
continue
66+
data.append([])
67+
for j, col in enumerate(line.split(delimiter)):
68+
if j == 0:
6469
continue
65-
data.append([])
66-
for j, col in enumerate(line.split(delimiter)):
67-
if j == 0:
68-
continue
69-
if col == "":
70+
if col == "":
71+
data[i-1].append(None)
72+
else:
73+
try:
74+
data[i-1].append(float(col))
75+
except ValueError:
7076
data[i-1].append(None)
71-
else:
72-
try:
73-
data[i-1].append(float(col))
74-
except ValueError:
75-
data[i - 1].append(None)
7677

7778
dsm = DSM(matrix=data, columns=column_names, instigator=instigator)
7879

7980
return dsm
8081

8182

83+
def _read_file(file, encoding):
84+
if isinstance(file, str):
85+
with open(file, 'r', encoding=encoding) as f:
86+
return f.read()
87+
elif hasattr(file, 'read'):
88+
position = file.tell()
89+
content = file.read()
90+
file.seek(position)
91+
return content
92+
else:
93+
raise ValueError("Invalid file input. Must be a filepath or a file-like object.")
94+
95+
96+
def _get_file_lines(file, encoding):
97+
if isinstance(file, str):
98+
with open(file, 'r', encoding=encoding) as f:
99+
return f.readlines()
100+
elif hasattr(file, 'read'):
101+
position = file.tell()
102+
file.seek(0)
103+
lines = file.readlines()
104+
file.seek(position)
105+
return lines
106+
else:
107+
raise ValueError("Invalid file input. Must be a filepath or a file-like object.")
108+
109+
82110
def detect_delimiter(text, look_ahead=1000):
83111
"""
84112
Attempts to determine CSV delmiter based on a certain amount of sample characters
@@ -114,4 +142,3 @@ def detect_delimiter(text, look_ahead=1000):
114142
raise AutoDelimiterError('None of the default delimiters matched the file. Is the file empty?')
115143

116144
return best_delimiter
117-

tests/test_parser.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
from cpm.parse import parse_csv
32

43

@@ -58,3 +57,11 @@ def test_parse_dsm_network_instigator_row():
5857
assert len(a_neighbours) == 1
5958
assert a_neighbours[0] == 3
6059

60+
61+
def test_parse_file_object():
62+
path = './tests/test-assets/dsm-network-test.csv'
63+
with open(path) as file:
64+
dsm = parse_csv(file)
65+
66+
for col in ['A', 'B', 'C', 'D']:
67+
assert col in dsm.columns

0 commit comments

Comments
 (0)