done
This commit is contained in:
502
lib/python3.11/site-packages/narwhals/_dask/dataframe.py
Normal file
502
lib/python3.11/site-packages/narwhals/_dask/dataframe.py
Normal file
@ -0,0 +1,502 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import dask.dataframe as dd
|
||||
|
||||
from narwhals._dask.utils import add_row_index, evaluate_exprs
|
||||
from narwhals._expression_parsing import ExprKind
|
||||
from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name
|
||||
from narwhals._typing_compat import assert_never
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
ValidateBackendVersion,
|
||||
_remap_full_join_keys,
|
||||
check_column_names_are_unique,
|
||||
check_columns_exist,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
zip_strict,
|
||||
)
|
||||
from narwhals.typing import CompliantLazyFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import dask.dataframe.dask_expr as dx
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals._dask.group_by import DaskLazyGroupBy
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
from narwhals._typing import _EagerAllowedImpl
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dataframe import LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.exceptions import ColumnNotFoundError
|
||||
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
|
||||
|
||||
Incomplete: TypeAlias = "Any"
|
||||
"""Using `_pandas_like` utils with `_dask`.
|
||||
|
||||
Typing this correctly will complicate the `_pandas_like`-side.
|
||||
Very low priority until `dask` adds typing.
|
||||
"""
|
||||
|
||||
|
||||
class DaskLazyFrame(
|
||||
CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"],
|
||||
ValidateBackendVersion,
|
||||
):
|
||||
_implementation = Implementation.DASK
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
native_dataframe: dd.DataFrame,
|
||||
*,
|
||||
version: Version,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
self._native_frame: dd.DataFrame = native_dataframe
|
||||
self._version = version
|
||||
self._cached_schema: dict[str, DType] | None = None
|
||||
self._cached_columns: list[str] | None = None
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: dd.DataFrame | Any) -> TypeIs[dd.DataFrame]:
|
||||
return isinstance(obj, dd.DataFrame)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: dd.DataFrame, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
def to_narwhals(self) -> LazyFrame[dd.DataFrame]:
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.DASK:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected dask, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def __narwhals_namespace__(self) -> DaskNamespace:
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
|
||||
return DaskNamespace(version=self._version)
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
def _with_native(self, df: Any) -> Self:
|
||||
return self.__class__(df, version=self._version)
|
||||
|
||||
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
|
||||
return check_columns_exist(subset, available=self.columns)
|
||||
|
||||
def _iter_columns(self) -> Iterator[dx.Series]:
|
||||
for _col, ser in self.native.items(): # noqa: PERF102
|
||||
yield ser
|
||||
|
||||
def with_columns(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
return self._with_native(self.native.assign(**dict(new_series)))
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
result = self.native.compute(**kwargs)
|
||||
|
||||
if backend is None or backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
result,
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
pl.from_pandas(result),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
if backend is Implementation.PYARROW:
|
||||
import pyarrow as pa # ignore-banned-import
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
pa.Table.from_pandas(result),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
if self._cached_columns is None:
|
||||
self._cached_columns = (
|
||||
list(self.schema)
|
||||
if self._cached_schema is not None
|
||||
else self.native.columns.tolist()
|
||||
)
|
||||
return self._cached_columns
|
||||
|
||||
def filter(self, predicate: DaskExpr) -> Self:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
mask = predicate(self)[0]
|
||||
return self._with_native(self.native.loc[mask])
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
df: Incomplete = self.native
|
||||
native = select_columns_by_name(df, list(column_names), self._implementation)
|
||||
return self._with_native(native)
|
||||
|
||||
def aggregate(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
df = dd.concat([val.rename(name) for name, val in new_series], axis=1)
|
||||
return self._with_native(df)
|
||||
|
||||
def select(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
df: Incomplete = self.native
|
||||
df = select_columns_by_name(
|
||||
df.assign(**dict(new_series)),
|
||||
[s[0] for s in new_series],
|
||||
self._implementation,
|
||||
)
|
||||
return self._with_native(df)
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
if subset is None:
|
||||
return self._with_native(self.native.dropna())
|
||||
plx = self.__narwhals_namespace__()
|
||||
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
|
||||
return self.filter(mask)
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
if self._cached_schema is None:
|
||||
native_dtypes = self.native.dtypes
|
||||
self._cached_schema = {
|
||||
col: native_to_narwhals_dtype(
|
||||
native_dtypes[col], self._version, self._implementation
|
||||
)
|
||||
for col in self.native.columns
|
||||
}
|
||||
return self._cached_schema
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return self.schema
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
|
||||
return self._with_native(self.native.drop(columns=to_drop))
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
|
||||
# Implementation is based on the following StackOverflow reply:
|
||||
# https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409
|
||||
if order_by is None:
|
||||
return self._with_native(add_row_index(self.native, name))
|
||||
plx = self.__narwhals_namespace__()
|
||||
columns = self.columns
|
||||
const_expr = plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL)
|
||||
row_index_expr = (
|
||||
plx.col(name).cum_sum(reverse=False).over(partition_by=[], order_by=order_by)
|
||||
- 1
|
||||
)
|
||||
return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns))
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
return self._with_native(self.native.rename(columns=mapping))
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
return self._with_native(self.native.head(n=n, compute=False, npartitions=-1))
|
||||
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self:
|
||||
if subset and (error := self._check_columns_exist(subset)):
|
||||
raise error
|
||||
if keep == "none":
|
||||
subset = subset or self.columns
|
||||
token = generate_temporary_column_name(n_bytes=8, columns=subset)
|
||||
ser = self.native.groupby(subset).size().rename(token)
|
||||
ser = ser[ser == 1]
|
||||
unique = ser.reset_index().drop(columns=token)
|
||||
result = self.native.merge(unique, on=subset, how="inner")
|
||||
else:
|
||||
mapped_keep = {"any": "first"}.get(keep, keep)
|
||||
result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
|
||||
return self._with_native(result)
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
ascending: bool | list[bool] = not descending
|
||||
else:
|
||||
ascending = [not d for d in descending]
|
||||
position = "last" if nulls_last else "first"
|
||||
return self._with_native(
|
||||
self.native.sort_values(list(by), ascending=ascending, na_position=position)
|
||||
)
|
||||
|
||||
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
|
||||
df = self.native
|
||||
schema = self.schema
|
||||
by = list(by)
|
||||
if isinstance(reverse, bool) and all(schema[x].is_numeric() for x in by):
|
||||
if reverse:
|
||||
return self._with_native(df.nsmallest(k, by))
|
||||
return self._with_native(df.nlargest(k, by))
|
||||
if isinstance(reverse, bool):
|
||||
reverse = [reverse] * len(by)
|
||||
return self._with_native(
|
||||
df.sort_values(by, ascending=list(reverse)).head(
|
||||
n=k, compute=False, npartitions=-1
|
||||
)
|
||||
)
|
||||
|
||||
def _join_inner(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
return self.native.merge(
|
||||
other.native,
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
how="inner",
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
|
||||
def _join_left(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
result_native = self.native.merge(
|
||||
other.native,
|
||||
how="left",
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
extra = [
|
||||
right_key if right_key not in self.columns else f"{right_key}{suffix}"
|
||||
for left_key, right_key in zip_strict(left_on, right_on)
|
||||
if right_key != left_key
|
||||
]
|
||||
return result_native.drop(columns=extra)
|
||||
|
||||
def _join_full(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
# dask does not retain keys post-join
|
||||
# we must append the suffix to each key before-hand
|
||||
|
||||
right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
|
||||
other_native = other.native.rename(columns=right_on_mapper)
|
||||
check_column_names_are_unique(other_native.columns)
|
||||
right_suffixed = list(right_on_mapper.values())
|
||||
return self.native.merge(
|
||||
other_native,
|
||||
left_on=left_on,
|
||||
right_on=right_suffixed,
|
||||
how="outer",
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
|
||||
def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame:
|
||||
key_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=(*self.columns, *other.columns)
|
||||
)
|
||||
return (
|
||||
self.native.assign(**{key_token: 0})
|
||||
.merge(
|
||||
other.native.assign(**{key_token: 0}),
|
||||
how="inner",
|
||||
left_on=key_token,
|
||||
right_on=key_token,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
.drop(columns=key_token)
|
||||
)
|
||||
|
||||
def _join_semi(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
|
||||
) -> dd.DataFrame:
|
||||
other_native = self._join_filter_rename(
|
||||
other=other,
|
||||
columns_to_select=list(right_on),
|
||||
columns_mapping=dict(zip(right_on, left_on)),
|
||||
)
|
||||
return self.native.merge(
|
||||
other_native, how="inner", left_on=left_on, right_on=left_on
|
||||
)
|
||||
|
||||
def _join_anti(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
|
||||
) -> dd.DataFrame:
|
||||
indicator_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=(*self.columns, *other.columns)
|
||||
)
|
||||
other_native = self._join_filter_rename(
|
||||
other=other,
|
||||
columns_to_select=list(right_on),
|
||||
columns_mapping=dict(zip(right_on, left_on)),
|
||||
)
|
||||
df = self.native.merge(
|
||||
other_native,
|
||||
how="left",
|
||||
indicator=indicator_token, # pyright: ignore[reportArgumentType]
|
||||
left_on=left_on,
|
||||
right_on=left_on,
|
||||
)
|
||||
return df[df[indicator_token] == "left_only"].drop(columns=[indicator_token])
|
||||
|
||||
def _join_filter_rename(
|
||||
self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str]
|
||||
) -> dd.DataFrame:
|
||||
"""Helper function to avoid creating extra columns and row duplication.
|
||||
|
||||
Used in `"anti"` and `"semi`" join's.
|
||||
|
||||
Notice that a native object is returned.
|
||||
"""
|
||||
other_native: Incomplete = other.native
|
||||
# rename to avoid creating extra columns in join
|
||||
return (
|
||||
select_columns_by_name(other_native, columns_to_select, self._implementation)
|
||||
.rename(columns=columns_mapping)
|
||||
.drop_duplicates()
|
||||
)
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
if how == "cross":
|
||||
result = self._join_cross(other=other, suffix=suffix)
|
||||
|
||||
elif left_on is None or right_on is None: # pragma: no cover
|
||||
raise ValueError(left_on, right_on)
|
||||
|
||||
elif how == "inner":
|
||||
result = self._join_inner(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
elif how == "anti":
|
||||
result = self._join_anti(other=other, left_on=left_on, right_on=right_on)
|
||||
elif how == "semi":
|
||||
result = self._join_semi(other=other, left_on=left_on, right_on=right_on)
|
||||
elif how == "left":
|
||||
result = self._join_left(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
elif how == "full":
|
||||
result = self._join_full(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
else:
|
||||
assert_never(how)
|
||||
return self._with_native(result)
|
||||
|
||||
def join_asof(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
left_on: str,
|
||||
right_on: str,
|
||||
by_left: Sequence[str] | None,
|
||||
by_right: Sequence[str] | None,
|
||||
strategy: AsofJoinStrategy,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
plx = self.__native_namespace__()
|
||||
return self._with_native(
|
||||
plx.merge_asof(
|
||||
self.native,
|
||||
other.native,
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
left_by=by_left,
|
||||
right_by=by_right,
|
||||
direction=strategy,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool
|
||||
) -> DaskLazyGroupBy:
|
||||
from narwhals._dask.group_by import DaskLazyGroupBy
|
||||
|
||||
return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def tail(self, n: int) -> Self: # pragma: no cover
|
||||
native_frame = self.native
|
||||
n_partitions = native_frame.npartitions
|
||||
|
||||
if n_partitions == 1:
|
||||
return self._with_native(self.native.tail(n=n, compute=False))
|
||||
msg = (
|
||||
"`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def gather_every(self, n: int, offset: int) -> Self:
|
||||
row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
||||
plx = self.__narwhals_namespace__()
|
||||
return (
|
||||
self.with_row_index(row_index_token, order_by=None)
|
||||
.filter(
|
||||
(plx.col(row_index_token) >= offset)
|
||||
& ((plx.col(row_index_token) - offset) % n == 0)
|
||||
)
|
||||
.drop([row_index_token], strict=False)
|
||||
)
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
return self._with_native(
|
||||
self.native.melt(
|
||||
id_vars=index,
|
||||
value_vars=on,
|
||||
var_name=variable_name,
|
||||
value_name=value_name,
|
||||
)
|
||||
)
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
self.native.to_parquet(file)
|
||||
|
||||
explode = not_implemented()
|
701
lib/python3.11/site-packages/narwhals/_dask/expr.py
Normal file
701
lib/python3.11/site-packages/narwhals/_dask/expr.py
Normal file
@ -0,0 +1,701 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from narwhals._compliant import DepthTrackingExpr, LazyExpr
|
||||
from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
|
||||
from narwhals._dask.expr_str import DaskExprStringNamespace
|
||||
from narwhals._dask.utils import (
|
||||
add_row_index,
|
||||
maybe_evaluate_expr,
|
||||
narwhals_to_native_dtype,
|
||||
)
|
||||
from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases
|
||||
from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
)
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import dask.dataframe.dask_expr as dx
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.typing import (
|
||||
FillNullStrategy,
|
||||
IntoDType,
|
||||
ModeKeepStrategy,
|
||||
NonNestedLiteral,
|
||||
NumericLiteral,
|
||||
RollingInterpolationMethod,
|
||||
TemporalLiteral,
|
||||
)
|
||||
|
||||
|
||||
class DaskExpr(
|
||||
LazyExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
|
||||
DepthTrackingExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
|
||||
):
|
||||
_implementation: Implementation = Implementation.DASK
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[DaskLazyFrame, dx.Series], # pyright: ignore[reportInvalidTypeForm]
|
||||
*,
|
||||
depth: int,
|
||||
function_name: str,
|
||||
evaluate_output_names: EvalNames[DaskLazyFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._depth = depth
|
||||
self._function_name = function_name
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._scalar_kwargs = scalar_kwargs or {}
|
||||
self._metadata: ExprMetadata | None = None
|
||||
|
||||
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
return self._call(df)
|
||||
|
||||
def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
|
||||
return DaskNamespace(version=self._version)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
# result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16
|
||||
# that raised a KeyError for result[0] during collection.
|
||||
return [result.loc[0][0] for result in self(df)]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=self._scalar_kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls: type[Self],
|
||||
evaluate_column_names: EvalNames[DaskLazyFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
function_name: str = "",
|
||||
) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
try:
|
||||
return [
|
||||
df._native_frame[column_name]
|
||||
for column_name in evaluate_column_names(df)
|
||||
]
|
||||
except KeyError as e:
|
||||
if error := df._check_columns_exist(evaluate_column_names(df)):
|
||||
raise error from e
|
||||
raise
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name=function_name,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
return [df.native.iloc[:, i] for i in column_indices]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="nth",
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
def _with_callable(
|
||||
self,
|
||||
# First argument to `call` should be `dx.Series`
|
||||
call: Callable[..., dx.Series],
|
||||
/,
|
||||
expr_name: str = "",
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
**expressifiable_args: Self | Any,
|
||||
) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
native_results: list[dx.Series] = []
|
||||
native_series_list = self._call(df)
|
||||
other_native_series = {
|
||||
key: maybe_evaluate_expr(df, value)
|
||||
for key, value in expressifiable_args.items()
|
||||
}
|
||||
for native_series in native_series_list:
|
||||
result_native = call(native_series, **other_native_series)
|
||||
native_results.append(result_native)
|
||||
return native_results
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=f"{self._function_name}->{expr_name}",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=scalar_kwargs,
|
||||
)
|
||||
|
||||
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
|
||||
current_alias_output_names = self._alias_output_names
|
||||
alias_output_names = (
|
||||
None
|
||||
if func is None
|
||||
else func
|
||||
if current_alias_output_names is None
|
||||
else lambda output_names: func(current_alias_output_names(output_names))
|
||||
)
|
||||
return type(self)(
|
||||
call=self._call,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=self._scalar_kwargs,
|
||||
)
|
||||
|
||||
def _with_binary(
|
||||
self,
|
||||
call: Callable[[dx.Series, Any], dx.Series],
|
||||
name: str,
|
||||
other: Any,
|
||||
*,
|
||||
reverse: bool = False,
|
||||
) -> Self:
|
||||
result = self._with_callable(
|
||||
lambda expr, other: call(expr, other), name, other=other
|
||||
)
|
||||
if reverse:
|
||||
result = result.alias("literal")
|
||||
return result
|
||||
|
||||
def _binary_op(self, op_name: str, other: Any) -> Self:
|
||||
return self._with_binary(
|
||||
lambda expr, other: getattr(expr, op_name)(other), op_name, other
|
||||
)
|
||||
|
||||
def _reverse_binary_op(
|
||||
self, op_name: str, operator_func: Callable[..., dx.Series], other: Any
|
||||
) -> Self:
|
||||
return self._with_binary(
|
||||
lambda expr, other: operator_func(other, expr), op_name, other, reverse=True
|
||||
)
|
||||
|
||||
def __add__(self, other: Any) -> Self:
|
||||
return self._binary_op("__add__", other)
|
||||
|
||||
def __sub__(self, other: Any) -> Self:
|
||||
return self._binary_op("__sub__", other)
|
||||
|
||||
def __mul__(self, other: Any) -> Self:
|
||||
return self._binary_op("__mul__", other)
|
||||
|
||||
def __truediv__(self, other: Any) -> Self:
|
||||
return self._binary_op("__truediv__", other)
|
||||
|
||||
def __floordiv__(self, other: Any) -> Self:
|
||||
return self._binary_op("__floordiv__", other)
|
||||
|
||||
def __pow__(self, other: Any) -> Self:
|
||||
return self._binary_op("__pow__", other)
|
||||
|
||||
def __mod__(self, other: Any) -> Self:
|
||||
return self._binary_op("__mod__", other)
|
||||
|
||||
def __eq__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._binary_op("__eq__", other)
|
||||
|
||||
def __ne__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._binary_op("__ne__", other)
|
||||
|
||||
def __ge__(self, other: Any) -> Self:
|
||||
return self._binary_op("__ge__", other)
|
||||
|
||||
def __gt__(self, other: Any) -> Self:
|
||||
return self._binary_op("__gt__", other)
|
||||
|
||||
def __le__(self, other: Any) -> Self:
|
||||
return self._binary_op("__le__", other)
|
||||
|
||||
def __lt__(self, other: Any) -> Self:
|
||||
return self._binary_op("__lt__", other)
|
||||
|
||||
def __and__(self, other: Any) -> Self:
|
||||
return self._binary_op("__and__", other)
|
||||
|
||||
def __or__(self, other: Any) -> Self:
|
||||
return self._binary_op("__or__", other)
|
||||
|
||||
def __rsub__(self, other: Any) -> Self:
|
||||
return self._reverse_binary_op("__rsub__", lambda a, b: a - b, other)
|
||||
|
||||
def __rtruediv__(self, other: Any) -> Self:
|
||||
return self._reverse_binary_op("__rtruediv__", lambda a, b: a / b, other)
|
||||
|
||||
def __rfloordiv__(self, other: Any) -> Self:
|
||||
return self._reverse_binary_op("__rfloordiv__", lambda a, b: a // b, other)
|
||||
|
||||
def __rpow__(self, other: Any) -> Self:
|
||||
return self._reverse_binary_op("__rpow__", lambda a, b: a**b, other)
|
||||
|
||||
def __rmod__(self, other: Any) -> Self:
|
||||
return self._reverse_binary_op("__rmod__", lambda a, b: a % b, other)
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.__invert__(), "__invert__")
|
||||
|
||||
def mean(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.mean().to_series(), "mean")
|
||||
|
||||
def median(self) -> Self:
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
def func(s: dx.Series) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK)
|
||||
if not dtype.is_numeric():
|
||||
msg = "`median` operation not supported for non-numeric input type."
|
||||
raise InvalidOperationError(msg)
|
||||
return s.median_approximate().to_series()
|
||||
|
||||
return self._with_callable(func, "median")
|
||||
|
||||
def min(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.min().to_series(), "min")
|
||||
|
||||
def max(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.max().to_series(), "max")
|
||||
|
||||
def std(self, ddof: int) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.std(ddof=ddof).to_series(),
|
||||
"std",
|
||||
scalar_kwargs={"ddof": ddof},
|
||||
)
|
||||
|
||||
def var(self, ddof: int) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.var(ddof=ddof).to_series(),
|
||||
"var",
|
||||
scalar_kwargs={"ddof": ddof},
|
||||
)
|
||||
|
||||
def skew(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.skew().to_series(), "skew")
|
||||
|
||||
def kurtosis(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis")
|
||||
|
||||
def shift(self, n: int) -> Self:
|
||||
return self._with_callable(lambda expr: expr.shift(n), "shift")
|
||||
|
||||
def cum_sum(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
# https://github.com/dask/dask/issues/11802
|
||||
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cumsum(), "cum_sum")
|
||||
|
||||
def cum_count(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(
|
||||
lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count"
|
||||
)
|
||||
|
||||
def cum_min(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cummin(), "cum_min")
|
||||
|
||||
def cum_max(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cummax(), "cum_max")
|
||||
|
||||
def cum_prod(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cumprod(), "cum_prod")
|
||||
|
||||
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).sum(),
|
||||
"rolling_sum",
|
||||
)
|
||||
|
||||
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).mean(),
|
||||
"rolling_mean",
|
||||
)
|
||||
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
if ddof == 1:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).var(),
|
||||
"rolling_var",
|
||||
)
|
||||
msg = "Dask backend only supports `ddof=1` for `rolling_var`"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
if ddof == 1:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).std(),
|
||||
"rolling_std",
|
||||
)
|
||||
msg = "Dask backend only supports `ddof=1` for `rolling_std`"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def sum(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.sum().to_series(), "sum")
|
||||
|
||||
def count(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.count().to_series(), "count")
|
||||
|
||||
def round(self, decimals: int) -> Self:
|
||||
return self._with_callable(lambda expr: expr.round(decimals), "round")
|
||||
|
||||
def unique(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.unique(), "unique")
|
||||
|
||||
def drop_nulls(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.dropna(), "drop_nulls")
|
||||
|
||||
def abs(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.abs(), "abs")
|
||||
|
||||
def all(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.all(
|
||||
axis=None, skipna=True, split_every=False, out=None
|
||||
).to_series(),
|
||||
"all",
|
||||
)
|
||||
|
||||
def any(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(),
|
||||
"any",
|
||||
)
|
||||
|
||||
def fill_nan(self, value: float | None) -> Self:
|
||||
value_nullable = pd.NA if value is None else value
|
||||
value_numpy = float("nan") if value is None else value
|
||||
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
# If/when pandas exposes an API which distinguishes NaN vs null, use that.
|
||||
mask = cast("dx.Series", expr != expr) # noqa: PLR0124
|
||||
mask = mask.fillna(False)
|
||||
fill = (
|
||||
value_nullable
|
||||
if get_dtype_backend(expr.dtype, self._implementation)
|
||||
else value_numpy
|
||||
)
|
||||
return expr.mask(mask, fill) # pyright: ignore[reportArgumentType]
|
||||
|
||||
return self._with_callable(func, "fill_nan")
|
||||
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
if value is not None:
|
||||
res_ser = expr.fillna(value)
|
||||
else:
|
||||
res_ser = (
|
||||
expr.ffill(limit=limit)
|
||||
if strategy == "forward"
|
||||
else expr.bfill(limit=limit)
|
||||
)
|
||||
return res_ser
|
||||
|
||||
return self._with_callable(func, "fill_null")
|
||||
|
||||
def clip(
|
||||
self,
|
||||
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, lower_bound, upper_bound: expr.clip(
|
||||
lower=lower_bound, upper=upper_bound
|
||||
),
|
||||
"clip",
|
||||
lower_bound=lower_bound,
|
||||
upper_bound=upper_bound,
|
||||
)
|
||||
|
||||
def diff(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.diff(), "diff")
|
||||
|
||||
def n_unique(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.nunique(dropna=False).to_series(), "n_unique"
|
||||
)
|
||||
|
||||
def is_null(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.isna(), "is_null")
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
expr.dtype, self._version, self._implementation
|
||||
)
|
||||
if dtype.is_numeric():
|
||||
return expr != expr # pyright: ignore[reportReturnType] # noqa: PLR0124
|
||||
msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?"
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
return self._with_callable(func, "is_null")
|
||||
|
||||
def len(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.size.to_series(), "len")
|
||||
|
||||
def quantile(
|
||||
self, quantile: float, interpolation: RollingInterpolationMethod
|
||||
) -> Self:
|
||||
if interpolation == "linear":
|
||||
|
||||
def func(expr: dx.Series, quantile: float) -> dx.Series:
|
||||
if expr.npartitions > 1:
|
||||
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
|
||||
raise NotImplementedError(msg)
|
||||
return expr.quantile(
|
||||
q=quantile, method="dask"
|
||||
).to_series() # pragma: no cover
|
||||
|
||||
return self._with_callable(func, "quantile", quantile=quantile)
|
||||
msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def is_first_distinct(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
|
||||
frame = add_row_index(expr.to_frame(), col_token)
|
||||
first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token]
|
||||
return frame[col_token].isin(first_distinct_index)
|
||||
|
||||
return self._with_callable(func, "is_first_distinct")
|
||||
|
||||
def is_last_distinct(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
|
||||
frame = add_row_index(expr.to_frame(), col_token)
|
||||
last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token]
|
||||
return frame[col_token].isin(last_distinct_index)
|
||||
|
||||
return self._with_callable(func, "is_last_distinct")
|
||||
|
||||
def is_unique(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
return (
|
||||
expr.to_frame()
|
||||
.groupby(_name, dropna=False)
|
||||
.transform("size", meta=(_name, int))
|
||||
== 1
|
||||
)
|
||||
|
||||
return self._with_callable(func, "is_unique")
|
||||
|
||||
def is_in(self, other: Any) -> Self:
|
||||
return self._with_callable(lambda expr: expr.isin(other), "is_in")
|
||||
|
||||
def null_count(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.isna().sum().to_series(), "null_count"
|
||||
)
|
||||
|
||||
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
|
||||
# pandas is a required dependency of dask so it's safe to import this
|
||||
from narwhals._pandas_like.group_by import PandasLikeGroupBy
|
||||
|
||||
if not partition_by:
|
||||
assert order_by # noqa: S101
|
||||
|
||||
# This is something like `nw.col('a').cum_sum().order_by(key)`
|
||||
# which we can always easily support, as it doesn't require grouping.
|
||||
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
return self(df.sort(*order_by, descending=False, nulls_last=False))
|
||||
elif not self._is_elementary(): # pragma: no cover
|
||||
msg = (
|
||||
"Only elementary expressions are supported for `.over` in dask.\n\n"
|
||||
"Please see: "
|
||||
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
elif order_by:
|
||||
# Wrong results https://github.com/dask/dask/issues/11806.
|
||||
msg = "`over` with `order_by` is not yet supported in Dask."
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
function_name = PandasLikeGroupBy._leaf_name(self)
|
||||
try:
|
||||
dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name]
|
||||
except KeyError:
|
||||
# window functions are unsupported: https://github.com/dask/dask/issues/11806
|
||||
msg = (
|
||||
f"Unsupported function: {function_name} in `over` context.\n\n"
|
||||
f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n"
|
||||
)
|
||||
raise NotImplementedError(msg) from None
|
||||
|
||||
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# https://github.com/dask/dask/issues/11804
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*`meta` is not specified",
|
||||
category=UserWarning,
|
||||
)
|
||||
grouped = df.native.groupby(partition_by)
|
||||
if dask_function_name == "size":
|
||||
if len(output_names) != 1: # pragma: no cover
|
||||
msg = "Safety check failed, please report a bug."
|
||||
raise AssertionError(msg)
|
||||
res_native = grouped.transform(
|
||||
dask_function_name, **self._scalar_kwargs
|
||||
).to_frame(output_names[0])
|
||||
else:
|
||||
res_native = grouped[list(output_names)].transform(
|
||||
dask_function_name, **self._scalar_kwargs
|
||||
)
|
||||
result_frame = df._with_native(
|
||||
res_native.rename(columns=dict(zip(output_names, aliases)))
|
||||
).native
|
||||
return [result_frame[name] for name in aliases]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=self._function_name + "->over",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version)
|
||||
return expr.astype(native_dtype)
|
||||
|
||||
return self._with_callable(func, "cast")
|
||||
|
||||
def is_finite(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.isfinite, "is_finite")
|
||||
|
||||
def log(self, base: float) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
def _log(expr: dx.Series) -> dx.Series:
|
||||
return da.log(expr) / da.log(base)
|
||||
|
||||
return self._with_callable(_log, "log")
|
||||
|
||||
def exp(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.exp, "exp")
|
||||
|
||||
def sqrt(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.sqrt, "sqrt")
|
||||
|
||||
def mode(self, *, keep: ModeKeepStrategy) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
result = expr.to_frame().mode()[_name]
|
||||
return result.head(1) if keep == "any" else result
|
||||
|
||||
return self._with_callable(func, "mode", scalar_kwargs={"keep": keep})
|
||||
|
||||
@property
|
||||
def str(self) -> DaskExprStringNamespace:
|
||||
return DaskExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def dt(self) -> DaskExprDateTimeNamespace:
|
||||
return DaskExprDateTimeNamespace(self)
|
||||
|
||||
arg_max: not_implemented = not_implemented()
|
||||
arg_min: not_implemented = not_implemented()
|
||||
arg_true: not_implemented = not_implemented()
|
||||
ewm_mean: not_implemented = not_implemented()
|
||||
gather_every: not_implemented = not_implemented()
|
||||
head: not_implemented = not_implemented()
|
||||
map_batches: not_implemented = not_implemented()
|
||||
sample: not_implemented = not_implemented()
|
||||
rank: not_implemented = not_implemented()
|
||||
replace_strict: not_implemented = not_implemented()
|
||||
sort: not_implemented = not_implemented()
|
||||
tail: not_implemented = not_implemented()
|
||||
|
||||
# namespaces
|
||||
list: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
cat: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
struct: not_implemented = not_implemented() # type: ignore[assignment]
|
175
lib/python3.11/site-packages/narwhals/_dask/expr_dt.py
Normal file
175
lib/python3.11/site-packages/narwhals/_dask/expr_dt.py
Normal file
@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import DateTimeNamespace
|
||||
from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._pandas_like.utils import (
|
||||
ALIAS_DICT,
|
||||
calculate_timestamp_date,
|
||||
calculate_timestamp_datetime,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._utils import Implementation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
|
||||
class DaskExprDateTimeNamespace(
|
||||
LazyExprNamespace["DaskExpr"], DateTimeNamespace["DaskExpr"]
|
||||
):
|
||||
def date(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.date, "date")
|
||||
|
||||
def year(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.year, "year")
|
||||
|
||||
def month(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.month, "month")
|
||||
|
||||
def day(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.day, "day")
|
||||
|
||||
def hour(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.hour, "hour")
|
||||
|
||||
def minute(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.minute, "minute")
|
||||
|
||||
def second(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.second, "second")
|
||||
|
||||
def millisecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond // 1000, "millisecond"
|
||||
)
|
||||
|
||||
def microsecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond, "microsecond"
|
||||
)
|
||||
|
||||
def nanosecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond, "nanosecond"
|
||||
)
|
||||
|
||||
def ordinal_day(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.dayofyear, "ordinal_day"
|
||||
)
|
||||
|
||||
def weekday(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.weekday + 1, # Dask is 0-6
|
||||
"weekday",
|
||||
)
|
||||
|
||||
def to_string(self, format: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, format: expr.dt.strftime(format.replace("%.f", ".%f")),
|
||||
"strftime",
|
||||
format=format,
|
||||
)
|
||||
|
||||
def replace_time_zone(self, time_zone: str | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, time_zone: expr.dt.tz_localize(None).dt.tz_localize(time_zone)
|
||||
if time_zone is not None
|
||||
else expr.dt.tz_localize(None),
|
||||
"tz_localize",
|
||||
time_zone=time_zone,
|
||||
)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> DaskExpr:
|
||||
def func(s: dx.Series, time_zone: str) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
s.dtype, self.compliant._version, Implementation.DASK
|
||||
)
|
||||
if dtype.time_zone is None: # type: ignore[attr-defined]
|
||||
return s.dt.tz_localize("UTC").dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
|
||||
return s.dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return self.compliant._with_callable(func, "tz_convert", time_zone=time_zone)
|
||||
|
||||
# ignoring coverage due to https://github.com/narwhals-dev/narwhals/issues/2808.
|
||||
def timestamp(self, time_unit: TimeUnit) -> DaskExpr: # pragma: no cover
|
||||
def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
s.dtype, self.compliant._version, Implementation.DASK
|
||||
)
|
||||
is_pyarrow_dtype = "pyarrow" in str(dtype)
|
||||
mask_na = s.isna()
|
||||
dtypes = self.compliant._version.dtypes
|
||||
if dtype == dtypes.Date:
|
||||
# Date is only supported in pandas dtypes if pyarrow-backed
|
||||
s_cast = s.astype("Int32[pyarrow]")
|
||||
result = calculate_timestamp_date(s_cast, time_unit)
|
||||
elif isinstance(dtype, dtypes.Datetime):
|
||||
original_time_unit = dtype.time_unit
|
||||
s_cast = (
|
||||
s.astype("Int64[pyarrow]") if is_pyarrow_dtype else s.astype("int64")
|
||||
)
|
||||
result = calculate_timestamp_datetime(
|
||||
s_cast, original_time_unit, time_unit
|
||||
)
|
||||
else:
|
||||
msg = "Input should be either of Date or Datetime type"
|
||||
raise TypeError(msg)
|
||||
return result.where(~mask_na) # pyright: ignore[reportReturnType]
|
||||
|
||||
return self.compliant._with_callable(func, "datetime", time_unit=time_unit)
|
||||
|
||||
def total_minutes(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() // 60, "total_minutes"
|
||||
)
|
||||
|
||||
def total_seconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() // 1, "total_seconds"
|
||||
)
|
||||
|
||||
def total_milliseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1,
|
||||
"total_milliseconds",
|
||||
)
|
||||
|
||||
def total_microseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1,
|
||||
"total_microseconds",
|
||||
)
|
||||
|
||||
def total_nanoseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1, "total_nanoseconds"
|
||||
)
|
||||
|
||||
def truncate(self, every: str) -> DaskExpr:
|
||||
interval = Interval.parse(every)
|
||||
unit = interval.unit
|
||||
if unit in {"mo", "q", "y"}:
|
||||
msg = f"Truncating to {unit} is not yet supported for dask."
|
||||
raise NotImplementedError(msg)
|
||||
freq = f"{interval.multiple}{ALIAS_DICT.get(unit, unit)}"
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.floor(freq), "truncate")
|
||||
|
||||
def offset_by(self, by: str) -> DaskExpr:
|
||||
def func(s: dx.Series, by: str) -> dx.Series:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
unit = interval.unit
|
||||
if unit in {"y", "q", "mo", "d", "ns"}:
|
||||
msg = f"Offsetting by {unit} is not yet supported for dask."
|
||||
raise NotImplementedError(msg)
|
||||
offset = interval.to_timedelta()
|
||||
return s.add(offset)
|
||||
|
||||
return self.compliant._with_callable(func, "offset_by", by=by)
|
121
lib/python3.11/site-packages/narwhals/_dask/expr_str.py
Normal file
121
lib/python3.11/site-packages/narwhals/_dask/expr_str.py
Normal file
@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import dask.dataframe as dd
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import StringNamespace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
|
||||
class DaskExprStringNamespace(LazyExprNamespace["DaskExpr"], StringNamespace["DaskExpr"]):
|
||||
def len_chars(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.str.len(), "len")
|
||||
|
||||
def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> DaskExpr:
|
||||
def _replace(
|
||||
expr: dx.Series, pattern: str, value: str, *, literal: bool, n: int
|
||||
) -> dx.Series:
|
||||
try:
|
||||
return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue]
|
||||
pattern, value, regex=not literal, n=n
|
||||
)
|
||||
except TypeError as e:
|
||||
if not isinstance(value, str):
|
||||
msg = "dask backed `Expr.str.replace` only supports str replacement values"
|
||||
raise TypeError(msg) from e
|
||||
raise
|
||||
|
||||
return self.compliant._with_callable(
|
||||
_replace, "replace", pattern=pattern, value=value, literal=literal, n=n
|
||||
)
|
||||
|
||||
def replace_all(self, pattern: str, value: str, *, literal: bool) -> DaskExpr:
|
||||
def _replace_all(
|
||||
expr: dx.Series, pattern: str, value: str, *, literal: bool
|
||||
) -> dx.Series:
|
||||
try:
|
||||
return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue]
|
||||
pattern, value, regex=not literal, n=-1
|
||||
)
|
||||
except TypeError as e:
|
||||
if not isinstance(value, str):
|
||||
msg = "dask backed `Expr.str.replace_all` only supports str replacement values."
|
||||
raise TypeError(msg) from e
|
||||
raise
|
||||
|
||||
return self.compliant._with_callable(
|
||||
_replace_all, "replace", pattern=pattern, value=value, literal=literal
|
||||
)
|
||||
|
||||
def strip_chars(self, characters: str | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, characters: expr.str.strip(characters),
|
||||
"strip",
|
||||
characters=characters,
|
||||
)
|
||||
|
||||
def starts_with(self, prefix: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, prefix: expr.str.startswith(prefix), "starts_with", prefix=prefix
|
||||
)
|
||||
|
||||
def ends_with(self, suffix: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, suffix: expr.str.endswith(suffix), "ends_with", suffix=suffix
|
||||
)
|
||||
|
||||
def contains(self, pattern: str, *, literal: bool) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, pattern, literal: expr.str.contains(
|
||||
pat=pattern, regex=not literal
|
||||
),
|
||||
"contains",
|
||||
pattern=pattern,
|
||||
literal=literal,
|
||||
)
|
||||
|
||||
def slice(self, offset: int, length: int | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, offset, length: expr.str.slice(
|
||||
start=offset, stop=offset + length if length else None
|
||||
),
|
||||
"slice",
|
||||
offset=offset,
|
||||
length=length,
|
||||
)
|
||||
|
||||
def split(self, by: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, by: expr.str.split(pat=by), "split", by=by
|
||||
)
|
||||
|
||||
def to_datetime(self, format: str | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, format: dd.to_datetime(expr, format=format),
|
||||
"to_datetime",
|
||||
format=format,
|
||||
)
|
||||
|
||||
def to_uppercase(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.str.upper(), "to_uppercase"
|
||||
)
|
||||
|
||||
def to_lowercase(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.str.lower(), "to_lowercase"
|
||||
)
|
||||
|
||||
def zfill(self, width: int) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, width: expr.str.zfill(width), "zfill", width=width
|
||||
)
|
||||
|
||||
to_date = not_implemented()
|
147
lib/python3.11/site-packages/narwhals/_dask/group_by.py
Normal file
147
lib/python3.11/site-packages/narwhals/_dask/group_by.py
Normal file
@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
||||
|
||||
import dask.dataframe as dd
|
||||
|
||||
from narwhals._compliant import DepthTrackingGroupBy
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
from narwhals._utils import zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
import pandas as pd
|
||||
from dask.dataframe.api import GroupBy as _DaskGroupBy
|
||||
from pandas.core.groupby import SeriesGroupBy as _PandasSeriesGroupBy
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.typing import NarwhalsAggregation
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
PandasSeriesGroupBy: TypeAlias = _PandasSeriesGroupBy[Any, Any]
|
||||
_AggFn: TypeAlias = Callable[..., Any]
|
||||
|
||||
else:
|
||||
try:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
import dask_expr as dx
|
||||
_DaskGroupBy = dx._groupby.GroupBy
|
||||
|
||||
Aggregation: TypeAlias = "str | _AggFn"
|
||||
"""The name of an aggregation function, or the function itself."""
|
||||
|
||||
|
||||
def n_unique() -> dd.Aggregation:
|
||||
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s.nunique(dropna=False)
|
||||
|
||||
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s0.sum()
|
||||
|
||||
return dd.Aggregation(name="nunique", chunk=chunk, agg=agg)
|
||||
|
||||
|
||||
def _all() -> dd.Aggregation:
|
||||
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s.all(skipna=True)
|
||||
|
||||
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s0.all(skipna=True)
|
||||
|
||||
return dd.Aggregation(name="all", chunk=chunk, agg=agg)
|
||||
|
||||
|
||||
def _any() -> dd.Aggregation:
|
||||
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s.any(skipna=True)
|
||||
|
||||
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s0.any(skipna=True)
|
||||
|
||||
return dd.Aggregation(name="any", chunk=chunk, agg=agg)
|
||||
|
||||
|
||||
def var(ddof: int) -> _AggFn:
|
||||
return partial(_DaskGroupBy.var, ddof=ddof)
|
||||
|
||||
|
||||
def std(ddof: int) -> _AggFn:
|
||||
return partial(_DaskGroupBy.std, ddof=ddof)
|
||||
|
||||
|
||||
class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]):
|
||||
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
|
||||
"sum": "sum",
|
||||
"mean": "mean",
|
||||
"median": "median",
|
||||
"max": "max",
|
||||
"min": "min",
|
||||
"std": std,
|
||||
"var": var,
|
||||
"len": "size",
|
||||
"n_unique": n_unique,
|
||||
"count": "count",
|
||||
"quantile": "quantile",
|
||||
"all": _all,
|
||||
"any": _any,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: DaskLazyFrame,
|
||||
keys: Sequence[DaskExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._compliant_frame, self._keys, self._output_key_names = self._parse_keys(
|
||||
df, keys=keys
|
||||
)
|
||||
self._grouped = self.compliant.native.groupby(
|
||||
self._keys, dropna=drop_null_keys, observed=True
|
||||
)
|
||||
|
||||
def agg(self, *exprs: DaskExpr) -> DaskLazyFrame:
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
|
||||
if not exprs:
|
||||
# No aggregation provided
|
||||
return (
|
||||
self.compliant.simple_select(*self._keys)
|
||||
.unique(self._keys, keep="any")
|
||||
.rename(dict(zip(self._keys, self._output_key_names)))
|
||||
)
|
||||
|
||||
self._ensure_all_simple(exprs)
|
||||
# This should be the fastpath, but cuDF is too far behind to use it.
|
||||
# - https://github.com/rapidsai/cudf/issues/15118
|
||||
# - https://github.com/rapidsai/cudf/issues/15084
|
||||
simple_aggregations: dict[str, tuple[str, Aggregation]] = {}
|
||||
exclude = (*self._keys, *self._output_key_names)
|
||||
for expr in exprs:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(
|
||||
expr, self.compliant, exclude
|
||||
)
|
||||
if expr._depth == 0:
|
||||
# e.g. `agg(nw.len())`
|
||||
column = self._keys[0]
|
||||
agg_fn = self._remap_expr_name(expr._function_name)
|
||||
simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn)))
|
||||
continue
|
||||
|
||||
# e.g. `agg(nw.mean('a'))`
|
||||
agg_fn = self._remap_expr_name(self._leaf_name(expr))
|
||||
# deal with n_unique case in a "lazy" mode to not depend on dask globally
|
||||
agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn
|
||||
simple_aggregations.update(
|
||||
(alias, (output_name, agg_fn))
|
||||
for alias, output_name in zip_strict(aliases, output_names)
|
||||
)
|
||||
return DaskLazyFrame(
|
||||
self._grouped.agg(**simple_aggregations).reset_index(),
|
||||
version=self.compliant._version,
|
||||
).rename(dict(zip(self._keys, self._output_key_names)))
|
338
lib/python3.11/site-packages/narwhals/_dask/namespace.py
Normal file
338
lib/python3.11/site-packages/narwhals/_dask/namespace.py
Normal file
@ -0,0 +1,338 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import dask.dataframe as dd
|
||||
import pandas as pd
|
||||
|
||||
from narwhals._compliant import (
|
||||
CompliantThen,
|
||||
CompliantWhen,
|
||||
DepthTrackingNamespace,
|
||||
LazyNamespace,
|
||||
)
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals._dask.selectors import DaskSelectorNamespace
|
||||
from narwhals._dask.utils import (
|
||||
align_series_full_broadcast,
|
||||
narwhals_to_native_dtype,
|
||||
validate_comparand,
|
||||
)
|
||||
from narwhals._expression_parsing import (
|
||||
ExprKind,
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._utils import Implementation, zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral
|
||||
|
||||
|
||||
class DaskNamespace(
|
||||
LazyNamespace[DaskLazyFrame, DaskExpr, dd.DataFrame],
|
||||
DepthTrackingNamespace[DaskLazyFrame, DaskExpr],
|
||||
):
|
||||
_implementation: Implementation = Implementation.DASK
|
||||
|
||||
@property
|
||||
def selectors(self) -> DaskSelectorNamespace:
|
||||
return DaskSelectorNamespace.from_namespace(self)
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[DaskExpr]:
|
||||
return DaskExpr
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[DaskLazyFrame]:
|
||||
return DaskLazyFrame
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
if dtype is not None:
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version)
|
||||
native_pd_series = pd.Series([value], dtype=native_dtype, name="literal")
|
||||
else:
|
||||
native_pd_series = pd.Series([value], name="literal")
|
||||
npartitions = df._native_frame.npartitions
|
||||
dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions)
|
||||
return [dask_series[0].to_series()]
|
||||
|
||||
return self._expr(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="lit",
|
||||
evaluate_output_names=lambda _df: ["literal"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def len(self) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
# We don't allow dataframes with 0 columns, so `[0]` is safe.
|
||||
return [df._native_frame[df.columns[0]].size.to_series()]
|
||||
|
||||
return self._expr(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="len",
|
||||
evaluate_output_names=lambda _df: ["len"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def all_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series: Iterator[dx.Series] = chain.from_iterable(e(df) for e in exprs)
|
||||
# Note on `ignore_nulls`: Dask doesn't support storing arbitrary Python
|
||||
# objects in `object` dtype, so we don't need the same check we have for pandas-like.
|
||||
if ignore_nulls:
|
||||
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
|
||||
series = (s if s.dtype == "bool" else s.fillna(True) for s in series)
|
||||
return [reduce(operator.and_, align_series_full_broadcast(df, *series))]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="all_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def any_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series: Iterator[dx.Series] = chain.from_iterable(e(df) for e in exprs)
|
||||
if ignore_nulls:
|
||||
series = (s if s.dtype == "bool" else s.fillna(False) for s in series)
|
||||
return [reduce(operator.or_, align_series_full_broadcast(df, *series))]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="any_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def sum_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
return [dd.concat(series, axis=1).sum(axis=1)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="sum_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def concat(
|
||||
self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod
|
||||
) -> DaskLazyFrame:
|
||||
if not items:
|
||||
msg = "No items to concatenate" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
dfs = [i._native_frame for i in items]
|
||||
cols_0 = dfs[0].columns
|
||||
if how == "vertical":
|
||||
for i, df in enumerate(dfs[1:], start=1):
|
||||
cols_current = df.columns
|
||||
if not (
|
||||
(len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
|
||||
):
|
||||
msg = (
|
||||
"unable to vstack, column names don't match:\n"
|
||||
f" - dataframe 0: {cols_0.to_list()}\n"
|
||||
f" - dataframe {i}: {cols_current.to_list()}\n"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
return DaskLazyFrame(
|
||||
dd.concat(dfs, axis=0, join="inner"), version=self._version
|
||||
)
|
||||
if how == "diagonal":
|
||||
return DaskLazyFrame(
|
||||
dd.concat(dfs, axis=0, join="outer"), version=self._version
|
||||
)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
expr_results = [s for _expr in exprs for s in _expr(df)]
|
||||
series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results))
|
||||
non_na = align_series_full_broadcast(
|
||||
df, *(1 - s.isna() for s in expr_results)
|
||||
)
|
||||
num = reduce(lambda x, y: x + y, series) # pyright: ignore[reportOperatorIssue]
|
||||
den = reduce(lambda x, y: x + y, non_na) # pyright: ignore[reportOperatorIssue]
|
||||
return [cast("dx.Series", num / den)] # pyright: ignore[reportOperatorIssue]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="mean_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def min_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
|
||||
return [dd.concat(series, axis=1).min(axis=1)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="min_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def max_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
|
||||
return [dd.concat(series, axis=1).max(axis=1)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="max_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def when(self, predicate: DaskExpr) -> DaskWhen:
|
||||
return DaskWhen.from_expr(predicate, context=self)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: DaskExpr, separator: str, ignore_nulls: bool
|
||||
) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
expr_results = [s for _expr in exprs for s in _expr(df)]
|
||||
series = (
|
||||
s.astype(str) for s in align_series_full_broadcast(df, *expr_results)
|
||||
)
|
||||
null_mask = [s.isna() for s in align_series_full_broadcast(df, *expr_results)]
|
||||
|
||||
if not ignore_nulls:
|
||||
null_mask_result = reduce(operator.or_, null_mask)
|
||||
result = reduce(lambda x, y: x + separator + y, series).where(
|
||||
~null_mask_result, None
|
||||
)
|
||||
else:
|
||||
init_value, *values = [
|
||||
s.where(~nm, "") for s, nm in zip_strict(series, null_mask)
|
||||
]
|
||||
|
||||
separators = (
|
||||
nm.map({True: "", False: separator}, meta=str)
|
||||
for nm in null_mask[:-1]
|
||||
)
|
||||
result = reduce(
|
||||
operator.add,
|
||||
(s + v for s, v in zip_strict(separators, values)),
|
||||
init_value,
|
||||
)
|
||||
|
||||
return [result]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="concat_str",
|
||||
evaluate_output_names=getattr(
|
||||
exprs[0], "_evaluate_output_names", lambda _df: ["literal"]
|
||||
),
|
||||
alias_output_names=getattr(exprs[0], "_alias_output_names", None),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def coalesce(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
return [reduce(lambda x, y: x.fillna(y), series)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="coalesce",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
|
||||
class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): # pyright: ignore[reportInvalidTypeArguments]
|
||||
@property
|
||||
def _then(self) -> type[DaskThen]:
|
||||
return DaskThen
|
||||
|
||||
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
then_value = (
|
||||
self._then_value(df)[0]
|
||||
if isinstance(self._then_value, DaskExpr)
|
||||
else self._then_value
|
||||
)
|
||||
otherwise_value = (
|
||||
self._otherwise_value(df)[0]
|
||||
if isinstance(self._otherwise_value, DaskExpr)
|
||||
else self._otherwise_value
|
||||
)
|
||||
|
||||
condition = self._condition(df)[0]
|
||||
# re-evaluate DataFrame if the condition aggregates to force
|
||||
# then/otherwise to be evaluated against the aggregated frame
|
||||
assert self._condition._metadata is not None # noqa: S101
|
||||
if self._condition._metadata.is_scalar_like:
|
||||
new_df = df._with_native(condition.to_frame())
|
||||
condition = self._condition.broadcast(ExprKind.AGGREGATION)(df)[0]
|
||||
df = new_df
|
||||
|
||||
if self._otherwise_value is None:
|
||||
(condition, then_series) = align_series_full_broadcast(
|
||||
df, condition, then_value
|
||||
)
|
||||
validate_comparand(condition, then_series)
|
||||
return [then_series.where(condition)] # pyright: ignore[reportArgumentType]
|
||||
(condition, then_series, otherwise_series) = align_series_full_broadcast(
|
||||
df, condition, then_value, otherwise_value
|
||||
)
|
||||
validate_comparand(condition, then_series)
|
||||
validate_comparand(condition, otherwise_series)
|
||||
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr): # pyright: ignore[reportInvalidTypeArguments]
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "whenthen"
|
34
lib/python3.11/site-packages/narwhals/_dask/selectors.py
Normal file
34
lib/python3.11/site-packages/narwhals/_dask/selectors.py
Normal file
@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask.dataframe.dask_expr as dx # noqa: F401
|
||||
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
from narwhals._dask.dataframe import DaskLazyFrame # noqa: F401
|
||||
|
||||
|
||||
class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments]
|
||||
@property
|
||||
def _selector(self) -> type[DaskSelector]:
|
||||
return DaskSelector
|
||||
|
||||
|
||||
class DaskSelector(CompliantSelector["DaskLazyFrame", "dx.Series"], DaskExpr): # type: ignore[misc]
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "selector"
|
||||
|
||||
def _to_expr(self) -> DaskExpr:
|
||||
return DaskExpr(
|
||||
self._call,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
139
lib/python3.11/site-packages/narwhals/_dask/utils.py
Normal file
139
lib/python3.11/site-packages/narwhals/_dask/utils.py
Normal file
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from narwhals._pandas_like.utils import select_columns_by_name
|
||||
from narwhals._utils import Implementation, Version, isinstance_or_issubclass
|
||||
from narwhals.dependencies import get_pyarrow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
import dask.dataframe as dd
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._dask.dataframe import DaskLazyFrame, Incomplete
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType
|
||||
else:
|
||||
try:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
import dask_expr as dx
|
||||
|
||||
|
||||
def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series | object:
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
if isinstance(obj, DaskExpr):
|
||||
results = obj._call(df)
|
||||
assert len(results) == 1 # debug assertion # noqa: S101
|
||||
return results[0]
|
||||
return obj
|
||||
|
||||
|
||||
def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> list[tuple[str, dx.Series]]:
|
||||
native_results: list[tuple[str, dx.Series]] = []
|
||||
for expr in exprs:
|
||||
native_series_list = expr(df)
|
||||
aliases = expr._evaluate_aliases(df)
|
||||
if len(aliases) != len(native_series_list): # pragma: no cover
|
||||
msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results"
|
||||
raise AssertionError(msg)
|
||||
native_results.extend(zip(aliases, native_series_list))
|
||||
return native_results
|
||||
|
||||
|
||||
def align_series_full_broadcast(
|
||||
df: DaskLazyFrame, *series: dx.Series | object
|
||||
) -> Sequence[dx.Series]:
|
||||
return [
|
||||
s if isinstance(s, dx.Series) else df._native_frame.assign(_tmp=s)["_tmp"]
|
||||
for s in series
|
||||
] # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
def add_row_index(frame: dd.DataFrame, name: str) -> dd.DataFrame:
|
||||
original_cols = frame.columns
|
||||
df: Incomplete = frame.assign(**{name: 1})
|
||||
return select_columns_by_name(
|
||||
df.assign(**{name: df[name].cumsum(method="blelloch") - 1}),
|
||||
[name, *original_cols],
|
||||
Implementation.DASK,
|
||||
)
|
||||
|
||||
|
||||
def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:
|
||||
if not dx.expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover
|
||||
# are_co_aligned is a method which cheaply checks if two Dask expressions
|
||||
# have the same index, and therefore don't require index alignment.
|
||||
# If someone only operates on a Dask DataFrame via expressions, then this
|
||||
# should always be the case: expression outputs (by definition) all come from the
|
||||
# same input dataframe, and Dask Series does not have any operations which
|
||||
# change the index. Nonetheless, we perform this safety check anyway.
|
||||
|
||||
# However, we still need to carefully vet which methods we support for Dask, to
|
||||
# avoid issues where `are_co_aligned` doesn't do what we want it to do:
|
||||
# https://github.com/dask/dask-expr/issues/1112.
|
||||
msg = "Objects are not co-aligned, so this operation is not supported for Dask backend"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
dtypes_v1 = Version.V1.dtypes
|
||||
NW_TO_DASK_DTYPES: Mapping[type[DType], str] = {
|
||||
dtypes.Float64: "float64",
|
||||
dtypes.Float32: "float32",
|
||||
dtypes.Boolean: "bool",
|
||||
dtypes.Categorical: "category",
|
||||
dtypes.Date: "date32[day][pyarrow]",
|
||||
dtypes.Int8: "int8",
|
||||
dtypes.Int16: "int16",
|
||||
dtypes.Int32: "int32",
|
||||
dtypes.Int64: "int64",
|
||||
dtypes.UInt8: "uint8",
|
||||
dtypes.UInt16: "uint16",
|
||||
dtypes.UInt32: "uint32",
|
||||
dtypes.UInt64: "uint64",
|
||||
dtypes.Datetime: "datetime64[us]",
|
||||
dtypes.Duration: "timedelta64[ns]",
|
||||
dtypes_v1.Datetime: "datetime64[us]",
|
||||
dtypes_v1.Duration: "timedelta64[ns]",
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (
|
||||
dtypes.List,
|
||||
dtypes.Struct,
|
||||
dtypes.Array,
|
||||
dtypes.Time,
|
||||
dtypes.Binary,
|
||||
)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> Any:
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if dask_type := NW_TO_DASK_DTYPES.get(base_type):
|
||||
return dask_type
|
||||
if isinstance_or_issubclass(dtype, dtypes.String):
|
||||
if Implementation.PANDAS._backend_version() >= (2, 0, 0):
|
||||
return "string[pyarrow]" if get_pyarrow() else "string[python]"
|
||||
return "object" # pragma: no cover
|
||||
if isinstance_or_issubclass(dtype, dtypes.Enum):
|
||||
if version is Version.V1:
|
||||
msg = "Converting to Enum is not supported in narwhals.stable.v1"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance(dtype, dtypes.Enum):
|
||||
import pandas as pd
|
||||
|
||||
# NOTE: `pandas-stubs.core.dtypes.dtypes.CategoricalDtype.categories` is too narrow
|
||||
# Should be one of the `ListLike*` types
|
||||
# https://github.com/pandas-dev/pandas-stubs/blob/8434bde95460b996323cc8c0fea7b0a8bb00ea26/pandas-stubs/_typing.pyi#L497-L505
|
||||
return pd.CategoricalDtype(dtype.categories, ordered=True) # type: ignore[arg-type]
|
||||
msg = "Can not cast / initialize Enum without categories present"
|
||||
raise ValueError(msg)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES): # pragma: no cover
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for Dask."
|
||||
raise NotImplementedError(msg)
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
Reference in New Issue
Block a user