done
This commit is contained in:
		
							
								
								
									
										502
									
								
								lib/python3.11/site-packages/narwhals/_dask/dataframe.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										502
									
								
								lib/python3.11/site-packages/narwhals/_dask/dataframe.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,502 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import TYPE_CHECKING, Any | ||||
|  | ||||
| import dask.dataframe as dd | ||||
|  | ||||
| from narwhals._dask.utils import add_row_index, evaluate_exprs | ||||
| from narwhals._expression_parsing import ExprKind | ||||
| from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name | ||||
| from narwhals._typing_compat import assert_never | ||||
| from narwhals._utils import ( | ||||
|     Implementation, | ||||
|     ValidateBackendVersion, | ||||
|     _remap_full_join_keys, | ||||
|     check_column_names_are_unique, | ||||
|     check_columns_exist, | ||||
|     generate_temporary_column_name, | ||||
|     not_implemented, | ||||
|     parse_columns_to_drop, | ||||
|     zip_strict, | ||||
| ) | ||||
| from narwhals.typing import CompliantLazyFrame | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from collections.abc import Iterable, Iterator, Mapping, Sequence | ||||
|     from io import BytesIO | ||||
|     from pathlib import Path | ||||
|     from types import ModuleType | ||||
|  | ||||
|     import dask.dataframe.dask_expr as dx | ||||
|     from typing_extensions import Self, TypeAlias, TypeIs | ||||
|  | ||||
|     from narwhals._compliant.typing import CompliantDataFrameAny | ||||
|     from narwhals._dask.expr import DaskExpr | ||||
|     from narwhals._dask.group_by import DaskLazyGroupBy | ||||
|     from narwhals._dask.namespace import DaskNamespace | ||||
|     from narwhals._typing import _EagerAllowedImpl | ||||
|     from narwhals._utils import Version, _LimitedContext | ||||
|     from narwhals.dataframe import LazyFrame | ||||
|     from narwhals.dtypes import DType | ||||
|     from narwhals.exceptions import ColumnNotFoundError | ||||
|     from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy | ||||
|  | ||||
| Incomplete: TypeAlias = "Any" | ||||
| """Using `_pandas_like` utils with `_dask`. | ||||
|  | ||||
| Typing this correctly will complicate the `_pandas_like`-side. | ||||
| Very low priority until `dask` adds typing. | ||||
| """ | ||||
|  | ||||
|  | ||||
| class DaskLazyFrame( | ||||
|     CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"], | ||||
|     ValidateBackendVersion, | ||||
| ): | ||||
|     _implementation = Implementation.DASK | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         native_dataframe: dd.DataFrame, | ||||
|         *, | ||||
|         version: Version, | ||||
|         validate_backend_version: bool = False, | ||||
|     ) -> None: | ||||
|         self._native_frame: dd.DataFrame = native_dataframe | ||||
|         self._version = version | ||||
|         self._cached_schema: dict[str, DType] | None = None | ||||
|         self._cached_columns: list[str] | None = None | ||||
|         if validate_backend_version: | ||||
|             self._validate_backend_version() | ||||
|  | ||||
|     @staticmethod | ||||
|     def _is_native(obj: dd.DataFrame | Any) -> TypeIs[dd.DataFrame]: | ||||
|         return isinstance(obj, dd.DataFrame) | ||||
|  | ||||
|     @classmethod | ||||
|     def from_native(cls, data: dd.DataFrame, /, *, context: _LimitedContext) -> Self: | ||||
|         return cls(data, version=context._version) | ||||
|  | ||||
|     def to_narwhals(self) -> LazyFrame[dd.DataFrame]: | ||||
|         return self._version.lazyframe(self, level="lazy") | ||||
|  | ||||
|     def __native_namespace__(self) -> ModuleType: | ||||
|         if self._implementation is Implementation.DASK: | ||||
|             return self._implementation.to_native_namespace() | ||||
|  | ||||
|         msg = f"Expected dask, got: {type(self._implementation)}"  # pragma: no cover | ||||
|         raise AssertionError(msg) | ||||
|  | ||||
|     def __narwhals_namespace__(self) -> DaskNamespace: | ||||
|         from narwhals._dask.namespace import DaskNamespace | ||||
|  | ||||
|         return DaskNamespace(version=self._version) | ||||
|  | ||||
|     def __narwhals_lazyframe__(self) -> Self: | ||||
|         return self | ||||
|  | ||||
|     def _with_version(self, version: Version) -> Self: | ||||
|         return self.__class__(self.native, version=version) | ||||
|  | ||||
|     def _with_native(self, df: Any) -> Self: | ||||
|         return self.__class__(df, version=self._version) | ||||
|  | ||||
|     def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: | ||||
|         return check_columns_exist(subset, available=self.columns) | ||||
|  | ||||
|     def _iter_columns(self) -> Iterator[dx.Series]: | ||||
|         for _col, ser in self.native.items():  # noqa: PERF102 | ||||
|             yield ser | ||||
|  | ||||
|     def with_columns(self, *exprs: DaskExpr) -> Self: | ||||
|         new_series = evaluate_exprs(self, *exprs) | ||||
|         return self._with_native(self.native.assign(**dict(new_series))) | ||||
|  | ||||
|     def collect( | ||||
|         self, backend: _EagerAllowedImpl | None, **kwargs: Any | ||||
|     ) -> CompliantDataFrameAny: | ||||
|         result = self.native.compute(**kwargs) | ||||
|  | ||||
|         if backend is None or backend is Implementation.PANDAS: | ||||
|             from narwhals._pandas_like.dataframe import PandasLikeDataFrame | ||||
|  | ||||
|             return PandasLikeDataFrame( | ||||
|                 result, | ||||
|                 implementation=Implementation.PANDAS, | ||||
|                 validate_backend_version=True, | ||||
|                 version=self._version, | ||||
|                 validate_column_names=True, | ||||
|             ) | ||||
|  | ||||
|         if backend is Implementation.POLARS: | ||||
|             import polars as pl  # ignore-banned-import | ||||
|  | ||||
|             from narwhals._polars.dataframe import PolarsDataFrame | ||||
|  | ||||
|             return PolarsDataFrame( | ||||
|                 pl.from_pandas(result), | ||||
|                 validate_backend_version=True, | ||||
|                 version=self._version, | ||||
|             ) | ||||
|  | ||||
|         if backend is Implementation.PYARROW: | ||||
|             import pyarrow as pa  # ignore-banned-import | ||||
|  | ||||
|             from narwhals._arrow.dataframe import ArrowDataFrame | ||||
|  | ||||
|             return ArrowDataFrame( | ||||
|                 pa.Table.from_pandas(result), | ||||
|                 validate_backend_version=True, | ||||
|                 version=self._version, | ||||
|                 validate_column_names=True, | ||||
|             ) | ||||
|  | ||||
|         msg = f"Unsupported `backend` value: {backend}"  # pragma: no cover | ||||
|         raise ValueError(msg)  # pragma: no cover | ||||
|  | ||||
|     @property | ||||
|     def columns(self) -> list[str]: | ||||
|         if self._cached_columns is None: | ||||
|             self._cached_columns = ( | ||||
|                 list(self.schema) | ||||
|                 if self._cached_schema is not None | ||||
|                 else self.native.columns.tolist() | ||||
|             ) | ||||
|         return self._cached_columns | ||||
|  | ||||
|     def filter(self, predicate: DaskExpr) -> Self: | ||||
|         # `[0]` is safe as the predicate's expression only returns a single column | ||||
|         mask = predicate(self)[0] | ||||
|         return self._with_native(self.native.loc[mask]) | ||||
|  | ||||
|     def simple_select(self, *column_names: str) -> Self: | ||||
|         df: Incomplete = self.native | ||||
|         native = select_columns_by_name(df, list(column_names), self._implementation) | ||||
|         return self._with_native(native) | ||||
|  | ||||
|     def aggregate(self, *exprs: DaskExpr) -> Self: | ||||
|         new_series = evaluate_exprs(self, *exprs) | ||||
|         df = dd.concat([val.rename(name) for name, val in new_series], axis=1) | ||||
|         return self._with_native(df) | ||||
|  | ||||
|     def select(self, *exprs: DaskExpr) -> Self: | ||||
|         new_series = evaluate_exprs(self, *exprs) | ||||
|         df: Incomplete = self.native | ||||
|         df = select_columns_by_name( | ||||
|             df.assign(**dict(new_series)), | ||||
|             [s[0] for s in new_series], | ||||
|             self._implementation, | ||||
|         ) | ||||
|         return self._with_native(df) | ||||
|  | ||||
|     def drop_nulls(self, subset: Sequence[str] | None) -> Self: | ||||
|         if subset is None: | ||||
|             return self._with_native(self.native.dropna()) | ||||
|         plx = self.__narwhals_namespace__() | ||||
|         mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True) | ||||
|         return self.filter(mask) | ||||
|  | ||||
|     @property | ||||
|     def schema(self) -> dict[str, DType]: | ||||
|         if self._cached_schema is None: | ||||
|             native_dtypes = self.native.dtypes | ||||
|             self._cached_schema = { | ||||
|                 col: native_to_narwhals_dtype( | ||||
|                     native_dtypes[col], self._version, self._implementation | ||||
|                 ) | ||||
|                 for col in self.native.columns | ||||
|             } | ||||
|         return self._cached_schema | ||||
|  | ||||
|     def collect_schema(self) -> dict[str, DType]: | ||||
|         return self.schema | ||||
|  | ||||
|     def drop(self, columns: Sequence[str], *, strict: bool) -> Self: | ||||
|         to_drop = parse_columns_to_drop(self, columns, strict=strict) | ||||
|  | ||||
|         return self._with_native(self.native.drop(columns=to_drop)) | ||||
|  | ||||
|     def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: | ||||
|         # Implementation is based on the following StackOverflow reply: | ||||
|         # https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409 | ||||
|         if order_by is None: | ||||
|             return self._with_native(add_row_index(self.native, name)) | ||||
|         plx = self.__narwhals_namespace__() | ||||
|         columns = self.columns | ||||
|         const_expr = plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL) | ||||
|         row_index_expr = ( | ||||
|             plx.col(name).cum_sum(reverse=False).over(partition_by=[], order_by=order_by) | ||||
|             - 1 | ||||
|         ) | ||||
|         return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns)) | ||||
|  | ||||
|     def rename(self, mapping: Mapping[str, str]) -> Self: | ||||
|         return self._with_native(self.native.rename(columns=mapping)) | ||||
|  | ||||
|     def head(self, n: int) -> Self: | ||||
|         return self._with_native(self.native.head(n=n, compute=False, npartitions=-1)) | ||||
|  | ||||
|     def unique( | ||||
|         self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy | ||||
|     ) -> Self: | ||||
|         if subset and (error := self._check_columns_exist(subset)): | ||||
|             raise error | ||||
|         if keep == "none": | ||||
|             subset = subset or self.columns | ||||
|             token = generate_temporary_column_name(n_bytes=8, columns=subset) | ||||
|             ser = self.native.groupby(subset).size().rename(token) | ||||
|             ser = ser[ser == 1] | ||||
|             unique = ser.reset_index().drop(columns=token) | ||||
|             result = self.native.merge(unique, on=subset, how="inner") | ||||
|         else: | ||||
|             mapped_keep = {"any": "first"}.get(keep, keep) | ||||
|             result = self.native.drop_duplicates(subset=subset, keep=mapped_keep) | ||||
|         return self._with_native(result) | ||||
|  | ||||
|     def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self: | ||||
|         if isinstance(descending, bool): | ||||
|             ascending: bool | list[bool] = not descending | ||||
|         else: | ||||
|             ascending = [not d for d in descending] | ||||
|         position = "last" if nulls_last else "first" | ||||
|         return self._with_native( | ||||
|             self.native.sort_values(list(by), ascending=ascending, na_position=position) | ||||
|         ) | ||||
|  | ||||
|     def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self: | ||||
|         df = self.native | ||||
|         schema = self.schema | ||||
|         by = list(by) | ||||
|         if isinstance(reverse, bool) and all(schema[x].is_numeric() for x in by): | ||||
|             if reverse: | ||||
|                 return self._with_native(df.nsmallest(k, by)) | ||||
|             return self._with_native(df.nlargest(k, by)) | ||||
|         if isinstance(reverse, bool): | ||||
|             reverse = [reverse] * len(by) | ||||
|         return self._with_native( | ||||
|             df.sort_values(by, ascending=list(reverse)).head( | ||||
|                 n=k, compute=False, npartitions=-1 | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def _join_inner( | ||||
|         self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str | ||||
|     ) -> dd.DataFrame: | ||||
|         return self.native.merge( | ||||
|             other.native, | ||||
|             left_on=left_on, | ||||
|             right_on=right_on, | ||||
|             how="inner", | ||||
|             suffixes=("", suffix), | ||||
|         ) | ||||
|  | ||||
|     def _join_left( | ||||
|         self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str | ||||
|     ) -> dd.DataFrame: | ||||
|         result_native = self.native.merge( | ||||
|             other.native, | ||||
|             how="left", | ||||
|             left_on=left_on, | ||||
|             right_on=right_on, | ||||
|             suffixes=("", suffix), | ||||
|         ) | ||||
|         extra = [ | ||||
|             right_key if right_key not in self.columns else f"{right_key}{suffix}" | ||||
|             for left_key, right_key in zip_strict(left_on, right_on) | ||||
|             if right_key != left_key | ||||
|         ] | ||||
|         return result_native.drop(columns=extra) | ||||
|  | ||||
|     def _join_full( | ||||
|         self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str | ||||
|     ) -> dd.DataFrame: | ||||
|         # dask does not retain keys post-join | ||||
|         # we must append the suffix to each key before-hand | ||||
|  | ||||
|         right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix) | ||||
|         other_native = other.native.rename(columns=right_on_mapper) | ||||
|         check_column_names_are_unique(other_native.columns) | ||||
|         right_suffixed = list(right_on_mapper.values()) | ||||
|         return self.native.merge( | ||||
|             other_native, | ||||
|             left_on=left_on, | ||||
|             right_on=right_suffixed, | ||||
|             how="outer", | ||||
|             suffixes=("", suffix), | ||||
|         ) | ||||
|  | ||||
|     def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame: | ||||
|         key_token = generate_temporary_column_name( | ||||
|             n_bytes=8, columns=(*self.columns, *other.columns) | ||||
|         ) | ||||
|         return ( | ||||
|             self.native.assign(**{key_token: 0}) | ||||
|             .merge( | ||||
|                 other.native.assign(**{key_token: 0}), | ||||
|                 how="inner", | ||||
|                 left_on=key_token, | ||||
|                 right_on=key_token, | ||||
|                 suffixes=("", suffix), | ||||
|             ) | ||||
|             .drop(columns=key_token) | ||||
|         ) | ||||
|  | ||||
|     def _join_semi( | ||||
|         self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str] | ||||
|     ) -> dd.DataFrame: | ||||
|         other_native = self._join_filter_rename( | ||||
|             other=other, | ||||
|             columns_to_select=list(right_on), | ||||
|             columns_mapping=dict(zip(right_on, left_on)), | ||||
|         ) | ||||
|         return self.native.merge( | ||||
|             other_native, how="inner", left_on=left_on, right_on=left_on | ||||
|         ) | ||||
|  | ||||
|     def _join_anti( | ||||
|         self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str] | ||||
|     ) -> dd.DataFrame: | ||||
|         indicator_token = generate_temporary_column_name( | ||||
|             n_bytes=8, columns=(*self.columns, *other.columns) | ||||
|         ) | ||||
|         other_native = self._join_filter_rename( | ||||
|             other=other, | ||||
|             columns_to_select=list(right_on), | ||||
|             columns_mapping=dict(zip(right_on, left_on)), | ||||
|         ) | ||||
|         df = self.native.merge( | ||||
|             other_native, | ||||
|             how="left", | ||||
|             indicator=indicator_token,  # pyright: ignore[reportArgumentType] | ||||
|             left_on=left_on, | ||||
|             right_on=left_on, | ||||
|         ) | ||||
|         return df[df[indicator_token] == "left_only"].drop(columns=[indicator_token]) | ||||
|  | ||||
|     def _join_filter_rename( | ||||
|         self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str] | ||||
|     ) -> dd.DataFrame: | ||||
|         """Helper function to avoid creating extra columns and row duplication. | ||||
|  | ||||
|         Used in `"anti"` and `"semi`" join's. | ||||
|  | ||||
|         Notice that a native object is returned. | ||||
|         """ | ||||
|         other_native: Incomplete = other.native | ||||
|         # rename to avoid creating extra columns in join | ||||
|         return ( | ||||
|             select_columns_by_name(other_native, columns_to_select, self._implementation) | ||||
|             .rename(columns=columns_mapping) | ||||
|             .drop_duplicates() | ||||
|         ) | ||||
|  | ||||
|     def join( | ||||
|         self, | ||||
|         other: Self, | ||||
|         *, | ||||
|         how: JoinStrategy, | ||||
|         left_on: Sequence[str] | None, | ||||
|         right_on: Sequence[str] | None, | ||||
|         suffix: str, | ||||
|     ) -> Self: | ||||
|         if how == "cross": | ||||
|             result = self._join_cross(other=other, suffix=suffix) | ||||
|  | ||||
|         elif left_on is None or right_on is None:  # pragma: no cover | ||||
|             raise ValueError(left_on, right_on) | ||||
|  | ||||
|         elif how == "inner": | ||||
|             result = self._join_inner( | ||||
|                 other=other, left_on=left_on, right_on=right_on, suffix=suffix | ||||
|             ) | ||||
|         elif how == "anti": | ||||
|             result = self._join_anti(other=other, left_on=left_on, right_on=right_on) | ||||
|         elif how == "semi": | ||||
|             result = self._join_semi(other=other, left_on=left_on, right_on=right_on) | ||||
|         elif how == "left": | ||||
|             result = self._join_left( | ||||
|                 other=other, left_on=left_on, right_on=right_on, suffix=suffix | ||||
|             ) | ||||
|         elif how == "full": | ||||
|             result = self._join_full( | ||||
|                 other=other, left_on=left_on, right_on=right_on, suffix=suffix | ||||
|             ) | ||||
|         else: | ||||
|             assert_never(how) | ||||
|         return self._with_native(result) | ||||
|  | ||||
|     def join_asof( | ||||
|         self, | ||||
|         other: Self, | ||||
|         *, | ||||
|         left_on: str, | ||||
|         right_on: str, | ||||
|         by_left: Sequence[str] | None, | ||||
|         by_right: Sequence[str] | None, | ||||
|         strategy: AsofJoinStrategy, | ||||
|         suffix: str, | ||||
|     ) -> Self: | ||||
|         plx = self.__native_namespace__() | ||||
|         return self._with_native( | ||||
|             plx.merge_asof( | ||||
|                 self.native, | ||||
|                 other.native, | ||||
|                 left_on=left_on, | ||||
|                 right_on=right_on, | ||||
|                 left_by=by_left, | ||||
|                 right_by=by_right, | ||||
|                 direction=strategy, | ||||
|                 suffixes=("", suffix), | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def group_by( | ||||
|         self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool | ||||
|     ) -> DaskLazyGroupBy: | ||||
|         from narwhals._dask.group_by import DaskLazyGroupBy | ||||
|  | ||||
|         return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) | ||||
|  | ||||
|     def tail(self, n: int) -> Self:  # pragma: no cover | ||||
|         native_frame = self.native | ||||
|         n_partitions = native_frame.npartitions | ||||
|  | ||||
|         if n_partitions == 1: | ||||
|             return self._with_native(self.native.tail(n=n, compute=False)) | ||||
|         msg = ( | ||||
|             "`LazyFrame.tail` is not supported for Dask backend with multiple partitions." | ||||
|         ) | ||||
|         raise NotImplementedError(msg) | ||||
|  | ||||
|     def gather_every(self, n: int, offset: int) -> Self: | ||||
|         row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) | ||||
|         plx = self.__narwhals_namespace__() | ||||
|         return ( | ||||
|             self.with_row_index(row_index_token, order_by=None) | ||||
|             .filter( | ||||
|                 (plx.col(row_index_token) >= offset) | ||||
|                 & ((plx.col(row_index_token) - offset) % n == 0) | ||||
|             ) | ||||
|             .drop([row_index_token], strict=False) | ||||
|         ) | ||||
|  | ||||
|     def unpivot( | ||||
|         self, | ||||
|         on: Sequence[str] | None, | ||||
|         index: Sequence[str] | None, | ||||
|         variable_name: str, | ||||
|         value_name: str, | ||||
|     ) -> Self: | ||||
|         return self._with_native( | ||||
|             self.native.melt( | ||||
|                 id_vars=index, | ||||
|                 value_vars=on, | ||||
|                 var_name=variable_name, | ||||
|                 value_name=value_name, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def sink_parquet(self, file: str | Path | BytesIO) -> None: | ||||
|         self.native.to_parquet(file) | ||||
|  | ||||
|     explode = not_implemented() | ||||
							
								
								
									
										701
									
								
								lib/python3.11/site-packages/narwhals/_dask/expr.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										701
									
								
								lib/python3.11/site-packages/narwhals/_dask/expr.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,701 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import warnings | ||||
| from typing import TYPE_CHECKING, Any, Callable, Literal, cast | ||||
|  | ||||
| import pandas as pd | ||||
|  | ||||
| from narwhals._compliant import DepthTrackingExpr, LazyExpr | ||||
| from narwhals._dask.expr_dt import DaskExprDateTimeNamespace | ||||
| from narwhals._dask.expr_str import DaskExprStringNamespace | ||||
| from narwhals._dask.utils import ( | ||||
|     add_row_index, | ||||
|     maybe_evaluate_expr, | ||||
|     narwhals_to_native_dtype, | ||||
| ) | ||||
| from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases | ||||
| from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype | ||||
| from narwhals._utils import ( | ||||
|     Implementation, | ||||
|     generate_temporary_column_name, | ||||
|     not_implemented, | ||||
| ) | ||||
| from narwhals.exceptions import InvalidOperationError | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from collections.abc import Sequence | ||||
|  | ||||
|     import dask.dataframe.dask_expr as dx | ||||
|     from typing_extensions import Self | ||||
|  | ||||
|     from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs | ||||
|     from narwhals._dask.dataframe import DaskLazyFrame | ||||
|     from narwhals._dask.namespace import DaskNamespace | ||||
|     from narwhals._expression_parsing import ExprKind, ExprMetadata | ||||
|     from narwhals._utils import Version, _LimitedContext | ||||
|     from narwhals.typing import ( | ||||
|         FillNullStrategy, | ||||
|         IntoDType, | ||||
|         ModeKeepStrategy, | ||||
|         NonNestedLiteral, | ||||
|         NumericLiteral, | ||||
|         RollingInterpolationMethod, | ||||
|         TemporalLiteral, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class DaskExpr( | ||||
|     LazyExpr["DaskLazyFrame", "dx.Series"],  # pyright: ignore[reportInvalidTypeArguments] | ||||
|     DepthTrackingExpr["DaskLazyFrame", "dx.Series"],  # pyright: ignore[reportInvalidTypeArguments] | ||||
| ): | ||||
|     _implementation: Implementation = Implementation.DASK | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         call: EvalSeries[DaskLazyFrame, dx.Series],  # pyright: ignore[reportInvalidTypeForm] | ||||
|         *, | ||||
|         depth: int, | ||||
|         function_name: str, | ||||
|         evaluate_output_names: EvalNames[DaskLazyFrame], | ||||
|         alias_output_names: AliasNames | None, | ||||
|         version: Version, | ||||
|         scalar_kwargs: ScalarKwargs | None = None, | ||||
|     ) -> None: | ||||
|         self._call = call | ||||
|         self._depth = depth | ||||
|         self._function_name = function_name | ||||
|         self._evaluate_output_names = evaluate_output_names | ||||
|         self._alias_output_names = alias_output_names | ||||
|         self._version = version | ||||
|         self._scalar_kwargs = scalar_kwargs or {} | ||||
|         self._metadata: ExprMetadata | None = None | ||||
|  | ||||
|     def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: | ||||
|         return self._call(df) | ||||
|  | ||||
|     def __narwhals_namespace__(self) -> DaskNamespace:  # pragma: no cover | ||||
|         from narwhals._dask.namespace import DaskNamespace | ||||
|  | ||||
|         return DaskNamespace(version=self._version) | ||||
|  | ||||
|     def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             # result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16 | ||||
|             #   that raised a KeyError for result[0] during collection. | ||||
|             return [result.loc[0][0] for result in self(df)] | ||||
|  | ||||
|         return self.__class__( | ||||
|             func, | ||||
|             depth=self._depth, | ||||
|             function_name=self._function_name, | ||||
|             evaluate_output_names=self._evaluate_output_names, | ||||
|             alias_output_names=self._alias_output_names, | ||||
|             version=self._version, | ||||
|             scalar_kwargs=self._scalar_kwargs, | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     def from_column_names( | ||||
|         cls: type[Self], | ||||
|         evaluate_column_names: EvalNames[DaskLazyFrame], | ||||
|         /, | ||||
|         *, | ||||
|         context: _LimitedContext, | ||||
|         function_name: str = "", | ||||
|     ) -> Self: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             try: | ||||
|                 return [ | ||||
|                     df._native_frame[column_name] | ||||
|                     for column_name in evaluate_column_names(df) | ||||
|                 ] | ||||
|             except KeyError as e: | ||||
|                 if error := df._check_columns_exist(evaluate_column_names(df)): | ||||
|                     raise error from e | ||||
|                 raise | ||||
|  | ||||
|         return cls( | ||||
|             func, | ||||
|             depth=0, | ||||
|             function_name=function_name, | ||||
|             evaluate_output_names=evaluate_column_names, | ||||
|             alias_output_names=None, | ||||
|             version=context._version, | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             return [df.native.iloc[:, i] for i in column_indices] | ||||
|  | ||||
|         return cls( | ||||
|             func, | ||||
|             depth=0, | ||||
|             function_name="nth", | ||||
|             evaluate_output_names=cls._eval_names_indices(column_indices), | ||||
|             alias_output_names=None, | ||||
|             version=context._version, | ||||
|         ) | ||||
|  | ||||
|     def _with_callable( | ||||
|         self, | ||||
|         # First argument to `call` should be `dx.Series` | ||||
|         call: Callable[..., dx.Series], | ||||
|         /, | ||||
|         expr_name: str = "", | ||||
|         scalar_kwargs: ScalarKwargs | None = None, | ||||
|         **expressifiable_args: Self | Any, | ||||
|     ) -> Self: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             native_results: list[dx.Series] = [] | ||||
|             native_series_list = self._call(df) | ||||
|             other_native_series = { | ||||
|                 key: maybe_evaluate_expr(df, value) | ||||
|                 for key, value in expressifiable_args.items() | ||||
|             } | ||||
|             for native_series in native_series_list: | ||||
|                 result_native = call(native_series, **other_native_series) | ||||
|                 native_results.append(result_native) | ||||
|             return native_results | ||||
|  | ||||
|         return self.__class__( | ||||
|             func, | ||||
|             depth=self._depth + 1, | ||||
|             function_name=f"{self._function_name}->{expr_name}", | ||||
|             evaluate_output_names=self._evaluate_output_names, | ||||
|             alias_output_names=self._alias_output_names, | ||||
|             version=self._version, | ||||
|             scalar_kwargs=scalar_kwargs, | ||||
|         ) | ||||
|  | ||||
|     def _with_alias_output_names(self, func: AliasNames | None, /) -> Self: | ||||
|         current_alias_output_names = self._alias_output_names | ||||
|         alias_output_names = ( | ||||
|             None | ||||
|             if func is None | ||||
|             else func | ||||
|             if current_alias_output_names is None | ||||
|             else lambda output_names: func(current_alias_output_names(output_names)) | ||||
|         ) | ||||
|         return type(self)( | ||||
|             call=self._call, | ||||
|             depth=self._depth, | ||||
|             function_name=self._function_name, | ||||
|             evaluate_output_names=self._evaluate_output_names, | ||||
|             alias_output_names=alias_output_names, | ||||
|             version=self._version, | ||||
|             scalar_kwargs=self._scalar_kwargs, | ||||
|         ) | ||||
|  | ||||
|     def _with_binary( | ||||
|         self, | ||||
|         call: Callable[[dx.Series, Any], dx.Series], | ||||
|         name: str, | ||||
|         other: Any, | ||||
|         *, | ||||
|         reverse: bool = False, | ||||
|     ) -> Self: | ||||
|         result = self._with_callable( | ||||
|             lambda expr, other: call(expr, other), name, other=other | ||||
|         ) | ||||
|         if reverse: | ||||
|             result = result.alias("literal") | ||||
|         return result | ||||
|  | ||||
|     def _binary_op(self, op_name: str, other: Any) -> Self: | ||||
|         return self._with_binary( | ||||
|             lambda expr, other: getattr(expr, op_name)(other), op_name, other | ||||
|         ) | ||||
|  | ||||
|     def _reverse_binary_op( | ||||
|         self, op_name: str, operator_func: Callable[..., dx.Series], other: Any | ||||
|     ) -> Self: | ||||
|         return self._with_binary( | ||||
|             lambda expr, other: operator_func(other, expr), op_name, other, reverse=True | ||||
|         ) | ||||
|  | ||||
|     def __add__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__add__", other) | ||||
|  | ||||
|     def __sub__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__sub__", other) | ||||
|  | ||||
|     def __mul__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__mul__", other) | ||||
|  | ||||
|     def __truediv__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__truediv__", other) | ||||
|  | ||||
|     def __floordiv__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__floordiv__", other) | ||||
|  | ||||
|     def __pow__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__pow__", other) | ||||
|  | ||||
|     def __mod__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__mod__", other) | ||||
|  | ||||
|     def __eq__(self, other: object) -> Self:  # type: ignore[override] | ||||
|         return self._binary_op("__eq__", other) | ||||
|  | ||||
|     def __ne__(self, other: object) -> Self:  # type: ignore[override] | ||||
|         return self._binary_op("__ne__", other) | ||||
|  | ||||
|     def __ge__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__ge__", other) | ||||
|  | ||||
|     def __gt__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__gt__", other) | ||||
|  | ||||
|     def __le__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__le__", other) | ||||
|  | ||||
|     def __lt__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__lt__", other) | ||||
|  | ||||
|     def __and__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__and__", other) | ||||
|  | ||||
|     def __or__(self, other: Any) -> Self: | ||||
|         return self._binary_op("__or__", other) | ||||
|  | ||||
|     def __rsub__(self, other: Any) -> Self: | ||||
|         return self._reverse_binary_op("__rsub__", lambda a, b: a - b, other) | ||||
|  | ||||
|     def __rtruediv__(self, other: Any) -> Self: | ||||
|         return self._reverse_binary_op("__rtruediv__", lambda a, b: a / b, other) | ||||
|  | ||||
|     def __rfloordiv__(self, other: Any) -> Self: | ||||
|         return self._reverse_binary_op("__rfloordiv__", lambda a, b: a // b, other) | ||||
|  | ||||
|     def __rpow__(self, other: Any) -> Self: | ||||
|         return self._reverse_binary_op("__rpow__", lambda a, b: a**b, other) | ||||
|  | ||||
|     def __rmod__(self, other: Any) -> Self: | ||||
|         return self._reverse_binary_op("__rmod__", lambda a, b: a % b, other) | ||||
|  | ||||
|     def __invert__(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.__invert__(), "__invert__") | ||||
|  | ||||
|     def mean(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.mean().to_series(), "mean") | ||||
|  | ||||
|     def median(self) -> Self: | ||||
|         from narwhals.exceptions import InvalidOperationError | ||||
|  | ||||
|         def func(s: dx.Series) -> dx.Series: | ||||
|             dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK) | ||||
|             if not dtype.is_numeric(): | ||||
|                 msg = "`median` operation not supported for non-numeric input type." | ||||
|                 raise InvalidOperationError(msg) | ||||
|             return s.median_approximate().to_series() | ||||
|  | ||||
|         return self._with_callable(func, "median") | ||||
|  | ||||
|     def min(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.min().to_series(), "min") | ||||
|  | ||||
|     def max(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.max().to_series(), "max") | ||||
|  | ||||
|     def std(self, ddof: int) -> Self: | ||||
|         return self._with_callable( | ||||
|             lambda expr: expr.std(ddof=ddof).to_series(), | ||||
|             "std", | ||||
|             scalar_kwargs={"ddof": ddof}, | ||||
|         ) | ||||
|  | ||||
|     def var(self, ddof: int) -> Self: | ||||
|         return self._with_callable( | ||||
|             lambda expr: expr.var(ddof=ddof).to_series(), | ||||
|             "var", | ||||
|             scalar_kwargs={"ddof": ddof}, | ||||
|         ) | ||||
|  | ||||
|     def skew(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.skew().to_series(), "skew") | ||||
|  | ||||
|     def kurtosis(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis") | ||||
|  | ||||
|     def shift(self, n: int) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.shift(n), "shift") | ||||
|  | ||||
|     def cum_sum(self, *, reverse: bool) -> Self: | ||||
|         if reverse:  # pragma: no cover | ||||
|             # https://github.com/dask/dask/issues/11802 | ||||
|             msg = "`cum_sum(reverse=True)` is not supported with Dask backend" | ||||
|             raise NotImplementedError(msg) | ||||
|  | ||||
|         return self._with_callable(lambda expr: expr.cumsum(), "cum_sum") | ||||
|  | ||||
|     def cum_count(self, *, reverse: bool) -> Self: | ||||
|         if reverse:  # pragma: no cover | ||||
|             msg = "`cum_count(reverse=True)` is not supported with Dask backend" | ||||
|             raise NotImplementedError(msg) | ||||
|  | ||||
|         return self._with_callable( | ||||
|             lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count" | ||||
|         ) | ||||
|  | ||||
|     def cum_min(self, *, reverse: bool) -> Self: | ||||
|         if reverse:  # pragma: no cover | ||||
|             msg = "`cum_min(reverse=True)` is not supported with Dask backend" | ||||
|             raise NotImplementedError(msg) | ||||
|  | ||||
|         return self._with_callable(lambda expr: expr.cummin(), "cum_min") | ||||
|  | ||||
|     def cum_max(self, *, reverse: bool) -> Self: | ||||
|         if reverse:  # pragma: no cover | ||||
|             msg = "`cum_max(reverse=True)` is not supported with Dask backend" | ||||
|             raise NotImplementedError(msg) | ||||
|  | ||||
|         return self._with_callable(lambda expr: expr.cummax(), "cum_max") | ||||
|  | ||||
|     def cum_prod(self, *, reverse: bool) -> Self: | ||||
|         if reverse:  # pragma: no cover | ||||
|             msg = "`cum_prod(reverse=True)` is not supported with Dask backend" | ||||
|             raise NotImplementedError(msg) | ||||
|  | ||||
|         return self._with_callable(lambda expr: expr.cumprod(), "cum_prod") | ||||
|  | ||||
|     def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: | ||||
|         return self._with_callable( | ||||
|             lambda expr: expr.rolling( | ||||
|                 window=window_size, min_periods=min_samples, center=center | ||||
|             ).sum(), | ||||
|             "rolling_sum", | ||||
|         ) | ||||
|  | ||||
|     def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: | ||||
|         return self._with_callable( | ||||
|             lambda expr: expr.rolling( | ||||
|                 window=window_size, min_periods=min_samples, center=center | ||||
|             ).mean(), | ||||
|             "rolling_mean", | ||||
|         ) | ||||
|  | ||||
|     def rolling_var( | ||||
|         self, window_size: int, *, min_samples: int, center: bool, ddof: int | ||||
|     ) -> Self: | ||||
|         if ddof == 1: | ||||
|             return self._with_callable( | ||||
|                 lambda expr: expr.rolling( | ||||
|                     window=window_size, min_periods=min_samples, center=center | ||||
|                 ).var(), | ||||
|                 "rolling_var", | ||||
|             ) | ||||
|         msg = "Dask backend only supports `ddof=1` for `rolling_var`" | ||||
|         raise NotImplementedError(msg) | ||||
|  | ||||
|     def rolling_std( | ||||
|         self, window_size: int, *, min_samples: int, center: bool, ddof: int | ||||
|     ) -> Self: | ||||
|         if ddof == 1: | ||||
|             return self._with_callable( | ||||
|                 lambda expr: expr.rolling( | ||||
|                     window=window_size, min_periods=min_samples, center=center | ||||
|                 ).std(), | ||||
|                 "rolling_std", | ||||
|             ) | ||||
|         msg = "Dask backend only supports `ddof=1` for `rolling_std`" | ||||
|         raise NotImplementedError(msg) | ||||
|  | ||||
|     def sum(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.sum().to_series(), "sum") | ||||
|  | ||||
|     def count(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.count().to_series(), "count") | ||||
|  | ||||
|     def round(self, decimals: int) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.round(decimals), "round") | ||||
|  | ||||
|     def unique(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.unique(), "unique") | ||||
|  | ||||
|     def drop_nulls(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.dropna(), "drop_nulls") | ||||
|  | ||||
|     def abs(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.abs(), "abs") | ||||
|  | ||||
|     def all(self) -> Self: | ||||
|         return self._with_callable( | ||||
|             lambda expr: expr.all( | ||||
|                 axis=None, skipna=True, split_every=False, out=None | ||||
|             ).to_series(), | ||||
|             "all", | ||||
|         ) | ||||
|  | ||||
|     def any(self) -> Self: | ||||
|         return self._with_callable( | ||||
|             lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(), | ||||
|             "any", | ||||
|         ) | ||||
|  | ||||
|     def fill_nan(self, value: float | None) -> Self: | ||||
|         value_nullable = pd.NA if value is None else value | ||||
|         value_numpy = float("nan") if value is None else value | ||||
|  | ||||
|         def func(expr: dx.Series) -> dx.Series: | ||||
|             # If/when pandas exposes an API which distinguishes NaN vs null, use that. | ||||
|             mask = cast("dx.Series", expr != expr)  # noqa: PLR0124 | ||||
|             mask = mask.fillna(False) | ||||
|             fill = ( | ||||
|                 value_nullable | ||||
|                 if get_dtype_backend(expr.dtype, self._implementation) | ||||
|                 else value_numpy | ||||
|             ) | ||||
|             return expr.mask(mask, fill)  # pyright: ignore[reportArgumentType] | ||||
|  | ||||
|         return self._with_callable(func, "fill_nan") | ||||
|  | ||||
|     def fill_null( | ||||
|         self, | ||||
|         value: Self | NonNestedLiteral, | ||||
|         strategy: FillNullStrategy | None, | ||||
|         limit: int | None, | ||||
|     ) -> Self: | ||||
|         def func(expr: dx.Series) -> dx.Series: | ||||
|             if value is not None: | ||||
|                 res_ser = expr.fillna(value) | ||||
|             else: | ||||
|                 res_ser = ( | ||||
|                     expr.ffill(limit=limit) | ||||
|                     if strategy == "forward" | ||||
|                     else expr.bfill(limit=limit) | ||||
|                 ) | ||||
|             return res_ser | ||||
|  | ||||
|         return self._with_callable(func, "fill_null") | ||||
|  | ||||
|     def clip( | ||||
|         self, | ||||
|         lower_bound: Self | NumericLiteral | TemporalLiteral | None, | ||||
|         upper_bound: Self | NumericLiteral | TemporalLiteral | None, | ||||
|     ) -> Self: | ||||
|         return self._with_callable( | ||||
|             lambda expr, lower_bound, upper_bound: expr.clip( | ||||
|                 lower=lower_bound, upper=upper_bound | ||||
|             ), | ||||
|             "clip", | ||||
|             lower_bound=lower_bound, | ||||
|             upper_bound=upper_bound, | ||||
|         ) | ||||
|  | ||||
|     def diff(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.diff(), "diff") | ||||
|  | ||||
|     def n_unique(self) -> Self: | ||||
|         return self._with_callable( | ||||
|             lambda expr: expr.nunique(dropna=False).to_series(), "n_unique" | ||||
|         ) | ||||
|  | ||||
|     def is_null(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.isna(), "is_null") | ||||
|  | ||||
|     def is_nan(self) -> Self: | ||||
|         def func(expr: dx.Series) -> dx.Series: | ||||
|             dtype = native_to_narwhals_dtype( | ||||
|                 expr.dtype, self._version, self._implementation | ||||
|             ) | ||||
|             if dtype.is_numeric(): | ||||
|                 return expr != expr  # pyright: ignore[reportReturnType] # noqa: PLR0124 | ||||
|             msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?" | ||||
|             raise InvalidOperationError(msg) | ||||
|  | ||||
|         return self._with_callable(func, "is_null") | ||||
|  | ||||
|     def len(self) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.size.to_series(), "len") | ||||
|  | ||||
|     def quantile( | ||||
|         self, quantile: float, interpolation: RollingInterpolationMethod | ||||
|     ) -> Self: | ||||
|         if interpolation == "linear": | ||||
|  | ||||
|             def func(expr: dx.Series, quantile: float) -> dx.Series: | ||||
|                 if expr.npartitions > 1: | ||||
|                     msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions." | ||||
|                     raise NotImplementedError(msg) | ||||
|                 return expr.quantile( | ||||
|                     q=quantile, method="dask" | ||||
|                 ).to_series()  # pragma: no cover | ||||
|  | ||||
|             return self._with_callable(func, "quantile", quantile=quantile) | ||||
|         msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead." | ||||
|         raise NotImplementedError(msg) | ||||
|  | ||||
|     def is_first_distinct(self) -> Self: | ||||
|         def func(expr: dx.Series) -> dx.Series: | ||||
|             _name = expr.name | ||||
|             col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) | ||||
|             frame = add_row_index(expr.to_frame(), col_token) | ||||
|             first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token] | ||||
|             return frame[col_token].isin(first_distinct_index) | ||||
|  | ||||
|         return self._with_callable(func, "is_first_distinct") | ||||
|  | ||||
|     def is_last_distinct(self) -> Self: | ||||
|         def func(expr: dx.Series) -> dx.Series: | ||||
|             _name = expr.name | ||||
|             col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) | ||||
|             frame = add_row_index(expr.to_frame(), col_token) | ||||
|             last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token] | ||||
|             return frame[col_token].isin(last_distinct_index) | ||||
|  | ||||
|         return self._with_callable(func, "is_last_distinct") | ||||
|  | ||||
|     def is_unique(self) -> Self: | ||||
|         def func(expr: dx.Series) -> dx.Series: | ||||
|             _name = expr.name | ||||
|             return ( | ||||
|                 expr.to_frame() | ||||
|                 .groupby(_name, dropna=False) | ||||
|                 .transform("size", meta=(_name, int)) | ||||
|                 == 1 | ||||
|             ) | ||||
|  | ||||
|         return self._with_callable(func, "is_unique") | ||||
|  | ||||
|     def is_in(self, other: Any) -> Self: | ||||
|         return self._with_callable(lambda expr: expr.isin(other), "is_in") | ||||
|  | ||||
|     def null_count(self) -> Self: | ||||
|         return self._with_callable( | ||||
|             lambda expr: expr.isna().sum().to_series(), "null_count" | ||||
|         ) | ||||
|  | ||||
|     def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: | ||||
|         # pandas is a required dependency of dask so it's safe to import this | ||||
|         from narwhals._pandas_like.group_by import PandasLikeGroupBy | ||||
|  | ||||
|         if not partition_by: | ||||
|             assert order_by  # noqa: S101 | ||||
|  | ||||
|             # This is something like `nw.col('a').cum_sum().order_by(key)` | ||||
|             # which we can always easily support, as it doesn't require grouping. | ||||
|             def func(df: DaskLazyFrame) -> Sequence[dx.Series]: | ||||
|                 return self(df.sort(*order_by, descending=False, nulls_last=False)) | ||||
|         elif not self._is_elementary():  # pragma: no cover | ||||
|             msg = ( | ||||
|                 "Only elementary expressions are supported for `.over` in dask.\n\n" | ||||
|                 "Please see: " | ||||
|                 "https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/" | ||||
|             ) | ||||
|             raise NotImplementedError(msg) | ||||
|         elif order_by: | ||||
|             # Wrong results https://github.com/dask/dask/issues/11806. | ||||
|             msg = "`over` with `order_by` is not yet supported in Dask." | ||||
|             raise NotImplementedError(msg) | ||||
|         else: | ||||
|             function_name = PandasLikeGroupBy._leaf_name(self) | ||||
|             try: | ||||
|                 dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name] | ||||
|             except KeyError: | ||||
|                 # window functions are unsupported: https://github.com/dask/dask/issues/11806 | ||||
|                 msg = ( | ||||
|                     f"Unsupported function: {function_name} in `over` context.\n\n" | ||||
|                     f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n" | ||||
|                 ) | ||||
|                 raise NotImplementedError(msg) from None | ||||
|  | ||||
|             def func(df: DaskLazyFrame) -> Sequence[dx.Series]: | ||||
|                 output_names, aliases = evaluate_output_names_and_aliases(self, df, []) | ||||
|  | ||||
|                 with warnings.catch_warnings(): | ||||
|                     # https://github.com/dask/dask/issues/11804 | ||||
|                     warnings.filterwarnings( | ||||
|                         "ignore", | ||||
|                         message=".*`meta` is not specified", | ||||
|                         category=UserWarning, | ||||
|                     ) | ||||
|                     grouped = df.native.groupby(partition_by) | ||||
|                     if dask_function_name == "size": | ||||
|                         if len(output_names) != 1:  # pragma: no cover | ||||
|                             msg = "Safety check failed, please report a bug." | ||||
|                             raise AssertionError(msg) | ||||
|                         res_native = grouped.transform( | ||||
|                             dask_function_name, **self._scalar_kwargs | ||||
|                         ).to_frame(output_names[0]) | ||||
|                     else: | ||||
|                         res_native = grouped[list(output_names)].transform( | ||||
|                             dask_function_name, **self._scalar_kwargs | ||||
|                         ) | ||||
|                 result_frame = df._with_native( | ||||
|                     res_native.rename(columns=dict(zip(output_names, aliases))) | ||||
|                 ).native | ||||
|                 return [result_frame[name] for name in aliases] | ||||
|  | ||||
|         return self.__class__( | ||||
|             func, | ||||
|             depth=self._depth + 1, | ||||
|             function_name=self._function_name + "->over", | ||||
|             evaluate_output_names=self._evaluate_output_names, | ||||
|             alias_output_names=self._alias_output_names, | ||||
|             version=self._version, | ||||
|         ) | ||||
|  | ||||
|     def cast(self, dtype: IntoDType) -> Self: | ||||
|         def func(expr: dx.Series) -> dx.Series: | ||||
|             native_dtype = narwhals_to_native_dtype(dtype, self._version) | ||||
|             return expr.astype(native_dtype) | ||||
|  | ||||
|         return self._with_callable(func, "cast") | ||||
|  | ||||
|     def is_finite(self) -> Self: | ||||
|         import dask.array as da | ||||
|  | ||||
|         return self._with_callable(da.isfinite, "is_finite") | ||||
|  | ||||
|     def log(self, base: float) -> Self: | ||||
|         import dask.array as da | ||||
|  | ||||
|         def _log(expr: dx.Series) -> dx.Series: | ||||
|             return da.log(expr) / da.log(base) | ||||
|  | ||||
|         return self._with_callable(_log, "log") | ||||
|  | ||||
|     def exp(self) -> Self: | ||||
|         import dask.array as da | ||||
|  | ||||
|         return self._with_callable(da.exp, "exp") | ||||
|  | ||||
|     def sqrt(self) -> Self: | ||||
|         import dask.array as da | ||||
|  | ||||
|         return self._with_callable(da.sqrt, "sqrt") | ||||
|  | ||||
|     def mode(self, *, keep: ModeKeepStrategy) -> Self: | ||||
|         def func(expr: dx.Series) -> dx.Series: | ||||
|             _name = expr.name | ||||
|             result = expr.to_frame().mode()[_name] | ||||
|             return result.head(1) if keep == "any" else result | ||||
|  | ||||
|         return self._with_callable(func, "mode", scalar_kwargs={"keep": keep}) | ||||
|  | ||||
|     @property | ||||
|     def str(self) -> DaskExprStringNamespace: | ||||
|         return DaskExprStringNamespace(self) | ||||
|  | ||||
|     @property | ||||
|     def dt(self) -> DaskExprDateTimeNamespace: | ||||
|         return DaskExprDateTimeNamespace(self) | ||||
|  | ||||
|     arg_max: not_implemented = not_implemented() | ||||
|     arg_min: not_implemented = not_implemented() | ||||
|     arg_true: not_implemented = not_implemented() | ||||
|     ewm_mean: not_implemented = not_implemented() | ||||
|     gather_every: not_implemented = not_implemented() | ||||
|     head: not_implemented = not_implemented() | ||||
|     map_batches: not_implemented = not_implemented() | ||||
|     sample: not_implemented = not_implemented() | ||||
|     rank: not_implemented = not_implemented() | ||||
|     replace_strict: not_implemented = not_implemented() | ||||
|     sort: not_implemented = not_implemented() | ||||
|     tail: not_implemented = not_implemented() | ||||
|  | ||||
|     # namespaces | ||||
|     list: not_implemented = not_implemented()  # type: ignore[assignment] | ||||
|     cat: not_implemented = not_implemented()  # type: ignore[assignment] | ||||
|     struct: not_implemented = not_implemented()  # type: ignore[assignment] | ||||
							
								
								
									
										175
									
								
								lib/python3.11/site-packages/narwhals/_dask/expr_dt.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								lib/python3.11/site-packages/narwhals/_dask/expr_dt.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,175 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| from narwhals._compliant import LazyExprNamespace | ||||
