This commit is contained in:
2025-09-07 22:09:54 +02:00
parent e1b817252c
commit 2fc0d000b6
7796 changed files with 2159515 additions and 933 deletions

View File

@ -0,0 +1,502 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import dask.dataframe as dd
from narwhals._dask.utils import add_row_index, evaluate_exprs
from narwhals._expression_parsing import ExprKind
from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name
from narwhals._typing_compat import assert_never
from narwhals._utils import (
Implementation,
ValidateBackendVersion,
_remap_full_join_keys,
check_column_names_are_unique,
check_columns_exist,
generate_temporary_column_name,
not_implemented,
parse_columns_to_drop,
zip_strict,
)
from narwhals.typing import CompliantLazyFrame
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from io import BytesIO
from pathlib import Path
from types import ModuleType
import dask.dataframe.dask_expr as dx
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.typing import CompliantDataFrameAny
from narwhals._dask.expr import DaskExpr
from narwhals._dask.group_by import DaskLazyGroupBy
from narwhals._dask.namespace import DaskNamespace
from narwhals._typing import _EagerAllowedImpl
from narwhals._utils import Version, _LimitedContext
from narwhals.dataframe import LazyFrame
from narwhals.dtypes import DType
from narwhals.exceptions import ColumnNotFoundError
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
Incomplete: TypeAlias = "Any"
"""Using `_pandas_like` utils with `_dask`.
Typing this correctly will complicate the `_pandas_like`-side.
Very low priority until `dask` adds typing.
"""
class DaskLazyFrame(
CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"],
ValidateBackendVersion,
):
_implementation = Implementation.DASK
def __init__(
self,
native_dataframe: dd.DataFrame,
*,
version: Version,
validate_backend_version: bool = False,
) -> None:
self._native_frame: dd.DataFrame = native_dataframe
self._version = version
self._cached_schema: dict[str, DType] | None = None
self._cached_columns: list[str] | None = None
if validate_backend_version:
self._validate_backend_version()
@staticmethod
def _is_native(obj: dd.DataFrame | Any) -> TypeIs[dd.DataFrame]:
return isinstance(obj, dd.DataFrame)
@classmethod
def from_native(cls, data: dd.DataFrame, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version)
def to_narwhals(self) -> LazyFrame[dd.DataFrame]:
return self._version.lazyframe(self, level="lazy")
def __native_namespace__(self) -> ModuleType:
if self._implementation is Implementation.DASK:
return self._implementation.to_native_namespace()
msg = f"Expected dask, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def __narwhals_namespace__(self) -> DaskNamespace:
from narwhals._dask.namespace import DaskNamespace
return DaskNamespace(version=self._version)
def __narwhals_lazyframe__(self) -> Self:
return self
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version)
def _with_native(self, df: Any) -> Self:
return self.__class__(df, version=self._version)
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
return check_columns_exist(subset, available=self.columns)
def _iter_columns(self) -> Iterator[dx.Series]:
for _col, ser in self.native.items(): # noqa: PERF102
yield ser
def with_columns(self, *exprs: DaskExpr) -> Self:
new_series = evaluate_exprs(self, *exprs)
return self._with_native(self.native.assign(**dict(new_series)))
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny:
result = self.native.compute(**kwargs)
if backend is None or backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
result,
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame(
pl.from_pandas(result),
validate_backend_version=True,
version=self._version,
)
if backend is Implementation.PYARROW:
import pyarrow as pa # ignore-banned-import
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
pa.Table.from_pandas(result),
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
@property
def columns(self) -> list[str]:
if self._cached_columns is None:
self._cached_columns = (
list(self.schema)
if self._cached_schema is not None
else self.native.columns.tolist()
)
return self._cached_columns
def filter(self, predicate: DaskExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask = predicate(self)[0]
return self._with_native(self.native.loc[mask])
def simple_select(self, *column_names: str) -> Self:
df: Incomplete = self.native
native = select_columns_by_name(df, list(column_names), self._implementation)
return self._with_native(native)
def aggregate(self, *exprs: DaskExpr) -> Self:
new_series = evaluate_exprs(self, *exprs)
df = dd.concat([val.rename(name) for name, val in new_series], axis=1)
return self._with_native(df)
def select(self, *exprs: DaskExpr) -> Self:
new_series = evaluate_exprs(self, *exprs)
df: Incomplete = self.native
df = select_columns_by_name(
df.assign(**dict(new_series)),
[s[0] for s in new_series],
self._implementation,
)
return self._with_native(df)
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
if subset is None:
return self._with_native(self.native.dropna())
plx = self.__narwhals_namespace__()
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
return self.filter(mask)
@property
def schema(self) -> dict[str, DType]:
if self._cached_schema is None:
native_dtypes = self.native.dtypes
self._cached_schema = {
col: native_to_narwhals_dtype(
native_dtypes[col], self._version, self._implementation
)
for col in self.native.columns
}
return self._cached_schema
def collect_schema(self) -> dict[str, DType]:
return self.schema
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
to_drop = parse_columns_to_drop(self, columns, strict=strict)
return self._with_native(self.native.drop(columns=to_drop))
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
# Implementation is based on the following StackOverflow reply:
# https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409
if order_by is None:
return self._with_native(add_row_index(self.native, name))
plx = self.__narwhals_namespace__()
columns = self.columns
const_expr = plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL)
row_index_expr = (
plx.col(name).cum_sum(reverse=False).over(partition_by=[], order_by=order_by)
- 1
)
return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns))
def rename(self, mapping: Mapping[str, str]) -> Self:
return self._with_native(self.native.rename(columns=mapping))
def head(self, n: int) -> Self:
return self._with_native(self.native.head(n=n, compute=False, npartitions=-1))
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self:
if subset and (error := self._check_columns_exist(subset)):
raise error
if keep == "none":
subset = subset or self.columns
token = generate_temporary_column_name(n_bytes=8, columns=subset)
ser = self.native.groupby(subset).size().rename(token)
ser = ser[ser == 1]
unique = ser.reset_index().drop(columns=token)
result = self.native.merge(unique, on=subset, how="inner")
else:
mapped_keep = {"any": "first"}.get(keep, keep)
result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
return self._with_native(result)
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
if isinstance(descending, bool):
ascending: bool | list[bool] = not descending
else:
ascending = [not d for d in descending]
position = "last" if nulls_last else "first"
return self._with_native(
self.native.sort_values(list(by), ascending=ascending, na_position=position)
)
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
df = self.native
schema = self.schema
by = list(by)
if isinstance(reverse, bool) and all(schema[x].is_numeric() for x in by):
if reverse:
return self._with_native(df.nsmallest(k, by))
return self._with_native(df.nlargest(k, by))
if isinstance(reverse, bool):
reverse = [reverse] * len(by)
return self._with_native(
df.sort_values(by, ascending=list(reverse)).head(
n=k, compute=False, npartitions=-1
)
)
def _join_inner(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
return self.native.merge(
other.native,
left_on=left_on,
right_on=right_on,
how="inner",
suffixes=("", suffix),
)
def _join_left(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
result_native = self.native.merge(
other.native,
how="left",
left_on=left_on,
right_on=right_on,
suffixes=("", suffix),
)
extra = [
right_key if right_key not in self.columns else f"{right_key}{suffix}"
for left_key, right_key in zip_strict(left_on, right_on)
if right_key != left_key
]
return result_native.drop(columns=extra)
def _join_full(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
# dask does not retain keys post-join
# we must append the suffix to each key before-hand
right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
other_native = other.native.rename(columns=right_on_mapper)
check_column_names_are_unique(other_native.columns)
right_suffixed = list(right_on_mapper.values())
return self.native.merge(
other_native,
left_on=left_on,
right_on=right_suffixed,
how="outer",
suffixes=("", suffix),
)
def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame:
key_token = generate_temporary_column_name(
n_bytes=8, columns=(*self.columns, *other.columns)
)
return (
self.native.assign(**{key_token: 0})
.merge(
other.native.assign(**{key_token: 0}),
how="inner",
left_on=key_token,
right_on=key_token,
suffixes=("", suffix),
)
.drop(columns=key_token)
)
def _join_semi(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
) -> dd.DataFrame:
other_native = self._join_filter_rename(
other=other,
columns_to_select=list(right_on),
columns_mapping=dict(zip(right_on, left_on)),
)
return self.native.merge(
other_native, how="inner", left_on=left_on, right_on=left_on
)
def _join_anti(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
) -> dd.DataFrame:
indicator_token = generate_temporary_column_name(
n_bytes=8, columns=(*self.columns, *other.columns)
)
other_native = self._join_filter_rename(
other=other,
columns_to_select=list(right_on),
columns_mapping=dict(zip(right_on, left_on)),
)
df = self.native.merge(
other_native,
how="left",
indicator=indicator_token, # pyright: ignore[reportArgumentType]
left_on=left_on,
right_on=left_on,
)
return df[df[indicator_token] == "left_only"].drop(columns=[indicator_token])
def _join_filter_rename(
self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str]
) -> dd.DataFrame:
"""Helper function to avoid creating extra columns and row duplication.
Used in `"anti"` and `"semi`" join's.
Notice that a native object is returned.
"""
other_native: Incomplete = other.native
# rename to avoid creating extra columns in join
return (
select_columns_by_name(other_native, columns_to_select, self._implementation)
.rename(columns=columns_mapping)
.drop_duplicates()
)
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
if how == "cross":
result = self._join_cross(other=other, suffix=suffix)
elif left_on is None or right_on is None: # pragma: no cover
raise ValueError(left_on, right_on)
elif how == "inner":
result = self._join_inner(
other=other, left_on=left_on, right_on=right_on, suffix=suffix
)
elif how == "anti":
result = self._join_anti(other=other, left_on=left_on, right_on=right_on)
elif how == "semi":
result = self._join_semi(other=other, left_on=left_on, right_on=right_on)
elif how == "left":
result = self._join_left(
other=other, left_on=left_on, right_on=right_on, suffix=suffix
)
elif how == "full":
result = self._join_full(
other=other, left_on=left_on, right_on=right_on, suffix=suffix
)
else:
assert_never(how)
return self._with_native(result)
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self:
plx = self.__native_namespace__()
return self._with_native(
plx.merge_asof(
self.native,
other.native,
left_on=left_on,
right_on=right_on,
left_by=by_left,
right_by=by_right,
direction=strategy,
suffixes=("", suffix),
)
)
def group_by(
self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool
) -> DaskLazyGroupBy:
from narwhals._dask.group_by import DaskLazyGroupBy
return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
def tail(self, n: int) -> Self: # pragma: no cover
native_frame = self.native
n_partitions = native_frame.npartitions
if n_partitions == 1:
return self._with_native(self.native.tail(n=n, compute=False))
msg = (
"`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
)
raise NotImplementedError(msg)
def gather_every(self, n: int, offset: int) -> Self:
row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
plx = self.__narwhals_namespace__()
return (
self.with_row_index(row_index_token, order_by=None)
.filter(
(plx.col(row_index_token) >= offset)
& ((plx.col(row_index_token) - offset) % n == 0)
)
.drop([row_index_token], strict=False)
)
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
return self._with_native(
self.native.melt(
id_vars=index,
value_vars=on,
var_name=variable_name,
value_name=value_name,
)
)
def sink_parquet(self, file: str | Path | BytesIO) -> None:
self.native.to_parquet(file)
explode = not_implemented()

View File

@ -0,0 +1,701 @@
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
import pandas as pd
from narwhals._compliant import DepthTrackingExpr, LazyExpr
from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
from narwhals._dask.expr_str import DaskExprStringNamespace
from narwhals._dask.utils import (
add_row_index,
maybe_evaluate_expr,
narwhals_to_native_dtype,
)
from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases
from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype
from narwhals._utils import (
Implementation,
generate_temporary_column_name,
not_implemented,
)
from narwhals.exceptions import InvalidOperationError
if TYPE_CHECKING:
from collections.abc import Sequence
import dask.dataframe.dask_expr as dx
from typing_extensions import Self
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.namespace import DaskNamespace
from narwhals._expression_parsing import ExprKind, ExprMetadata
from narwhals._utils import Version, _LimitedContext
from narwhals.typing import (
FillNullStrategy,
IntoDType,
ModeKeepStrategy,
NonNestedLiteral,
NumericLiteral,
RollingInterpolationMethod,
TemporalLiteral,
)
class DaskExpr(
LazyExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
DepthTrackingExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
):
_implementation: Implementation = Implementation.DASK
def __init__(
self,
call: EvalSeries[DaskLazyFrame, dx.Series], # pyright: ignore[reportInvalidTypeForm]
*,
depth: int,
function_name: str,
evaluate_output_names: EvalNames[DaskLazyFrame],
alias_output_names: AliasNames | None,
version: Version,
scalar_kwargs: ScalarKwargs | None = None,
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._version = version
self._scalar_kwargs = scalar_kwargs or {}
self._metadata: ExprMetadata | None = None
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
return self._call(df)
def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
from narwhals._dask.namespace import DaskNamespace
return DaskNamespace(version=self._version)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
# result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16
# that raised a KeyError for result[0] during collection.
return [result.loc[0][0] for result in self(df)]
return self.__class__(
func,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
scalar_kwargs=self._scalar_kwargs,
)
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[DaskLazyFrame],
/,
*,
context: _LimitedContext,
function_name: str = "",
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
try:
return [
df._native_frame[column_name]
for column_name in evaluate_column_names(df)
]
except KeyError as e:
if error := df._check_columns_exist(evaluate_column_names(df)):
raise error from e
raise
return cls(
func,
depth=0,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
version=context._version,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [df.native.iloc[:, i] for i in column_indices]
return cls(
func,
depth=0,
function_name="nth",
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
version=context._version,
)
def _with_callable(
self,
# First argument to `call` should be `dx.Series`
call: Callable[..., dx.Series],
/,
expr_name: str = "",
scalar_kwargs: ScalarKwargs | None = None,
**expressifiable_args: Self | Any,
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
native_results: list[dx.Series] = []
native_series_list = self._call(df)
other_native_series = {
key: maybe_evaluate_expr(df, value)
for key, value in expressifiable_args.items()
}
for native_series in native_series_list:
result_native = call(native_series, **other_native_series)
native_results.append(result_native)
return native_results
return self.__class__(
func,
depth=self._depth + 1,
function_name=f"{self._function_name}->{expr_name}",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
scalar_kwargs=scalar_kwargs,
)
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
current_alias_output_names = self._alias_output_names
alias_output_names = (
None
if func is None
else func
if current_alias_output_names is None
else lambda output_names: func(current_alias_output_names(output_names))
)
return type(self)(
call=self._call,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=alias_output_names,
version=self._version,
scalar_kwargs=self._scalar_kwargs,
)
def _with_binary(
self,
call: Callable[[dx.Series, Any], dx.Series],
name: str,
other: Any,
*,
reverse: bool = False,
) -> Self:
result = self._with_callable(
lambda expr, other: call(expr, other), name, other=other
)
if reverse:
result = result.alias("literal")
return result
def _binary_op(self, op_name: str, other: Any) -> Self:
return self._with_binary(
lambda expr, other: getattr(expr, op_name)(other), op_name, other
)
def _reverse_binary_op(
self, op_name: str, operator_func: Callable[..., dx.Series], other: Any
) -> Self:
return self._with_binary(
lambda expr, other: operator_func(other, expr), op_name, other, reverse=True
)
def __add__(self, other: Any) -> Self:
return self._binary_op("__add__", other)
def __sub__(self, other: Any) -> Self:
return self._binary_op("__sub__", other)
def __mul__(self, other: Any) -> Self:
return self._binary_op("__mul__", other)
def __truediv__(self, other: Any) -> Self:
return self._binary_op("__truediv__", other)
def __floordiv__(self, other: Any) -> Self:
return self._binary_op("__floordiv__", other)
def __pow__(self, other: Any) -> Self:
return self._binary_op("__pow__", other)
def __mod__(self, other: Any) -> Self:
return self._binary_op("__mod__", other)
def __eq__(self, other: object) -> Self: # type: ignore[override]
return self._binary_op("__eq__", other)
def __ne__(self, other: object) -> Self: # type: ignore[override]
return self._binary_op("__ne__", other)
def __ge__(self, other: Any) -> Self:
return self._binary_op("__ge__", other)
def __gt__(self, other: Any) -> Self:
return self._binary_op("__gt__", other)
def __le__(self, other: Any) -> Self:
return self._binary_op("__le__", other)
def __lt__(self, other: Any) -> Self:
return self._binary_op("__lt__", other)
def __and__(self, other: Any) -> Self:
return self._binary_op("__and__", other)
def __or__(self, other: Any) -> Self:
return self._binary_op("__or__", other)
def __rsub__(self, other: Any) -> Self:
return self._reverse_binary_op("__rsub__", lambda a, b: a - b, other)
def __rtruediv__(self, other: Any) -> Self:
return self._reverse_binary_op("__rtruediv__", lambda a, b: a / b, other)
def __rfloordiv__(self, other: Any) -> Self:
return self._reverse_binary_op("__rfloordiv__", lambda a, b: a // b, other)
def __rpow__(self, other: Any) -> Self:
return self._reverse_binary_op("__rpow__", lambda a, b: a**b, other)
def __rmod__(self, other: Any) -> Self:
return self._reverse_binary_op("__rmod__", lambda a, b: a % b, other)
def __invert__(self) -> Self:
return self._with_callable(lambda expr: expr.__invert__(), "__invert__")
def mean(self) -> Self:
return self._with_callable(lambda expr: expr.mean().to_series(), "mean")
def median(self) -> Self:
from narwhals.exceptions import InvalidOperationError
def func(s: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK)
if not dtype.is_numeric():
msg = "`median` operation not supported for non-numeric input type."
raise InvalidOperationError(msg)
return s.median_approximate().to_series()
return self._with_callable(func, "median")
def min(self) -> Self:
return self._with_callable(lambda expr: expr.min().to_series(), "min")
def max(self) -> Self:
return self._with_callable(lambda expr: expr.max().to_series(), "max")
def std(self, ddof: int) -> Self:
return self._with_callable(
lambda expr: expr.std(ddof=ddof).to_series(),
"std",
scalar_kwargs={"ddof": ddof},
)
def var(self, ddof: int) -> Self:
return self._with_callable(
lambda expr: expr.var(ddof=ddof).to_series(),
"var",
scalar_kwargs={"ddof": ddof},
)
def skew(self) -> Self:
return self._with_callable(lambda expr: expr.skew().to_series(), "skew")
def kurtosis(self) -> Self:
return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis")
def shift(self, n: int) -> Self:
return self._with_callable(lambda expr: expr.shift(n), "shift")
def cum_sum(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
# https://github.com/dask/dask/issues/11802
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cumsum(), "cum_sum")
def cum_count(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(
lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count"
)
def cum_min(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cummin(), "cum_min")
def cum_max(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cummax(), "cum_max")
def cum_prod(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cumprod(), "cum_prod")
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).sum(),
"rolling_sum",
)
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).mean(),
"rolling_mean",
)
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if ddof == 1:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).var(),
"rolling_var",
)
msg = "Dask backend only supports `ddof=1` for `rolling_var`"
raise NotImplementedError(msg)
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if ddof == 1:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).std(),
"rolling_std",
)
msg = "Dask backend only supports `ddof=1` for `rolling_std`"
raise NotImplementedError(msg)
def sum(self) -> Self:
return self._with_callable(lambda expr: expr.sum().to_series(), "sum")
def count(self) -> Self:
return self._with_callable(lambda expr: expr.count().to_series(), "count")
def round(self, decimals: int) -> Self:
return self._with_callable(lambda expr: expr.round(decimals), "round")
def unique(self) -> Self:
return self._with_callable(lambda expr: expr.unique(), "unique")
def drop_nulls(self) -> Self:
return self._with_callable(lambda expr: expr.dropna(), "drop_nulls")
def abs(self) -> Self:
return self._with_callable(lambda expr: expr.abs(), "abs")
def all(self) -> Self:
return self._with_callable(
lambda expr: expr.all(
axis=None, skipna=True, split_every=False, out=None
).to_series(),
"all",
)
def any(self) -> Self:
return self._with_callable(
lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(),
"any",
)
def fill_nan(self, value: float | None) -> Self:
value_nullable = pd.NA if value is None else value
value_numpy = float("nan") if value is None else value
def func(expr: dx.Series) -> dx.Series:
# If/when pandas exposes an API which distinguishes NaN vs null, use that.
mask = cast("dx.Series", expr != expr) # noqa: PLR0124
mask = mask.fillna(False)
fill = (
value_nullable
if get_dtype_backend(expr.dtype, self._implementation)
else value_numpy
)
return expr.mask(mask, fill) # pyright: ignore[reportArgumentType]
return self._with_callable(func, "fill_nan")
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self:
def func(expr: dx.Series) -> dx.Series:
if value is not None:
res_ser = expr.fillna(value)
else:
res_ser = (
expr.ffill(limit=limit)
if strategy == "forward"
else expr.bfill(limit=limit)
)
return res_ser
return self._with_callable(func, "fill_null")
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self:
return self._with_callable(
lambda expr, lower_bound, upper_bound: expr.clip(
lower=lower_bound, upper=upper_bound
),
"clip",
lower_bound=lower_bound,
upper_bound=upper_bound,
)
def diff(self) -> Self:
return self._with_callable(lambda expr: expr.diff(), "diff")
def n_unique(self) -> Self:
return self._with_callable(
lambda expr: expr.nunique(dropna=False).to_series(), "n_unique"
)
def is_null(self) -> Self:
return self._with_callable(lambda expr: expr.isna(), "is_null")
def is_nan(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(
expr.dtype, self._version, self._implementation
)
if dtype.is_numeric():
return expr != expr # pyright: ignore[reportReturnType] # noqa: PLR0124
msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?"
raise InvalidOperationError(msg)
return self._with_callable(func, "is_null")
def len(self) -> Self:
return self._with_callable(lambda expr: expr.size.to_series(), "len")
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> Self:
if interpolation == "linear":
def func(expr: dx.Series, quantile: float) -> dx.Series:
if expr.npartitions > 1:
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
raise NotImplementedError(msg)
return expr.quantile(
q=quantile, method="dask"
).to_series() # pragma: no cover
return self._with_callable(func, "quantile", quantile=quantile)
msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead."
raise NotImplementedError(msg)
def is_first_distinct(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
frame = add_row_index(expr.to_frame(), col_token)
first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token]
return frame[col_token].isin(first_distinct_index)
return self._with_callable(func, "is_first_distinct")
def is_last_distinct(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
frame = add_row_index(expr.to_frame(), col_token)
last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token]
return frame[col_token].isin(last_distinct_index)
return self._with_callable(func, "is_last_distinct")
def is_unique(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
return (
expr.to_frame()
.groupby(_name, dropna=False)
.transform("size", meta=(_name, int))
== 1
)
return self._with_callable(func, "is_unique")
def is_in(self, other: Any) -> Self:
return self._with_callable(lambda expr: expr.isin(other), "is_in")
def null_count(self) -> Self:
return self._with_callable(
lambda expr: expr.isna().sum().to_series(), "null_count"
)
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
# pandas is a required dependency of dask so it's safe to import this
from narwhals._pandas_like.group_by import PandasLikeGroupBy
if not partition_by:
assert order_by # noqa: S101
# This is something like `nw.col('a').cum_sum().order_by(key)`
# which we can always easily support, as it doesn't require grouping.
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
return self(df.sort(*order_by, descending=False, nulls_last=False))
elif not self._is_elementary(): # pragma: no cover
msg = (
"Only elementary expressions are supported for `.over` in dask.\n\n"
"Please see: "
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/"
)
raise NotImplementedError(msg)
elif order_by:
# Wrong results https://github.com/dask/dask/issues/11806.
msg = "`over` with `order_by` is not yet supported in Dask."
raise NotImplementedError(msg)
else:
function_name = PandasLikeGroupBy._leaf_name(self)
try:
dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name]
except KeyError:
# window functions are unsupported: https://github.com/dask/dask/issues/11806
msg = (
f"Unsupported function: {function_name} in `over` context.\n\n"
f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n"
)
raise NotImplementedError(msg) from None
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
with warnings.catch_warnings():
# https://github.com/dask/dask/issues/11804
warnings.filterwarnings(
"ignore",
message=".*`meta` is not specified",
category=UserWarning,
)
grouped = df.native.groupby(partition_by)
if dask_function_name == "size":
if len(output_names) != 1: # pragma: no cover
msg = "Safety check failed, please report a bug."
raise AssertionError(msg)
res_native = grouped.transform(
dask_function_name, **self._scalar_kwargs
).to_frame(output_names[0])
else:
res_native = grouped[list(output_names)].transform(
dask_function_name, **self._scalar_kwargs
)
result_frame = df._with_native(
res_native.rename(columns=dict(zip(output_names, aliases)))
).native
return [result_frame[name] for name in aliases]
return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)
def cast(self, dtype: IntoDType) -> Self:
def func(expr: dx.Series) -> dx.Series:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
return expr.astype(native_dtype)
return self._with_callable(func, "cast")
def is_finite(self) -> Self:
import dask.array as da
return self._with_callable(da.isfinite, "is_finite")
def log(self, base: float) -> Self:
import dask.array as da
def _log(expr: dx.Series) -> dx.Series:
return da.log(expr) / da.log(base)
return self._with_callable(_log, "log")
def exp(self) -> Self:
import dask.array as da
return self._with_callable(da.exp, "exp")
def sqrt(self) -> Self:
import dask.array as da
return self._with_callable(da.sqrt, "sqrt")
def mode(self, *, keep: ModeKeepStrategy) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
result = expr.to_frame().mode()[_name]
return result.head(1) if keep == "any" else result
return self._with_callable(func, "mode", scalar_kwargs={"keep": keep})
@property
def str(self) -> DaskExprStringNamespace:
return DaskExprStringNamespace(self)
@property
def dt(self) -> DaskExprDateTimeNamespace:
return DaskExprDateTimeNamespace(self)
arg_max: not_implemented = not_implemented()
arg_min: not_implemented = not_implemented()
arg_true: not_implemented = not_implemented()
ewm_mean: not_implemented = not_implemented()
gather_every: not_implemented = not_implemented()
head: not_implemented = not_implemented()
map_batches: not_implemented = not_implemented()
sample: not_implemented = not_implemented()
rank: not_implemented = not_implemented()
replace_strict: not_implemented = not_implemented()
sort: not_implemented = not_implemented()
tail: not_implemented = not_implemented()
# namespaces
list: not_implemented = not_implemented() # type: ignore[assignment]
cat: not_implemented = not_implemented() # type: ignore[assignment]
struct: not_implemented = not_implemented() # type: ignore[assignment]

