Skip to content

Commit f0d4d21

Browse files
Session.virtualfile_to_dataset: Add new parameters 'dtype'/'index_col' for pandas output (#3140)
Co-authored-by: Yvonne Fröhlich <94163266+yvonnefroehlich@users.noreply.github.com>
1 parent 4b3b3eb commit f0d4d21

File tree

3 files changed

+42
-18
lines changed

3 files changed

+42
-18
lines changed

pygmt/clib/session.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,8 @@ def virtualfile_to_dataset(
17471747
vfname: str,
17481748
output_type: Literal["pandas", "numpy", "file"] = "pandas",
17491749
column_names: list[str] | None = None,
1750+
dtype: type | dict[str, type] | None = None,
1751+
index_col: str | int | None = None,
17501752
) -> pd.DataFrame | np.ndarray | None:
17511753
"""
17521754
Output a tabular dataset stored in a virtual file to a different format.
@@ -1766,6 +1768,11 @@ def virtualfile_to_dataset(
17661768
- ``"file"`` means the result was saved to a file and will return ``None``.
17671769
column_names
17681770
The column names for the :class:`pandas.DataFrame` output.
1771+
dtype
1772+
Data type for the columns of the :class:`pandas.DataFrame` output. Can be a
1773+
single type for all columns or a dictionary mapping column names to types.
1774+
index_col
1775+
Column to set as the index of the :class:`pandas.DataFrame` output.
17691776
17701777
Returns
17711778
-------
@@ -1854,13 +1861,13 @@ def virtualfile_to_dataset(
18541861
return None
18551862

18561863
# Read the virtual file as a GMT dataset and convert to pandas.DataFrame
1857-
result = self.read_virtualfile(vfname, kind="dataset").contents.to_dataframe()
1864+
result = self.read_virtualfile(vfname, kind="dataset").contents.to_dataframe(
1865+
column_names=column_names,
1866+
dtype=dtype,
1867+
index_col=index_col,
1868+
)
18581869
if output_type == "numpy": # numpy.ndarray output
18591870
return result.to_numpy()
1860-
1861-
# Assign column names
1862-
if column_names is not None:
1863-
result.columns = column_names
18641871
return result # pandas.DataFrame output
18651872

18661873
def extract_region(self):

pygmt/datatypes/dataset.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,29 @@ class _GMT_DATASEGMENT(ctp.Structure): # noqa: N801
143143
("hidden", ctp.c_void_p),
144144
]
145145

146-
def to_dataframe(self) -> pd.DataFrame:
146+
def to_dataframe(
147+
self,
148+
column_names: list[str] | None = None,
149+
dtype: type | dict[str, type] | None = None,
150+
index_col: str | int | None = None,
151+
) -> pd.DataFrame:
147152
"""
148153
Convert a _GMT_DATASET object to a :class:`pandas.DataFrame` object.
149154
150155
Currently, the number of columns in all segments of all tables are assumed to be
151156
the same. The same column in all segments of all tables are concatenated. The
152157
trailing text column is also concatenated as a single string column.
153158
159+
Parameters
160+
----------
161+
column_names
162+
A list of column names.
163+
dtype
164+
Data type. Can be a single type for all columns or a dictionary mapping
165+
column names to types.
166+
index_col
167+
Column to set as index.
168+
154169
Returns
155170
-------
156171
df
@@ -211,5 +226,11 @@ def to_dataframe(self) -> pd.DataFrame:
211226
pd.Series(data=np.char.decode(textvector), dtype=pd.StringDtype())
212227
)
213228

214-
df = pd.concat(objs=vectors, axis=1)
229+
df = pd.concat(objs=vectors, axis="columns")
230+
if column_names is not None: # Assign column names
231+
df.columns = column_names
232+
if dtype is not None:
233+
df = df.astype(dtype)
234+
if index_col is not None:
235+
df = df.set_index(index_col)
215236
return df

pygmt/src/grdhisteq.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,14 @@ def compute_bins(
238238
module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd)
239239
)
240240

241-
result = lib.virtualfile_to_dataset(
241+
return lib.virtualfile_to_dataset(
242242
vfname=vouttbl,
243243
output_type=output_type,
244244
column_names=["start", "stop", "bin_id"],
245+
dtype={
246+
"start": np.float32,
247+
"stop": np.float32,
248+
"bin_id": np.uint32,
249+
},
250+
index_col="bin_id" if output_type == "pandas" else None,
245251
)
246-
if output_type == "pandas":
247-
result = result.astype(
248-
{
249-
"start": np.float32,
250-
"stop": np.float32,
251-
"bin_id": np.uint32,
252-
}
253-
)
254-
return result.set_index("bin_id")
255-
return result

0 commit comments

Comments
 (0)