| from narwhals._compliant.any_namespace import DateTimeNamespace | ||||
| from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND | ||||
| from narwhals._duration import Interval | ||||
| from narwhals._pandas_like.utils import ( | ||||
|     ALIAS_DICT, | ||||
|     calculate_timestamp_date, | ||||
|     calculate_timestamp_datetime, | ||||
|     native_to_narwhals_dtype, | ||||
| ) | ||||
| from narwhals._utils import Implementation | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     import dask.dataframe.dask_expr as dx | ||||
|  | ||||
|     from narwhals._dask.expr import DaskExpr | ||||
|     from narwhals.typing import TimeUnit | ||||
|  | ||||
|  | ||||
| class DaskExprDateTimeNamespace( | ||||
|     LazyExprNamespace["DaskExpr"], DateTimeNamespace["DaskExpr"] | ||||
| ): | ||||
|     def date(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable(lambda expr: expr.dt.date, "date") | ||||
|  | ||||
|     def year(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable(lambda expr: expr.dt.year, "year") | ||||
|  | ||||
|     def month(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable(lambda expr: expr.dt.month, "month") | ||||
|  | ||||
|     def day(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable(lambda expr: expr.dt.day, "day") | ||||
|  | ||||
|     def hour(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable(lambda expr: expr.dt.hour, "hour") | ||||
|  | ||||
|     def minute(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable(lambda expr: expr.dt.minute, "minute") | ||||
|  | ||||
|     def second(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable(lambda expr: expr.dt.second, "second") | ||||
|  | ||||
|     def millisecond(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.dt.microsecond // 1000, "millisecond" | ||||
|         ) | ||||
|  | ||||
|     def microsecond(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.dt.microsecond, "microsecond" | ||||
|         ) | ||||
|  | ||||
|     def nanosecond(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond, "nanosecond" | ||||
|         ) | ||||
|  | ||||
|     def ordinal_day(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.dt.dayofyear, "ordinal_day" | ||||
|         ) | ||||
|  | ||||
|     def weekday(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.dt.weekday + 1,  # Dask is 0-6 | ||||
|             "weekday", | ||||
|         ) | ||||
|  | ||||
|     def to_string(self, format: str) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr, format: expr.dt.strftime(format.replace("%.f", ".%f")), | ||||
|             "strftime", | ||||
|             format=format, | ||||
|         ) | ||||
|  | ||||
|     def replace_time_zone(self, time_zone: str | None) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr, time_zone: expr.dt.tz_localize(None).dt.tz_localize(time_zone) | ||||
|             if time_zone is not None | ||||
|             else expr.dt.tz_localize(None), | ||||
|             "tz_localize", | ||||
|             time_zone=time_zone, | ||||
|         ) | ||||
|  | ||||
|     def convert_time_zone(self, time_zone: str) -> DaskExpr: | ||||
|         def func(s: dx.Series, time_zone: str) -> dx.Series: | ||||
|             dtype = native_to_narwhals_dtype( | ||||
|                 s.dtype, self.compliant._version, Implementation.DASK | ||||
|             ) | ||||
|             if dtype.time_zone is None:  # type: ignore[attr-defined] | ||||
|                 return s.dt.tz_localize("UTC").dt.tz_convert(time_zone)  # pyright: ignore[reportAttributeAccessIssue] | ||||
|             return s.dt.tz_convert(time_zone)  # pyright: ignore[reportAttributeAccessIssue] | ||||
|  | ||||
|         return self.compliant._with_callable(func, "tz_convert", time_zone=time_zone) | ||||
|  | ||||
|     # ignoring coverage due to https://github.com/narwhals-dev/narwhals/issues/2808. | ||||
|     def timestamp(self, time_unit: TimeUnit) -> DaskExpr:  # pragma: no cover | ||||
|         def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series: | ||||
|             dtype = native_to_narwhals_dtype( | ||||
|                 s.dtype, self.compliant._version, Implementation.DASK | ||||
|             ) | ||||
|             is_pyarrow_dtype = "pyarrow" in str(dtype) | ||||
|             mask_na = s.isna() | ||||
|             dtypes = self.compliant._version.dtypes | ||||
|             if dtype == dtypes.Date: | ||||
|                 # Date is only supported in pandas dtypes if pyarrow-backed | ||||
|                 s_cast = s.astype("Int32[pyarrow]") | ||||
|                 result = calculate_timestamp_date(s_cast, time_unit) | ||||
|             elif isinstance(dtype, dtypes.Datetime): | ||||
|                 original_time_unit = dtype.time_unit | ||||
|                 s_cast = ( | ||||
|                     s.astype("Int64[pyarrow]") if is_pyarrow_dtype else s.astype("int64") | ||||
|                 ) | ||||
|                 result = calculate_timestamp_datetime( | ||||
|                     s_cast, original_time_unit, time_unit | ||||
|                 ) | ||||
|             else: | ||||
|                 msg = "Input should be either of Date or Datetime type" | ||||
|                 raise TypeError(msg) | ||||
|             return result.where(~mask_na)  # pyright: ignore[reportReturnType] | ||||
|  | ||||
|         return self.compliant._with_callable(func, "datetime", time_unit=time_unit) | ||||
|  | ||||
|     def total_minutes(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.dt.total_seconds() // 60, "total_minutes" | ||||
|         ) | ||||
|  | ||||
|     def total_seconds(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.dt.total_seconds() // 1, "total_seconds" | ||||
|         ) | ||||
|  | ||||
|     def total_milliseconds(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1, | ||||
|             "total_milliseconds", | ||||
|         ) | ||||
|  | ||||
|     def total_microseconds(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1, | ||||
|             "total_microseconds", | ||||
|         ) | ||||
|  | ||||
|     def total_nanoseconds(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1, "total_nanoseconds" | ||||
|         ) | ||||
|  | ||||
|     def truncate(self, every: str) -> DaskExpr: | ||||
|         interval = Interval.parse(every) | ||||
|         unit = interval.unit | ||||
|         if unit in {"mo", "q", "y"}: | ||||
|             msg = f"Truncating to {unit} is not yet supported for dask." | ||||
|             raise NotImplementedError(msg) | ||||
|         freq = f"{interval.multiple}{ALIAS_DICT.get(unit, unit)}" | ||||
|         return self.compliant._with_callable(lambda expr: expr.dt.floor(freq), "truncate") | ||||
|  | ||||
|     def offset_by(self, by: str) -> DaskExpr: | ||||
|         def func(s: dx.Series, by: str) -> dx.Series: | ||||
|             interval = Interval.parse_no_constraints(by) | ||||
|             unit = interval.unit | ||||
|             if unit in {"y", "q", "mo", "d", "ns"}: | ||||
|                 msg = f"Offsetting by {unit} is not yet supported for dask." | ||||
|                 raise NotImplementedError(msg) | ||||
|             offset = interval.to_timedelta() | ||||
|             return s.add(offset) | ||||
|  | ||||
|         return self.compliant._with_callable(func, "offset_by", by=by) | ||||
							
								
								
									
										121
									
								
								lib/python3.11/site-packages/narwhals/_dask/expr_str.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								lib/python3.11/site-packages/narwhals/_dask/expr_str.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,121 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| import dask.dataframe as dd | ||||
|  | ||||
| from narwhals._compliant import LazyExprNamespace | ||||
| from narwhals._compliant.any_namespace import StringNamespace | ||||
| from narwhals._utils import not_implemented | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     import dask.dataframe.dask_expr as dx | ||||
|  | ||||
|     from narwhals._dask.expr import DaskExpr | ||||
|  | ||||
|  | ||||
| class DaskExprStringNamespace(LazyExprNamespace["DaskExpr"], StringNamespace["DaskExpr"]): | ||||
|     def len_chars(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable(lambda expr: expr.str.len(), "len") | ||||
|  | ||||
|     def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> DaskExpr: | ||||
|         def _replace( | ||||
|             expr: dx.Series, pattern: str, value: str, *, literal: bool, n: int | ||||
|         ) -> dx.Series: | ||||
|             try: | ||||
|                 return expr.str.replace(  # pyright: ignore[reportAttributeAccessIssue] | ||||
|                     pattern, value, regex=not literal, n=n | ||||
|                 ) | ||||
|             except TypeError as e: | ||||
|                 if not isinstance(value, str): | ||||
|                     msg = "dask backed `Expr.str.replace` only supports str replacement values" | ||||
|                     raise TypeError(msg) from e | ||||
|                 raise | ||||
|  | ||||
|         return self.compliant._with_callable( | ||||
|             _replace, "replace", pattern=pattern, value=value, literal=literal, n=n | ||||
|         ) | ||||
|  | ||||
|     def replace_all(self, pattern: str, value: str, *, literal: bool) -> DaskExpr: | ||||
|         def _replace_all( | ||||
|             expr: dx.Series, pattern: str, value: str, *, literal: bool | ||||
|         ) -> dx.Series: | ||||
|             try: | ||||
|                 return expr.str.replace(  # pyright: ignore[reportAttributeAccessIssue] | ||||
|                     pattern, value, regex=not literal, n=-1 | ||||
|                 ) | ||||
|             except TypeError as e: | ||||
|                 if not isinstance(value, str): | ||||
|                     msg = "dask backed `Expr.str.replace_all` only supports str replacement values." | ||||
|                     raise TypeError(msg) from e | ||||
|                 raise | ||||
|  | ||||
|         return self.compliant._with_callable( | ||||
|             _replace_all, "replace", pattern=pattern, value=value, literal=literal | ||||
|         ) | ||||
|  | ||||
|     def strip_chars(self, characters: str | None) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr, characters: expr.str.strip(characters), | ||||
|             "strip", | ||||
|             characters=characters, | ||||
|         ) | ||||
|  | ||||
|     def starts_with(self, prefix: str) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr, prefix: expr.str.startswith(prefix), "starts_with", prefix=prefix | ||||
|         ) | ||||
|  | ||||
|     def ends_with(self, suffix: str) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr, suffix: expr.str.endswith(suffix), "ends_with", suffix=suffix | ||||
|         ) | ||||
|  | ||||
|     def contains(self, pattern: str, *, literal: bool) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr, pattern, literal: expr.str.contains( | ||||
|                 pat=pattern, regex=not literal | ||||
|             ), | ||||
|             "contains", | ||||
|             pattern=pattern, | ||||
|             literal=literal, | ||||
|         ) | ||||
|  | ||||
|     def slice(self, offset: int, length: int | None) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr, offset, length: expr.str.slice( | ||||
|                 start=offset, stop=offset + length if length else None | ||||
|             ), | ||||
|             "slice", | ||||
|             offset=offset, | ||||
|             length=length, | ||||
|         ) | ||||
|  | ||||
|     def split(self, by: str) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr, by: expr.str.split(pat=by), "split", by=by | ||||
|         ) | ||||
|  | ||||
|     def to_datetime(self, format: str | None) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr, format: dd.to_datetime(expr, format=format), | ||||
|             "to_datetime", | ||||
|             format=format, | ||||
|         ) | ||||
|  | ||||
|     def to_uppercase(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.str.upper(), "to_uppercase" | ||||
|         ) | ||||
|  | ||||
|     def to_lowercase(self) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr: expr.str.lower(), "to_lowercase" | ||||
|         ) | ||||
|  | ||||
|     def zfill(self, width: int) -> DaskExpr: | ||||
|         return self.compliant._with_callable( | ||||
|             lambda expr, width: expr.str.zfill(width), "zfill", width=width | ||||
|         ) | ||||
|  | ||||
|     to_date = not_implemented() | ||||
							
								
								
									
										147
									
								
								lib/python3.11/site-packages/narwhals/_dask/group_by.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										147
									
								
								lib/python3.11/site-packages/narwhals/_dask/group_by.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,147 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from functools import partial | ||||