View File

@ -0,0 +1,175 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import DateTimeNamespace
from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
from narwhals._duration import Interval
from narwhals._pandas_like.utils import (
ALIAS_DICT,
calculate_timestamp_date,
calculate_timestamp_datetime,
native_to_narwhals_dtype,
)
from narwhals._utils import Implementation
if TYPE_CHECKING:
import dask.dataframe.dask_expr as dx
from narwhals._dask.expr import DaskExpr
from narwhals.typing import TimeUnit
class DaskExprDateTimeNamespace(
LazyExprNamespace["DaskExpr"], DateTimeNamespace["DaskExpr"]
):
def date(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.date, "date")
def year(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.year, "year")
def month(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.month, "month")
def day(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.day, "day")
def hour(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.hour, "hour")
def minute(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.minute, "minute")
def second(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.second, "second")
def millisecond(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.microsecond // 1000, "millisecond"
)
def microsecond(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.microsecond, "microsecond"
)
def nanosecond(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond, "nanosecond"
)
def ordinal_day(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.dayofyear, "ordinal_day"
)
def weekday(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.weekday + 1, # Dask is 0-6
"weekday",
)
def to_string(self, format: str) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, format: expr.dt.strftime(format.replace("%.f", ".%f")),
"strftime",
format=format,
)
def replace_time_zone(self, time_zone: str | None) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, time_zone: expr.dt.tz_localize(None).dt.tz_localize(time_zone)
if time_zone is not None
else expr.dt.tz_localize(None),
"tz_localize",
time_zone=time_zone,
)
def convert_time_zone(self, time_zone: str) -> DaskExpr:
def func(s: dx.Series, time_zone: str) -> dx.Series:
dtype = native_to_narwhals_dtype(
s.dtype, self.compliant._version, Implementation.DASK
)
if dtype.time_zone is None: # type: ignore[attr-defined]
return s.dt.tz_localize("UTC").dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
return s.dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
return self.compliant._with_callable(func, "tz_convert", time_zone=time_zone)
# ignoring coverage due to https://github.com/narwhals-dev/narwhals/issues/2808.
def timestamp(self, time_unit: TimeUnit) -> DaskExpr: # pragma: no cover
def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series:
dtype = native_to_narwhals_dtype(
s.dtype, self.compliant._version, Implementation.DASK
)
is_pyarrow_dtype = "pyarrow" in str(dtype)
mask_na = s.isna()
dtypes = self.compliant._version.dtypes
if dtype == dtypes.Date:
# Date is only supported in pandas dtypes if pyarrow-backed
s_cast = s.astype("Int32[pyarrow]")
result = calculate_timestamp_date(s_cast, time_unit)
elif isinstance(dtype, dtypes.Datetime):
original_time_unit = dtype.time_unit
s_cast = (
s.astype("Int64[pyarrow]") if is_pyarrow_dtype else s.astype("int64")
)
result = calculate_timestamp_datetime(
s_cast, original_time_unit, time_unit
)
else:
msg = "Input should be either of Date or Datetime type"
raise TypeError(msg)
return result.where(~mask_na) # pyright: ignore[reportReturnType]
return self.compliant._with_callable(func, "datetime", time_unit=time_unit)
def total_minutes(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() // 60, "total_minutes"
)
def total_seconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() // 1, "total_seconds"
)
def total_milliseconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1,
"total_milliseconds",
)
def total_microseconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1,
"total_microseconds",
)
def total_nanoseconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1, "total_nanoseconds"
)
def truncate(self, every: str) -> DaskExpr:
interval = Interval.parse(every)
unit = interval.unit
if unit in {"mo", "q", "y"}:
msg = f"Truncating to {unit} is not yet supported for dask."
raise NotImplementedError(msg)
freq = f"{interval.multiple}{ALIAS_DICT.get(unit, unit)}"
return self.compliant._with_callable(lambda expr: expr.dt.floor(freq), "truncate")
def offset_by(self, by: str) -> DaskExpr:
def func(s: dx.Series, by: str) -> dx.Series:
interval = Interval.parse_no_constraints(by)
unit = interval.unit
if unit in {"y", "q", "mo", "d", "ns"}:
msg = f"Offsetting by {unit} is not yet supported for dask."
raise NotImplementedError(msg)
offset = interval.to_timedelta()
return s.add(offset)
return self.compliant._with_callable(func, "offset_by", by=by)

