This commit is contained in:
2025-09-07 22:09:54 +02:00
parent e1b817252c
commit 2fc0d000b6
7796 changed files with 2159515 additions and 933 deletions

View File

@ -0,0 +1,103 @@
from __future__ import annotations
from narwhals._compliant.dataframe import (
CompliantDataFrame,
CompliantFrame,
CompliantLazyFrame,
EagerDataFrame,
)
from narwhals._compliant.expr import (
CompliantExpr,
DepthTrackingExpr,
EagerExpr,
LazyExpr,
LazyExprNamespace,
)
from narwhals._compliant.group_by import (
CompliantGroupBy,
DepthTrackingGroupBy,
EagerGroupBy,
)
from narwhals._compliant.namespace import (
CompliantNamespace,
DepthTrackingNamespace,
EagerNamespace,
LazyNamespace,
)
from narwhals._compliant.selectors import (
CompliantSelector,
CompliantSelectorNamespace,
EagerSelectorNamespace,
LazySelectorNamespace,
)
from narwhals._compliant.series import (
CompliantSeries,
EagerSeries,
EagerSeriesCatNamespace,
EagerSeriesDateTimeNamespace,
EagerSeriesHist,
EagerSeriesListNamespace,
EagerSeriesNamespace,
EagerSeriesStringNamespace,
EagerSeriesStructNamespace,
)
from narwhals._compliant.typing import (
CompliantExprT,
CompliantFrameT,
CompliantSeriesOrNativeExprT_co,
CompliantSeriesT,
EagerDataFrameT,
EagerSeriesT,
EvalNames,
EvalSeries,
NativeFrameT_co,
NativeSeriesT_co,
)
from narwhals._compliant.when_then import CompliantThen, CompliantWhen, EagerWhen
from narwhals._compliant.window import WindowInputs
__all__ = [
"CompliantDataFrame",
"CompliantExpr",
"CompliantExprT",
"CompliantFrame",
"CompliantFrameT",
"CompliantGroupBy",
"CompliantLazyFrame",
"CompliantNamespace",
"CompliantSelector",
"CompliantSelectorNamespace",
"CompliantSeries",
"CompliantSeriesOrNativeExprT_co",
"CompliantSeriesT",
"CompliantThen",
"CompliantWhen",
"DepthTrackingExpr",
"DepthTrackingGroupBy",
"DepthTrackingNamespace",
"EagerDataFrame",
"EagerDataFrameT",
"EagerExpr",
"EagerGroupBy",
"EagerNamespace",
"EagerSelectorNamespace",
"EagerSeries",
"EagerSeriesCatNamespace",
"EagerSeriesDateTimeNamespace",
"EagerSeriesHist",
"EagerSeriesListNamespace",
"EagerSeriesNamespace",
"EagerSeriesStringNamespace",
"EagerSeriesStructNamespace",
"EagerSeriesT",
"EagerWhen",
"EvalNames",
"EvalSeries",
"LazyExpr",
"LazyExprNamespace",
"LazyNamespace",
"LazySelectorNamespace",
"NativeFrameT_co",
"NativeSeriesT_co",
"WindowInputs",
]

View File