| from typing import TYPE_CHECKING, Any, Callable, ClassVar | ||||
|  | ||||
| import dask.dataframe as dd | ||||
|  | ||||
| from narwhals._compliant import DepthTrackingGroupBy | ||||
| from narwhals._expression_parsing import evaluate_output_names_and_aliases | ||||
| from narwhals._utils import zip_strict | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from collections.abc import Mapping, Sequence | ||||
|  | ||||
|     import pandas as pd | ||||
|     from dask.dataframe.api import GroupBy as _DaskGroupBy | ||||
|     from pandas.core.groupby import SeriesGroupBy as _PandasSeriesGroupBy | ||||
|     from typing_extensions import TypeAlias | ||||
|  | ||||
|     from narwhals._compliant.typing import NarwhalsAggregation | ||||
|     from narwhals._dask.dataframe import DaskLazyFrame | ||||
|     from narwhals._dask.expr import DaskExpr | ||||
|  | ||||
|     PandasSeriesGroupBy: TypeAlias = _PandasSeriesGroupBy[Any, Any] | ||||
|     _AggFn: TypeAlias = Callable[..., Any] | ||||
|  | ||||
| else: | ||||
|     try: | ||||
|         import dask.dataframe.dask_expr as dx | ||||
|     except ModuleNotFoundError:  # pragma: no cover | ||||
|         import dask_expr as dx | ||||
|     _DaskGroupBy = dx._groupby.GroupBy | ||||
|  | ||||
| Aggregation: TypeAlias = "str | _AggFn" | ||||
| """The name of an aggregation function, or the function itself.""" | ||||
|  | ||||
|  | ||||
| def n_unique() -> dd.Aggregation: | ||||
|     def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]: | ||||
|         return s.nunique(dropna=False) | ||||
|  | ||||
|     def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]: | ||||
|         return s0.sum() | ||||
|  | ||||
|     return dd.Aggregation(name="nunique", chunk=chunk, agg=agg) | ||||
|  | ||||
|  | ||||
| def _all() -> dd.Aggregation: | ||||
|     def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]: | ||||
|         return s.all(skipna=True) | ||||
|  | ||||
|     def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]: | ||||
|         return s0.all(skipna=True) | ||||
|  | ||||
|     return dd.Aggregation(name="all", chunk=chunk, agg=agg) | ||||
|  | ||||
|  | ||||
| def _any() -> dd.Aggregation: | ||||
|     def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]: | ||||
|         return s.any(skipna=True) | ||||
|  | ||||
|     def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]: | ||||
|         return s0.any(skipna=True) | ||||
|  | ||||
|     return dd.Aggregation(name="any", chunk=chunk, agg=agg) | ||||
|  | ||||
|  | ||||
| def var(ddof: int) -> _AggFn: | ||||
|     return partial(_DaskGroupBy.var, ddof=ddof) | ||||
|  | ||||
|  | ||||
| def std(ddof: int) -> _AggFn: | ||||
|     return partial(_DaskGroupBy.std, ddof=ddof) | ||||
|  | ||||
|  | ||||
| class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]): | ||||
|     _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = { | ||||
|         "sum": "sum", | ||||
|         "mean": "mean", | ||||
|         "median": "median", | ||||
|         "max": "max", | ||||
|         "min": "min", | ||||
|         "std": std, | ||||
|         "var": var, | ||||
|         "len": "size", | ||||
|         "n_unique": n_unique, | ||||
|         "count": "count", | ||||
|         "quantile": "quantile", | ||||
|         "all": _all, | ||||
|         "any": _any, | ||||
|     } | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         df: DaskLazyFrame, | ||||
|         keys: Sequence[DaskExpr] | Sequence[str], | ||||
|         /, | ||||
|         *, | ||||
|         drop_null_keys: bool, | ||||
|     ) -> None: | ||||
|         self._compliant_frame, self._keys, self._output_key_names = self._parse_keys( | ||||
|             df, keys=keys | ||||
|         ) | ||||
|         self._grouped = self.compliant.native.groupby( | ||||
|             self._keys, dropna=drop_null_keys, observed=True | ||||
|         ) | ||||
|  | ||||
|     def agg(self, *exprs: DaskExpr) -> DaskLazyFrame: | ||||
|         from narwhals._dask.dataframe import DaskLazyFrame | ||||
|  | ||||
|         if not exprs: | ||||
|             # No aggregation provided | ||||
|             return ( | ||||
|                 self.compliant.simple_select(*self._keys) | ||||
|                 .unique(self._keys, keep="any") | ||||
|                 .rename(dict(zip(self._keys, self._output_key_names))) | ||||
|             ) | ||||
|  | ||||
|         self._ensure_all_simple(exprs) | ||||
|         # This should be the fastpath, but cuDF is too far behind to use it. | ||||
|         # - https://github.com/rapidsai/cudf/issues/15118 | ||||
|         # - https://github.com/rapidsai/cudf/issues/15084 | ||||
|         simple_aggregations: dict[str, tuple[str, Aggregation]] = {} | ||||
|         exclude = (*self._keys, *self._output_key_names) | ||||
|         for expr in exprs: | ||||
|             output_names, aliases = evaluate_output_names_and_aliases( | ||||
|                 expr, self.compliant, exclude | ||||
|             ) | ||||
|             if expr._depth == 0: | ||||
|                 # e.g. `agg(nw.len())` | ||||
|                 column = self._keys[0] | ||||
|                 agg_fn = self._remap_expr_name(expr._function_name) | ||||
|                 simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn))) | ||||
|                 continue | ||||
|  | ||||
|             # e.g. `agg(nw.mean('a'))` | ||||
|             agg_fn = self._remap_expr_name(self._leaf_name(expr)) | ||||
|             # deal with n_unique case in a "lazy" mode to not depend on dask globally | ||||
|             agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn | ||||
|             simple_aggregations.update( | ||||
|                 (alias, (output_name, agg_fn)) | ||||
|                 for alias, output_name in zip_strict(aliases, output_names) | ||||
|             ) | ||||
|         return DaskLazyFrame( | ||||
|             self._grouped.agg(**simple_aggregations).reset_index(), | ||||
|             version=self.compliant._version, | ||||
|         ).rename(dict(zip(self._keys, self._output_key_names))) | ||||
							
								
								
									
										338
									
								
								lib/python3.11/site-packages/narwhals/_dask/namespace.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										338
									
								
								lib/python3.11/site-packages/narwhals/_dask/namespace.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,338 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import operator | ||||
