Skip to content

Commit bbe6261

Browse files
committed
feat: start rough draft of narwhals support
1 parent dc0b294 commit bbe6261

File tree

3 files changed

+68
-7
lines changed

3 files changed

+68
-7
lines changed

shiny/render/_data_frame_utils/_tbl_data.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from functools import singledispatch
66
from typing import Any, List, Tuple, cast
77

8+
import narwhals.stable.v1 as nw
89
from htmltools import TagNode
910

1011
from ..._typing_extensions import TypeIs
@@ -184,7 +185,7 @@ def _(col: PlSeries) -> FrameDtype:
184185

185186
from ._html import col_contains_shiny_html
186187

187-
if col.dtype.is_(pl.String):
188+
if col.dtype == pl.String():
188189
if col_contains_shiny_html(col):
189190
type_ = "html"
190191
else:
@@ -203,6 +204,30 @@ def _(col: PlSeries) -> FrameDtype:
203204
return {"type": type_}
204205

205206

207+
@serialize_dtype.register
208+
def _(col: nw.Series) -> FrameDtype:
209+
210+
from ._html import col_contains_shiny_html
211+
212+
if col.dtype == nw.String():
213+
if col_contains_shiny_html(col):
214+
type_ = "html"
215+
else:
216+
type_ = "string"
217+
elif col.dtype.is_numeric():
218+
type_ = "numeric"
219+
220+
elif col.dtype == nw.Categorical():
221+
categories = col.cat.get_categories().to_list()
222+
return {"type": "categorical", "categories": categories}
223+
else:
224+
type_ = "unknown"
225+
if col_contains_shiny_html(col):
226+
type_ = "html"
227+
228+
return {"type": type_}
229+
230+
206231
# serialize_frame ----------------------------------------------------------------------
207232

208233

@@ -218,6 +243,8 @@ def _(data: PdDataFrame) -> FrameJson:
218243
return serialize_frame_pd(data)
219244

220245

246+
# TODO: test this
247+
@serialize_frame.register(nw.DataFrame)
221248
@serialize_frame.register
222249
def _(data: PlDataFrame) -> FrameJson:
223250
import json
@@ -308,6 +335,7 @@ def _(
308335
return data.iloc[indx_rows, indx_cols]
309336

310337

338+
@subset_frame.register(nw.DataFrame)
311339
@subset_frame.register
312340
def _(
313341
data: PlDataFrame,
@@ -321,7 +349,7 @@ def _(
321349
else slice(None)
322350
)
323351
indx_rows = rows if rows is not None else slice(None)
324-
return data[indx_rows, indx_cols]
352+
return data[indx_cols][indx_rows]
325353

326354

327355
# get_frame_cell -----------------------------------------------------------------------
@@ -341,9 +369,10 @@ def _(data: PdDataFrame, row: int, col: int) -> Any:
341369
)
342370

343371

372+
@get_frame_cell.register(nw.DataFrame)
344373
@get_frame_cell.register
345374
def _(data: PlDataFrame, row: int, col: int) -> Any:
346-
return data[row, col]
375+
return data.item(row, col)
347376

348377

349378
# shape --------------------------------------------------------------------------------
@@ -359,6 +388,7 @@ def _(data: PdDataFrame) -> Tuple[int, ...]:
359388
return data.shape
360389

361390

391+
@frame_shape.register(nw.DataFrame)
362392
@frame_shape.register
363393
def _(data: PlDataFrame) -> Tuple[int, ...]:
364394
return data.shape
@@ -377,6 +407,7 @@ def _(data: PdDataFrame) -> PdDataFrame:
377407
return data.copy()
378408

379409

410+
@copy_frame.register(nw.DataFrame)
380411
@copy_frame.register
381412
def _(data: PlDataFrame) -> PlDataFrame:
382413
return data.clone()
@@ -393,6 +424,7 @@ def _(data: PdDataFrame) -> List[str]:
393424
return data.columns.to_list()
394425

395426

427+
@frame_column_names.register(nw.DataFrame)
396428
@frame_column_names.register
397429
def _(data: PlDataFrame) -> List[str]:
398430
return data.columns

tests/pytest/test_render_data_frame.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import narwhals as nw
12
import pandas as pd
23
import pytest
34

tests/pytest/test_render_data_frame_tbl_data.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import datetime
55
from typing import Any, Union
66

7+
import narwhals.stable.v1 as nw
78
import pandas as pd
89
import polars as pl
910
import polars.testing as pl_testing
@@ -53,6 +54,9 @@ class D:
5354
params_frames = [
5455
pytest.param(pd.DataFrame, id="pandas"),
5556
pytest.param(pl.DataFrame, id="polars"),
57+
pytest.param(
58+
lambda d: nw.from_native(pl.DataFrame(d), eager_only=True), id="narwhals"
59+
),
5660
]
5761

5862
DataFrameLike: TypeAlias = Union[pd.DataFrame, pl.DataFrame]
@@ -93,6 +97,17 @@ def assert_frame_equal(
9397
raise NotImplementedError(f"Unsupported data type: {type(src)}")
9498

9599

100+
def assert_frame_equal2(
101+
src: pd.DataFrame | pl.DataFrame,
102+
target_dict: dict,
103+
use_index: bool = False,
104+
):
105+
src = nw.to_native(src, strict=False)
106+
target = nw.to_native(src, strict=False).__class__(target_dict)
107+
108+
assert_frame_equal(src, target, use_index)
109+
110+
96111
# TODO: explicitly pass dtype= when doing Series construction
97112
@pytest.mark.parametrize(
98113
"ser, res_type",
@@ -126,6 +141,13 @@ def test_serialize_dtype(
126141
],
127142
res_type: str,
128143
):
144+
if isinstance(ser, pl.Series):
145+
assert (
146+
serialize_dtype(nw.from_native(ser, eager_only=True, allow_series=True))[
147+
"type"
148+
]
149+
== res_type
150+
)
129151
assert serialize_dtype(ser)["type"] == res_type
130152

131153

@@ -158,9 +180,9 @@ def test_serialize_frame(df: DataFrameLike):
158180
def test_subset_frame(df: DataFrameLike):
159181
# TODO: this assumes subset_frame doesn't reset index
160182
res = subset_frame(df, rows=[1], cols=["chr", "num"])
161-
dst = df.__class__({"chr": ["b"], "num": [2]})
183+
dst = {"chr": ["b"], "num": [2]}
162184

163-
assert_frame_equal(res, dst)
185+
assert_frame_equal2(res, dst)
164186

165187

166188
def test_get_frame_cell(df: DataFrameLike):
@@ -176,14 +198,20 @@ def test_copy_frame(df: DataFrameLike):
176198
def test_subset_frame_rows_single(small_df: DataFrameLike):
177199
res = subset_frame(small_df, rows=[1])
178200

179-
assert_frame_equal(res, small_df.__class__({"x": [2], "y": [4]}))
201+
assert_frame_equal2(
202+
res,
203+
{"x": [2], "y": [4]},
204+
)
180205

181206

182207
def test_subset_frame_cols_single(small_df: DataFrameLike):
183208
# TODO: include test of polars
184209
res = subset_frame(small_df, cols=["y"])
185210

186-
assert_frame_equal(res, small_df.__class__({"y": [3, 4]}))
211+
assert_frame_equal2(
212+
res,
213+
{"y": [3, 4]},
214+
)
187215

188216

189217
def test_shape(small_df: DataFrameLike):

0 commit comments

Comments
 (0)