Skip to content

Commit 5e24432

Browse files
committed
add Narwhals Marimo notebook
1 parent 4fbc862 commit 5e24432

File tree

1 file changed

+214
-0
lines changed

1 file changed

+214
-0
lines changed

data_science_tools/narwhals.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import marimo
2+
3+
__generated_with = "0.13.7"
4+
app = marimo.App(width="medium")
5+
6+
7+
@app.cell
8+
def _():
9+
import marimo as mo
10+
return (mo,)
11+
12+
13+
@app.cell
14+
def _(mo):
15+
mo.md(
16+
r"""
17+
# Dataframe-agnostic data science
18+
19+
Let's define a dataframe-agnostic function to calculate monthly average prices. It needs to support pandas, Polars, PySpark, DuckDB, PyArrow, Dask, and cuDF, without doing any conversion between libraries.
20+
21+
## Bad solution: just convert to pandas
22+
23+
This kind of works, but:
24+
25+
- It doesn't return to the user the same class they started with.
26+
- It kills lazy execution.
27+
- It kills GPU acceleration.
28+
- If forces pandas as a required dependency.
29+
"""
30+
)
31+
return
32+
33+
34+
@app.function
35+
def monthly_aggregate_bad(user_df):
36+
if hasattr(user_df, "to_pandas"):
37+
df = user_df.to_pandas()
38+
elif hasattr(user_df, "toPandas"):
39+
df = user_df.toPandas()
40+
elif hasattr(user_df, "_to_pandas"):
41+
df = user_df._to_pandas()
42+
return df.resample("MS", on="date")[["price"]].mean()
43+
44+
45+
@app.cell
46+
def _(mo):
47+
mo.md(
48+
r"""
49+
## Unmaintainable solution: different branches for each library
50+
51+
This works, but is unfeasibly difficult to test and maintain, especially when also factoring in API changes between different versions of the same library (e.g. pandas `1.*` vs pandas `2.*`).
52+
"""
53+
)
54+
return
55+
56+
57+
@app.cell
58+
def _(F):
59+
import pandas as pd
60+
import polars as pl
61+
import duckdb
62+
import pyspark
63+
64+
65+
def monthly_aggregate_unmaintainable(user_df):
66+
if isinstance(user_df, pd.DataFrame):
67+
result = user_df.resample("MS", on="date")[["price"]].mean()
68+
elif isinstance(user_df, pl.DataFrame):
69+
result = (
70+
user_df.group_by(pl.col("date").dt.truncate("1mo"))
71+
.agg(pl.col("price").mean())
72+
.sort("date")
73+
)
74+
elif isinstance(user_df, pyspark.sql.dataframe.DataFrame):
75+
result = (
76+
user_df.groupBy(F.date_trunc("month", F.col("date")))
77+
.agg(F.mean("price"))
78+
.orderBy("date")
79+
)
80+
elif isinstance(user_df, duckdb.DuckDBPyRelation):
81+
result = user_df.aggregate(
82+
[
83+
duckdb.FunctionExpression(
84+
"time_bucket",
85+
duckdb.ConstantExpression("1 month"),
86+
duckdb.FunctionExpression("date"),
87+
).alias("date"),
88+
duckdb.FunctionExpression("mean", "price").alias("price"),
89+
],
90+
).sort("date")
91+
# TODO: more branches for PyArrow, Dask, etc... :sob:
92+
return result
93+
return duckdb, pd, pl
94+
95+
96+
@app.cell
97+
def _(mo):
98+
mo.md(
99+
r"""
100+
## Best solution: Narwhals as a unified dataframe interface
101+
102+
- Preserves lazy execution and GPU acceleration.
103+
- Users get back what they started with.
104+
- Easy to write and maintain.
105+
- Strong and complete static typing.
106+
"""
107+
)
108+
return
109+
110+
111+
@app.cell
112+
def _():
113+
import narwhals as nw
114+
from narwhals.typing import IntoFrameT
115+
116+
117+
def monthly_aggregate(user_df: IntoFrameT) -> IntoFrameT:
118+
return (
119+
nw.from_native(user_df)
120+
.group_by(nw.col("date").dt.truncate("1mo"))
121+
.agg(nw.col("price").mean())
122+
.sort("date")
123+
.to_native()
124+
)
125+
return (monthly_aggregate,)
126+
127+
128+
@app.cell
129+
def _(mo):
130+
mo.md(r"""## Demo: let's verify that it works!""")
131+
return
132+
133+
134+
@app.cell
135+
def _():
136+
from datetime import datetime
137+
138+
data = {
139+
"date": [datetime(2020, 1, 1), datetime(2020, 1, 8), datetime(2020, 2, 3)],
140+
"price": [1, 4, 3],
141+
}
142+
return (data,)
143+
144+
145+
@app.cell
146+
def _(data, monthly_aggregate, pd):
147+
# pandas
148+
df_pd = pd.DataFrame(data)
149+
monthly_aggregate(df_pd)
150+
return (df_pd,)
151+
152+
153+
@app.cell
154+
def _(data, monthly_aggregate, pl):
155+
# Polars
156+
df_pl = pl.DataFrame(data)
157+
monthly_aggregate(df_pl)
158+
return
159+
160+
161+
@app.cell
162+
def _(duckdb, monthly_aggregate):
163+
# DuckDB
164+
rel = duckdb.sql("""
165+
from values (timestamp '2020-01-01', 1),
166+
(timestamp '2020-01-08', 4),
167+
(timestamp '2020-02-03', 3)
168+
df(date, price)
169+
select *
170+
""")
171+
monthly_aggregate(rel)
172+
return
173+
174+
175+
@app.cell
176+
def _(data, monthly_aggregate):
177+
# PyArrow
178+
import pyarrow as pa
179+
180+
tbl = pa.table(data)
181+
monthly_aggregate(tbl)
182+
return
183+
184+
185+
@app.cell
186+
def _(mo):
187+
mo.md(
188+
r"""
189+
## Bonus - can we generate SQL?
190+
191+
Narwhals comes with an extra bonus feature: by combining it with [SQLFrame](https://github.com/eakmanrq/sqlframe), we can easily transpiling the Polars API to any major SQL dialect. For example, to translate to the DataBricks SQL dialect, we can do:
192+
"""
193+
)
194+
return
195+
196+
197+
@app.cell
198+
def _(df_pd, monthly_aggregate):
199+
from sqlframe.duckdb import DuckDBSession
200+
201+
sqlframe = DuckDBSession()
202+
sqlframe_df = sqlframe.createDataFrame(df_pd)
203+
sqlframe_result = monthly_aggregate(sqlframe_df)
204+
print(sqlframe_result.sql(dialect="databricks"))
205+
return
206+
207+
208+
@app.cell
209+
def _():
210+
return
211+
212+
213+
if __name__ == "__main__":
214+
app.run()

0 commit comments

Comments
 (0)