| from functools import reduce | ||||
| from itertools import chain | ||||
| from typing import TYPE_CHECKING, cast | ||||
|  | ||||
| import dask.dataframe as dd | ||||
| import pandas as pd | ||||
|  | ||||
| from narwhals._compliant import ( | ||||
|     CompliantThen, | ||||
|     CompliantWhen, | ||||
|     DepthTrackingNamespace, | ||||
|     LazyNamespace, | ||||
| ) | ||||
| from narwhals._dask.dataframe import DaskLazyFrame | ||||
| from narwhals._dask.expr import DaskExpr | ||||
| from narwhals._dask.selectors import DaskSelectorNamespace | ||||
| from narwhals._dask.utils import ( | ||||
|     align_series_full_broadcast, | ||||
|     narwhals_to_native_dtype, | ||||
|     validate_comparand, | ||||
| ) | ||||
| from narwhals._expression_parsing import ( | ||||
|     ExprKind, | ||||
|     combine_alias_output_names, | ||||
|     combine_evaluate_output_names, | ||||
| ) | ||||
| from narwhals._utils import Implementation, zip_strict | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from collections.abc import Iterable, Iterator, Sequence | ||||
|  | ||||
|     import dask.dataframe.dask_expr as dx | ||||
|  | ||||
|     from narwhals._compliant.typing import ScalarKwargs | ||||
|     from narwhals._utils import Version | ||||
|     from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral | ||||
|  | ||||
|  | ||||
| class DaskNamespace( | ||||
|     LazyNamespace[DaskLazyFrame, DaskExpr, dd.DataFrame], | ||||
|     DepthTrackingNamespace[DaskLazyFrame, DaskExpr], | ||||
| ): | ||||
|     _implementation: Implementation = Implementation.DASK | ||||
|  | ||||
|     @property | ||||
|     def selectors(self) -> DaskSelectorNamespace: | ||||
|         return DaskSelectorNamespace.from_namespace(self) | ||||
|  | ||||
|     @property | ||||
|     def _expr(self) -> type[DaskExpr]: | ||||
|         return DaskExpr | ||||
|  | ||||
|     @property | ||||
|     def _lazyframe(self) -> type[DaskLazyFrame]: | ||||
|         return DaskLazyFrame | ||||
|  | ||||
|     def __init__(self, *, version: Version) -> None: | ||||
|         self._version = version | ||||
|  | ||||
|     def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DaskExpr: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             if dtype is not None: | ||||
|                 native_dtype = narwhals_to_native_dtype(dtype, self._version) | ||||
|                 native_pd_series = pd.Series([value], dtype=native_dtype, name="literal") | ||||
|             else: | ||||
|                 native_pd_series = pd.Series([value], name="literal") | ||||
|             npartitions = df._native_frame.npartitions | ||||
|             dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions) | ||||
|             return [dask_series[0].to_series()] | ||||
|  | ||||
|         return self._expr( | ||||
|             func, | ||||
|             depth=0, | ||||
|             function_name="lit", | ||||
|             evaluate_output_names=lambda _df: ["literal"], | ||||
|             alias_output_names=None, | ||||
|             version=self._version, | ||||
|         ) | ||||
|  | ||||
|     def len(self) -> DaskExpr: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             # We don't allow dataframes with 0 columns, so `[0]` is safe. | ||||
|             return [df._native_frame[df.columns[0]].size.to_series()] | ||||
|  | ||||
|         return self._expr( | ||||
|             func, | ||||
|             depth=0, | ||||
|             function_name="len", | ||||
|             evaluate_output_names=lambda _df: ["len"], | ||||
|             alias_output_names=None, | ||||
|             version=self._version, | ||||
|         ) | ||||
|  | ||||
|     def all_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             series: Iterator[dx.Series] = chain.from_iterable(e(df) for e in exprs) | ||||
|             # Note on `ignore_nulls`: Dask doesn't support storing arbitrary Python | ||||
|             # objects in `object` dtype, so we don't need the same check we have for pandas-like. | ||||
|             if ignore_nulls: | ||||
|                 # NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling. | ||||
|                 series = (s if s.dtype == "bool" else s.fillna(True) for s in series) | ||||
|             return [reduce(operator.and_, align_series_full_broadcast(df, *series))] | ||||
|  | ||||
|         return self._expr( | ||||
|             call=func, | ||||
|             depth=max(x._depth for x in exprs) + 1, | ||||
|             function_name="all_horizontal", | ||||
|             evaluate_output_names=combine_evaluate_output_names(*exprs), | ||||
|             alias_output_names=combine_alias_output_names(*exprs), | ||||
|             version=self._version, | ||||
|         ) | ||||
|  | ||||
|     def any_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             series: Iterator[dx.Series] = chain.from_iterable(e(df) for e in exprs) | ||||
|             if ignore_nulls: | ||||
|                 series = (s if s.dtype == "bool" else s.fillna(False) for s in series) | ||||
|             return [reduce(operator.or_, align_series_full_broadcast(df, *series))] | ||||
|  | ||||
|         return self._expr( | ||||
|             call=func, | ||||
|             depth=max(x._depth for x in exprs) + 1, | ||||
|             function_name="any_horizontal", | ||||
|             evaluate_output_names=combine_evaluate_output_names(*exprs), | ||||
|             alias_output_names=combine_alias_output_names(*exprs), | ||||
|             version=self._version, | ||||
|         ) | ||||
|  | ||||
|     def sum_horizontal(self, *exprs: DaskExpr) -> DaskExpr: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             series = align_series_full_broadcast( | ||||
|                 df, *(s for _expr in exprs for s in _expr(df)) | ||||
|             ) | ||||
|             return [dd.concat(series, axis=1).sum(axis=1)] | ||||
|  | ||||
|         return self._expr( | ||||
|             call=func, | ||||
|             depth=max(x._depth for x in exprs) + 1, | ||||
|             function_name="sum_horizontal", | ||||
|             evaluate_output_names=combine_evaluate_output_names(*exprs), | ||||
|             alias_output_names=combine_alias_output_names(*exprs), | ||||
|             version=self._version, | ||||
|         ) | ||||
|  | ||||
|     def concat( | ||||
|         self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod | ||||
|     ) -> DaskLazyFrame: | ||||
|         if not items: | ||||
|             msg = "No items to concatenate"  # pragma: no cover | ||||
|             raise AssertionError(msg) | ||||
|         dfs = [i._native_frame for i in items] | ||||
|         cols_0 = dfs[0].columns | ||||
|         if how == "vertical": | ||||
|             for i, df in enumerate(dfs[1:], start=1): | ||||
|                 cols_current = df.columns | ||||
|                 if not ( | ||||
|                     (len(cols_current) == len(cols_0)) and (cols_current == cols_0).all() | ||||
|                 ): | ||||
|                     msg = ( | ||||
|                         "unable to vstack, column names don't match:\n" | ||||
|                         f"   - dataframe 0: {cols_0.to_list()}\n" | ||||
|                         f"   - dataframe {i}: {cols_current.to_list()}\n" | ||||
|                     ) | ||||
|                     raise TypeError(msg) | ||||
|             return DaskLazyFrame( | ||||
|                 dd.concat(dfs, axis=0, join="inner"), version=self._version | ||||
|             ) | ||||
|         if how == "diagonal": | ||||
|             return DaskLazyFrame( | ||||
|                 dd.concat(dfs, axis=0, join="outer"), version=self._version | ||||
|             ) | ||||
|  | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             expr_results = [s for _expr in exprs for s in _expr(df)] | ||||
|             series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results)) | ||||
|             non_na = align_series_full_broadcast( | ||||
|                 df, *(1 - s.isna() for s in expr_results) | ||||
|             ) | ||||
|             num = reduce(lambda x, y: x + y, series)  # pyright: ignore[reportOperatorIssue] | ||||
|             den = reduce(lambda x, y: x + y, non_na)  # pyright: ignore[reportOperatorIssue] | ||||
|             return [cast("dx.Series", num / den)]  # pyright: ignore[reportOperatorIssue] | ||||
|  | ||||
|         return self._expr( | ||||
|             call=func, | ||||
|             depth=max(x._depth for x in exprs) + 1, | ||||
|             function_name="mean_horizontal", | ||||
|             evaluate_output_names=combine_evaluate_output_names(*exprs), | ||||
|             alias_output_names=combine_alias_output_names(*exprs), | ||||
|             version=self._version, | ||||
|         ) | ||||
|  | ||||
|     def min_horizontal(self, *exprs: DaskExpr) -> DaskExpr: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             series = align_series_full_broadcast( | ||||
|                 df, *(s for _expr in exprs for s in _expr(df)) | ||||
|             ) | ||||
|  | ||||
|             return [dd.concat(series, axis=1).min(axis=1)] | ||||
|  | ||||
|         return self._expr( | ||||
|             call=func, | ||||
|             depth=max(x._depth for x in exprs) + 1, | ||||
|             function_name="min_horizontal", | ||||
|             evaluate_output_names=combine_evaluate_output_names(*exprs), | ||||
|             alias_output_names=combine_alias_output_names(*exprs), | ||||
|             version=self._version, | ||||
|         ) | ||||
|  | ||||
|     def max_horizontal(self, *exprs: DaskExpr) -> DaskExpr: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             series = align_series_full_broadcast( | ||||
|                 df, *(s for _expr in exprs for s in _expr(df)) | ||||
|             ) | ||||
|  | ||||
|             return [dd.concat(series, axis=1).max(axis=1)] | ||||
|  | ||||
|         return self._expr( | ||||
|             call=func, | ||||
|             depth=max(x._depth for x in exprs) + 1, | ||||
|             function_name="max_horizontal", | ||||
|             evaluate_output_names=combine_evaluate_output_names(*exprs), | ||||
|             alias_output_names=combine_alias_output_names(*exprs), | ||||
|             version=self._version, | ||||
|         ) | ||||
|  | ||||
|     def when(self, predicate: DaskExpr) -> DaskWhen: | ||||
|         return DaskWhen.from_expr(predicate, context=self) | ||||
|  | ||||
|     def concat_str( | ||||
|         self, *exprs: DaskExpr, separator: str, ignore_nulls: bool | ||||
|     ) -> DaskExpr: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             expr_results = [s for _expr in exprs for s in _expr(df)] | ||||
|             series = ( | ||||
|                 s.astype(str) for s in align_series_full_broadcast(df, *expr_results) | ||||
|             ) | ||||
|             null_mask = [s.isna() for s in align_series_full_broadcast(df, *expr_results)] | ||||
|  | ||||
|             if not ignore_nulls: | ||||
|                 null_mask_result = reduce(operator.or_, null_mask) | ||||
|                 result = reduce(lambda x, y: x + separator + y, series).where( | ||||
|                     ~null_mask_result, None | ||||
|                 ) | ||||
|             else: | ||||
|                 init_value, *values = [ | ||||
|                     s.where(~nm, "") for s, nm in zip_strict(series, null_mask) | ||||
|                 ] | ||||
|  | ||||
|                 separators = ( | ||||
|                     nm.map({True: "", False: separator}, meta=str) | ||||
|                     for nm in null_mask[:-1] | ||||
|                 ) | ||||
|                 result = reduce( | ||||
|                     operator.add, | ||||
|                     (s + v for s, v in zip_strict(separators, values)), | ||||
|                     init_value, | ||||
|                 ) | ||||
|  | ||||
|             return [result] | ||||
|  | ||||
|         return self._expr( | ||||
|             call=func, | ||||
|             depth=max(x._depth for x in exprs) + 1, | ||||
|             function_name="concat_str", | ||||
|             evaluate_output_names=getattr( | ||||
|                 exprs[0], "_evaluate_output_names", lambda _df: ["literal"] | ||||
|             ), | ||||
|             alias_output_names=getattr(exprs[0], "_alias_output_names", None), | ||||
|             version=self._version, | ||||
|         ) | ||||
|  | ||||
|     def coalesce(self, *exprs: DaskExpr) -> DaskExpr: | ||||
|         def func(df: DaskLazyFrame) -> list[dx.Series]: | ||||
|             series = align_series_full_broadcast( | ||||
|                 df, *(s for _expr in exprs for s in _expr(df)) | ||||
|             ) | ||||
|             return [reduce(lambda x, y: x.fillna(y), series)] | ||||
|  | ||||
|         return self._expr( | ||||
|             call=func, | ||||
|             depth=max(x._depth for x in exprs) + 1, | ||||
|             function_name="coalesce", | ||||
|             evaluate_output_names=combine_evaluate_output_names(*exprs), | ||||
|             alias_output_names=combine_alias_output_names(*exprs), | ||||
|             version=self._version, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]):  # pyright: ignore[reportInvalidTypeArguments] | ||||
|     @property | ||||
|     def _then(self) -> type[DaskThen]: | ||||
|         return DaskThen | ||||
|  | ||||
|     def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: | ||||
|         then_value = ( | ||||
|             self._then_value(df)[0] | ||||
|             if isinstance(self._then_value, DaskExpr) | ||||
|             else self._then_value | ||||
|         ) | ||||
|         otherwise_value = ( | ||||
|             self._otherwise_value(df)[0] | ||||
|             if isinstance(self._otherwise_value, DaskExpr) | ||||
|             else self._otherwise_value | ||||
|         ) | ||||
|  | ||||
|         condition = self._condition(df)[0] | ||||
|         # re-evaluate DataFrame if the condition aggregates to force | ||||
|         #   then/otherwise to be evaluated against the aggregated frame | ||||
|         assert self._condition._metadata is not None  # noqa: S101 | ||||
|         if self._condition._metadata.is_scalar_like: | ||||
|             new_df = df._with_native(condition.to_frame()) | ||||
|             condition = self._condition.broadcast(ExprKind.AGGREGATION)(df)[0] | ||||
|             df = new_df | ||||
|  | ||||
|         if self._otherwise_value is None: | ||||
|             (condition, then_series) = align_series_full_broadcast( | ||||
|                 df, condition, then_value | ||||
|             ) | ||||
|             validate_comparand(condition, then_series) | ||||
|             return [then_series.where(condition)]  # pyright: ignore[reportArgumentType] | ||||
|         (condition, then_series, otherwise_series) = align_series_full_broadcast( | ||||
|             df, condition, then_value, otherwise_value | ||||
|         ) | ||||
|         validate_comparand(condition, then_series) | ||||
|         validate_comparand(condition, otherwise_series) | ||||
|         return [then_series.where(condition, otherwise_series)]  # pyright: ignore[reportArgumentType] | ||||
|  | ||||
|  | ||||
| class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr):  # pyright: ignore[reportInvalidTypeArguments] | ||||
|     _depth: int = 0 | ||||
|     _scalar_kwargs: ScalarKwargs = {}  # noqa: RUF012 | ||||
|     _function_name: str = "whenthen" | ||||
							
								
								
									
										34
									
								
								lib/python3.11/site-packages/narwhals/_dask/selectors.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								lib/python3.11/site-packages/narwhals/_dask/selectors.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,34 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| from narwhals._compliant import CompliantSelector, LazySelectorNamespace | ||||