View File

@ -0,0 +1,121 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import dask.dataframe as dd
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import StringNamespace
from narwhals._utils import not_implemented
if TYPE_CHECKING:
import dask.dataframe.dask_expr as dx
from narwhals._dask.expr import DaskExpr
class DaskExprStringNamespace(LazyExprNamespace["DaskExpr"], StringNamespace["DaskExpr"]):
def len_chars(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.str.len(), "len")
def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> DaskExpr:
def _replace(
expr: dx.Series, pattern: str, value: str, *, literal: bool, n: int
) -> dx.Series:
try:
return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue]
pattern, value, regex=not literal, n=n
)
except TypeError as e:
if not isinstance(value, str):
msg = "dask backed `Expr.str.replace` only supports str replacement values"
raise TypeError(msg) from e
raise
return self.compliant._with_callable(
_replace, "replace", pattern=pattern, value=value, literal=literal, n=n
)
def replace_all(self, pattern: str, value: str, *, literal: bool) -> DaskExpr:
def _replace_all(
expr: dx.Series, pattern: str, value: str, *, literal: bool
) -> dx.Series:
try:
return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue]
pattern, value, regex=not literal, n=-1
)
except TypeError as e:
if not isinstance(value, str):
msg = "dask backed `Expr.str.replace_all` only supports str replacement values."
raise TypeError(msg) from e
raise
return self.compliant._with_callable(
_replace_all, "replace", pattern=pattern, value=value, literal=literal
)
def strip_chars(self, characters: str | None) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, characters: expr.str.strip(characters),
"strip",
characters=characters,
)
def starts_with(self, prefix: str) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, prefix: expr.str.startswith(prefix), "starts_with", prefix=prefix
)
def ends_with(self, suffix: str) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, suffix: expr.str.endswith(suffix), "ends_with", suffix=suffix
)
def contains(self, pattern: str, *, literal: bool) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, pattern, literal: expr.str.contains(
pat=pattern, regex=not literal
),
"contains",
pattern=pattern,
literal=literal,
)
def slice(self, offset: int, length: int | None) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, offset, length: expr.str.slice(
start=offset, stop=offset + length if length else None
),
"slice",
offset=offset,
length=length,
)
def split(self, by: str) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, by: expr.str.split(pat=by), "split", by=by
)
def to_datetime(self, format: str | None) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, format: dd.to_datetime(expr, format=format),
"to_datetime",
format=format,
)
def to_uppercase(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.str.upper(), "to_uppercase"
)
def to_lowercase(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.str.lower(), "to_lowercase"
)
def zfill(self, width: int) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, width: expr.str.zfill(width), "zfill", width=width
)
to_date = not_implemented()

