Skip to content

Commit 87e52d1

Browse files
Refactor UDF utils and add a hook to enable NRT when necessary (#18823)
This PR makes a few changes in preparation for #18453: - Consolidates the shared logic among the various similar `apply` kernels (series, dataframe, groupby) into a base class and subclasses that contain the particulars - Creates a mechanism for detecting if NRT must be enabled for a particular kernel compilation - Deletes redundant tests/code relevant to strings and potentially other types that would need to be refcounted Authors: - https://github.com/brandon-b-miller Approvers: - Matthew Murray (https://github.com/Matt711) - Vyas Ramasubramani (https://github.com/vyasr) URL: #18823
1 parent 54546be commit 87e52d1

File tree

11 files changed

+433
-692
lines changed

11 files changed

+433
-692
lines changed

python/cudf/cudf/core/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
from cudf.core.multiindex import MultiIndex
8383
from cudf.core.resample import DataFrameResampler
8484
from cudf.core.series import Series
85-
from cudf.core.udf.row_function import _get_row_kernel
85+
from cudf.core.udf.row_function import DataFrameApplyKernel
8686
from cudf.errors import MixedTypeError
8787
from cudf.utils import applyutils, docutils, ioutils, queryutils
8888
from cudf.utils.docutils import copy_docstring
@@ -4903,7 +4903,7 @@ def apply(
49034903
if by_row != "compat":
49044904
raise NotImplementedError("by_row is currently not supported.")
49054905

4906-
return self._apply(func, _get_row_kernel, *args, **kwargs)
4906+
return self._apply(func, DataFrameApplyKernel, *args, **kwargs)
49074907

49084908
def applymap(
49094909
self,

python/cudf/cudf/core/indexed_frame.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from cudf.core.resample import _Resampler
5454
from cudf.core.scalar import pa_scalar_to_plc_scalar
5555
from cudf.core.udf.utils import (
56-
_compile_or_get,
5756
_get_input_args_from_frame,
5857
_post_process_output_col,
5958
_return_arr_from_dtype,
@@ -3471,14 +3470,13 @@ def add_suffix(self, suffix, axis=None):
34713470

34723471
@acquire_spill_lock()
34733472
@_performance_tracking
3474-
def _apply(self, func, kernel_getter, *args, **kwargs):
3473+
def _apply(self, func, kernel_class, *args, **kwargs):
34753474
"""Apply `func` across the rows of the frame."""
34763475
if kwargs:
34773476
raise ValueError("UDFs using **kwargs are not yet supported.")
34783477
try:
3479-
kernel, retty = _compile_or_get(
3480-
self, func, args, kernel_getter=kernel_getter
3481-
)
3478+
kr = kernel_class(self, func, args)
3479+
kernel, retty = kr.get_kernel()
34823480
except Exception as e:
34833481
raise ValueError(
34843482
"user defined function compilation failed."

python/cudf/cudf/core/series.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
)
6060
from cudf.core.resample import SeriesResampler
6161
from cudf.core.single_column_frame import SingleColumnFrame
62-
from cudf.core.udf.scalar_function import _get_scalar_kernel
62+
from cudf.core.udf.scalar_function import SeriesApplyKernel
6363
from cudf.utils import docutils
6464
from cudf.utils.docutils import copy_docstring
6565
from cudf.utils.dtypes import (
@@ -2636,7 +2636,7 @@ def apply(
26362636
elif by_row != "compat":
26372637
raise NotImplementedError("by_row is currently not supported.")
26382638

2639-
result = self._apply(func, _get_scalar_kernel, *args, **kwargs)
2639+
result = self._apply(func, SeriesApplyKernel, *args, **kwargs)
26402640
result.name = self.name
26412641
return result
26422642

python/cudf/cudf/core/udf/groupby_utils.py

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) 2022-2025, NVIDIA CORPORATION.
22

33

4+
from functools import cache
5+
46
import cupy as cp
57
import numpy as np
68
from numba import cuda, types
@@ -19,15 +21,12 @@
1921
group_initializer_template,
2022
groupby_apply_kernel_template,
2123
)
24+
from cudf.core.udf.udf_kernel_base import ApplyKernelBase
2225
from cudf.core.udf.utils import (
2326
UDFError,
2427
_all_dtypes_from_frame,
25-
_compile_or_get,
2628
_get_extensionty_size,
27-
_get_kernel,
28-
_get_udf_return_type,
2929
_supported_cols_from_frame,
30-
_supported_dtypes_from_frame,
3130
)
3231
from cudf.utils._numba import _CUDFNumbaConfig
3332
from cudf.utils.performance_tracking import _performance_tracking
@@ -105,27 +104,6 @@ def _groupby_apply_kernel_string_from_template(frame, args):
105104
)
106105

107106

108-
def _get_groupby_apply_kernel(frame, func, args):
109-
np_field_types = np.dtype(list(_all_dtypes_from_frame(frame).items()))
110-
dataframe_group_type = _get_frame_groupby_type(
111-
np_field_types, frame.index.dtype
112-
)
113-
114-
return_type = _get_udf_return_type(dataframe_group_type, func, args)
115-
116-
# Dict of 'local' variables into which `_kernel` is defined
117-
global_exec_context = {
118-
"cuda": cuda,
119-
"Group": Group,
120-
"dataframe_group_type": dataframe_group_type,
121-
"types": types,
122-
}
123-
kernel_string = _groupby_apply_kernel_string_from_template(frame, args)
124-
kernel = _get_kernel(kernel_string, global_exec_context, None, func)
125-
126-
return kernel, return_type
127-
128-
129107
@_performance_tracking
130108
def jit_groupby_apply(offsets, grouped_values, function, *args):
131109
"""
@@ -143,13 +121,8 @@ def jit_groupby_apply(offsets, grouped_values, function, *args):
143121
The user-defined function to execute
144122
"""
145123

146-
kernel, return_type = _compile_or_get(
147-
grouped_values,
148-
function,
149-
args,
150-
kernel_getter=_get_groupby_apply_kernel,
151-
suffix="__GROUPBY_APPLY_UDF",
152-
)
124+
kr = GroupByApplyKernel(grouped_values, function, args)
125+
kernel, return_type = kr.get_kernel()
153126

154127
offsets = cp.asarray(offsets)
155128
ngroups = len(offsets) - 1
@@ -211,18 +184,70 @@ def _can_be_jitted(frame, func, args):
211184

212185
if any(col.has_nulls() for col in frame._columns):
213186
return False
214-
np_field_types = np.dtype(
215-
list(
216-
_supported_dtypes_from_frame(
217-
frame, supported_types=SUPPORTED_GROUPBY_NUMPY_TYPES
218-
).items()
219-
)
220-
)
221-
dataframe_group_type = _get_frame_groupby_type(
222-
np_field_types, frame.index.dtype
223-
)
187+
kr = GroupByApplyKernel(frame, func, args)
224188
try:
225-
_get_udf_return_type(dataframe_group_type, func, args)
189+
kr._get_udf_return_type()
226190
return True
227191
except (UDFError, TypingError):
228192
return False
193+
194+
195+
class GroupByApplyKernel(ApplyKernelBase):
196+
"""
197+
Class representing a kernel that computes the result of
198+
a GroupBy.apply operation. Expects that the user passed
199+
a function that operates on a single group of the data,
200+
for example
201+
202+
def f(group):
203+
return group['x'].sum() + group['y'].sum()
204+
"""
205+
206+
@property
207+
def kernel_type(self):
208+
return "groupby_apply"
209+
210+
def _get_frame_type(self):
211+
return _get_frame_groupby_type(
212+
np.dtype(list(_all_dtypes_from_frame(self.frame).items())),
213+
self.frame.index.dtype,
214+
)
215+
216+
def _get_kernel_string(self):
217+
# Create argument list for kernel
218+
frame = _supported_cols_from_frame(
219+
self.frame, supported_types=SUPPORTED_GROUPBY_NUMPY_TYPES
220+
)
221+
input_columns = ", ".join(
222+
[f"input_col_{i}" for i in range(len(frame))]
223+
)
224+
extra_args = ", ".join(
225+
[f"extra_arg_{i}" for i in range(len(self.args))]
226+
)
227+
228+
# Generate the initializers for each device function argument
229+
initializers = []
230+
for i, colname in enumerate(frame.keys()):
231+
initializers.append(
232+
group_initializer_template.format(idx=i, name=colname)
233+
)
234+
235+
return groupby_apply_kernel_template.format(
236+
input_columns=input_columns,
237+
extra_args=extra_args,
238+
group_initializers="\n".join(initializers),
239+
)
240+
241+
@cache
242+
def _get_kernel_string_exec_context(self):
243+
dataframe_group_type = self._get_frame_type()
244+
global_exec_context = {
245+
"cuda": cuda,
246+
"Group": Group,
247+
"dataframe_group_type": dataframe_group_type,
248+
"types": types,
249+
}
250+
return global_exec_context
251+
252+
def _construct_signature(self, return_type):
253+
return None

python/cudf/cudf/core/udf/masked_typing.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
22

33
import operator
44

55
import numpy as np
66
from numba import types
7+
from numba.core.datamodel import default_manager
78
from numba.core.extending import (
89
make_attribute_wrapper,
910
models,
@@ -28,6 +29,7 @@
2829
comparison_ops,
2930
unary_ops,
3031
)
32+
from cudf.core.udf.nrt_utils import _current_nrt_context
3133
from cudf.core.udf.strings_typing import (
3234
StringView,
3335
UDFString,
@@ -108,6 +110,12 @@ class MaskedType(types.Type):
108110
def __init__(self, value):
109111
# MaskedType in Numba shall be parameterized
110112
# with a value type
113+
if default_manager[value].has_nrt_meminfo():
114+
ctx = _current_nrt_context.get(None)
115+
if ctx is not None:
116+
# we're in a compilation that is determining
117+
# if NRT must be linked
118+
ctx.use_nrt = True
111119
self.value_type = _type_to_masked_type(value)
112120
super().__init__(name=f"Masked({self.value_type})")
113121

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION.
2+
3+
import contextvars
4+
from contextlib import contextmanager
5+
6+
from numba import config as numba_config
7+
8+
_current_nrt_context: contextvars.ContextVar = contextvars.ContextVar(
9+
"current_nrt_context"
10+
)
11+
12+
13+
class CaptureNRTUsage:
14+
"""
15+
Context manager for determining if NRT is needed.
16+
Managed types may set use_nrt to be true during
17+
instantiation to signal that NRT must be enabled
18+
during code generation.
19+
"""
20+
21+
def __init__(self):
22+
self.use_nrt = False
23+
24+
def __enter__(self):
25+
self._token = _current_nrt_context.set(self)
26+
return self
27+
28+
def __exit__(self, exc_type, exc_val, exc_tb):
29+
_current_nrt_context.reset(self._token)
30+
31+
32+
@contextmanager
33+
def nrt_enabled():
34+
"""
35+
Context manager for enabling NRT via the numba
36+
config. CUDA_ENABLE_NRT may be toggled dynamically
37+
for a single kernel launch, so we use this context
38+
to enable it for those that we know need it.
39+
"""
40+
original_value = numba_config.CUDA_ENABLE_NRT
41+
numba_config.CUDA_ENABLE_NRT = True
42+
try:
43+
yield
44+
finally:
45+
numba_config.CUDA_ENABLE_NRT = original_value

0 commit comments

Comments
 (0)