| from narwhals._dask.expr import DaskExpr | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     import dask.dataframe.dask_expr as dx  # noqa: F401 | ||||
|  | ||||
|     from narwhals._compliant.typing import ScalarKwargs | ||||
|     from narwhals._dask.dataframe import DaskLazyFrame  # noqa: F401 | ||||
|  | ||||
|  | ||||
| class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]):  # pyright: ignore[reportInvalidTypeArguments] | ||||
|     @property | ||||
|     def _selector(self) -> type[DaskSelector]: | ||||
|         return DaskSelector | ||||
|  | ||||
|  | ||||
| class DaskSelector(CompliantSelector["DaskLazyFrame", "dx.Series"], DaskExpr):  # type: ignore[misc] | ||||
|     _depth: int = 0 | ||||
|     _scalar_kwargs: ScalarKwargs = {}  # noqa: RUF012 | ||||
|     _function_name: str = "selector" | ||||
|  | ||||
|     def _to_expr(self) -> DaskExpr: | ||||
|         return DaskExpr( | ||||
|             self._call, | ||||
|             depth=self._depth, | ||||
|             function_name=self._function_name, | ||||
|             evaluate_output_names=self._evaluate_output_names, | ||||
|             alias_output_names=self._alias_output_names, | ||||
|             version=self._version, | ||||
|         ) | ||||
							
								
								
									
										139
									
								
								lib/python3.11/site-packages/narwhals/_dask/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								lib/python3.11/site-packages/narwhals/_dask/utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,139 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import TYPE_CHECKING, Any | ||||