View File

@ -0,0 +1,147 @@
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, ClassVar
import dask.dataframe as dd
from narwhals._compliant import DepthTrackingGroupBy
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import zip_strict
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
import pandas as pd
from dask.dataframe.api import GroupBy as _DaskGroupBy
from pandas.core.groupby import SeriesGroupBy as _PandasSeriesGroupBy
from typing_extensions import TypeAlias
from narwhals._compliant.typing import NarwhalsAggregation
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
PandasSeriesGroupBy: TypeAlias = _PandasSeriesGroupBy[Any, Any]
_AggFn: TypeAlias = Callable[..., Any]
else:
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx
_DaskGroupBy = dx._groupby.GroupBy
Aggregation: TypeAlias = "str | _AggFn"
"""The name of an aggregation function, or the function itself."""
def n_unique() -> dd.Aggregation:
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
return s.nunique(dropna=False)
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
return s0.sum()
return dd.Aggregation(name="nunique", chunk=chunk, agg=agg)
def _all() -> dd.Aggregation:
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
return s.all(skipna=True)
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
return s0.all(skipna=True)
return dd.Aggregation(name="all", chunk=chunk, agg=agg)
def _any() -> dd.Aggregation:
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
return s.any(skipna=True)
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
return s0.any(skipna=True)
return dd.Aggregation(name="any", chunk=chunk, agg=agg)
def var(ddof: int) -> _AggFn:
return partial(_DaskGroupBy.var, ddof=ddof)
def std(ddof: int) -> _AggFn:
return partial(_DaskGroupBy.std, ddof=ddof)
class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]):
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
"sum": "sum",
"mean": "mean",
"median": "median",
"max": "max",
"min": "min",
"std": std,
"var": var,
"len": "size",
"n_unique": n_unique,
"count": "count",
"quantile": "quantile",
"all": _all,
"any": _any,
}
def __init__(
self,
df: DaskLazyFrame,
keys: Sequence[DaskExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
self._compliant_frame, self._keys, self._output_key_names = self._parse_keys(
df, keys=keys
)
self._grouped = self.compliant.native.groupby(
self._keys, dropna=drop_null_keys, observed=True
)
def agg(self, *exprs: DaskExpr) -> DaskLazyFrame:
from narwhals._dask.dataframe import DaskLazyFrame
if not exprs:
# No aggregation provided
return (
self.compliant.simple_select(*self._keys)
.unique(self._keys, keep="any")
.rename(dict(zip(self._keys, self._output_key_names)))
)
self._ensure_all_simple(exprs)
# This should be the fastpath, but cuDF is too far behind to use it.
# - https://github.com/rapidsai/cudf/issues/15118
# - https://github.com/rapidsai/cudf/issues/15084
simple_aggregations: dict[str, tuple[str, Aggregation]] = {}
exclude = (*self._keys, *self._output_key_names)
for expr in exprs:
output_names, aliases = evaluate_output_names_and_aliases(
expr, self.compliant, exclude
)
if expr._depth == 0:
# e.g. `agg(nw.len())`
column = self._keys[0]
agg_fn = self._remap_expr_name(expr._function_name)
simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn)))
continue
# e.g. `agg(nw.mean('a'))`
agg_fn = self._remap_expr_name(self._leaf_name(expr))
# deal with n_unique case in a "lazy" mode to not depend on dask globally
agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn
simple_aggregations.update(
(alias, (output_name, agg_fn))
for alias, output_name in zip_strict(aliases, output_names)
)
return DaskLazyFrame(
self._grouped.agg(**simple_aggregations).reset_index(),
version=self.compliant._version,
).rename(dict(zip(self._keys, self._output_key_names)))

