Skip to content

Commit 4308c82

Browse files
edit narwhals
1 parent 4fbc862 commit 4308c82

File tree

2 files changed

+286
-9
lines changed

2 files changed

+286
-9
lines changed

data_science_tools/narwhals.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
# /// script
2+
# requires-python = ">=3.11"
3+
# dependencies = [
4+
# "duckdb==1.2.2",
5+
# "marimo",
6+
# "narwhals==1.39.0",
7+
# "pandas==2.2.3",
8+
# "polars==1.29.0",
9+
# "pyarrow==20.0.0",
10+
# "pyspark==3.5.5",
11+
# "sqlframe==3.32.1",
12+
# ]
13+
# ///
14+
15+
import marimo
16+
17+
__generated_with = "0.13.6"
18+
app = marimo.App(width="medium")
19+
20+
21+
@app.cell
22+
def _():
23+
import marimo as mo
24+
25+
return (mo,)
26+
27+
28+
@app.cell(hide_code=True)
29+
def _(mo):
30+
mo.md(r"""# Motivation""")
31+
return
32+
33+
34+
@app.cell
35+
def _():
36+
from datetime import datetime
37+
38+
import pandas as pd
39+
40+
df = pd.DataFrame(
41+
{
42+
"date": [datetime(2020, 1, 1), datetime(2020, 1, 8), datetime(2020, 2, 3)],
43+
"price": [1, 4, 3],
44+
}
45+
)
46+
df
47+
return datetime, df, pd
48+
49+
50+
@app.cell
51+
def _(df):
52+
def monthly_aggregate_pandas(user_df):
53+
return user_df.resample("MS", on="date")[["price"]].mean()
54+
55+
monthly_aggregate_pandas(df)
56+
return
57+
58+
59+
@app.cell(hide_code=True)
60+
def _(mo):
61+
mo.md(
62+
r"""
63+
# Dataframe-agnostic data science
64+
65+
Let's define a dataframe-agnostic function to calculate monthly average prices. It needs to support pandas, Polars, PySpark, DuckDB, PyArrow, Dask, and cuDF, without doing any conversion between libraries.
66+
67+
## Bad solution: just convert to pandas
68+
69+
This kind of works, but:
70+
71+
- It doesn't return to the user the same class they started with.
72+
- It kills lazy execution.
73+
- It kills GPU acceleration.
74+
- If forces pandas as a required dependency.
75+
"""
76+
)
77+
return
78+
79+
80+
@app.cell
81+
def _():
82+
import duckdb
83+
import polars as pl
84+
import pyarrow as pa
85+
import pyspark
86+
import pyspark.sql.functions as F
87+
from pyspark.sql import SparkSession
88+
89+
return F, SparkSession, duckdb, pa, pl, pyspark
90+
91+
92+
@app.cell
93+
def _(duckdb, pa, pd, pl, pyspark):
94+
def monthly_aggregate_bad(user_df):
95+
if isinstance(user_df, pd.DataFrame):
96+
df = user_df
97+
elif isinstance(user_df, pl.DataFrame):
98+
df = user_df.to_pandas()
99+
elif isinstance(user_df, duckdb.DuckDBPyRelation):
100+
df = user_df.df()
101+
elif isinstance(user_df, pa.Table):
102+
df = user_df.to_pandas()
103+
elif isinstance(user_df, pyspark.sql.dataframe.DataFrame):
104+
df = user_df.toPandas()
105+
else:
106+
raise TypeError("Unsupported DataFrame type: cannot convert to pandas")
107+
108+
return df.resample("MS", on="date")[["price"]].mean()
109+
110+
return (monthly_aggregate_bad,)
111+
112+
113+
@app.cell
114+
def _(datetime):
115+
data = {
116+
"date": [datetime(2020, 1, 1), datetime(2020, 1, 8), datetime(2020, 2, 3)],
117+
"price": [1, 4, 3],
118+
}
119+
return (data,)
120+
121+
122+
@app.cell
123+
def _(SparkSession, data, duckdb, monthly_aggregate_bad, pa, pd, pl):
124+
# pandas
125+
pandas_df = pd.DataFrame(data)
126+
monthly_aggregate_bad(pandas_df)
127+
128+
# polars
129+
polars_df = pl.DataFrame(data)
130+
monthly_aggregate_bad(polars_df)
131+
132+
# duckdb
133+
duckdb_df = duckdb.from_df(pandas_df)
134+
monthly_aggregate_bad(duckdb_df)
135+
136+
# pyspark
137+
spark = SparkSession.builder.getOrCreate()
138+
spark_df = spark.createDataFrame(pandas_df)
139+
monthly_aggregate_bad(spark_df)
140+
141+
# pyarrow
142+
arrow_table = pa.table(data)
143+
monthly_aggregate_bad(arrow_table)
144+
return arrow_table, duckdb_df, pandas_df, polars_df, spark_df
145+
146+
147+
@app.cell(hide_code=True)
148+
def _(mo):
149+
mo.md(
150+
r"""
151+
## Unmaintainable solution: different branches for each library
152+
153+
This works, but is unfeasibly difficult to test and maintain, especially when also factoring in API changes between different versions of the same library (e.g. pandas `1.*` vs pandas `2.*`).
154+
"""
155+
)
156+
return
157+
158+
159+
@app.cell
160+
def _(F, pd, pl, pyspark):
161+
def monthly_aggregate_unmaintainable(user_df):
162+
if isinstance(user_df, pd.DataFrame):
163+
result = user_df.resample("MS", on="date")[["price"]].mean()
164+
elif isinstance(user_df, pl.DataFrame):
165+
result = (
166+
user_df.group_by(pl.col("date").dt.truncate("1mo"))
167+
.agg(pl.col("price").mean())
168+
.sort("date")
169+
)
170+
elif isinstance(user_df, pyspark.sql.dataframe.DataFrame):
171+
result = (
172+
user_df.withColumn("date_month", F.date_trunc("month", F.col("date")))
173+
.groupBy("date_month")
174+
.agg(F.mean("price").alias("price_mean"))
175+
.orderBy("date_month")
176+
)
177+
# TODO: more branches for DuckDB, PyArrow, Dask, etc... :sob:
178+
return result
179+
180+
return (monthly_aggregate_unmaintainable,)
181+
182+
183+
@app.cell
184+
def _(monthly_aggregate_unmaintainable, pandas_df, polars_df, spark_df):
185+
# pandas
186+
monthly_aggregate_unmaintainable(pandas_df)
187+
188+
# polars
189+
monthly_aggregate_unmaintainable(polars_df)
190+
191+
# pyspark
192+
monthly_aggregate_unmaintainable(spark_df)
193+
return
194+
195+
196+
@app.cell(hide_code=True)
197+
def _(mo):
198+
mo.md(
199+
r"""
200+
## Best solution: Narwhals as a unified dataframe interface
201+
202+
- Preserves lazy execution and GPU acceleration.
203+
- Users get back what they started with.
204+
- Easy to write and maintain.
205+
- Strong and complete static typing.
206+
"""
207+
)
208+
return
209+
210+
211+
@app.cell
212+
def _():
213+
import narwhals as nw
214+
from narwhals.typing import IntoFrameT
215+
216+
def monthly_aggregate(user_df: IntoFrameT) -> IntoFrameT:
217+
return (
218+
nw.from_native(user_df)
219+
.group_by(nw.col("date").dt.truncate("1mo"))
220+
.agg(nw.col("price").mean())
221+
.sort("date")
222+
.to_native()
223+
)
224+
225+
return (monthly_aggregate,)
226+
227+
228+
@app.cell
229+
def _(
230+
arrow_table,
231+
duckdb_df,
232+
monthly_aggregate,
233+
pandas_df,
234+
polars_df,
235+
spark_df,
236+
):
237+
# pandas
238+
monthly_aggregate(pandas_df)
239+
240+
# polars
241+
monthly_aggregate(polars_df)
242+
243+
# duckdb
244+
monthly_aggregate(duckdb_df)
245+
246+
# pyarrow
247+
monthly_aggregate(arrow_table)
248+
249+
# pyspark
250+
monthly_aggregate(spark_df)
251+
return
252+
253+
254+
@app.cell(hide_code=True)
255+
def _(mo):
256+
mo.md(
257+
r"""
258+
## Bonus - can we generate SQL?
259+
260+
Narwhals comes with an extra bonus feature: by combining it with [SQLFrame](https://github.com/eakmanrq/sqlframe), we can easily transpiling the Polars API to any major SQL dialect. For example, to translate to the DataBricks SQL dialect, we can do:
261+
"""
262+
)
263+
return
264+
265+
266+
@app.cell
267+
def _(monthly_aggregate, pandas_df):
268+
from sqlframe.duckdb import DuckDBSession
269+
270+
sqlframe = DuckDBSession()
271+
sqlframe_df = sqlframe.createDataFrame(pandas_df)
272+
sqlframe_result = monthly_aggregate(sqlframe_df)
273+
print(sqlframe_result.sql(dialect="databricks"))
274+
return
275+
276+
277+
@app.cell
278+
def _():
279+
return
280+
281+
282+
if __name__ == "__main__":
283+
app.run()

pyproject.toml

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,9 @@ description = "Add your description here"
55
readme = "README.md"
66
requires-python = ">=3.11"
77
dependencies = [
8-
"loguru>=0.7.3",
9-
"marimo==0.13.6",
10-
"narwhals==1.36.0",
11-
"nbformat>=5.10.4",
12-
"pandas>=2.2.3",
13-
"pyspark[sql]>=3.5.5",
8+
"marimo>=0.13.7",
9+
"pre-commit>=4.2.0",
1410
]
1511

1612
[dependency-groups]
17-
dev = [
18-
"pytest>=8.3.5",
19-
]
13+
dev = []

0 commit comments

Comments
 (0)