|  | ||||
| from narwhals._pandas_like.utils import select_columns_by_name | ||||
| from narwhals._utils import Implementation, Version, isinstance_or_issubclass | ||||
| from narwhals.dependencies import get_pyarrow | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from collections.abc import Mapping, Sequence | ||||
|  | ||||
|     import dask.dataframe as dd | ||||
|     import dask.dataframe.dask_expr as dx | ||||
|  | ||||
|     from narwhals._dask.dataframe import DaskLazyFrame, Incomplete | ||||
|     from narwhals._dask.expr import DaskExpr | ||||
|     from narwhals.dtypes import DType | ||||
|     from narwhals.typing import IntoDType | ||||
| else: | ||||
|     try: | ||||
|         import dask.dataframe.dask_expr as dx | ||||
|     except ModuleNotFoundError:  # pragma: no cover | ||||
|         import dask_expr as dx | ||||
|  | ||||
|  | ||||
| def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series | object: | ||||
|     from narwhals._dask.expr import DaskExpr | ||||
|  | ||||
|     if isinstance(obj, DaskExpr): | ||||
|         results = obj._call(df) | ||||
|         assert len(results) == 1  # debug assertion  # noqa: S101 | ||||
|         return results[0] | ||||
|     return obj | ||||
|  | ||||
|  | ||||
| def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> list[tuple[str, dx.Series]]: | ||||
|     native_results: list[tuple[str, dx.Series]] = [] | ||||
|     for expr in exprs: | ||||
|         native_series_list = expr(df) | ||||
|         aliases = expr._evaluate_aliases(df) | ||||
|         if len(aliases) != len(native_series_list):  # pragma: no cover | ||||
|             msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results" | ||||
|             raise AssertionError(msg) | ||||
|         native_results.extend(zip(aliases, native_series_list)) | ||||
|     return native_results | ||||
|  | ||||
|  | ||||
| def align_series_full_broadcast( | ||||
|     df: DaskLazyFrame, *series: dx.Series | object | ||||
| ) -> Sequence[dx.Series]: | ||||
|     return [ | ||||
|         s if isinstance(s, dx.Series) else df._native_frame.assign(_tmp=s)["_tmp"] | ||||
|         for s in series | ||||
|     ]  # pyright: ignore[reportReturnType] | ||||
|  | ||||
|  | ||||
| def add_row_index(frame: dd.DataFrame, name: str) -> dd.DataFrame: | ||||
|     original_cols = frame.columns | ||||
|     df: Incomplete = frame.assign(**{name: 1}) | ||||
|     return select_columns_by_name( | ||||
|         df.assign(**{name: df[name].cumsum(method="blelloch") - 1}), | ||||
|         [name, *original_cols], | ||||
|         Implementation.DASK, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None: | ||||
|     if not dx.expr.are_co_aligned(lhs._expr, rhs._expr):  # pragma: no cover | ||||
|         # are_co_aligned is a method which cheaply checks if two Dask expressions | ||||
|         # have the same index, and therefore don't require index alignment. | ||||
|         # If someone only operates on a Dask DataFrame via expressions, then this | ||||
|         # should always be the case: expression outputs (by definition) all come from the | ||||
|         # same input dataframe, and Dask Series does not have any operations which | ||||
|         # change the index. Nonetheless, we perform this safety check anyway. | ||||
|  | ||||
|         # However, we still need to carefully vet which methods we support for Dask, to | ||||
|         # avoid issues where `are_co_aligned` doesn't do what we want it to do: | ||||
|         # https://github.com/dask/dask-expr/issues/1112. | ||||
|         msg = "Objects are not co-aligned, so this operation is not supported for Dask backend" | ||||
|         raise RuntimeError(msg) | ||||
|  | ||||
|  | ||||
| dtypes = Version.MAIN.dtypes | ||||
| dtypes_v1 = Version.V1.dtypes | ||||
| NW_TO_DASK_DTYPES: Mapping[type[DType], str] = { | ||||
|     dtypes.Float64: "float64", | ||||
|     dtypes.Float32: "float32", | ||||
|     dtypes.Boolean: "bool", | ||||
|     dtypes.Categorical: "category", | ||||
|     dtypes.Date: "date32[day][pyarrow]", | ||||
|     dtypes.Int8: "int8", | ||||
|     dtypes.Int16: "int16", | ||||
|     dtypes.Int32: "int32", | ||||
|     dtypes.Int64: "int64", | ||||
|     dtypes.UInt8: "uint8", | ||||
|     dtypes.UInt16: "uint16", | ||||
|     dtypes.UInt32: "uint32", | ||||
|     dtypes.UInt64: "uint64", | ||||
|     dtypes.Datetime: "datetime64[us]", | ||||
|     dtypes.Duration: "timedelta64[ns]", | ||||
|     dtypes_v1.Datetime: "datetime64[us]", | ||||
|     dtypes_v1.Duration: "timedelta64[ns]", | ||||
| } | ||||
| UNSUPPORTED_DTYPES = ( | ||||
|     dtypes.List, | ||||
|     dtypes.Struct, | ||||
|     dtypes.Array, | ||||
|     dtypes.Time, | ||||
|     dtypes.Binary, | ||||
| ) | ||||
|  | ||||
|  | ||||
| def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> Any: | ||||
|     dtypes = version.dtypes | ||||
|     base_type = dtype.base_type() | ||||
|     if dask_type := NW_TO_DASK_DTYPES.get(base_type): | ||||
|         return dask_type | ||||
|     if isinstance_or_issubclass(dtype, dtypes.String): | ||||
|         if Implementation.PANDAS._backend_version() >= (2, 0, 0): | ||||
|             return "string[pyarrow]" if get_pyarrow() else "string[python]" | ||||
|         return "object"  # pragma: no cover | ||||
|     if isinstance_or_issubclass(dtype, dtypes.Enum): | ||||
|         if version is Version.V1: | ||||
|             msg = "Converting to Enum is not supported in narwhals.stable.v1" | ||||
|             raise NotImplementedError(msg) | ||||
|         if isinstance(dtype, dtypes.Enum): | ||||
|             import pandas as pd | ||||
|  | ||||
|             # NOTE: `pandas-stubs.core.dtypes.dtypes.CategoricalDtype.categories` is too narrow | ||||
|             # Should be one of the `ListLike*` types | ||||
|             # https://github.com/pandas-dev/pandas-stubs/blob/8434bde95460b996323cc8c0fea7b0a8bb00ea26/pandas-stubs/_typing.pyi#L497-L505 | ||||
|             return pd.CategoricalDtype(dtype.categories, ordered=True)  # type: ignore[arg-type] | ||||
|         msg = "Can not cast / initialize Enum without categories present" | ||||
|         raise ValueError(msg) | ||||
|     if issubclass(base_type, UNSUPPORTED_DTYPES):  # pragma: no cover | ||||
|         msg = f"Converting to {base_type.__name__} dtype is not supported for Dask." | ||||
|         raise NotImplementedError(msg) | ||||
|     msg = f"Unknown dtype: {dtype}"  # pragma: no cover | ||||
|     raise AssertionError(msg) | ||||
		Reference in New Issue
	
	Block a user