View File

@ -0,0 +1,338 @@
from __future__ import annotations
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, cast
import dask.dataframe as dd
import pandas as pd
from narwhals._compliant import (
CompliantThen,
CompliantWhen,
DepthTrackingNamespace,
LazyNamespace,
)
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
from narwhals._dask.utils import (
align_series_full_broadcast,
narwhals_to_native_dtype,
validate_comparand,
)
from narwhals._expression_parsing import (
ExprKind,
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._utils import Implementation, zip_strict
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
import dask.dataframe.dask_expr as dx
from narwhals._compliant.typing import ScalarKwargs
from narwhals._utils import Version
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral
class DaskNamespace(
LazyNamespace[DaskLazyFrame, DaskExpr, dd.DataFrame],
DepthTrackingNamespace[DaskLazyFrame, DaskExpr],
):
_implementation: Implementation = Implementation.DASK
@property
def selectors(self) -> DaskSelectorNamespace:
return DaskSelectorNamespace.from_namespace(self)
@property
def _expr(self) -> type[DaskExpr]:
return DaskExpr
@property
def _lazyframe(self) -> type[DaskLazyFrame]:
return DaskLazyFrame
def __init__(self, *, version: Version) -> None:
self._version = version
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
if dtype is not None:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
native_pd_series = pd.Series([value], dtype=native_dtype, name="literal")
else:
native_pd_series = pd.Series([value], name="literal")
npartitions = df._native_frame.npartitions
dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions)
return [dask_series[0].to_series()]
return self._expr(
func,
depth=0,
function_name="lit",
evaluate_output_names=lambda _df: ["literal"],
alias_output_names=None,
version=self._version,
)
def len(self) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
# We don't allow dataframes with 0 columns, so `[0]` is safe.
return [df._native_frame[df.columns[0]].size.to_series()]
return self._expr(
func,
depth=0,
function_name="len",
evaluate_output_names=lambda _df: ["len"],
alias_output_names=None,
version=self._version,
)
def all_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series: Iterator[dx.Series] = chain.from_iterable(e(df) for e in exprs)
# Note on `ignore_nulls`: Dask doesn't support storing arbitrary Python
# objects in `object` dtype, so we don't need the same check we have for pandas-like.
if ignore_nulls:
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
series = (s if s.dtype == "bool" else s.fillna(True) for s in series)
return [reduce(operator.and_, align_series_full_broadcast(df, *series))]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="all_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def any_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series: Iterator[dx.Series] = chain.from_iterable(e(df) for e in exprs)
if ignore_nulls:
series = (s if s.dtype == "bool" else s.fillna(False) for s in series)
return [reduce(operator.or_, align_series_full_broadcast(df, *series))]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="any_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def sum_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [dd.concat(series, axis=1).sum(axis=1)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="sum_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def concat(
self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod
) -> DaskLazyFrame:
if not items:
msg = "No items to concatenate" # pragma: no cover
raise AssertionError(msg)
dfs = [i._native_frame for i in items]
cols_0 = dfs[0].columns
if how == "vertical":
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.columns
if not (
(len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
):
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0.to_list()}\n"
f" - dataframe {i}: {cols_current.to_list()}\n"
)
raise TypeError(msg)
return DaskLazyFrame(
dd.concat(dfs, axis=0, join="inner"), version=self._version
)
if how == "diagonal":
return DaskLazyFrame(
dd.concat(dfs, axis=0, join="outer"), version=self._version
)
raise NotImplementedError
def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
expr_results = [s for _expr in exprs for s in _expr(df)]
series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results))
non_na = align_series_full_broadcast(
df, *(1 - s.isna() for s in expr_results)
)
num = reduce(lambda x, y: x + y, series) # pyright: ignore[reportOperatorIssue]
den = reduce(lambda x, y: x + y, non_na) # pyright: ignore[reportOperatorIssue]
return [cast("dx.Series", num / den)] # pyright: ignore[reportOperatorIssue]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="mean_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def min_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [dd.concat(series, axis=1).min(axis=1)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="min_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def max_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [dd.concat(series, axis=1).max(axis=1)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="max_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def when(self, predicate: DaskExpr) -> DaskWhen:
return DaskWhen.from_expr(predicate, context=self)
def concat_str(
self, *exprs: DaskExpr, separator: str, ignore_nulls: bool
) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
expr_results = [s for _expr in exprs for s in _expr(df)]
series = (
s.astype(str) for s in align_series_full_broadcast(df, *expr_results)
)
null_mask = [s.isna() for s in align_series_full_broadcast(df, *expr_results)]
if not ignore_nulls:
null_mask_result = reduce(operator.or_, null_mask)
result = reduce(lambda x, y: x + separator + y, series).where(
~null_mask_result, None
)
else:
init_value, *values = [
s.where(~nm, "") for s, nm in zip_strict(series, null_mask)
]
separators = (
nm.map({True: "", False: separator}, meta=str)
for nm in null_mask[:-1]
)
result = reduce(
operator.add,
(s + v for s, v in zip_strict(separators, values)),
init_value,
)
return [result]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="concat_str",
evaluate_output_names=getattr(
exprs[0], "_evaluate_output_names", lambda _df: ["literal"]
),
alias_output_names=getattr(exprs[0], "_alias_output_names", None),
version=self._version,
)
def coalesce(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [reduce(lambda x, y: x.fillna(y), series)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="coalesce",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): # pyright: ignore[reportInvalidTypeArguments]
@property
def _then(self) -> type[DaskThen]:
return DaskThen
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
then_value = (
self._then_value(df)[0]
if isinstance(self._then_value, DaskExpr)
else self._then_value
)
otherwise_value = (
self._otherwise_value(df)[0]
if isinstance(self._otherwise_value, DaskExpr)
else self._otherwise_value
)
condition = self._condition(df)[0]
# re-evaluate DataFrame if the condition aggregates to force
# then/otherwise to be evaluated against the aggregated frame
assert self._condition._metadata is not None # noqa: S101
if self._condition._metadata.is_scalar_like:
new_df = df._with_native(condition.to_frame())
condition = self._condition.broadcast(ExprKind.AGGREGATION)(df)[0]
df = new_df
if self._otherwise_value is None:
(condition, then_series) = align_series_full_broadcast(
df, condition, then_value
)
validate_comparand(condition, then_series)
return [then_series.where(condition)] # pyright: ignore[reportArgumentType]
(condition, then_series, otherwise_series) = align_series_full_broadcast(
df, condition, then_value, otherwise_value
)
validate_comparand(condition, then_series)
validate_comparand(condition, otherwise_series)
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr): # pyright: ignore[reportInvalidTypeArguments]
_depth: int = 0
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
_function_name: str = "whenthen"

