diff --git a/docs/cudf/source/conf.py b/docs/cudf/source/conf.py index f55dfe8b36c..9fe76619585 100644 --- a/docs/cudf/source/conf.py +++ b/docs/cudf/source/conf.py @@ -405,6 +405,7 @@ def _generate_namespaces(namespaces): "type_id", # Unknown base types "int32_t", + "uint64_t", "void", } diff --git a/python/cudf_polars/cudf_polars/experimental/dask_serialize.py b/python/cudf_polars/cudf_polars/experimental/dask_registers.py similarity index 79% rename from python/cudf_polars/cudf_polars/experimental/dask_serialize.py rename to python/cudf_polars/cudf_polars/experimental/dask_registers.py index aed6e5d6177..c9f2b0be72b 100644 --- a/python/cudf_polars/cudf_polars/experimental/dask_serialize.py +++ b/python/cudf_polars/cudf_polars/experimental/dask_registers.py @@ -1,12 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 -"""Dask serialization.""" +"""Dask function registrations such as serializers and dispatch implementations.""" from __future__ import annotations from typing import TYPE_CHECKING, ClassVar, overload +from dask.sizeof import sizeof as sizeof_dispatch from distributed.protocol import dask_deserialize, dask_serialize from distributed.protocol.cuda import cuda_deserialize, cuda_serialize from distributed.utils import log_errors @@ -21,35 +22,35 @@ from cudf_polars.typing import ColumnHeader, DataFrameHeader -__all__ = ["SerializerManager", "register"] +__all__ = ["DaskRegisterManager", "register"] -class SerializerManager: # pragma: no cover; Only used with Distributed scheduler +class DaskRegisterManager: # pragma: no cover; Only used with Distributed scheduler """Manager to ensure ensure serializer is only registered once.""" - _serializer_registered: bool = False + _registered: bool = False _client_run_executed: ClassVar[set[str]] = set() @classmethod - def register_serialize(cls) -> None: + def register_once(cls) -> None: """Register Dask/cudf-polars serializers in calling process.""" - if not cls._serializer_registered: - from cudf_polars.experimental.dask_serialize import register + if not cls._registered: + from cudf_polars.experimental.dask_registers import register register() - cls._serializer_registered = True + cls._registered = True @classmethod def run_on_cluster(cls, client: Client) -> None: - """Run serializer registration on the workers and scheduler.""" + """Run register on the workers and scheduler once.""" if client.id not in cls._client_run_executed: - client.run(cls.register_serialize) - client.run_on_scheduler(cls.register_serialize) + client.run(cls.register_once) + client.run_on_scheduler(cls.register_once) cls._client_run_executed.add(client.id) def register() -> None: - """Register dask serialization routines for DataFrames.""" + """Register dask serialization and dispatch functions.""" @overload def serialize_column_or_frame( @@ -128,3 +129,13 @@ def _(header: ColumnHeader, frames: tuple[memoryview, memoryview]) -> Column: # Copy the second frame (the gpudata in host memory) back to the gpu frames = frames[0], plc.gpumemoryview(rmm.DeviceBuffer.to_device(frames[1])) return Column.deserialize(header, frames) + + @sizeof_dispatch.register(Column) + def _(x: Column) -> int: + """The total size of the device buffers used by the DataFrame or Column.""" + return x.obj.device_buffer_size() + + @sizeof_dispatch.register(DataFrame) + def _(x: DataFrame) -> int: + """The total size of the device buffers used by the DataFrame or Column.""" + return sum(c.obj.device_buffer_size() for c in x.columns) diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index acb8ab2bb13..aee7590c4b2 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -145,11 +145,11 @@ def get_scheduler(config_options: ConfigOptions) -> Any: ): # pragma: no cover; block depends on executor type and Distributed cluster from distributed import get_client - from cudf_polars.experimental.dask_serialize import SerializerManager + from cudf_polars.experimental.dask_registers import DaskRegisterManager client = get_client() - SerializerManager.register_serialize() - SerializerManager.run_on_cluster(client) + DaskRegisterManager.register_once() + DaskRegisterManager.run_on_cluster(client) return client.get elif scheduler == "synchronous": from cudf_polars.experimental.scheduler import synchronous_scheduler diff --git a/python/cudf_polars/tests/experimental/test_dask_serialize.py b/python/cudf_polars/tests/experimental/test_dask_serialize.py index ce907f394d3..1a529288701 100644 --- a/python/cudf_polars/tests/experimental/test_dask_serialize.py +++ b/python/cudf_polars/tests/experimental/test_dask_serialize.py @@ -13,7 +13,7 @@ import rmm from cudf_polars.containers import DataFrame -from cudf_polars.experimental.dask_serialize import register +from cudf_polars.experimental.dask_registers import register # Must register serializers before running tests register() diff --git a/python/cudf_polars/tests/experimental/test_dask_sizeof.py b/python/cudf_polars/tests/experimental/test_dask_sizeof.py new file mode 100644 index 00000000000..48a718707b8 --- /dev/null +++ b/python/cudf_polars/tests/experimental/test_dask_sizeof.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pyarrow as pa +import pytest +from dask.sizeof import sizeof + +import pylibcudf as plc + +from cudf_polars.containers import DataFrame +from cudf_polars.experimental.dask_registers import register + +# Must register sizeof dispatch before running tests +register() + + +@pytest.mark.parametrize( + "arrow_tbl, size", + [ + (pa.table([]), 0), + (pa.table({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), 9 * 8), + (pa.table({"a": [1, 2, 3]}), 3 * 8), + (pa.table({"a": ["a"], "b": ["bc"]}), 2 * 8 + 3), + (pa.table({"a": [1, 2, None]}), 88), + ], +) +def test_dask_sizeof(arrow_tbl, size): + plc_tbl = plc.interop.from_arrow(arrow_tbl) + df = DataFrame.from_table(plc_tbl, names=arrow_tbl.column_names) + assert sizeof(df) == size + assert sum(sizeof(c) for c in df.columns) == size diff --git a/python/pylibcudf/pylibcudf/column.pxd b/python/pylibcudf/pylibcudf/column.pxd index d4d5eefbf27..61b2f32594f 100644 --- a/python/pylibcudf/pylibcudf/column.pxd +++ b/python/pylibcudf/pylibcudf/column.pxd @@ -2,6 +2,8 @@ from libcpp.memory cimport unique_ptr from libcpp.vector cimport vector +from libc.stdint cimport uint64_t + from rmm.librmm.device_buffer cimport device_buffer from rmm.pylibrmm.stream cimport Stream from pylibcudf.libcudf.column.column cimport column @@ -68,6 +70,7 @@ cdef class Column: cpdef gpumemoryview null_mask(self) cpdef list children(self) cpdef Column copy(self) + cpdef uint64_t device_buffer_size(self) cpdef Column with_mask(self, gpumemoryview, size_type) cpdef ListColumnView list_view(self) diff --git a/python/pylibcudf/pylibcudf/column.pyi b/python/pylibcudf/pylibcudf/column.pyi index 17a49ec4b0b..75f1e30858b 100644 --- a/python/pylibcudf/pylibcudf/column.pyi +++ b/python/pylibcudf/pylibcudf/column.pyi @@ -50,6 +50,7 @@ class Column: def null_mask(self) -> gpumemoryview | None: ... def children(self) -> list[Column]: ... def copy(self) -> Column: ... + def device_buffer_size(self) -> int: ... def with_mask( self, mask: gpumemoryview | None, null_count: int ) -> Column: ... diff --git a/python/pylibcudf/pylibcudf/column.pyx b/python/pylibcudf/pylibcudf/column.pyx index ffd13db8894..4d16e48ffc6 100644 --- a/python/pylibcudf/pylibcudf/column.pyx +++ b/python/pylibcudf/pylibcudf/column.pyx @@ -51,11 +51,11 @@ from ._interop_helpers cimport ( from .null_mask cimport bitmask_allocation_size_bytes from .utils cimport _get_stream +from .gpumemoryview import _datatype_from_dtype_desc from ._interop_helpers import ColumnMetadata import functools - __all__ = ["Column", "ListColumnView", "is_c_contiguous"] @@ -767,6 +767,30 @@ cdef class Column: c_result = make_unique[column](self.view()) return Column.from_libcudf(move(c_result)) + cpdef uint64_t device_buffer_size(self): + """ + The total size of the device buffers used by the Column. + + Notes + ----- + Since Columns rely on Python memoryview-like semantics to maintain + shared ownership of the data, the device buffers underlying this column + might be shared between other data structures including other columns. + + Returns + ------- + Number of bytes. + """ + cdef uint64_t ret = 0 + if self.data() is not None: + ret += self.data().nbytes + if self.null_mask() is not None: + ret += self.null_mask().nbytes + if self.children() is not None: + for child in self.children(): + ret += (child).device_buffer_size() + return ret + def _create_nested_column_metadata(self): return ColumnMetadata( children_meta=[ @@ -855,34 +879,6 @@ cdef class ListColumnView: return lists_column_view(self._column.view()) -@functools.cache -def _datatype_from_dtype_desc(desc): - mapping = { - 'u1': type_id.UINT8, - 'u2': type_id.UINT16, - 'u4': type_id.UINT32, - 'u8': type_id.UINT64, - 'i1': type_id.INT8, - 'i2': type_id.INT16, - 'i4': type_id.INT32, - 'i8': type_id.INT64, - 'f4': type_id.FLOAT32, - 'f8': type_id.FLOAT64, - 'b1': type_id.BOOL8, - 'M8[s]': type_id.TIMESTAMP_SECONDS, - 'M8[ms]': type_id.TIMESTAMP_MILLISECONDS, - 'M8[us]': type_id.TIMESTAMP_MICROSECONDS, - 'M8[ns]': type_id.TIMESTAMP_NANOSECONDS, - 'm8[s]': type_id.DURATION_SECONDS, - 'm8[ms]': type_id.DURATION_MILLISECONDS, - 'm8[us]': type_id.DURATION_MICROSECONDS, - 'm8[ns]': type_id.DURATION_NANOSECONDS, - } - if desc not in mapping: - raise ValueError(f"Unsupported dtype: {desc}") - return DataType(mapping[desc]) - - def is_c_contiguous( shape: Sequence[int], strides: None | Sequence[int], itemsize: int ) -> bool: diff --git a/python/pylibcudf/pylibcudf/gpumemoryview.pxd b/python/pylibcudf/pylibcudf/gpumemoryview.pxd index c5cd32920dd..02bb69dbbe4 100644 --- a/python/pylibcudf/pylibcudf/gpumemoryview.pxd +++ b/python/pylibcudf/pylibcudf/gpumemoryview.pxd @@ -1,5 +1,5 @@ # Copyright (c) 2023-2025, NVIDIA CORPORATION. -from libc.stdint cimport uintptr_t +from libc.stdint cimport uint64_t, uintptr_t cdef class gpumemoryview: # TODO: Eventually probably want to make this opaque, but for now it's fine @@ -7,3 +7,4 @@ cdef class gpumemoryview: cdef readonly uintptr_t ptr cdef readonly object obj cdef readonly dict cai + cdef readonly uint64_t nbytes diff --git a/python/pylibcudf/pylibcudf/gpumemoryview.pyx b/python/pylibcudf/pylibcudf/gpumemoryview.pyx index 6c13cac4f3f..0a1a8f33ad9 100644 --- a/python/pylibcudf/pylibcudf/gpumemoryview.pyx +++ b/python/pylibcudf/pylibcudf/gpumemoryview.pyx @@ -3,8 +3,42 @@ import functools import operator +from .types cimport DataType, size_of, type_id + +from pylibcudf.libcudf.types cimport size_type + + __all__ = ["gpumemoryview"] + +@functools.cache +def _datatype_from_dtype_desc(desc): + mapping = { + 'u1': type_id.UINT8, + 'u2': type_id.UINT16, + 'u4': type_id.UINT32, + 'u8': type_id.UINT64, + 'i1': type_id.INT8, + 'i2': type_id.INT16, + 'i4': type_id.INT32, + 'i8': type_id.INT64, + 'f4': type_id.FLOAT32, + 'f8': type_id.FLOAT64, + 'b1': type_id.BOOL8, + 'M8[s]': type_id.TIMESTAMP_SECONDS, + 'M8[ms]': type_id.TIMESTAMP_MILLISECONDS, + 'M8[us]': type_id.TIMESTAMP_MICROSECONDS, + 'M8[ns]': type_id.TIMESTAMP_NANOSECONDS, + 'm8[s]': type_id.DURATION_SECONDS, + 'm8[ms]': type_id.DURATION_MILLISECONDS, + 'm8[us]': type_id.DURATION_MICROSECONDS, + 'm8[ns]': type_id.DURATION_NANOSECONDS, + } + if desc not in mapping: + raise ValueError(f"Unsupported dtype: {desc}") + return DataType(mapping[desc]) + + cdef class gpumemoryview: """Minimal representation of a memory buffer. @@ -27,6 +61,14 @@ cdef class gpumemoryview: # TODO: Need to respect readonly self.ptr = cai["data"][0] + # Compute the buffer size. + cdef size_type itemsize = size_of( + _datatype_from_dtype_desc( + cai["typestr"][1:] # ignore the byteorder (the first char). + ) + ) + self.nbytes = functools.reduce(operator.mul, cai["shape"]) * itemsize + @property def __cuda_array_interface__(self): return self.cai @@ -34,16 +76,4 @@ cdef class gpumemoryview: def __len__(self): return self.obj.__cuda_array_interface__["shape"][0] - @property - def nbytes(self): - cai = self.obj.__cuda_array_interface__ - shape, typestr = cai["shape"], cai["typestr"] - - # Get element size from typestr, format is two character specifying - # the type and the latter part is the number of bytes. E.g., '