@ -0,0 +1,94 @@
"""`Expr` and `Series` namespace accessor protocols."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from narwhals._utils import CompliantT_co, _StoresCompliant
if TYPE_CHECKING:
from typing import Callable
from narwhals.typing import NonNestedLiteral, TimeUnit
__all__ = [
"CatNamespace",
"DateTimeNamespace",
"ListNamespace",
"NameNamespace",
"StringNamespace",
"StructNamespace",
]
class CatNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def get_categories(self) -> CompliantT_co: ...
class DateTimeNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def to_string(self, format: str) -> CompliantT_co: ...
def replace_time_zone(self, time_zone: str | None) -> CompliantT_co: ...
def convert_time_zone(self, time_zone: str) -> CompliantT_co: ...
def timestamp(self, time_unit: TimeUnit) -> CompliantT_co: ...
def date(self) -> CompliantT_co: ...
def year(self) -> CompliantT_co: ...
def month(self) -> CompliantT_co: ...
def day(self) -> CompliantT_co: ...
def hour(self) -> CompliantT_co: ...
def minute(self) -> CompliantT_co: ...
def second(self) -> CompliantT_co: ...
def millisecond(self) -> CompliantT_co: ...
def microsecond(self) -> CompliantT_co: ...
def nanosecond(self) -> CompliantT_co: ...
def ordinal_day(self) -> CompliantT_co: ...
def weekday(self) -> CompliantT_co: ...
def total_minutes(self) -> CompliantT_co: ...
def total_seconds(self) -> CompliantT_co: ...
def total_milliseconds(self) -> CompliantT_co: ...
def total_microseconds(self) -> CompliantT_co: ...
def total_nanoseconds(self) -> CompliantT_co: ...
def truncate(self, every: str) -> CompliantT_co: ...
def offset_by(self, by: str) -> CompliantT_co: ...
class ListNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def get(self, index: int) -> CompliantT_co: ...
def len(self) -> CompliantT_co: ...
def unique(self) -> CompliantT_co: ...
def contains(self, item: NonNestedLiteral) -> CompliantT_co: ...
class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def keep(self) -> CompliantT_co: ...
def map(self, function: Callable[[str], str]) -> CompliantT_co: ...
def prefix(self, prefix: str) -> CompliantT_co: ...
def suffix(self, suffix: str) -> CompliantT_co: ...
def to_lowercase(self) -> CompliantT_co: ...
def to_uppercase(self) -> CompliantT_co: ...
class StringNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def len_chars(self) -> CompliantT_co: ...
def replace(
self, pattern: str, value: str, *, literal: bool, n: int
) -> CompliantT_co: ...
def replace_all(
self, pattern: str, value: str, *, literal: bool
) -> CompliantT_co: ...
def strip_chars(self, characters: str | None) -> CompliantT_co: ...
def starts_with(self, prefix: str) -> CompliantT_co: ...
def ends_with(self, suffix: str) -> CompliantT_co: ...
def contains(self, pattern: str, *, literal: bool) -> CompliantT_co: ...
def slice(self, offset: int, length: int | None) -> CompliantT_co: ...
def split(self, by: str) -> CompliantT_co: ...
def to_datetime(self, format: str | None) -> CompliantT_co: ...
def to_date(self, format: str | None) -> CompliantT_co: ...
def to_lowercase(self) -> CompliantT_co: ...
def to_uppercase(self) -> CompliantT_co: ...
def zfill(self, width: int) -> CompliantT_co: ...
class StructNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def field(self, name: str) -> CompliantT_co: ...

View File

@ -0,0 +1,213 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from typing_extensions import Self
from narwhals._compliant.any_namespace import (
CatNamespace,
DateTimeNamespace,
ListNamespace,
StringNamespace,
StructNamespace,
)
from narwhals._compliant.namespace import CompliantNamespace
from narwhals._utils import Version
from narwhals.typing import (
ClosedInterval,
FillNullStrategy,
IntoDType,
ModeKeepStrategy,
NonNestedLiteral,
NumericLiteral,
RankMethod,
TemporalLiteral,
)
__all__ = ["CompliantColumn"]
class CompliantColumn(Protocol):
"""Common parts of `Expr`, `Series`."""
_version: Version
def __add__(self, other: Any) -> Self: ...
def __and__(self, other: Any) -> Self: ...
def __eq__(self, other: object) -> Self: ... # type: ignore[override]
def __floordiv__(self, other: Any) -> Self: ...
def __ge__(self, other: Any) -> Self: ...
def __gt__(self, other: Any) -> Self: ...
def __invert__(self) -> Self: ...
def __le__(self, other: Any) -> Self: ...
def __lt__(self, other: Any) -> Self: ...
def __mod__(self, other: Any) -> Self: ...
def __mul__(self, other: Any) -> Self: ...
def __ne__(self, other: object) -> Self: ... # type: ignore[override]
def __or__(self, other: Any) -> Self: ...
def __pow__(self, other: Any) -> Self: ...
def __rfloordiv__(self, other: Any) -> Self: ...
def __rmod__(self, other: Any) -> Self: ...
def __rpow__(self, other: Any) -> Self: ...
def __rsub__(self, other: Any) -> Self: ...
def __rtruediv__(self, other: Any) -> Self: ...
def __sub__(self, other: Any) -> Self: ...
def __truediv__(self, other: Any) -> Self: ...
def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ...
def abs(self) -> Self: ...
def alias(self, name: str) -> Self: ...
def cast(self, dtype: IntoDType) -> Self: ...
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self: ...
def cum_count(self, *, reverse: bool) -> Self: ...
def cum_max(self, *, reverse: bool) -> Self: ...
def cum_min(self, *, reverse: bool) -> Self: ...
def cum_prod(self, *, reverse: bool) -> Self: ...
def cum_sum(self, *, reverse: bool) -> Self: ...
def diff(self) -> Self: ...
def drop_nulls(self) -> Self: ...
def ewm_mean(
self,
*,
com: float | None,
span: float | None,
half_life: float | None,
alpha: float | None,
adjust: bool,
min_samples: int,
ignore_nulls: bool,
) -> Self: ...
def exp(self) -> Self: ...
def sqrt(self) -> Self: ...
def fill_nan(self, value: float | None) -> Self: ...
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self: ...
def is_between(
self, lower_bound: Self, upper_bound: Self, closed: ClosedInterval
) -> Self:
if closed == "left":
return (self >= lower_bound) & (self < upper_bound)
if closed == "right":
return (self > lower_bound) & (self <= upper_bound)
if closed == "none":
return (self > lower_bound) & (self < upper_bound)
return (self >= lower_bound) & (self <= upper_bound)
def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
from decimal import Decimal
other_abs: Self | NumericLiteral
other_is_nan: Self | bool
other_is_inf: Self | bool
other_is_not_inf: Self | bool
if isinstance(other, (float, int, Decimal)):
from math import isinf, isnan
# NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447
other_abs = other.__abs__()
other_is_nan = isnan(other)
other_is_inf = isinf(other)
# Define the other_is_not_inf variable to prevent triggering the following warning:
# > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be
# > removed in Python 3.16.
other_is_not_inf = not other_is_inf
else:
other_abs, other_is_nan = other.abs(), other.is_nan()
other_is_not_inf = other.is_finite() | other_is_nan
other_is_inf = ~other_is_not_inf
rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol
tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None)
self_is_nan = self.is_nan()
self_is_not_inf = self.is_finite() | self_is_nan
# Values are close if abs_diff <= tolerance, and both finite
is_close = (
((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf
)
# Handle infinity cases: infinities are close/equal if they have the same sign
self_sign, other_sign = self > 0, other > 0
is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign)
# Handle nan cases:
# * If any value is NaN, then False (via `& ~either_nan`)
# * However, if `nans_equals = True` and if _both_ values are NaN, then True
either_nan = self_is_nan | other_is_nan
result = (is_close | is_same_inf) & ~either_nan
if nans_equal:
both_nan = self_is_nan & other_is_nan
result = result | both_nan
return result
def is_duplicated(self) -> Self:
return ~self.is_unique()
def is_finite(self) -> Self: ...
def is_first_distinct(self) -> Self: ...
def is_in(self, other: Any) -> Self: ...
def is_last_distinct(self) -> Self: ...
def is_nan(self) -> Self: ...
def is_null(self) -> Self: ...
def is_unique(self) -> Self: ...
def log(self, base: float) -> Self: ...
def mode(self, *, keep: ModeKeepStrategy) -> Self: ...
def rank(self, method: RankMethod, *, descending: bool) -> Self: ...
def replace_strict(
self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any],
*,
return_dtype: IntoDType | None,
) -> Self: ...
def rolling_mean(
self, window_size: int, *, min_samples: int, center: bool
) -> Self: ...
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self: ...
def rolling_sum(
self, window_size: int, *, min_samples: int, center: bool
) -> Self: ...
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self: ...
def round(self, decimals: int) -> Self: ...
def shift(self, n: int) -> Self: ...
def unique(self) -> Self: ...
@property
def str(self) -> StringNamespace[Self]: ...
@property
def dt(self) -> DateTimeNamespace[Self]: ...
@property
def cat(self) -> CatNamespace[Self]: ...
@property
def list(self) -> ListNamespace[Self]: ...
@property
def struct(self) -> StructNamespace[Self]: ...

View File

@ -0,0 +1,426 @@
from __future__ import annotations
from collections.abc import Iterator, Mapping, Sequence, Sized
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload
from narwhals._compliant.typing import (
CompliantDataFrameAny,
CompliantExprT_contra,
CompliantLazyFrameAny,
CompliantSeriesT,
EagerExprT,
EagerSeriesT,
NativeDataFrameT,
NativeLazyFrameT,
NativeSeriesT,
)
from narwhals._translate import (
ArrowConvertible,
DictConvertible,
FromNative,
NumpyConvertible,
ToNarwhals,
ToNarwhalsT_co,
)
from narwhals._typing_compat import assert_never
from narwhals._utils import (
ValidateBackendVersion,
Version,
_StoresNative,
check_columns_exist,
is_compliant_series,
is_index_selector,
is_range,
is_sequence_like,
is_sized_multi_index_selector,
is_slice_index,
is_slice_none,
)
if TYPE_CHECKING:
from io import BytesIO
from pathlib import Path
from types import ModuleType
import pandas as pd
import polars as pl
import pyarrow as pa
from typing_extensions import Self, TypeAlias
from narwhals._compliant.group_by import CompliantGroupBy, DataFrameGroupBy
from narwhals._compliant.namespace import EagerNamespace
from narwhals._spark_like.utils import SparkSession
from narwhals._translate import IntoArrowTable
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
from narwhals._utils import Implementation, _LimitedContext
from narwhals.dataframe import DataFrame
from narwhals.dtypes import DType
from narwhals.exceptions import ColumnNotFoundError
from narwhals.typing import (
AsofJoinStrategy,
IntoSchema,
JoinStrategy,
LazyUniqueKeepStrategy,
MultiColSelector,
MultiIndexSelector,
PivotAgg,
SingleIndexSelector,
SizedMultiIndexSelector,
SizedMultiNameSelector,
SizeUnit,
UniqueKeepStrategy,
_2DArray,
_SliceIndex,
_SliceName,
)
Incomplete: TypeAlias = Any
__all__ = ["CompliantDataFrame", "CompliantFrame", "CompliantLazyFrame", "EagerDataFrame"]
T = TypeVar("T")
_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047
_NativeFrameT = TypeVar("_NativeFrameT")
class CompliantFrame(
_StoresNative[_NativeFrameT],
FromNative[_NativeFrameT],
ToNarwhals[ToNarwhalsT_co],
Protocol[CompliantExprT_contra, _NativeFrameT, ToNarwhalsT_co],
):
"""Common parts of `DataFrame`, `LazyFrame`."""
_native_frame: _NativeFrameT
_implementation: Implementation
_version: Version
def __native_namespace__(self) -> ModuleType: ...
def __narwhals_namespace__(self) -> Any: ...
def _with_version(self, version: Version) -> Self: ...
@classmethod
def from_native(cls, data: _NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
@property
def columns(self) -> Sequence[str]: ...
@property
def native(self) -> _NativeFrameT:
return self._native_frame
@property
def schema(self) -> Mapping[str, DType]: ...
def collect_schema(self) -> Mapping[str, DType]: ...
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
def explode(self, columns: Sequence[str]) -> Self: ...
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
def group_by(
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
def head(self, n: int) -> Self: ...
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self: ...
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self: ...
def rename(self, mapping: Mapping[str, str]) -> Self: ...
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
def simple_select(self, *column_names: str) -> Self:
"""`select` where all args are column names."""
...
def sort(
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
) -> Self: ...
def tail(self, n: int) -> Self: ...
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self: ...
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self: ...
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self: ...
class CompliantDataFrame(
NumpyConvertible["_2DArray", "_2DArray"],
DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
ArrowConvertible["pa.Table", "IntoArrowTable"],
Sized,
CompliantFrame[CompliantExprT_contra, NativeDataFrameT, ToNarwhalsT_co],
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeDataFrameT, ToNarwhalsT_co],
):
def __narwhals_dataframe__(self) -> Self: ...
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _LimitedContext,
schema: IntoSchema | None,
) -> Self: ...
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _LimitedContext,
schema: IntoSchema | Sequence[str] | None,
) -> Self: ...
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
def __getitem__(
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[CompliantSeriesT],
MultiColSelector[CompliantSeriesT],
],
) -> Self: ...
@property
def shape(self) -> tuple[int, int]: ...
def clone(self) -> Self: ...
def estimated_size(self, unit: SizeUnit) -> int | float: ...
def gather_every(self, n: int, offset: int) -> Self: ...
def get_column(self, name: str) -> CompliantSeriesT: ...
def group_by(
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> DataFrameGroupBy[Self, Any]: ...
def item(self, row: int | None, column: int | str | None) -> Any: ...
def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
def iter_rows(
self, *, named: bool, buffer_size: int
) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ...
def is_unique(self) -> CompliantSeriesT: ...
def lazy(
self, backend: _LazyAllowedImpl | None, *, session: SparkSession | None
) -> CompliantLazyFrameAny: ...
def pivot(
self,
on: Sequence[str],
*,
index: Sequence[str] | None,
values: Sequence[str] | None,
aggregate_function: PivotAgg | None,
sort_columns: bool,
separator: str,
) -> Self: ...
def row(self, index: int) -> tuple[Any, ...]: ...
def rows(
self, *, named: bool
) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ...
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self: ...
def to_arrow(self) -> pa.Table: ...
def to_pandas(self) -> pd.DataFrame: ...
def to_polars(self) -> pl.DataFrame: ...
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ...
def unique(
self,
subset: Sequence[str] | None,
*,
keep: UniqueKeepStrategy,
maintain_order: bool | None = None,
) -> Self: ...
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: ...
@overload
def write_csv(self, file: None) -> str: ...
@overload
def write_csv(self, file: str | Path | BytesIO) -> None: ...
def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
def write_parquet(self, file: str | Path | BytesIO) -> None: ...
class CompliantLazyFrame(
CompliantFrame[CompliantExprT_contra, NativeLazyFrameT, ToNarwhalsT_co],
Protocol[CompliantExprT_contra, NativeLazyFrameT, ToNarwhalsT_co],
):
def __narwhals_lazyframe__(self) -> Self: ...
# `LazySelectorNamespace._iter_columns` depends
def _iter_columns(self) -> Iterator[Any]: ...
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
"""`select` where all args are aggregations or literals.
(so, no broadcasting is necessary).
"""
...
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny: ...
def sink_parquet(self, file: str | Path | BytesIO) -> None: ...
class EagerDataFrame(
CompliantDataFrame[
EagerSeriesT, EagerExprT, NativeDataFrameT, "DataFrame[NativeDataFrameT]"
],
CompliantLazyFrame[EagerExprT, "Incomplete", "DataFrame[NativeDataFrameT]"],
ValidateBackendVersion,
Protocol[EagerSeriesT, EagerExprT, NativeDataFrameT, NativeSeriesT],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
def __narwhals_namespace__(
self,
) -> EagerNamespace[
Self, EagerSeriesT, EagerExprT, NativeDataFrameT, NativeSeriesT
]: ...
def to_narwhals(self) -> DataFrame[NativeDataFrameT]:
return self._version.dataframe(self, level="full")
def aggregate(self, *exprs: EagerExprT) -> Self:
# NOTE: Ignore intermittent [False Negative]
# Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "exprs" of type "EagerExprT@EagerDataFrame" in function "select"
# Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame"
return self.select(*exprs) # pyright: ignore[reportArgumentType]
def _with_native(
self, df: NativeDataFrameT, *, validate_column_names: bool = True
) -> Self: ...
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
return check_columns_exist(subset, available=self.columns)
def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT:
"""Evaluate `expr` and ensure it has a **single** output."""
result: Sequence[EagerSeriesT] = expr(self)
assert len(result) == 1 # debug assertion # noqa: S101
return result[0]
def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
# NOTE: Ignore intermittent [False Negative]
# Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "expr" of type "EagerExprT@EagerDataFrame" in function "_evaluate_into_expr"
# Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame"
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
"""Return list of raw columns.
For eager backends we alias operations at each step.
As a safety precaution, here we can check that the expected result names match those
we were expecting from the various `evaluate_output_names` / `alias_output_names` calls.
Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
"""
aliases = expr._evaluate_aliases(self)
result = expr(self)
if list(aliases) != (
result_aliases := [s.name for s in result]
): # pragma: no cover
msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
raise AssertionError(msg)
return result
def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
"""Extract native Series, broadcasting to `len(self)` if necessary."""
...
@staticmethod
def _numpy_column_names(
data: _2DArray, columns: Sequence[str] | None, /
) -> list[str]:
return list(columns or (f"column_{x}" for x in range(data.shape[1])))
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
def _select_multi_index(
self, columns: SizedMultiIndexSelector[NativeSeriesT]
) -> Self: ...
def _select_multi_name(
self, columns: SizedMultiNameSelector[NativeSeriesT]
) -> Self: ...
def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
def _select_slice_name(self, columns: _SliceName) -> Self: ...
def __getitem__( # noqa: C901, PLR0912
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[EagerSeriesT],
MultiColSelector[EagerSeriesT],
],
) -> Self:
rows, columns = item
compliant = self
if not is_slice_none(columns):
if isinstance(columns, Sized) and len(columns) == 0:
return compliant.select()
if is_index_selector(columns):
if is_slice_index(columns) or is_range(columns):
compliant = compliant._select_slice_index(columns)
elif is_compliant_series(columns):
compliant = self._select_multi_index(columns.native)
else:
compliant = compliant._select_multi_index(columns)
elif isinstance(columns, slice):
compliant = compliant._select_slice_name(columns)
elif is_compliant_series(columns):
compliant = self._select_multi_name(columns.native)
elif is_sequence_like(columns):
compliant = self._select_multi_name(columns)
else:
assert_never(columns)
if not is_slice_none(rows):
if isinstance(rows, int):
compliant = compliant._gather([rows])
elif isinstance(rows, (slice, range)):
compliant = compliant._gather_slice(rows)
elif is_compliant_series(rows):
compliant = compliant._gather(rows.native)
elif is_sized_multi_index_selector(rows):
compliant = compliant._gather(rows)
else:
assert_never(rows)
return compliant
def sink_parquet(self, file: str | Path | BytesIO) -> None:
return self.write_parquet(file)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,180 @@
from __future__ import annotations
import re
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol, TypeVar
from narwhals._compliant.typing import (
CompliantDataFrameT,
CompliantDataFrameT_co,
CompliantExprT_contra,
CompliantFrameT,
CompliantFrameT_co,
DepthTrackingExprAny,
DepthTrackingExprT_contra,
EagerExprT_contra,
ImplExprT_contra,
NarwhalsAggregation,
)
from narwhals._utils import is_sequence_of, zip_strict
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from narwhals._compliant.expr import ImplExpr
__all__ = ["CompliantGroupBy", "DepthTrackingGroupBy", "EagerGroupBy"]
NativeAggregationT_co = TypeVar(
"NativeAggregationT_co", bound="str | Callable[..., Any]", covariant=True
)
_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)")
def _evaluate_aliases(
frame: CompliantFrameT, exprs: Iterable[ImplExpr[CompliantFrameT, Any]], /
) -> list[str]:
it = (expr._evaluate_aliases(frame) for expr in exprs)
return list(chain.from_iterable(it))
class CompliantGroupBy(Protocol[CompliantFrameT_co, CompliantExprT_contra]):
_compliant_frame: Any
@property
def compliant(self) -> CompliantFrameT_co:
return self._compliant_frame # type: ignore[no-any-return]
def __init__(
self,
compliant_frame: CompliantFrameT_co,
keys: Sequence[CompliantExprT_contra] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None: ...
def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ...
class DataFrameGroupBy(
CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra],
Protocol[CompliantDataFrameT_co, CompliantExprT_contra],
):
def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ...
class ParseKeysGroupBy(
CompliantGroupBy[CompliantFrameT, ImplExprT_contra],
Protocol[CompliantFrameT, ImplExprT_contra],
):
def _parse_keys(
self,
compliant_frame: CompliantFrameT,
keys: Sequence[ImplExprT_contra] | Sequence[str],
) -> tuple[CompliantFrameT, list[str], list[str]]:
if is_sequence_of(keys, str):
keys_str = list(keys)
return compliant_frame, keys_str, keys_str.copy()
return self._parse_expr_keys(compliant_frame, keys=keys)
@staticmethod
def _parse_expr_keys(
compliant_frame: CompliantFrameT, keys: Sequence[ImplExprT_contra]
) -> tuple[CompliantFrameT, list[str], list[str]]:
"""Parses key expressions to set up `.agg` operation with correct information.
Since keys are expressions, it's possible to alias any such key to match
other dataframe column names.
In order to match polars behavior and not overwrite columns when evaluating keys:
- We evaluate what the output key names should be, in order to remap temporary column
names to the expected ones, and to exclude those from unnamed expressions in
`.agg(...)` context (see https://github.com/narwhals-dev/narwhals/pull/2325#issuecomment-2800004520)
- Create temporary names for evaluated key expressions that are guaranteed to have
no overlap with any existing column name.
- Add these temporary columns to the compliant dataframe.
"""
tmp_name_length = max(len(str(c)) for c in compliant_frame.columns) + 1
def _temporary_name(key: str) -> str:
# 5 is the length of `__tmp`
key_str = str(key) # pandas allows non-string column names :sob:
return f"_{key_str}_tmp{'_' * (tmp_name_length - len(key_str) - 5)}"
keys_aliases = [expr._evaluate_aliases(compliant_frame) for expr in keys]
safe_keys = [
# multi-output expression cannot have duplicate names, hence it's safe to suffix
key.name.map(_temporary_name)
if (metadata := key._metadata) and metadata.expansion_kind.is_multi_output()
# otherwise it's single named and we can use Expr.alias
else key.alias(_temporary_name(new_names[0]))
for key, new_names in zip_strict(keys, keys_aliases)
]
return (
compliant_frame.with_columns(*safe_keys),
_evaluate_aliases(compliant_frame, safe_keys),
list(chain.from_iterable(keys_aliases)),
)
class DepthTrackingGroupBy(
ParseKeysGroupBy[CompliantFrameT, DepthTrackingExprT_contra],
Protocol[CompliantFrameT, DepthTrackingExprT_contra, NativeAggregationT_co],
):
"""`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`."""
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]]
"""Mapping from `narwhals` to native representation.
Note:
- `Dask` *may* return a `Callable` instead of a `str` referring to one.
"""
def _ensure_all_simple(self, exprs: Sequence[DepthTrackingExprT_contra]) -> None:
for expr in exprs:
if not self._is_simple(expr):
name = self.compliant._implementation.name.lower()
msg = (
f"Non-trivial complex aggregation found.\n\n"
f"Hint: you were probably trying to apply a non-elementary aggregation with a"
f"{name!r} table.\n"
"Please rewrite your query such that group-by aggregations "
"are elementary. For example, instead of:\n\n"
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
"use:\n\n"
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
)
raise ValueError(msg)
@classmethod
def _is_simple(cls, expr: DepthTrackingExprAny, /) -> bool:
"""Return `True` is we can efficiently use `expr` in a native `group_by` context."""
return expr._is_elementary() and cls._leaf_name(expr) in cls._REMAP_AGGS
@classmethod
def _remap_expr_name(
cls, name: NarwhalsAggregation | Any, /
) -> NativeAggregationT_co:
"""Replace `name`, with some native representation.
Arguments:
name: Name of a `nw.Expr` aggregation method.
"""
return cls._REMAP_AGGS.get(name, name)
@classmethod
def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any:
"""Return the last function name in the chain defined by `expr`."""
return _RE_LEAF_NAME.sub("", expr._function_name)
class EagerGroupBy(
DepthTrackingGroupBy[CompliantDataFrameT, EagerExprT_contra, NativeAggregationT_co],
DataFrameGroupBy[CompliantDataFrameT, EagerExprT_contra],
Protocol[CompliantDataFrameT, EagerExprT_contra, NativeAggregationT_co],
): ...

View File

@ -0,0 +1,238 @@
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, Protocol, overload
from narwhals._compliant.typing import (
CompliantExprT,
CompliantFrameT,
CompliantLazyFrameT,
DepthTrackingExprT,
EagerDataFrameT,
EagerExprT,
EagerSeriesT,
LazyExprT,
NativeFrameT,
NativeFrameT_co,
NativeSeriesT,
)
from narwhals._expression_parsing import is_expr, is_series
from narwhals._utils import (
exclude_column_names,
get_column_names,
passthrough_column_names,
)
from narwhals.dependencies import is_numpy_array, is_numpy_array_2d
if TYPE_CHECKING:
from collections.abc import Container, Iterable, Sequence
from typing_extensions import TypeAlias
from narwhals._compliant.selectors import CompliantSelectorNamespace
from narwhals._compliant.when_then import CompliantWhen, EagerWhen
from narwhals._utils import Implementation, Version
from narwhals.expr import Expr
from narwhals.series import Series
from narwhals.typing import (
ConcatMethod,
Into1DArray,
IntoDType,
IntoSchema,
NonNestedLiteral,
_1DArray,
_2DArray,
)
Incomplete: TypeAlias = Any
__all__ = [
"CompliantNamespace",
"DepthTrackingNamespace",
"EagerNamespace",
"LazyNamespace",
]
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
# NOTE: `narwhals`
_implementation: Implementation
_version: Version
@property
def _expr(self) -> type[CompliantExprT]: ...
def parse_into_expr(
self, data: Expr | NonNestedLiteral | Any, /, *, str_as_lit: bool
) -> CompliantExprT | NonNestedLiteral:
if is_expr(data):
expr = data._to_compliant_expr(self)
assert isinstance(expr, self._expr) # noqa: S101
return expr
if isinstance(data, str) and not str_as_lit:
return self.col(data)
return data
# NOTE: `polars`
def all(self) -> CompliantExprT:
return self._expr.from_column_names(get_column_names, context=self)
def col(self, *column_names: str) -> CompliantExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), context=self
)
def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names), context=self
)
def nth(self, *column_indices: int) -> CompliantExprT:
return self._expr.from_column_indices(*column_indices, context=self)
def len(self) -> CompliantExprT: ...
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ...
def all_horizontal(
self, *exprs: CompliantExprT, ignore_nulls: bool
) -> CompliantExprT: ...
def any_horizontal(
self, *exprs: CompliantExprT, ignore_nulls: bool
) -> CompliantExprT: ...
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def concat(
self, items: Iterable[CompliantFrameT], *, how: ConcatMethod
) -> CompliantFrameT: ...
def when(
self, predicate: CompliantExprT
) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ...
def concat_str(
self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool
) -> CompliantExprT: ...
@property
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
def coalesce(self, *exprs: CompliantExprT) -> CompliantExprT: ...
class DepthTrackingNamespace(
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
Protocol[CompliantFrameT, DepthTrackingExprT],
):
def all(self) -> DepthTrackingExprT:
return self._expr.from_column_names(
get_column_names, function_name="all", context=self
)
def col(self, *column_names: str) -> DepthTrackingExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), function_name="col", context=self
)
def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
context=self,
)
class LazyNamespace(
CompliantNamespace[CompliantLazyFrameT, LazyExprT],
Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT_co],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def _lazyframe(self) -> type[CompliantLazyFrameT]: ...
def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:
if self._lazyframe._is_native(data):
return self._lazyframe.from_native(data, context=self)
msg = f"Unsupported type: {type(data).__name__!r}" # pragma: no cover
raise TypeError(msg)
class EagerNamespace(
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def _dataframe(self) -> type[EagerDataFrameT]: ...
@property
def _series(self) -> type[EagerSeriesT]: ...
def when(
self, predicate: EagerExprT
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...
@overload
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
@overload
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ...
def from_native(
self, data: NativeFrameT | NativeSeriesT | Any, /
) -> EagerDataFrameT | EagerSeriesT:
if self._dataframe._is_native(data):
return self._dataframe.from_native(data, context=self)
if self._series._is_native(data):
return self._series.from_native(data, context=self)
msg = f"Unsupported type: {type(data).__name__!r}"
raise TypeError(msg)
def parse_into_expr(
self,
data: Expr | Series[NativeSeriesT] | _1DArray | NonNestedLiteral,
/,
*,
str_as_lit: bool,
) -> EagerExprT | NonNestedLiteral:
if not (is_series(data) or is_numpy_array(data)):
return super().parse_into_expr(data, str_as_lit=str_as_lit)
return self._expr._from_series(
data._compliant_series
if is_series(data)
else self._series.from_numpy(data, context=self)
)
@overload
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT: ...
@overload
def from_numpy(
self, data: _2DArray, /, schema: IntoSchema | Sequence[str] | None
) -> EagerDataFrameT: ...
def from_numpy(
self,
data: Into1DArray | _2DArray,
/,
schema: IntoSchema | Sequence[str] | None = None,
) -> EagerDataFrameT | EagerSeriesT:
if is_numpy_array_2d(data):
return self._dataframe.from_numpy(data, schema=schema, context=self)
return self._series.from_numpy(data, context=self)
def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
def _concat_horizontal(
self, dfs: Sequence[NativeFrameT | Any], /
) -> NativeFrameT: ...
def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
def concat(
self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod
) -> EagerDataFrameT:
dfs = [item.native for item in items]
if how == "horizontal":
native = self._concat_horizontal(dfs)
elif how == "vertical":
native = self._concat_vertical(dfs)
elif how == "diagonal":
native = self._concat_diagonal(dfs)
else: # pragma: no cover
raise NotImplementedError
return self._dataframe.from_native(native, context=self)

View File

@ -0,0 +1,318 @@
"""Almost entirely complete, generic `selectors` implementation."""
from __future__ import annotations
import re
from functools import partial
from typing import TYPE_CHECKING, Protocol, TypeVar, overload
from narwhals._compliant.expr import CompliantExpr
from narwhals._utils import (
_parse_time_unit_and_time_zone,
dtype_matches_time_unit_and_time_zone,
get_column_names,
is_compliant_dataframe,
zip_strict,
)
if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Iterator, Sequence
from datetime import timezone
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.expr import NativeExpr
from narwhals._compliant.typing import (
CompliantDataFrameAny,
CompliantExprAny,
CompliantFrameAny,
CompliantLazyFrameAny,
CompliantSeriesAny,
CompliantSeriesOrNativeExprAny,
EvalNames,
EvalSeries,
)
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.typing import TimeUnit
__all__ = [
"CompliantSelector",
"CompliantSelectorNamespace",
"EagerSelectorNamespace",
"LazySelectorNamespace",
]
SeriesOrExprT = TypeVar("SeriesOrExprT", bound="CompliantSeriesOrNativeExprAny")
SeriesT = TypeVar("SeriesT", bound="CompliantSeriesAny")
ExprT = TypeVar("ExprT", bound="NativeExpr")
FrameT = TypeVar("FrameT", bound="CompliantFrameAny")
DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrameAny")
LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrameAny")
SelectorOrExpr: TypeAlias = (
"CompliantSelector[FrameT, SeriesOrExprT] | CompliantExpr[FrameT, SeriesOrExprT]"
)
class CompliantSelectorNamespace(Protocol[FrameT, SeriesOrExprT]):
# NOTE: `narwhals`
_implementation: Implementation
_version: Version
@property
def _selector(self) -> type[CompliantSelector[FrameT, SeriesOrExprT]]: ...
@classmethod
def from_namespace(cls, context: _LimitedContext, /) -> Self:
obj = cls.__new__(cls)
obj._implementation = context._implementation
obj._version = context._version
return obj
def _iter_columns(self, df: FrameT, /) -> Iterator[SeriesOrExprT]: ...
def _iter_schema(self, df: FrameT, /) -> Iterator[tuple[str, DType]]: ...
def _iter_columns_dtypes(
self, df: FrameT, /
) -> Iterator[tuple[SeriesOrExprT, DType]]: ...
def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]:
yield from zip_strict(self._iter_columns(df), df.columns)
def _is_dtype(
self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], /
) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [
ser for ser, tp in self._iter_columns_dtypes(df) if isinstance(tp, dtype)
]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if isinstance(tp, dtype)]
return self._selector.from_callables(series, names, context=self)
# NOTE: `polars`
def by_dtype(
self, dtypes: Collection[DType | type[DType]]
) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp in dtypes]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if tp in dtypes]
return self._selector.from_callables(series, names, context=self)
def matches(self, pattern: str) -> CompliantSelector[FrameT, SeriesOrExprT]:
p = re.compile(pattern)
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
if (
is_compliant_dataframe(df)
and not self._implementation.is_duckdb()
and not self._implementation.is_ibis()
):
return [df.get_column(col) for col in df.columns if p.search(col)]
return [ser for ser, name in self._iter_columns_names(df) if p.search(name)]
def names(df: FrameT) -> Sequence[str]:
return [col for col in df.columns if p.search(col)]
return self._selector.from_callables(series, names, context=self)
def numeric(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp.is_numeric()]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if tp.is_numeric()]
return self._selector.from_callables(series, names, context=self)
def categorical(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.Categorical)
def string(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.String)
def boolean(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.Boolean)
def all(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return list(self._iter_columns(df))
return self._selector.from_callables(series, get_column_names, context=self)
def datetime(
self,
time_unit: TimeUnit | Iterable[TimeUnit] | None,
time_zone: str | timezone | Iterable[str | timezone | None] | None,
) -> CompliantSelector[FrameT, SeriesOrExprT]:
time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone)
matches = partial(
dtype_matches_time_unit_and_time_zone,
dtypes=self._version.dtypes,
time_units=time_units,
time_zones=time_zones,
)
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if matches(tp)]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if matches(tp)]
return self._selector.from_callables(series, names, context=self)
class EagerSelectorNamespace(
CompliantSelectorNamespace[DataFrameT, SeriesT], Protocol[DataFrameT, SeriesT]
):
def _iter_schema(self, df: DataFrameT, /) -> Iterator[tuple[str, DType]]:
for ser in self._iter_columns(df):
yield ser.name, ser.dtype
def _iter_columns(self, df: DataFrameT, /) -> Iterator[SeriesT]:
yield from df.iter_columns()
def _iter_columns_dtypes(self, df: DataFrameT, /) -> Iterator[tuple[SeriesT, DType]]:
for ser in self._iter_columns(df):
yield ser, ser.dtype
class LazySelectorNamespace(
CompliantSelectorNamespace[LazyFrameT, ExprT], Protocol[LazyFrameT, ExprT]
):
def _iter_schema(self, df: LazyFrameT) -> Iterator[tuple[str, DType]]:
yield from df.schema.items()
def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]:
yield from df._iter_columns()
def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]:
yield from zip_strict(self._iter_columns(df), df.schema.values())
class CompliantSelector(
CompliantExpr[FrameT, SeriesOrExprT], Protocol[FrameT, SeriesOrExprT]
):
_call: EvalSeries[FrameT, SeriesOrExprT]
_function_name: str
_implementation: Implementation
_version: Version
@classmethod
def from_callables(
cls,
call: EvalSeries[FrameT, SeriesOrExprT],
evaluate_output_names: EvalNames[FrameT],
*,
context: _LimitedContext,
) -> Self:
obj = cls.__new__(cls)
obj._call = call
obj._evaluate_output_names = evaluate_output_names
obj._alias_output_names = None
obj._implementation = context._implementation
obj._version = context._version
return obj
@property
def selectors(self) -> CompliantSelectorNamespace[FrameT, SeriesOrExprT]:
return self.__narwhals_namespace__().selectors
def _to_expr(self) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def _is_selector(
self, other: Self | CompliantExpr[FrameT, SeriesOrExprT]
) -> TypeIs[CompliantSelector[FrameT, SeriesOrExprT]]:
return isinstance(other, type(self))
@overload
def __sub__(self, other: Self) -> Self: ...
@overload
def __sub__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __sub__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
x
for x, name in zip_strict(self(df), lhs_names)
if name not in rhs_names
]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x in lhs_names if x not in rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() - other
@overload
def __or__(self, other: Self) -> Self: ...
@overload
def __or__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __or__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
*(
x
for x, name in zip_strict(self(df), lhs_names)
if name not in rhs_names
),
*other(df),
]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() | other
@overload
def __and__(self, other: Self) -> Self: ...
@overload
def __and__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __and__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
x for x, name in zip_strict(self(df), lhs_names) if name in rhs_names
]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x in lhs_names if x in rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() & other
def __invert__(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self.selectors.all() - self
def _eval_lhs_rhs(
df: CompliantFrameAny, lhs: CompliantExprAny, rhs: CompliantExprAny
) -> tuple[Sequence[str], Sequence[str]]:
return lhs._evaluate_output_names(df), rhs._evaluate_output_names(df)

View File

@ -0,0 +1,411 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol
from narwhals._compliant.any_namespace import (
CatNamespace,
DateTimeNamespace,
ListNamespace,
StringNamespace,
StructNamespace,
)
from narwhals._compliant.column import CompliantColumn
from narwhals._compliant.typing import (
CompliantSeriesT_co,
EagerDataFrameAny,
EagerSeriesT_co,
NativeSeriesT,
NativeSeriesT_co,
)
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
from narwhals._typing_compat import TypeVar, assert_never
from narwhals._utils import (
_StoresCompliant,
_StoresNative,
is_compliant_series,
is_sized_multi_index_selector,
unstable,
)
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
from types import ModuleType
import pandas as pd
import polars as pl
import pyarrow as pa
from typing_extensions import NotRequired, Self, TypedDict
from narwhals._compliant.dataframe import CompliantDataFrame
from narwhals._compliant.namespace import EagerNamespace
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.series import Series
from narwhals.typing import (
Into1DArray,
IntoDType,
MultiIndexSelector,
RollingInterpolationMethod,
SizedMultiIndexSelector,
_1DArray,
_SliceIndex,
)
class HistData(TypedDict, Generic[NativeSeriesT, "_CountsT_co"]):
breakpoint: NotRequired[list[float] | _1DArray | list[Any]]
count: NativeSeriesT | _1DArray | _CountsT_co | list[Any]
_CountsT_co = TypeVar("_CountsT_co", bound="Iterable[Any]", covariant=True)
__all__ = [
"CompliantSeries",
"EagerSeries",
"EagerSeriesCatNamespace",
"EagerSeriesDateTimeNamespace",
"EagerSeriesHist",
"EagerSeriesListNamespace",
"EagerSeriesNamespace",
"EagerSeriesStringNamespace",
"EagerSeriesStructNamespace",
]
class CompliantSeries(
NumpyConvertible["_1DArray", "Into1DArray"],
FromIterable,
FromNative[NativeSeriesT],
ToNarwhals["Series[NativeSeriesT]"],
CompliantColumn,
Protocol[NativeSeriesT],
):
# NOTE: `narwhals`
_implementation: Implementation
@property
def native(self) -> NativeSeriesT: ...
def __narwhals_series__(self) -> Self:
return self
def __native_namespace__(self) -> ModuleType: ...
@classmethod
def from_native(cls, data: NativeSeriesT, /, *, context: _LimitedContext) -> Self: ...
def to_narwhals(self) -> Series[NativeSeriesT]:
return self._version.series(self, level="full")
def _with_native(self, series: Any) -> Self: ...
def _with_version(self, version: Version) -> Self: ...
# NOTE: `polars`
@property
def dtype(self) -> DType: ...
@property
def name(self) -> str: ...
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: ...
def __contains__(self, other: Any) -> bool: ...
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any: ...
def __iter__(self) -> Iterator[Any]: ...
def __len__(self) -> int:
return len(self.native)
@classmethod
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_iterable(
cls,
data: Iterable[Any],
/,
*,
context: _LimitedContext,
name: str = "",
dtype: IntoDType | None = None,
) -> Self: ...
def __radd__(self, other: Any) -> Self: ...
def __rand__(self, other: Any) -> Self: ...
def __rmul__(self, other: Any) -> Self: ...
def __ror__(self, other: Any) -> Self: ...
def all(self) -> bool: ...
def any(self) -> bool: ...
def arg_max(self) -> int: ...
def arg_min(self) -> int: ...
def arg_true(self) -> Self: ...
def count(self) -> int: ...
def filter(self, predicate: Any) -> Self: ...
def gather_every(self, n: int, offset: int) -> Self: ...
def head(self, n: int) -> Self: ...
def is_empty(self) -> bool:
return self.len() == 0
def is_sorted(self, *, descending: bool) -> bool: ...
def item(self, index: int | None) -> Any: ...
def kurtosis(self) -> float | None: ...
def len(self) -> int: ...
def max(self) -> Any: ...
def mean(self) -> float: ...
def median(self) -> float: ...
def min(self) -> Any: ...
def n_unique(self) -> int: ...
def null_count(self) -> int: ...
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> float: ...
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self: ...
def scatter(self, indices: int | Sequence[int], values: Any) -> Self: ...
def shift(self, n: int) -> Self: ...
def skew(self) -> float | None: ...
def sort(self, *, descending: bool, nulls_last: bool) -> Self: ...
def std(self, *, ddof: int) -> float: ...
def sum(self) -> float: ...
def tail(self, n: int) -> Self: ...
def to_arrow(self) -> pa.Array[Any]: ...
def to_dummies(
self, *, separator: str, drop_first: bool
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def to_frame(self) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def to_list(self) -> list[Any]: ...
def to_pandas(self) -> pd.Series[Any]: ...
def to_polars(self) -> pl.Series: ...
def unique(self, *, maintain_order: bool = False) -> Self: ...
def value_counts(
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def var(self, *, ddof: int) -> float: ...
def zip_with(self, mask: Any, other: Any) -> Self: ...
# NOTE: *Technically* `polars`
@unstable
def hist_from_bins(
self, bins: list[float], *, include_breakpoint: bool
) -> CompliantDataFrame[Self, Any, Any, Any]:
"""`Series.hist(bins=..., bin_count=None)`."""
...
@unstable
def hist_from_bin_count(
self, bin_count: int, *, include_breakpoint: bool
) -> CompliantDataFrame[Self, Any, Any, Any]:
"""`Series.hist(bins=None, bin_count=...)`."""
...
class EagerSeries(CompliantSeries[NativeSeriesT], Protocol[NativeSeriesT]):
_native_series: Any
_implementation: Implementation
_version: Version
_broadcast: bool
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@classmethod
def _align_full_broadcast(cls, *series: Self) -> Sequence[Self]:
"""Ensure all of `series` have the same length (and index if `pandas`).
Scalars get broadcasted to the full length of the longest Series.
This is useful when you need to construct a full Series anyway, such as:
DataFrame.select(...)
It should not be used in binary operations, such as:
nw.col("a") - nw.col("a").mean()
because then it's more efficient to extract the right-hand-side's single element as a scalar.
"""
...
def _from_scalar(self, value: Any) -> Self:
return self.from_iterable([value], name=self.name, context=self)
def _with_native(
self, series: NativeSeriesT, *, preserve_broadcast: bool = False
) -> Self:
"""Return a new `CompliantSeries`, wrapping the native `series`.
In cases when operations are known to not affect whether a result should
be broadcast, we can pass `preserve_broadcast=True`.
Set this with care - it should only be set for unary expressions which don't
change length or order, such as `.alias` or `.fill_null`. If in doubt, don't
set it, you probably don't need it.
"""
...
def __narwhals_namespace__(
self,
) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ...
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
def __getitem__(self, item: MultiIndexSelector[Self]) -> Self:
if isinstance(item, (slice, range)):
return self._gather_slice(item)
if is_compliant_series(item):
return self._gather(item.native)
elif is_sized_multi_index_selector(item): # noqa: RET505
return self._gather(item)
assert_never(item)
@property
def str(self) -> EagerSeriesStringNamespace[Self, NativeSeriesT]: ...
@property
def dt(self) -> EagerSeriesDateTimeNamespace[Self, NativeSeriesT]: ...
@property
def cat(self) -> EagerSeriesCatNamespace[Self, NativeSeriesT]: ...
@property
def list(self) -> EagerSeriesListNamespace[Self, NativeSeriesT]: ...
@property
def struct(self) -> EagerSeriesStructNamespace[Self, NativeSeriesT]: ...
class _SeriesNamespace( # type: ignore[misc]
_StoresCompliant[CompliantSeriesT_co],
_StoresNative[NativeSeriesT_co],
Protocol[CompliantSeriesT_co, NativeSeriesT_co],
):
_compliant_series: CompliantSeriesT_co
@property
def compliant(self) -> CompliantSeriesT_co:
return self._compliant_series
@property
def implementation(self) -> Implementation:
return self.compliant._implementation
@property
def backend_version(self) -> tuple[int, ...]:
return self.implementation._backend_version()
@property
def version(self) -> Version:
return self.compliant._version
@property
def native(self) -> NativeSeriesT_co:
return self._compliant_series.native # type: ignore[no-any-return]
def with_native(self, series: Any, /) -> CompliantSeriesT_co:
return self.compliant._with_native(series)
class EagerSeriesNamespace(
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
Generic[EagerSeriesT_co, NativeSeriesT_co],
):
_compliant_series: EagerSeriesT_co
def __init__(self, series: EagerSeriesT_co, /) -> None:
self._compliant_series = series
class EagerSeriesCatNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
CatNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesDateTimeNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
DateTimeNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesListNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
ListNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesStringNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
StringNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesStructNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
StructNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesHist(Protocol[NativeSeriesT, _CountsT_co]):
_series: EagerSeries[NativeSeriesT]
_breakpoint: bool
_data: HistData[NativeSeriesT, _CountsT_co]
@property
def native(self) -> NativeSeriesT:
return self._series.native
@classmethod
def from_series(
cls, series: EagerSeries[NativeSeriesT], *, include_breakpoint: bool
) -> Self:
obj = cls.__new__(cls)
obj._series = series
obj._breakpoint = include_breakpoint
return obj
def to_frame(self) -> EagerDataFrameAny: ...
def _linear_space( # NOTE: Roughly `pl.linear_space`
self,
start: float,
end: float,
num_samples: int,
*,
closed: Literal["both", "none"] = "both",
) -> _1DArray: ...
# NOTE: *Could* be handled at narwhals-level
def is_empty_series(self) -> bool: ...
# NOTE: **Should** be handled at narwhals-level
def data_empty(self) -> HistData[NativeSeriesT, _CountsT_co]:
return {"breakpoint": [], "count": []} if self._breakpoint else {"count": []}
# NOTE: *Could* be handled at narwhals-level, **iff** we add `nw.repeat`, `nw.linear_space`
# See https://github.com/narwhals-dev/narwhals/pull/2839#discussion_r2215630696
def series_empty(
self, arg: int | list[float], /
) -> HistData[NativeSeriesT, _CountsT_co]: ...
def with_bins(self, bins: list[float], /) -> Self:
if len(bins) <= 1:
self._data = self.data_empty()
elif self.is_empty_series():
self._data = self.series_empty(bins)
else:
self._data = self._calculate_hist(bins)
return self
def with_bin_count(self, bin_count: int, /) -> Self:
if bin_count == 0:
self._data = self.data_empty()
elif self.is_empty_series():
self._data = self.series_empty(bin_count)
else:
self._data = self._calculate_hist(self._calculate_bins(bin_count))
return self
def _calculate_breakpoint(self, arg: int | list[float], /) -> list[float] | _1DArray:
bins = self._linear_space(0, 1, arg + 1) if isinstance(arg, int) else arg
return bins[1:]
def _calculate_bins(self, bin_count: int) -> _1DArray: ...
def _calculate_hist(
self, bins: list[float] | _1DArray
) -> HistData[NativeSeriesT, _CountsT_co]: ...

View File

@ -0,0 +1,206 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from narwhals._compliant.dataframe import (
CompliantDataFrame,
CompliantFrame,
CompliantLazyFrame,
EagerDataFrame,
)
from narwhals._compliant.expr import (
CompliantExpr,
DepthTrackingExpr,
EagerExpr,
ImplExpr,
LazyExpr,
NativeExpr,
)
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
from narwhals._compliant.series import CompliantSeries, EagerSeries
from narwhals._compliant.window import WindowInputs
from narwhals.typing import (
FillNullStrategy,
IntoLazyFrame,
ModeKeepStrategy,
NativeDataFrame,
NativeFrame,
NativeSeries,
RankMethod,
RollingInterpolationMethod,
)
class ScalarKwargs(TypedDict, total=False):
"""Non-expressifiable args which we may need to reuse in `agg` or `over`."""
adjust: bool
alpha: float | None
center: int
com: float | None
ddof: int
descending: bool
half_life: float | None
ignore_nulls: bool
interpolation: RollingInterpolationMethod
keep: ModeKeepStrategy
limit: int | None
method: RankMethod
min_samples: int
n: int
quantile: float
reverse: bool
span: float | None
strategy: FillNullStrategy | None
window_size: int
__all__ = [
"AliasName",
"AliasNames",
"CompliantDataFrameT",
"CompliantFrameT",
"CompliantLazyFrameT",
"CompliantSeriesT",
"EvalNames",
"EvalSeries",
"NarwhalsAggregation",
"NativeFrameT_co",
"NativeSeriesT_co",
]
CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]"
CompliantSeriesAny: TypeAlias = "CompliantSeries[Any]"
CompliantSeriesOrNativeExprAny: TypeAlias = "CompliantSeriesAny | NativeExpr"
CompliantDataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any, Any]"
CompliantLazyFrameAny: TypeAlias = "CompliantLazyFrame[Any, Any, Any]"
CompliantFrameAny: TypeAlias = "CompliantFrame[Any, Any, Any]"
CompliantNamespaceAny: TypeAlias = "CompliantNamespace[Any, Any]"
ImplExprAny: TypeAlias = "ImplExpr[Any, Any]"
DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]"
EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any, Any]"
EagerSeriesAny: TypeAlias = "EagerSeries[Any]"
EagerExprAny: TypeAlias = "EagerExpr[Any, Any]"
EagerNamespaceAny: TypeAlias = "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame, NativeSeries]"
LazyExprAny: TypeAlias = "LazyExpr[Any, Any]"
NativeExprT = TypeVar("NativeExprT", bound="NativeExpr")
NativeExprT_co = TypeVar("NativeExprT_co", bound="NativeExpr", covariant=True)
NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries")
NativeSeriesT_co = TypeVar("NativeSeriesT_co", bound="NativeSeries", covariant=True)
NativeSeriesT_contra = TypeVar(
"NativeSeriesT_contra", bound="NativeSeries", contravariant=True
)
NativeDataFrameT = TypeVar("NativeDataFrameT", bound="NativeDataFrame")
NativeLazyFrameT = TypeVar("NativeLazyFrameT", bound="IntoLazyFrame")
NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame")
NativeFrameT_co = TypeVar("NativeFrameT_co", bound="NativeFrame", covariant=True)
NativeFrameT_contra = TypeVar(
"NativeFrameT_contra", bound="NativeFrame", contravariant=True
)
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny)
CompliantExprT_co = TypeVar("CompliantExprT_co", bound=CompliantExprAny, covariant=True)
CompliantExprT_contra = TypeVar(
"CompliantExprT_contra", bound=CompliantExprAny, contravariant=True
)
CompliantSeriesT = TypeVar("CompliantSeriesT", bound=CompliantSeriesAny)
CompliantSeriesT_co = TypeVar(
"CompliantSeriesT_co", bound=CompliantSeriesAny, covariant=True
)
CompliantSeriesOrNativeExprT = TypeVar(
"CompliantSeriesOrNativeExprT", bound=CompliantSeriesOrNativeExprAny
)
CompliantSeriesOrNativeExprT_co = TypeVar(
"CompliantSeriesOrNativeExprT_co",
bound=CompliantSeriesOrNativeExprAny,
covariant=True,
)
CompliantFrameT = TypeVar("CompliantFrameT", bound=CompliantFrameAny)
CompliantFrameT_co = TypeVar(
"CompliantFrameT_co", bound=CompliantFrameAny, covariant=True
)
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound=CompliantDataFrameAny)
CompliantDataFrameT_co = TypeVar(
"CompliantDataFrameT_co", bound=CompliantDataFrameAny, covariant=True
)
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound=CompliantLazyFrameAny)
CompliantLazyFrameT_co = TypeVar(
"CompliantLazyFrameT_co", bound=CompliantLazyFrameAny, covariant=True
)
CompliantNamespaceT = TypeVar("CompliantNamespaceT", bound=CompliantNamespaceAny)
CompliantNamespaceT_co = TypeVar(
"CompliantNamespaceT_co", bound=CompliantNamespaceAny, covariant=True
)
ImplExprT_contra = TypeVar("ImplExprT_contra", bound=ImplExprAny, contravariant=True)
DepthTrackingExprT = TypeVar("DepthTrackingExprT", bound=DepthTrackingExprAny)
DepthTrackingExprT_contra = TypeVar(
"DepthTrackingExprT_contra", bound=DepthTrackingExprAny, contravariant=True
)
EagerExprT = TypeVar("EagerExprT", bound=EagerExprAny)
EagerExprT_contra = TypeVar("EagerExprT_contra", bound=EagerExprAny, contravariant=True)
EagerSeriesT = TypeVar("EagerSeriesT", bound=EagerSeriesAny)
EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True)
# NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`?
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any, Any]")
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True)
AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]]
"""A function aliasing a *sequence* of column names."""
AliasName: TypeAlias = Callable[[str], str]
"""A function aliasing a *single* column name."""
EvalSeries: TypeAlias = Callable[
[CompliantFrameT], Sequence[CompliantSeriesOrNativeExprT]
]
"""A function from a `Frame` to a sequence of `Series`*.
See [underwater unicorn magic](https://narwhals-dev.github.io/narwhals/how_it_works/).
"""
EvalNames: TypeAlias = Callable[[CompliantFrameT], Sequence[str]]
"""A function from a `Frame` to a sequence of columns names *before* any aliasing takes place."""
WindowFunction: TypeAlias = (
"Callable[[CompliantFrameT, WindowInputs[NativeExprT]], Sequence[NativeExprT]]"
)
"""A function evaluated with `over(partition_by=..., order_by=...)`."""
NarwhalsAggregation: TypeAlias = Literal[
"sum",
"mean",
"median",
"max",
"min",
"mode",
"std",
"var",
"len",
"n_unique",
"count",
"quantile",
"all",
"any",
]
"""`Expr` methods we aim to support in `DepthTrackingGroupBy`.
Be sure to update me if you're working on one of these:
- https://github.com/narwhals-dev/narwhals/issues/981
- https://github.com/narwhals-dev/narwhals/issues/2385
- https://github.com/narwhals-dev/narwhals/issues/2484
- https://github.com/narwhals-dev/narwhals/issues/2526
- https://github.com/narwhals-dev/narwhals/issues/2660
"""

View File

@ -0,0 +1,130 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
from narwhals._compliant.expr import CompliantExpr
from narwhals._compliant.typing import (
CompliantExprAny,
CompliantFrameAny,
CompliantSeriesOrNativeExprAny,
EagerDataFrameT,
EagerExprT,
EagerSeriesT,
LazyExprAny,
NativeSeriesT,
)
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Self, TypeAlias
from narwhals._compliant.typing import EvalSeries
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.typing import NonNestedLiteral
__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen"]
ExprT = TypeVar("ExprT", bound=CompliantExprAny)
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny)
FrameT = TypeVar("FrameT", bound=CompliantFrameAny)
Scalar: TypeAlias = Any
"""A native literal value."""
IntoExpr: TypeAlias = "SeriesT | ExprT | NonNestedLiteral | Scalar"
"""Anything that is convertible into a `CompliantExpr`."""
class CompliantWhen(Protocol[FrameT, SeriesT, ExprT]):
_condition: ExprT
_then_value: IntoExpr[SeriesT, ExprT]
_otherwise_value: IntoExpr[SeriesT, ExprT] | None
_implementation: Implementation
_version: Version
@property
def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT, Self]]: ...
def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ...
def then(
self, value: IntoExpr[SeriesT, ExprT], /
) -> CompliantThen[FrameT, SeriesT, ExprT, Self]:
return self._then.from_when(self, value)
@classmethod
def from_expr(cls, condition: ExprT, /, *, context: _LimitedContext) -> Self:
obj = cls.__new__(cls)
obj._condition = condition
obj._then_value = None
obj._otherwise_value = None
obj._implementation = context._implementation
obj._version = context._version
return obj
WhenT_contra = TypeVar(
"WhenT_contra", bound=CompliantWhen[Any, Any, Any], contravariant=True
)
class CompliantThen(
CompliantExpr[FrameT, SeriesT], Protocol[FrameT, SeriesT, ExprT, WhenT_contra]
):
_call: EvalSeries[FrameT, SeriesT]
_when_value: CompliantWhen[FrameT, SeriesT, ExprT]
_implementation: Implementation
_version: Version
@classmethod
def from_when(cls, when: WhenT_contra, then: IntoExpr[SeriesT, ExprT], /) -> Self:
when._then_value = then
obj = cls.__new__(cls)
obj._call = when
obj._when_value = when
obj._evaluate_output_names = getattr(
then, "_evaluate_output_names", lambda _df: ["literal"]
)
obj._alias_output_names = getattr(then, "_alias_output_names", None)
obj._implementation = when._implementation
obj._version = when._version
return obj
def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT:
self._when_value._otherwise_value = otherwise
return cast("ExprT", self)
class EagerWhen(
CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT],
):
def _if_then_else(
self,
when: NativeSeriesT,
then: NativeSeriesT,
otherwise: NativeSeriesT | NonNestedLiteral | Scalar,
/,
) -> NativeSeriesT: ...
def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
is_expr = self._condition._is_expr
when: EagerSeriesT = self._condition(df)[0]
then: EagerSeriesT
align = when._align_full_broadcast
if is_expr(self._then_value):
then = self._then_value(df)[0]
else:
then = when.alias("literal")._from_scalar(self._then_value)
then._broadcast = True
if is_expr(self._otherwise_value):
otherwise = self._otherwise_value(df)[0]
when, then, otherwise = align(when, then, otherwise)
result = self._if_then_else(when.native, then.native, otherwise.native)
else:
when, then = align(when, then)
result = self._if_then_else(when.native, then.native, self._otherwise_value)
return [then._with_native(result)]

View File

@ -0,0 +1,20 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Generic
from narwhals._compliant.typing import NativeExprT_co
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["WindowInputs"]
class WindowInputs(Generic[NativeExprT_co]):
__slots__ = ("order_by", "partition_by")
def __init__(
self, partition_by: Sequence[str | NativeExprT_co], order_by: Sequence[str]
) -> None:
self.partition_by = partition_by
self.order_by = order_by