View File

@ -0,0 +1,34 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
from narwhals._dask.expr import DaskExpr
if TYPE_CHECKING:
import dask.dataframe.dask_expr as dx # noqa: F401
from narwhals._compliant.typing import ScalarKwargs
from narwhals._dask.dataframe import DaskLazyFrame # noqa: F401
class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments]
@property
def _selector(self) -> type[DaskSelector]:
return DaskSelector
class DaskSelector(CompliantSelector["DaskLazyFrame", "dx.Series"], DaskExpr): # type: ignore[misc]
_depth: int = 0
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
_function_name: str = "selector"
def _to_expr(self) -> DaskExpr:
return DaskExpr(
self._call,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

View File

@ -0,0 +1,139 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals._utils import Implementation, Version, isinstance_or_issubclass
from narwhals.dependencies import get_pyarrow
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
import dask.dataframe as dd
import dask.dataframe.dask_expr as dx
from narwhals._dask.dataframe import DaskLazyFrame, Incomplete
from narwhals._dask.expr import DaskExpr
from narwhals.dtypes import DType
from narwhals.typing import IntoDType
else:
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx
def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series | object:
from narwhals._dask.expr import DaskExpr
if isinstance(obj, DaskExpr):
results = obj._call(df)
assert len(results) == 1 # debug assertion # noqa: S101
return results[0]
return obj
def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> list[tuple[str, dx.Series]]:
native_results: list[tuple[str, dx.Series]] = []
for expr in exprs:
native_series_list = expr(df)
aliases = expr._evaluate_aliases(df)
if len(aliases) != len(native_series_list): # pragma: no cover
msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.extend(zip(aliases, native_series_list))
return native_results
def align_series_full_broadcast(
df: DaskLazyFrame, *series: dx.Series | object
) -> Sequence[dx.Series]:
return [
s if isinstance(s, dx.Series) else df._native_frame.assign(_tmp=s)["_tmp"]
for s in series
] # pyright: ignore[reportReturnType]
def add_row_index(frame: dd.DataFrame, name: str) -> dd.DataFrame:
original_cols = frame.columns
df: Incomplete = frame.assign(**{name: 1})
return select_columns_by_name(
df.assign(**{name: df[name].cumsum(method="blelloch") - 1}),
[name, *original_cols],
Implementation.DASK,
)
def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:
if not dx.expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover
# are_co_aligned is a method which cheaply checks if two Dask expressions
# have the same index, and therefore don't require index alignment.
# If someone only operates on a Dask DataFrame via expressions, then this
# should always be the case: expression outputs (by definition) all come from the
# same input dataframe, and Dask Series does not have any operations which
# change the index. Nonetheless, we perform this safety check anyway.
# However, we still need to carefully vet which methods we support for Dask, to
# avoid issues where `are_co_aligned` doesn't do what we want it to do:
# https://github.com/dask/dask-expr/issues/1112.
msg = "Objects are not co-aligned, so this operation is not supported for Dask backend"
raise RuntimeError(msg)
dtypes = Version.MAIN.dtypes
dtypes_v1 = Version.V1.dtypes
NW_TO_DASK_DTYPES: Mapping[type[DType], str] = {
dtypes.Float64: "float64",
dtypes.Float32: "float32",
dtypes.Boolean: "bool",
dtypes.Categorical: "category",
dtypes.Date: "date32[day][pyarrow]",
dtypes.Int8: "int8",
dtypes.Int16: "int16",
dtypes.Int32: "int32",
dtypes.Int64: "int64",
dtypes.UInt8: "uint8",
dtypes.UInt16: "uint16",
dtypes.UInt32: "uint32",
dtypes.UInt64: "uint64",
dtypes.Datetime: "datetime64[us]",
dtypes.Duration: "timedelta64[ns]",
dtypes_v1.Datetime: "datetime64[us]",
dtypes_v1.Duration: "timedelta64[ns]",
}
UNSUPPORTED_DTYPES = (
dtypes.List,
dtypes.Struct,
dtypes.Array,
dtypes.Time,
dtypes.Binary,
)
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> Any:
dtypes = version.dtypes
base_type = dtype.base_type()
if dask_type := NW_TO_DASK_DTYPES.get(base_type):
return dask_type
if isinstance_or_issubclass(dtype, dtypes.String):
if Implementation.PANDAS._backend_version() >= (2, 0, 0):
return "string[pyarrow]" if get_pyarrow() else "string[python]"
return "object" # pragma: no cover
if isinstance_or_issubclass(dtype, dtypes.Enum):
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
raise NotImplementedError(msg)
if isinstance(dtype, dtypes.Enum):
import pandas as pd
# NOTE: `pandas-stubs.core.dtypes.dtypes.CategoricalDtype.categories` is too narrow
# Should be one of the `ListLike*` types
# https://github.com/pandas-dev/pandas-stubs/blob/8434bde95460b996323cc8c0fea7b0a8bb00ea26/pandas-stubs/_typing.pyi#L497-L505
return pd.CategoricalDtype(dtype.categories, ordered=True) # type: ignore[arg-type]
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)
if issubclass(base_type, UNSUPPORTED_DTYPES): # pragma: no cover
msg = f"Converting to {base_type.__name__} dtype is not supported for Dask."
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)