done
This commit is contained in:
103
lib/python3.11/site-packages/narwhals/_compliant/__init__.py
Normal file
103
lib/python3.11/site-packages/narwhals/_compliant/__init__.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from narwhals._compliant.dataframe import (
|
||||
CompliantDataFrame,
|
||||
CompliantFrame,
|
||||
CompliantLazyFrame,
|
||||
EagerDataFrame,
|
||||
)
|
||||
from narwhals._compliant.expr import (
|
||||
CompliantExpr,
|
||||
DepthTrackingExpr,
|
||||
EagerExpr,
|
||||
LazyExpr,
|
||||
LazyExprNamespace,
|
||||
)
|
||||
from narwhals._compliant.group_by import (
|
||||
CompliantGroupBy,
|
||||
DepthTrackingGroupBy,
|
||||
EagerGroupBy,
|
||||
)
|
||||
from narwhals._compliant.namespace import (
|
||||
CompliantNamespace,
|
||||
DepthTrackingNamespace,
|
||||
EagerNamespace,
|
||||
LazyNamespace,
|
||||
)
|
||||
from narwhals._compliant.selectors import (
|
||||
CompliantSelector,
|
||||
CompliantSelectorNamespace,
|
||||
EagerSelectorNamespace,
|
||||
LazySelectorNamespace,
|
||||
)
|
||||
from narwhals._compliant.series import (
|
||||
CompliantSeries,
|
||||
EagerSeries,
|
||||
EagerSeriesCatNamespace,
|
||||
EagerSeriesDateTimeNamespace,
|
||||
EagerSeriesHist,
|
||||
EagerSeriesListNamespace,
|
||||
EagerSeriesNamespace,
|
||||
EagerSeriesStringNamespace,
|
||||
EagerSeriesStructNamespace,
|
||||
)
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantExprT,
|
||||
CompliantFrameT,
|
||||
CompliantSeriesOrNativeExprT_co,
|
||||
CompliantSeriesT,
|
||||
EagerDataFrameT,
|
||||
EagerSeriesT,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
NativeFrameT_co,
|
||||
NativeSeriesT_co,
|
||||
)
|
||||
from narwhals._compliant.when_then import CompliantThen, CompliantWhen, EagerWhen
|
||||
from narwhals._compliant.window import WindowInputs
|
||||
|
||||
__all__ = [
|
||||
"CompliantDataFrame",
|
||||
"CompliantExpr",
|
||||
"CompliantExprT",
|
||||
"CompliantFrame",
|
||||
"CompliantFrameT",
|
||||
"CompliantGroupBy",
|
||||
"CompliantLazyFrame",
|
||||
"CompliantNamespace",
|
||||
"CompliantSelector",
|
||||
"CompliantSelectorNamespace",
|
||||
"CompliantSeries",
|
||||
"CompliantSeriesOrNativeExprT_co",
|
||||
"CompliantSeriesT",
|
||||
"CompliantThen",
|
||||
"CompliantWhen",
|
||||
"DepthTrackingExpr",
|
||||
"DepthTrackingGroupBy",
|
||||
"DepthTrackingNamespace",
|
||||
"EagerDataFrame",
|
||||
"EagerDataFrameT",
|
||||
"EagerExpr",
|
||||
"EagerGroupBy",
|
||||
"EagerNamespace",
|
||||
"EagerSelectorNamespace",
|
||||
"EagerSeries",
|
||||
"EagerSeriesCatNamespace",
|
||||
"EagerSeriesDateTimeNamespace",
|
||||
"EagerSeriesHist",
|
||||
"EagerSeriesListNamespace",
|
||||
"EagerSeriesNamespace",
|
||||
"EagerSeriesStringNamespace",
|
||||
"EagerSeriesStructNamespace",
|
||||
"EagerSeriesT",
|
||||
"EagerWhen",
|
||||
"EvalNames",
|
||||
"EvalSeries",
|
||||
"LazyExpr",
|
||||
"LazyExprNamespace",
|
||||
"LazyNamespace",
|
||||
"LazySelectorNamespace",
|
||||
"NativeFrameT_co",
|
||||
"NativeSeriesT_co",
|
||||
"WindowInputs",
|
||||
]
|
@ -0,0 +1,94 @@
|
||||
"""`Expr` and `Series` namespace accessor protocols."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from narwhals._utils import CompliantT_co, _StoresCompliant
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable
|
||||
|
||||
from narwhals.typing import NonNestedLiteral, TimeUnit
|
||||
|
||||
__all__ = [
|
||||
"CatNamespace",
|
||||
"DateTimeNamespace",
|
||||
"ListNamespace",
|
||||
"NameNamespace",
|
||||
"StringNamespace",
|
||||
"StructNamespace",
|
||||
]
|
||||
|
||||
|
||||
class CatNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def get_categories(self) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class DateTimeNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def to_string(self, format: str) -> CompliantT_co: ...
|
||||
def replace_time_zone(self, time_zone: str | None) -> CompliantT_co: ...
|
||||
def convert_time_zone(self, time_zone: str) -> CompliantT_co: ...
|
||||
def timestamp(self, time_unit: TimeUnit) -> CompliantT_co: ...
|
||||
def date(self) -> CompliantT_co: ...
|
||||
def year(self) -> CompliantT_co: ...
|
||||
def month(self) -> CompliantT_co: ...
|
||||
def day(self) -> CompliantT_co: ...
|
||||
def hour(self) -> CompliantT_co: ...
|
||||
def minute(self) -> CompliantT_co: ...
|
||||
def second(self) -> CompliantT_co: ...
|
||||
def millisecond(self) -> CompliantT_co: ...
|
||||
def microsecond(self) -> CompliantT_co: ...
|
||||
def nanosecond(self) -> CompliantT_co: ...
|
||||
def ordinal_day(self) -> CompliantT_co: ...
|
||||
def weekday(self) -> CompliantT_co: ...
|
||||
def total_minutes(self) -> CompliantT_co: ...
|
||||
def total_seconds(self) -> CompliantT_co: ...
|
||||
def total_milliseconds(self) -> CompliantT_co: ...
|
||||
def total_microseconds(self) -> CompliantT_co: ...
|
||||
def total_nanoseconds(self) -> CompliantT_co: ...
|
||||
def truncate(self, every: str) -> CompliantT_co: ...
|
||||
def offset_by(self, by: str) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class ListNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def get(self, index: int) -> CompliantT_co: ...
|
||||
|
||||
def len(self) -> CompliantT_co: ...
|
||||
|
||||
def unique(self) -> CompliantT_co: ...
|
||||
def contains(self, item: NonNestedLiteral) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def keep(self) -> CompliantT_co: ...
|
||||
def map(self, function: Callable[[str], str]) -> CompliantT_co: ...
|
||||
def prefix(self, prefix: str) -> CompliantT_co: ...
|
||||
def suffix(self, suffix: str) -> CompliantT_co: ...
|
||||
def to_lowercase(self) -> CompliantT_co: ...
|
||||
def to_uppercase(self) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class StringNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def len_chars(self) -> CompliantT_co: ...
|
||||
def replace(
|
||||
self, pattern: str, value: str, *, literal: bool, n: int
|
||||
) -> CompliantT_co: ...
|
||||
def replace_all(
|
||||
self, pattern: str, value: str, *, literal: bool
|
||||
) -> CompliantT_co: ...
|
||||
def strip_chars(self, characters: str | None) -> CompliantT_co: ...
|
||||
def starts_with(self, prefix: str) -> CompliantT_co: ...
|
||||
def ends_with(self, suffix: str) -> CompliantT_co: ...
|
||||
def contains(self, pattern: str, *, literal: bool) -> CompliantT_co: ...
|
||||
def slice(self, offset: int, length: int | None) -> CompliantT_co: ...
|
||||
def split(self, by: str) -> CompliantT_co: ...
|
||||
def to_datetime(self, format: str | None) -> CompliantT_co: ...
|
||||
def to_date(self, format: str | None) -> CompliantT_co: ...
|
||||
def to_lowercase(self) -> CompliantT_co: ...
|
||||
def to_uppercase(self) -> CompliantT_co: ...
|
||||
def zfill(self, width: int) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class StructNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def field(self, name: str) -> CompliantT_co: ...
|
213
lib/python3.11/site-packages/narwhals/_compliant/column.py
Normal file
213
lib/python3.11/site-packages/narwhals/_compliant/column.py
Normal file
@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant.any_namespace import (
|
||||
CatNamespace,
|
||||
DateTimeNamespace,
|
||||
ListNamespace,
|
||||
StringNamespace,
|
||||
StructNamespace,
|
||||
)
|
||||
from narwhals._compliant.namespace import CompliantNamespace
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import (
|
||||
ClosedInterval,
|
||||
FillNullStrategy,
|
||||
IntoDType,
|
||||
ModeKeepStrategy,
|
||||
NonNestedLiteral,
|
||||
NumericLiteral,
|
||||
RankMethod,
|
||||
TemporalLiteral,
|
||||
)
|
||||
|
||||
__all__ = ["CompliantColumn"]
|
||||
|
||||
|
||||
class CompliantColumn(Protocol):
|
||||
"""Common parts of `Expr`, `Series`."""
|
||||
|
||||
_version: Version
|
||||
|
||||
def __add__(self, other: Any) -> Self: ...
|
||||
def __and__(self, other: Any) -> Self: ...
|
||||
def __eq__(self, other: object) -> Self: ... # type: ignore[override]
|
||||
def __floordiv__(self, other: Any) -> Self: ...
|
||||
def __ge__(self, other: Any) -> Self: ...
|
||||
def __gt__(self, other: Any) -> Self: ...
|
||||
def __invert__(self) -> Self: ...
|
||||
def __le__(self, other: Any) -> Self: ...
|
||||
def __lt__(self, other: Any) -> Self: ...
|
||||
def __mod__(self, other: Any) -> Self: ...
|
||||
def __mul__(self, other: Any) -> Self: ...
|
||||
def __ne__(self, other: object) -> Self: ... # type: ignore[override]
|
||||
def __or__(self, other: Any) -> Self: ...
|
||||
def __pow__(self, other: Any) -> Self: ...
|
||||
def __rfloordiv__(self, other: Any) -> Self: ...
|
||||
def __rmod__(self, other: Any) -> Self: ...
|
||||
def __rpow__(self, other: Any) -> Self: ...
|
||||
def __rsub__(self, other: Any) -> Self: ...
|
||||
def __rtruediv__(self, other: Any) -> Self: ...
|
||||
def __sub__(self, other: Any) -> Self: ...
|
||||
def __truediv__(self, other: Any) -> Self: ...
|
||||
|
||||
def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ...
|
||||
|
||||
def abs(self) -> Self: ...
|
||||
def alias(self, name: str) -> Self: ...
|
||||
def cast(self, dtype: IntoDType) -> Self: ...
|
||||
def clip(
|
||||
self,
|
||||
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
) -> Self: ...
|
||||
def cum_count(self, *, reverse: bool) -> Self: ...
|
||||
def cum_max(self, *, reverse: bool) -> Self: ...
|
||||
def cum_min(self, *, reverse: bool) -> Self: ...
|
||||
def cum_prod(self, *, reverse: bool) -> Self: ...
|
||||
def cum_sum(self, *, reverse: bool) -> Self: ...
|
||||
def diff(self) -> Self: ...
|
||||
def drop_nulls(self) -> Self: ...
|
||||
def ewm_mean(
|
||||
self,
|
||||
*,
|
||||
com: float | None,
|
||||
span: float | None,
|
||||
half_life: float | None,
|
||||
alpha: float | None,
|
||||
adjust: bool,
|
||||
min_samples: int,
|
||||
ignore_nulls: bool,
|
||||
) -> Self: ...
|
||||
def exp(self) -> Self: ...
|
||||
def sqrt(self) -> Self: ...
|
||||
def fill_nan(self, value: float | None) -> Self: ...
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self: ...
|
||||
def is_between(
|
||||
self, lower_bound: Self, upper_bound: Self, closed: ClosedInterval
|
||||
) -> Self:
|
||||
if closed == "left":
|
||||
return (self >= lower_bound) & (self < upper_bound)
|
||||
if closed == "right":
|
||||
return (self > lower_bound) & (self <= upper_bound)
|
||||
if closed == "none":
|
||||
return (self > lower_bound) & (self < upper_bound)
|
||||
return (self >= lower_bound) & (self <= upper_bound)
|
||||
|
||||
def is_close(
|
||||
self,
|
||||
other: Self | NumericLiteral,
|
||||
*,
|
||||
abs_tol: float,
|
||||
rel_tol: float,
|
||||
nans_equal: bool,
|
||||
) -> Self:
|
||||
from decimal import Decimal
|
||||
|
||||
other_abs: Self | NumericLiteral
|
||||
other_is_nan: Self | bool
|
||||
other_is_inf: Self | bool
|
||||
other_is_not_inf: Self | bool
|
||||
|
||||
if isinstance(other, (float, int, Decimal)):
|
||||
from math import isinf, isnan
|
||||
|
||||
# NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447
|
||||
other_abs = other.__abs__()
|
||||
other_is_nan = isnan(other)
|
||||
other_is_inf = isinf(other)
|
||||
|
||||
# Define the other_is_not_inf variable to prevent triggering the following warning:
|
||||
# > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be
|
||||
# > removed in Python 3.16.
|
||||
other_is_not_inf = not other_is_inf
|
||||
|
||||
else:
|
||||
other_abs, other_is_nan = other.abs(), other.is_nan()
|
||||
other_is_not_inf = other.is_finite() | other_is_nan
|
||||
other_is_inf = ~other_is_not_inf
|
||||
|
||||
rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol
|
||||
tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None)
|
||||
|
||||
self_is_nan = self.is_nan()
|
||||
self_is_not_inf = self.is_finite() | self_is_nan
|
||||
|
||||
# Values are close if abs_diff <= tolerance, and both finite
|
||||
is_close = (
|
||||
((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf
|
||||
)
|
||||
|
||||
# Handle infinity cases: infinities are close/equal if they have the same sign
|
||||
self_sign, other_sign = self > 0, other > 0
|
||||
is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign)
|
||||
|
||||
# Handle nan cases:
|
||||
# * If any value is NaN, then False (via `& ~either_nan`)
|
||||
# * However, if `nans_equals = True` and if _both_ values are NaN, then True
|
||||
either_nan = self_is_nan | other_is_nan
|
||||
result = (is_close | is_same_inf) & ~either_nan
|
||||
|
||||
if nans_equal:
|
||||
both_nan = self_is_nan & other_is_nan
|
||||
result = result | both_nan
|
||||
|
||||
return result
|
||||
|
||||
def is_duplicated(self) -> Self:
|
||||
return ~self.is_unique()
|
||||
|
||||
def is_finite(self) -> Self: ...
|
||||
def is_first_distinct(self) -> Self: ...
|
||||
def is_in(self, other: Any) -> Self: ...
|
||||
def is_last_distinct(self) -> Self: ...
|
||||
def is_nan(self) -> Self: ...
|
||||
def is_null(self) -> Self: ...
|
||||
def is_unique(self) -> Self: ...
|
||||
def log(self, base: float) -> Self: ...
|
||||
def mode(self, *, keep: ModeKeepStrategy) -> Self: ...
|
||||
def rank(self, method: RankMethod, *, descending: bool) -> Self: ...
|
||||
def replace_strict(
|
||||
self,
|
||||
old: Sequence[Any] | Mapping[Any, Any],
|
||||
new: Sequence[Any],
|
||||
*,
|
||||
return_dtype: IntoDType | None,
|
||||
) -> Self: ...
|
||||
def rolling_mean(
|
||||
self, window_size: int, *, min_samples: int, center: bool
|
||||
) -> Self: ...
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self: ...
|
||||
def rolling_sum(
|
||||
self, window_size: int, *, min_samples: int, center: bool
|
||||
) -> Self: ...
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self: ...
|
||||
def round(self, decimals: int) -> Self: ...
|
||||
def shift(self, n: int) -> Self: ...
|
||||
def unique(self) -> Self: ...
|
||||
|
||||
@property
|
||||
def str(self) -> StringNamespace[Self]: ...
|
||||
@property
|
||||
def dt(self) -> DateTimeNamespace[Self]: ...
|
||||
@property
|
||||
def cat(self) -> CatNamespace[Self]: ...
|
||||
@property
|
||||
def list(self) -> ListNamespace[Self]: ...
|
||||
@property
|
||||
def struct(self) -> StructNamespace[Self]: ...
|
426
lib/python3.11/site-packages/narwhals/_compliant/dataframe.py
Normal file
426
lib/python3.11/site-packages/narwhals/_compliant/dataframe.py
Normal file
@ -0,0 +1,426 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping, Sequence, Sized
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload
|
||||
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantDataFrameAny,
|
||||
CompliantExprT_contra,
|
||||
CompliantLazyFrameAny,
|
||||
CompliantSeriesT,
|
||||
EagerExprT,
|
||||
EagerSeriesT,
|
||||
NativeDataFrameT,
|
||||
NativeLazyFrameT,
|
||||
NativeSeriesT,
|
||||
)
|
||||
from narwhals._translate import (
|
||||
ArrowConvertible,
|
||||
DictConvertible,
|
||||
FromNative,
|
||||
NumpyConvertible,
|
||||
ToNarwhals,
|
||||
ToNarwhalsT_co,
|
||||
)
|
||||
from narwhals._typing_compat import assert_never
|
||||
from narwhals._utils import (
|
||||
ValidateBackendVersion,
|
||||
Version,
|
||||
_StoresNative,
|
||||
check_columns_exist,
|
||||
is_compliant_series,
|
||||
is_index_selector,
|
||||
is_range,
|
||||
is_sequence_like,
|
||||
is_sized_multi_index_selector,
|
||||
is_slice_index,
|
||||
is_slice_none,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
from narwhals._compliant.group_by import CompliantGroupBy, DataFrameGroupBy
|
||||
from narwhals._compliant.namespace import EagerNamespace
|
||||
from narwhals._spark_like.utils import SparkSession
|
||||
from narwhals._translate import IntoArrowTable
|
||||
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
|
||||
from narwhals._utils import Implementation, _LimitedContext
|
||||
from narwhals.dataframe import DataFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.exceptions import ColumnNotFoundError
|
||||
from narwhals.typing import (
|
||||
AsofJoinStrategy,
|
||||
IntoSchema,
|
||||
JoinStrategy,
|
||||
LazyUniqueKeepStrategy,
|
||||
MultiColSelector,
|
||||
MultiIndexSelector,
|
||||
PivotAgg,
|
||||
SingleIndexSelector,
|
||||
SizedMultiIndexSelector,
|
||||
SizedMultiNameSelector,
|
||||
SizeUnit,
|
||||
UniqueKeepStrategy,
|
||||
_2DArray,
|
||||
_SliceIndex,
|
||||
_SliceName,
|
||||
)
|
||||
|
||||
Incomplete: TypeAlias = Any
|
||||
|
||||
__all__ = ["CompliantDataFrame", "CompliantFrame", "CompliantLazyFrame", "EagerDataFrame"]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047
|
||||
|
||||
_NativeFrameT = TypeVar("_NativeFrameT")
|
||||
|
||||
|
||||
class CompliantFrame(
|
||||
_StoresNative[_NativeFrameT],
|
||||
FromNative[_NativeFrameT],
|
||||
ToNarwhals[ToNarwhalsT_co],
|
||||
Protocol[CompliantExprT_contra, _NativeFrameT, ToNarwhalsT_co],
|
||||
):
|
||||
"""Common parts of `DataFrame`, `LazyFrame`."""
|
||||
|
||||
_native_frame: _NativeFrameT
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
def __native_namespace__(self) -> ModuleType: ...
|
||||
def __narwhals_namespace__(self) -> Any: ...
|
||||
def _with_version(self, version: Version) -> Self: ...
|
||||
@classmethod
|
||||
def from_native(cls, data: _NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
|
||||
@property
|
||||
def columns(self) -> Sequence[str]: ...
|
||||
@property
|
||||
def native(self) -> _NativeFrameT:
|
||||
return self._native_frame
|
||||
|
||||
@property
|
||||
def schema(self) -> Mapping[str, DType]: ...
|
||||
|
||||
def collect_schema(self) -> Mapping[str, DType]: ...
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
|
||||
def explode(self, columns: Sequence[str]) -> Self: ...
|
||||
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
|
||||
def group_by(
|
||||
self,
|
||||
keys: Sequence[str] | Sequence[CompliantExprT_contra],
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
|
||||
def head(self, n: int) -> Self: ...
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self: ...
|
||||
def join_asof(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
left_on: str,
|
||||
right_on: str,
|
||||
by_left: Sequence[str] | None,
|
||||
by_right: Sequence[str] | None,
|
||||
strategy: AsofJoinStrategy,
|
||||
suffix: str,
|
||||
) -> Self: ...
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self: ...
|
||||
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
"""`select` where all args are column names."""
|
||||
...
|
||||
|
||||
def sort(
|
||||
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
|
||||
) -> Self: ...
|
||||
def tail(self, n: int) -> Self: ...
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self: ...
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self: ...
|
||||
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
|
||||
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self: ...
|
||||
|
||||
|
||||
class CompliantDataFrame(
|
||||
NumpyConvertible["_2DArray", "_2DArray"],
|
||||
DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
|
||||
ArrowConvertible["pa.Table", "IntoArrowTable"],
|
||||
Sized,
|
||||
CompliantFrame[CompliantExprT_contra, NativeDataFrameT, ToNarwhalsT_co],
|
||||
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeDataFrameT, ToNarwhalsT_co],
|
||||
):
|
||||
def __narwhals_dataframe__(self) -> Self: ...
|
||||
@classmethod
|
||||
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self: ...
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: IntoSchema | None,
|
||||
) -> Self: ...
|
||||
@classmethod
|
||||
def from_numpy(
|
||||
cls,
|
||||
data: _2DArray,
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: IntoSchema | Sequence[str] | None,
|
||||
) -> Self: ...
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
|
||||
def __getitem__(
|
||||
self,
|
||||
item: tuple[
|
||||
SingleIndexSelector | MultiIndexSelector[CompliantSeriesT],
|
||||
MultiColSelector[CompliantSeriesT],
|
||||
],
|
||||
) -> Self: ...
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, int]: ...
|
||||
def clone(self) -> Self: ...
|
||||
def estimated_size(self, unit: SizeUnit) -> int | float: ...
|
||||
def gather_every(self, n: int, offset: int) -> Self: ...
|
||||
def get_column(self, name: str) -> CompliantSeriesT: ...
|
||||
def group_by(
|
||||
self,
|
||||
keys: Sequence[str] | Sequence[CompliantExprT_contra],
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> DataFrameGroupBy[Self, Any]: ...
|
||||
def item(self, row: int | None, column: int | str | None) -> Any: ...
|
||||
def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
|
||||
def iter_rows(
|
||||
self, *, named: bool, buffer_size: int
|
||||
) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ...
|
||||
def is_unique(self) -> CompliantSeriesT: ...
|
||||
def lazy(
|
||||
self, backend: _LazyAllowedImpl | None, *, session: SparkSession | None
|
||||
) -> CompliantLazyFrameAny: ...
|
||||
def pivot(
|
||||
self,
|
||||
on: Sequence[str],
|
||||
*,
|
||||
index: Sequence[str] | None,
|
||||
values: Sequence[str] | None,
|
||||
aggregate_function: PivotAgg | None,
|
||||
sort_columns: bool,
|
||||
separator: str,
|
||||
) -> Self: ...
|
||||
def row(self, index: int) -> tuple[Any, ...]: ...
|
||||
def rows(
|
||||
self, *, named: bool
|
||||
) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ...
|
||||
def sample(
|
||||
self,
|
||||
n: int | None,
|
||||
*,
|
||||
fraction: float | None,
|
||||
with_replacement: bool,
|
||||
seed: int | None,
|
||||
) -> Self: ...
|
||||
def to_arrow(self) -> pa.Table: ...
|
||||
def to_pandas(self) -> pd.DataFrame: ...
|
||||
def to_polars(self) -> pl.DataFrame: ...
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ...
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
|
||||
def to_dict(
|
||||
self, *, as_series: bool
|
||||
) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ...
|
||||
def unique(
|
||||
self,
|
||||
subset: Sequence[str] | None,
|
||||
*,
|
||||
keep: UniqueKeepStrategy,
|
||||
maintain_order: bool | None = None,
|
||||
) -> Self: ...
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: ...
|
||||
@overload
|
||||
def write_csv(self, file: None) -> str: ...
|
||||
@overload
|
||||
def write_csv(self, file: str | Path | BytesIO) -> None: ...
|
||||
def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
|
||||
def write_parquet(self, file: str | Path | BytesIO) -> None: ...
|
||||
|
||||
|
||||
class CompliantLazyFrame(
|
||||
CompliantFrame[CompliantExprT_contra, NativeLazyFrameT, ToNarwhalsT_co],
|
||||
Protocol[CompliantExprT_contra, NativeLazyFrameT, ToNarwhalsT_co],
|
||||
):
|
||||
def __narwhals_lazyframe__(self) -> Self: ...
|
||||
# `LazySelectorNamespace._iter_columns` depends
|
||||
def _iter_columns(self) -> Iterator[Any]: ...
|
||||
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
|
||||
"""`select` where all args are aggregations or literals.
|
||||
|
||||
(so, no broadcasting is necessary).
|
||||
"""
|
||||
...
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny: ...
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None: ...
|
||||
|
||||
|
||||
class EagerDataFrame(
|
||||
CompliantDataFrame[
|
||||
EagerSeriesT, EagerExprT, NativeDataFrameT, "DataFrame[NativeDataFrameT]"
|
||||
],
|
||||
CompliantLazyFrame[EagerExprT, "Incomplete", "DataFrame[NativeDataFrameT]"],
|
||||
ValidateBackendVersion,
|
||||
Protocol[EagerSeriesT, EagerExprT, NativeDataFrameT, NativeSeriesT],
|
||||
):
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
def __narwhals_namespace__(
|
||||
self,
|
||||
) -> EagerNamespace[
|
||||
Self, EagerSeriesT, EagerExprT, NativeDataFrameT, NativeSeriesT
|
||||
]: ...
|
||||
|
||||
def to_narwhals(self) -> DataFrame[NativeDataFrameT]:
|
||||
return self._version.dataframe(self, level="full")
|
||||
|
||||
def aggregate(self, *exprs: EagerExprT) -> Self:
|
||||
# NOTE: Ignore intermittent [False Negative]
|
||||
# Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "exprs" of type "EagerExprT@EagerDataFrame" in function "select"
|
||||
# Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame"
|
||||
return self.select(*exprs) # pyright: ignore[reportArgumentType]
|
||||
|
||||
def _with_native(
|
||||
self, df: NativeDataFrameT, *, validate_column_names: bool = True
|
||||
) -> Self: ...
|
||||
|
||||
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
|
||||
return check_columns_exist(subset, available=self.columns)
|
||||
|
||||
def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT:
|
||||
"""Evaluate `expr` and ensure it has a **single** output."""
|
||||
result: Sequence[EagerSeriesT] = expr(self)
|
||||
assert len(result) == 1 # debug assertion # noqa: S101
|
||||
return result[0]
|
||||
|
||||
def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
|
||||
# NOTE: Ignore intermittent [False Negative]
|
||||
# Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "expr" of type "EagerExprT@EagerDataFrame" in function "_evaluate_into_expr"
|
||||
# Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame"
|
||||
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
|
||||
"""Return list of raw columns.
|
||||
|
||||
For eager backends we alias operations at each step.
|
||||
|
||||
As a safety precaution, here we can check that the expected result names match those
|
||||
we were expecting from the various `evaluate_output_names` / `alias_output_names` calls.
|
||||
|
||||
Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
|
||||
"""
|
||||
aliases = expr._evaluate_aliases(self)
|
||||
result = expr(self)
|
||||
if list(aliases) != (
|
||||
result_aliases := [s.name for s in result]
|
||||
): # pragma: no cover
|
||||
msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
|
||||
raise AssertionError(msg)
|
||||
return result
|
||||
|
||||
def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
|
||||
"""Extract native Series, broadcasting to `len(self)` if necessary."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def _numpy_column_names(
|
||||
data: _2DArray, columns: Sequence[str] | None, /
|
||||
) -> list[str]:
|
||||
return list(columns or (f"column_{x}" for x in range(data.shape[1])))
|
||||
|
||||
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
|
||||
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
|
||||
def _select_multi_index(
|
||||
self, columns: SizedMultiIndexSelector[NativeSeriesT]
|
||||
) -> Self: ...
|
||||
def _select_multi_name(
|
||||
self, columns: SizedMultiNameSelector[NativeSeriesT]
|
||||
) -> Self: ...
|
||||
def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
|
||||
def _select_slice_name(self, columns: _SliceName) -> Self: ...
|
||||
def __getitem__( # noqa: C901, PLR0912
|
||||
self,
|
||||
item: tuple[
|
||||
SingleIndexSelector | MultiIndexSelector[EagerSeriesT],
|
||||
MultiColSelector[EagerSeriesT],
|
||||
],
|
||||
) -> Self:
|
||||
rows, columns = item
|
||||
compliant = self
|
||||
if not is_slice_none(columns):
|
||||
if isinstance(columns, Sized) and len(columns) == 0:
|
||||
return compliant.select()
|
||||
if is_index_selector(columns):
|
||||
if is_slice_index(columns) or is_range(columns):
|
||||
compliant = compliant._select_slice_index(columns)
|
||||
elif is_compliant_series(columns):
|
||||
compliant = self._select_multi_index(columns.native)
|
||||
else:
|
||||
compliant = compliant._select_multi_index(columns)
|
||||
elif isinstance(columns, slice):
|
||||
compliant = compliant._select_slice_name(columns)
|
||||
elif is_compliant_series(columns):
|
||||
compliant = self._select_multi_name(columns.native)
|
||||
elif is_sequence_like(columns):
|
||||
compliant = self._select_multi_name(columns)
|
||||
else:
|
||||
assert_never(columns)
|
||||
|
||||
if not is_slice_none(rows):
|
||||
if isinstance(rows, int):
|
||||
compliant = compliant._gather([rows])
|
||||
elif isinstance(rows, (slice, range)):
|
||||
compliant = compliant._gather_slice(rows)
|
||||
elif is_compliant_series(rows):
|
||||
compliant = compliant._gather(rows.native)
|
||||
elif is_sized_multi_index_selector(rows):
|
||||
compliant = compliant._gather(rows)
|
||||
else:
|
||||
assert_never(rows)
|
||||
|
||||
return compliant
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
return self.write_parquet(file)
|
1140
lib/python3.11/site-packages/narwhals/_compliant/expr.py
Normal file
1140
lib/python3.11/site-packages/narwhals/_compliant/expr.py
Normal file
File diff suppressed because it is too large
Load Diff
180
lib/python3.11/site-packages/narwhals/_compliant/group_by.py
Normal file
180
lib/python3.11/site-packages/narwhals/_compliant/group_by.py
Normal file
@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol, TypeVar
|
||||
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantDataFrameT,
|
||||
CompliantDataFrameT_co,
|
||||
CompliantExprT_contra,
|
||||
CompliantFrameT,
|
||||
CompliantFrameT_co,
|
||||
DepthTrackingExprAny,
|
||||
DepthTrackingExprT_contra,
|
||||
EagerExprT_contra,
|
||||
ImplExprT_contra,
|
||||
NarwhalsAggregation,
|
||||
)
|
||||
from narwhals._utils import is_sequence_of, zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
|
||||
from narwhals._compliant.expr import ImplExpr
|
||||
|
||||
|
||||
__all__ = ["CompliantGroupBy", "DepthTrackingGroupBy", "EagerGroupBy"]
|
||||
|
||||
NativeAggregationT_co = TypeVar(
|
||||
"NativeAggregationT_co", bound="str | Callable[..., Any]", covariant=True
|
||||
)
|
||||
|
||||
|
||||
_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)")
|
||||
|
||||
|
||||
def _evaluate_aliases(
|
||||
frame: CompliantFrameT, exprs: Iterable[ImplExpr[CompliantFrameT, Any]], /
|
||||
) -> list[str]:
|
||||
it = (expr._evaluate_aliases(frame) for expr in exprs)
|
||||
return list(chain.from_iterable(it))
|
||||
|
||||
|
||||
class CompliantGroupBy(Protocol[CompliantFrameT_co, CompliantExprT_contra]):
|
||||
_compliant_frame: Any
|
||||
|
||||
@property
|
||||
def compliant(self) -> CompliantFrameT_co:
|
||||
return self._compliant_frame # type: ignore[no-any-return]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compliant_frame: CompliantFrameT_co,
|
||||
keys: Sequence[CompliantExprT_contra] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None: ...
|
||||
|
||||
def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ...
|
||||
|
||||
|
||||
class DataFrameGroupBy(
|
||||
CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra],
|
||||
Protocol[CompliantDataFrameT_co, CompliantExprT_contra],
|
||||
):
|
||||
def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ...
|
||||
|
||||
|
||||
class ParseKeysGroupBy(
|
||||
CompliantGroupBy[CompliantFrameT, ImplExprT_contra],
|
||||
Protocol[CompliantFrameT, ImplExprT_contra],
|
||||
):
|
||||
def _parse_keys(
|
||||
self,
|
||||
compliant_frame: CompliantFrameT,
|
||||
keys: Sequence[ImplExprT_contra] | Sequence[str],
|
||||
) -> tuple[CompliantFrameT, list[str], list[str]]:
|
||||
if is_sequence_of(keys, str):
|
||||
keys_str = list(keys)
|
||||
return compliant_frame, keys_str, keys_str.copy()
|
||||
return self._parse_expr_keys(compliant_frame, keys=keys)
|
||||
|
||||
@staticmethod
|
||||
def _parse_expr_keys(
|
||||
compliant_frame: CompliantFrameT, keys: Sequence[ImplExprT_contra]
|
||||
) -> tuple[CompliantFrameT, list[str], list[str]]:
|
||||
"""Parses key expressions to set up `.agg` operation with correct information.
|
||||
|
||||
Since keys are expressions, it's possible to alias any such key to match
|
||||
other dataframe column names.
|
||||
|
||||
In order to match polars behavior and not overwrite columns when evaluating keys:
|
||||
|
||||
- We evaluate what the output key names should be, in order to remap temporary column
|
||||
names to the expected ones, and to exclude those from unnamed expressions in
|
||||
`.agg(...)` context (see https://github.com/narwhals-dev/narwhals/pull/2325#issuecomment-2800004520)
|
||||
- Create temporary names for evaluated key expressions that are guaranteed to have
|
||||
no overlap with any existing column name.
|
||||
- Add these temporary columns to the compliant dataframe.
|
||||
"""
|
||||
tmp_name_length = max(len(str(c)) for c in compliant_frame.columns) + 1
|
||||
|
||||
def _temporary_name(key: str) -> str:
|
||||
# 5 is the length of `__tmp`
|
||||
key_str = str(key) # pandas allows non-string column names :sob:
|
||||
return f"_{key_str}_tmp{'_' * (tmp_name_length - len(key_str) - 5)}"
|
||||
|
||||
keys_aliases = [expr._evaluate_aliases(compliant_frame) for expr in keys]
|
||||
safe_keys = [
|
||||
# multi-output expression cannot have duplicate names, hence it's safe to suffix
|
||||
key.name.map(_temporary_name)
|
||||
if (metadata := key._metadata) and metadata.expansion_kind.is_multi_output()
|
||||
# otherwise it's single named and we can use Expr.alias
|
||||
else key.alias(_temporary_name(new_names[0]))
|
||||
for key, new_names in zip_strict(keys, keys_aliases)
|
||||
]
|
||||
return (
|
||||
compliant_frame.with_columns(*safe_keys),
|
||||
_evaluate_aliases(compliant_frame, safe_keys),
|
||||
list(chain.from_iterable(keys_aliases)),
|
||||
)
|
||||
|
||||
|
||||
class DepthTrackingGroupBy(
|
||||
ParseKeysGroupBy[CompliantFrameT, DepthTrackingExprT_contra],
|
||||
Protocol[CompliantFrameT, DepthTrackingExprT_contra, NativeAggregationT_co],
|
||||
):
|
||||
"""`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`."""
|
||||
|
||||
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]]
|
||||
"""Mapping from `narwhals` to native representation.
|
||||
|
||||
Note:
|
||||
- `Dask` *may* return a `Callable` instead of a `str` referring to one.
|
||||
"""
|
||||
|
||||
def _ensure_all_simple(self, exprs: Sequence[DepthTrackingExprT_contra]) -> None:
|
||||
for expr in exprs:
|
||||
if not self._is_simple(expr):
|
||||
name = self.compliant._implementation.name.lower()
|
||||
msg = (
|
||||
f"Non-trivial complex aggregation found.\n\n"
|
||||
f"Hint: you were probably trying to apply a non-elementary aggregation with a"
|
||||
f"{name!r} table.\n"
|
||||
"Please rewrite your query such that group-by aggregations "
|
||||
"are elementary. For example, instead of:\n\n"
|
||||
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
|
||||
"use:\n\n"
|
||||
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@classmethod
|
||||
def _is_simple(cls, expr: DepthTrackingExprAny, /) -> bool:
|
||||
"""Return `True` is we can efficiently use `expr` in a native `group_by` context."""
|
||||
return expr._is_elementary() and cls._leaf_name(expr) in cls._REMAP_AGGS
|
||||
|
||||
@classmethod
|
||||
def _remap_expr_name(
|
||||
cls, name: NarwhalsAggregation | Any, /
|
||||
) -> NativeAggregationT_co:
|
||||
"""Replace `name`, with some native representation.
|
||||
|
||||
Arguments:
|
||||
name: Name of a `nw.Expr` aggregation method.
|
||||
"""
|
||||
return cls._REMAP_AGGS.get(name, name)
|
||||
|
||||
@classmethod
|
||||
def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any:
|
||||
"""Return the last function name in the chain defined by `expr`."""
|
||||
return _RE_LEAF_NAME.sub("", expr._function_name)
|
||||
|
||||
|
||||
class EagerGroupBy(
|
||||
DepthTrackingGroupBy[CompliantDataFrameT, EagerExprT_contra, NativeAggregationT_co],
|
||||
DataFrameGroupBy[CompliantDataFrameT, EagerExprT_contra],
|
||||
Protocol[CompliantDataFrameT, EagerExprT_contra, NativeAggregationT_co],
|
||||
): ...
|
238
lib/python3.11/site-packages/narwhals/_compliant/namespace.py
Normal file
238
lib/python3.11/site-packages/narwhals/_compliant/namespace.py
Normal file
@ -0,0 +1,238 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Protocol, overload
|
||||
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantExprT,
|
||||
CompliantFrameT,
|
||||
CompliantLazyFrameT,
|
||||
DepthTrackingExprT,
|
||||
EagerDataFrameT,
|
||||
EagerExprT,
|
||||
EagerSeriesT,
|
||||
LazyExprT,
|
||||
NativeFrameT,
|
||||
NativeFrameT_co,
|
||||
NativeSeriesT,
|
||||
)
|
||||
from narwhals._expression_parsing import is_expr, is_series
|
||||
from narwhals._utils import (
|
||||
exclude_column_names,
|
||||
get_column_names,
|
||||
passthrough_column_names,
|
||||
)
|
||||
from narwhals.dependencies import is_numpy_array, is_numpy_array_2d
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Container, Iterable, Sequence
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.selectors import CompliantSelectorNamespace
|
||||
from narwhals._compliant.when_then import CompliantWhen, EagerWhen
|
||||
from narwhals._utils import Implementation, Version
|
||||
from narwhals.expr import Expr
|
||||
from narwhals.series import Series
|
||||
from narwhals.typing import (
|
||||
ConcatMethod,
|
||||
Into1DArray,
|
||||
IntoDType,
|
||||
IntoSchema,
|
||||
NonNestedLiteral,
|
||||
_1DArray,
|
||||
_2DArray,
|
||||
)
|
||||
|
||||
Incomplete: TypeAlias = Any
|
||||
|
||||
__all__ = [
|
||||
"CompliantNamespace",
|
||||
"DepthTrackingNamespace",
|
||||
"EagerNamespace",
|
||||
"LazyNamespace",
|
||||
]
|
||||
|
||||
|
||||
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
|
||||
# NOTE: `narwhals`
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[CompliantExprT]: ...
|
||||
def parse_into_expr(
|
||||
self, data: Expr | NonNestedLiteral | Any, /, *, str_as_lit: bool
|
||||
) -> CompliantExprT | NonNestedLiteral:
|
||||
if is_expr(data):
|
||||
expr = data._to_compliant_expr(self)
|
||||
assert isinstance(expr, self._expr) # noqa: S101
|
||||
return expr
|
||||
if isinstance(data, str) and not str_as_lit:
|
||||
return self.col(data)
|
||||
return data
|
||||
|
||||
# NOTE: `polars`
|
||||
def all(self) -> CompliantExprT:
|
||||
return self._expr.from_column_names(get_column_names, context=self)
|
||||
|
||||
def col(self, *column_names: str) -> CompliantExprT:
|
||||
return self._expr.from_column_names(
|
||||
passthrough_column_names(column_names), context=self
|
||||
)
|
||||
|
||||
def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
|
||||
return self._expr.from_column_names(
|
||||
partial(exclude_column_names, names=excluded_names), context=self
|
||||
)
|
||||
|
||||
def nth(self, *column_indices: int) -> CompliantExprT:
|
||||
return self._expr.from_column_indices(*column_indices, context=self)
|
||||
|
||||
def len(self) -> CompliantExprT: ...
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ...
|
||||
def all_horizontal(
|
||||
self, *exprs: CompliantExprT, ignore_nulls: bool
|
||||
) -> CompliantExprT: ...
|
||||
def any_horizontal(
|
||||
self, *exprs: CompliantExprT, ignore_nulls: bool
|
||||
) -> CompliantExprT: ...
|
||||
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def concat(
|
||||
self, items: Iterable[CompliantFrameT], *, how: ConcatMethod
|
||||
) -> CompliantFrameT: ...
|
||||
def when(
|
||||
self, predicate: CompliantExprT
|
||||
) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ...
|
||||
def concat_str(
|
||||
self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool
|
||||
) -> CompliantExprT: ...
|
||||
@property
|
||||
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
|
||||
def coalesce(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
|
||||
|
||||
class DepthTrackingNamespace(
|
||||
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
|
||||
Protocol[CompliantFrameT, DepthTrackingExprT],
|
||||
):
|
||||
def all(self) -> DepthTrackingExprT:
|
||||
return self._expr.from_column_names(
|
||||
get_column_names, function_name="all", context=self
|
||||
)
|
||||
|
||||
def col(self, *column_names: str) -> DepthTrackingExprT:
|
||||
return self._expr.from_column_names(
|
||||
passthrough_column_names(column_names), function_name="col", context=self
|
||||
)
|
||||
|
||||
def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT:
|
||||
return self._expr.from_column_names(
|
||||
partial(exclude_column_names, names=excluded_names),
|
||||
function_name="exclude",
|
||||
context=self,
|
||||
)
|
||||
|
||||
|
||||
class LazyNamespace(
|
||||
CompliantNamespace[CompliantLazyFrameT, LazyExprT],
|
||||
Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT_co],
|
||||
):
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[CompliantLazyFrameT]: ...
|
||||
|
||||
def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:
|
||||
if self._lazyframe._is_native(data):
|
||||
return self._lazyframe.from_native(data, context=self)
|
||||
msg = f"Unsupported type: {type(data).__name__!r}" # pragma: no cover
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
class EagerNamespace(
|
||||
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
|
||||
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
|
||||
):
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def _dataframe(self) -> type[EagerDataFrameT]: ...
|
||||
@property
|
||||
def _series(self) -> type[EagerSeriesT]: ...
|
||||
def when(
|
||||
self, predicate: EagerExprT
|
||||
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...
|
||||
|
||||
@overload
|
||||
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
|
||||
@overload
|
||||
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ...
|
||||
def from_native(
|
||||
self, data: NativeFrameT | NativeSeriesT | Any, /
|
||||
) -> EagerDataFrameT | EagerSeriesT:
|
||||
if self._dataframe._is_native(data):
|
||||
return self._dataframe.from_native(data, context=self)
|
||||
if self._series._is_native(data):
|
||||
return self._series.from_native(data, context=self)
|
||||
msg = f"Unsupported type: {type(data).__name__!r}"
|
||||
raise TypeError(msg)
|
||||
|
||||
def parse_into_expr(
|
||||
self,
|
||||
data: Expr | Series[NativeSeriesT] | _1DArray | NonNestedLiteral,
|
||||
/,
|
||||
*,
|
||||
str_as_lit: bool,
|
||||
) -> EagerExprT | NonNestedLiteral:
|
||||
if not (is_series(data) or is_numpy_array(data)):
|
||||
return super().parse_into_expr(data, str_as_lit=str_as_lit)
|
||||
return self._expr._from_series(
|
||||
data._compliant_series
|
||||
if is_series(data)
|
||||
else self._series.from_numpy(data, context=self)
|
||||
)
|
||||
|
||||
@overload
|
||||
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT: ...
|
||||
|
||||
@overload
|
||||
def from_numpy(
|
||||
self, data: _2DArray, /, schema: IntoSchema | Sequence[str] | None
|
||||
) -> EagerDataFrameT: ...
|
||||
|
||||
def from_numpy(
|
||||
self,
|
||||
data: Into1DArray | _2DArray,
|
||||
/,
|
||||
schema: IntoSchema | Sequence[str] | None = None,
|
||||
) -> EagerDataFrameT | EagerSeriesT:
|
||||
if is_numpy_array_2d(data):
|
||||
return self._dataframe.from_numpy(data, schema=schema, context=self)
|
||||
return self._series.from_numpy(data, context=self)
|
||||
|
||||
def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
|
||||
def _concat_horizontal(
|
||||
self, dfs: Sequence[NativeFrameT | Any], /
|
||||
) -> NativeFrameT: ...
|
||||
def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
|
||||
def concat(
|
||||
self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod
|
||||
) -> EagerDataFrameT:
|
||||
dfs = [item.native for item in items]
|
||||
if how == "horizontal":
|
||||
native = self._concat_horizontal(dfs)
|
||||
elif how == "vertical":
|
||||
native = self._concat_vertical(dfs)
|
||||
elif how == "diagonal":
|
||||
native = self._concat_diagonal(dfs)
|
||||
else: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
return self._dataframe.from_native(native, context=self)
|
318
lib/python3.11/site-packages/narwhals/_compliant/selectors.py
Normal file
318
lib/python3.11/site-packages/narwhals/_compliant/selectors.py
Normal file
@ -0,0 +1,318 @@
|
||||
"""Almost entirely complete, generic `selectors` implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Protocol, TypeVar, overload
|
||||
|
||||
from narwhals._compliant.expr import CompliantExpr
|
||||
from narwhals._utils import (
|
||||
_parse_time_unit_and_time_zone,
|
||||
dtype_matches_time_unit_and_time_zone,
|
||||
get_column_names,
|
||||
is_compliant_dataframe,
|
||||
zip_strict,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection, Iterable, Iterator, Sequence
|
||||
from datetime import timezone
|
||||
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.expr import NativeExpr
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantDataFrameAny,
|
||||
CompliantExprAny,
|
||||
CompliantFrameAny,
|
||||
CompliantLazyFrameAny,
|
||||
CompliantSeriesAny,
|
||||
CompliantSeriesOrNativeExprAny,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
)
|
||||
from narwhals._utils import Implementation, Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
__all__ = [
|
||||
"CompliantSelector",
|
||||
"CompliantSelectorNamespace",
|
||||
"EagerSelectorNamespace",
|
||||
"LazySelectorNamespace",
|
||||
]
|
||||
|
||||
|
||||
SeriesOrExprT = TypeVar("SeriesOrExprT", bound="CompliantSeriesOrNativeExprAny")
|
||||
SeriesT = TypeVar("SeriesT", bound="CompliantSeriesAny")
|
||||
ExprT = TypeVar("ExprT", bound="NativeExpr")
|
||||
FrameT = TypeVar("FrameT", bound="CompliantFrameAny")
|
||||
DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrameAny")
|
||||
LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrameAny")
|
||||
SelectorOrExpr: TypeAlias = (
|
||||
"CompliantSelector[FrameT, SeriesOrExprT] | CompliantExpr[FrameT, SeriesOrExprT]"
|
||||
)
|
||||
|
||||
|
||||
class CompliantSelectorNamespace(Protocol[FrameT, SeriesOrExprT]):
|
||||
# NOTE: `narwhals`
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@property
|
||||
def _selector(self) -> type[CompliantSelector[FrameT, SeriesOrExprT]]: ...
|
||||
@classmethod
|
||||
def from_namespace(cls, context: _LimitedContext, /) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._implementation = context._implementation
|
||||
obj._version = context._version
|
||||
return obj
|
||||
|
||||
def _iter_columns(self, df: FrameT, /) -> Iterator[SeriesOrExprT]: ...
|
||||
def _iter_schema(self, df: FrameT, /) -> Iterator[tuple[str, DType]]: ...
|
||||
def _iter_columns_dtypes(
|
||||
self, df: FrameT, /
|
||||
) -> Iterator[tuple[SeriesOrExprT, DType]]: ...
|
||||
def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]:
|
||||
yield from zip_strict(self._iter_columns(df), df.columns)
|
||||
|
||||
def _is_dtype(
|
||||
self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], /
|
||||
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [
|
||||
ser for ser, tp in self._iter_columns_dtypes(df) if isinstance(tp, dtype)
|
||||
]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if isinstance(tp, dtype)]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
# NOTE: `polars`
|
||||
def by_dtype(
|
||||
self, dtypes: Collection[DType | type[DType]]
|
||||
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp in dtypes]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if tp in dtypes]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
def matches(self, pattern: str) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
p = re.compile(pattern)
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
if (
|
||||
is_compliant_dataframe(df)
|
||||
and not self._implementation.is_duckdb()
|
||||
and not self._implementation.is_ibis()
|
||||
):
|
||||
return [df.get_column(col) for col in df.columns if p.search(col)]
|
||||
|
||||
return [ser for ser, name in self._iter_columns_names(df) if p.search(name)]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [col for col in df.columns if p.search(col)]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
def numeric(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp.is_numeric()]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if tp.is_numeric()]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
def categorical(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self._is_dtype(self._version.dtypes.Categorical)
|
||||
|
||||
def string(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self._is_dtype(self._version.dtypes.String)
|
||||
|
||||
def boolean(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self._is_dtype(self._version.dtypes.Boolean)
|
||||
|
||||
def all(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return list(self._iter_columns(df))
|
||||
|
||||
return self._selector.from_callables(series, get_column_names, context=self)
|
||||
|
||||
def datetime(
|
||||
self,
|
||||
time_unit: TimeUnit | Iterable[TimeUnit] | None,
|
||||
time_zone: str | timezone | Iterable[str | timezone | None] | None,
|
||||
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone)
|
||||
matches = partial(
|
||||
dtype_matches_time_unit_and_time_zone,
|
||||
dtypes=self._version.dtypes,
|
||||
time_units=time_units,
|
||||
time_zones=time_zones,
|
||||
)
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [ser for ser, tp in self._iter_columns_dtypes(df) if matches(tp)]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if matches(tp)]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
|
||||
class EagerSelectorNamespace(
|
||||
CompliantSelectorNamespace[DataFrameT, SeriesT], Protocol[DataFrameT, SeriesT]
|
||||
):
|
||||
def _iter_schema(self, df: DataFrameT, /) -> Iterator[tuple[str, DType]]:
|
||||
for ser in self._iter_columns(df):
|
||||
yield ser.name, ser.dtype
|
||||
|
||||
def _iter_columns(self, df: DataFrameT, /) -> Iterator[SeriesT]:
|
||||
yield from df.iter_columns()
|
||||
|
||||
def _iter_columns_dtypes(self, df: DataFrameT, /) -> Iterator[tuple[SeriesT, DType]]:
|
||||
for ser in self._iter_columns(df):
|
||||
yield ser, ser.dtype
|
||||
|
||||
|
||||
class LazySelectorNamespace(
|
||||
CompliantSelectorNamespace[LazyFrameT, ExprT], Protocol[LazyFrameT, ExprT]
|
||||
):
|
||||
def _iter_schema(self, df: LazyFrameT) -> Iterator[tuple[str, DType]]:
|
||||
yield from df.schema.items()
|
||||
|
||||
def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]:
|
||||
yield from df._iter_columns()
|
||||
|
||||
def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]:
|
||||
yield from zip_strict(self._iter_columns(df), df.schema.values())
|
||||
|
||||
|
||||
class CompliantSelector(
|
||||
CompliantExpr[FrameT, SeriesOrExprT], Protocol[FrameT, SeriesOrExprT]
|
||||
):
|
||||
_call: EvalSeries[FrameT, SeriesOrExprT]
|
||||
_function_name: str
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@classmethod
|
||||
def from_callables(
|
||||
cls,
|
||||
call: EvalSeries[FrameT, SeriesOrExprT],
|
||||
evaluate_output_names: EvalNames[FrameT],
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._call = call
|
||||
obj._evaluate_output_names = evaluate_output_names
|
||||
obj._alias_output_names = None
|
||||
obj._implementation = context._implementation
|
||||
obj._version = context._version
|
||||
return obj
|
||||
|
||||
@property
|
||||
def selectors(self) -> CompliantSelectorNamespace[FrameT, SeriesOrExprT]:
|
||||
return self.__narwhals_namespace__().selectors
|
||||
|
||||
def _to_expr(self) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
|
||||
def _is_selector(
|
||||
self, other: Self | CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> TypeIs[CompliantSelector[FrameT, SeriesOrExprT]]:
|
||||
return isinstance(other, type(self))
|
||||
|
||||
@overload
|
||||
def __sub__(self, other: Self) -> Self: ...
|
||||
@overload
|
||||
def __sub__(
|
||||
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
def __sub__(
|
||||
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
||||
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
||||
if self._is_selector(other):
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [
|
||||
x
|
||||
for x, name in zip_strict(self(df), lhs_names)
|
||||
if name not in rhs_names
|
||||
]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [x for x in lhs_names if x not in rhs_names]
|
||||
|
||||
return self.selectors._selector.from_callables(series, names, context=self)
|
||||
return self._to_expr() - other
|
||||
|
||||
@overload
|
||||
def __or__(self, other: Self) -> Self: ...
|
||||
@overload
|
||||
def __or__(
|
||||
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
def __or__(
|
||||
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
||||
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
||||
if self._is_selector(other):
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [
|
||||
*(
|
||||
x
|
||||
for x, name in zip_strict(self(df), lhs_names)
|
||||
if name not in rhs_names
|
||||
),
|
||||
*other(df),
|
||||
]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
|
||||
|
||||
return self.selectors._selector.from_callables(series, names, context=self)
|
||||
return self._to_expr() | other
|
||||
|
||||
@overload
|
||||
def __and__(self, other: Self) -> Self: ...
|
||||
@overload
|
||||
def __and__(
|
||||
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
def __and__(
|
||||
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
||||
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
||||
if self._is_selector(other):
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [
|
||||
x for x, name in zip_strict(self(df), lhs_names) if name in rhs_names
|
||||
]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [x for x in lhs_names if x in rhs_names]
|
||||
|
||||
return self.selectors._selector.from_callables(series, names, context=self)
|
||||
return self._to_expr() & other
|
||||
|
||||
def __invert__(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self.selectors.all() - self
|
||||
|
||||
|
||||
def _eval_lhs_rhs(
|
||||
df: CompliantFrameAny, lhs: CompliantExprAny, rhs: CompliantExprAny
|
||||
) -> tuple[Sequence[str], Sequence[str]]:
|
||||
return lhs._evaluate_output_names(df), rhs._evaluate_output_names(df)
|
411
lib/python3.11/site-packages/narwhals/_compliant/series.py
Normal file
411
lib/python3.11/site-packages/narwhals/_compliant/series.py
Normal file
@ -0,0 +1,411 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol
|
||||
|
||||
from narwhals._compliant.any_namespace import (
|
||||
CatNamespace,
|
||||
DateTimeNamespace,
|
||||
ListNamespace,
|
||||
StringNamespace,
|
||||
StructNamespace,
|
||||
)
|
||||
from narwhals._compliant.column import CompliantColumn
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantSeriesT_co,
|
||||
EagerDataFrameAny,
|
||||
EagerSeriesT_co,
|
||||
NativeSeriesT,
|
||||
NativeSeriesT_co,
|
||||
)
|
||||
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
|
||||
from narwhals._typing_compat import TypeVar, assert_never
|
||||
from narwhals._utils import (
|
||||
_StoresCompliant,
|
||||
_StoresNative,
|
||||
is_compliant_series,
|
||||
is_sized_multi_index_selector,
|
||||
unstable,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
from types import ModuleType
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
from typing_extensions import NotRequired, Self, TypedDict
|
||||
|
||||
from narwhals._compliant.dataframe import CompliantDataFrame
|
||||
from narwhals._compliant.namespace import EagerNamespace
|
||||
from narwhals._utils import Implementation, Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.series import Series
|
||||
from narwhals.typing import (
|
||||
Into1DArray,
|
||||
IntoDType,
|
||||
MultiIndexSelector,
|
||||
RollingInterpolationMethod,
|
||||
SizedMultiIndexSelector,
|
||||
_1DArray,
|
||||
_SliceIndex,
|
||||
)
|
||||
|
||||
class HistData(TypedDict, Generic[NativeSeriesT, "_CountsT_co"]):
|
||||
breakpoint: NotRequired[list[float] | _1DArray | list[Any]]
|
||||
count: NativeSeriesT | _1DArray | _CountsT_co | list[Any]
|
||||
|
||||
|
||||
_CountsT_co = TypeVar("_CountsT_co", bound="Iterable[Any]", covariant=True)
|
||||
|
||||
__all__ = [
|
||||
"CompliantSeries",
|
||||
"EagerSeries",
|
||||
"EagerSeriesCatNamespace",
|
||||
"EagerSeriesDateTimeNamespace",
|
||||
"EagerSeriesHist",
|
||||
"EagerSeriesListNamespace",
|
||||
"EagerSeriesNamespace",
|
||||
"EagerSeriesStringNamespace",
|
||||
"EagerSeriesStructNamespace",
|
||||
]
|
||||
|
||||
|
||||
class CompliantSeries(
|
||||
NumpyConvertible["_1DArray", "Into1DArray"],
|
||||
FromIterable,
|
||||
FromNative[NativeSeriesT],
|
||||
ToNarwhals["Series[NativeSeriesT]"],
|
||||
CompliantColumn,
|
||||
Protocol[NativeSeriesT],
|
||||
):
|
||||
# NOTE: `narwhals`
|
||||
_implementation: Implementation
|
||||
|
||||
@property
|
||||
def native(self) -> NativeSeriesT: ...
|
||||
def __narwhals_series__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> ModuleType: ...
|
||||
@classmethod
|
||||
def from_native(cls, data: NativeSeriesT, /, *, context: _LimitedContext) -> Self: ...
|
||||
def to_narwhals(self) -> Series[NativeSeriesT]:
|
||||
return self._version.series(self, level="full")
|
||||
|
||||
def _with_native(self, series: Any) -> Self: ...
|
||||
def _with_version(self, version: Version) -> Self: ...
|
||||
|
||||
# NOTE: `polars`
|
||||
@property
|
||||
def dtype(self) -> DType: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: ...
|
||||
def __contains__(self, other: Any) -> bool: ...
|
||||
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any: ...
|
||||
def __iter__(self) -> Iterator[Any]: ...
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
@classmethod
|
||||
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self: ...
|
||||
@classmethod
|
||||
def from_iterable(
|
||||
cls,
|
||||
data: Iterable[Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
name: str = "",
|
||||
dtype: IntoDType | None = None,
|
||||
) -> Self: ...
|
||||
def __radd__(self, other: Any) -> Self: ...
|
||||
def __rand__(self, other: Any) -> Self: ...
|
||||
def __rmul__(self, other: Any) -> Self: ...
|
||||
def __ror__(self, other: Any) -> Self: ...
|
||||
def all(self) -> bool: ...
|
||||
def any(self) -> bool: ...
|
||||
def arg_max(self) -> int: ...
|
||||
def arg_min(self) -> int: ...
|
||||
def arg_true(self) -> Self: ...
|
||||
def count(self) -> int: ...
|
||||
def filter(self, predicate: Any) -> Self: ...
|
||||
def gather_every(self, n: int, offset: int) -> Self: ...
|
||||
def head(self, n: int) -> Self: ...
|
||||
def is_empty(self) -> bool:
|
||||
return self.len() == 0
|
||||
|
||||
def is_sorted(self, *, descending: bool) -> bool: ...
|
||||
def item(self, index: int | None) -> Any: ...
|
||||
def kurtosis(self) -> float | None: ...
|
||||
def len(self) -> int: ...
|
||||
def max(self) -> Any: ...
|
||||
def mean(self) -> float: ...
|
||||
def median(self) -> float: ...
|
||||
def min(self) -> Any: ...
|
||||
def n_unique(self) -> int: ...
|
||||
def null_count(self) -> int: ...
|
||||
def quantile(
|
||||
self, quantile: float, interpolation: RollingInterpolationMethod
|
||||
) -> float: ...
|
||||
def sample(
|
||||
self,
|
||||
n: int | None,
|
||||
*,
|
||||
fraction: float | None,
|
||||
with_replacement: bool,
|
||||
seed: int | None,
|
||||
) -> Self: ...
|
||||
def scatter(self, indices: int | Sequence[int], values: Any) -> Self: ...
|
||||
def shift(self, n: int) -> Self: ...
|
||||
def skew(self) -> float | None: ...
|
||||
def sort(self, *, descending: bool, nulls_last: bool) -> Self: ...
|
||||
def std(self, *, ddof: int) -> float: ...
|
||||
def sum(self) -> float: ...
|
||||
def tail(self, n: int) -> Self: ...
|
||||
def to_arrow(self) -> pa.Array[Any]: ...
|
||||
def to_dummies(
|
||||
self, *, separator: str, drop_first: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
|
||||
def to_frame(self) -> CompliantDataFrame[Self, Any, Any, Any]: ...
|
||||
def to_list(self) -> list[Any]: ...
|
||||
def to_pandas(self) -> pd.Series[Any]: ...
|
||||
def to_polars(self) -> pl.Series: ...
|
||||
def unique(self, *, maintain_order: bool = False) -> Self: ...
|
||||
def value_counts(
|
||||
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
|
||||
def var(self, *, ddof: int) -> float: ...
|
||||
def zip_with(self, mask: Any, other: Any) -> Self: ...
|
||||
|
||||
# NOTE: *Technically* `polars`
|
||||
@unstable
|
||||
def hist_from_bins(
|
||||
self, bins: list[float], *, include_breakpoint: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]:
|
||||
"""`Series.hist(bins=..., bin_count=None)`."""
|
||||
...
|
||||
|
||||
@unstable
|
||||
def hist_from_bin_count(
|
||||
self, bin_count: int, *, include_breakpoint: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]:
|
||||
"""`Series.hist(bins=None, bin_count=...)`."""
|
||||
...
|
||||
|
||||
|
||||
class EagerSeries(CompliantSeries[NativeSeriesT], Protocol[NativeSeriesT]):
|
||||
_native_series: Any
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
_broadcast: bool
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@classmethod
|
||||
def _align_full_broadcast(cls, *series: Self) -> Sequence[Self]:
|
||||
"""Ensure all of `series` have the same length (and index if `pandas`).
|
||||
|
||||
Scalars get broadcasted to the full length of the longest Series.
|
||||
|
||||
This is useful when you need to construct a full Series anyway, such as:
|
||||
|
||||
DataFrame.select(...)
|
||||
|
||||
It should not be used in binary operations, such as:
|
||||
|
||||
nw.col("a") - nw.col("a").mean()
|
||||
|
||||
because then it's more efficient to extract the right-hand-side's single element as a scalar.
|
||||
"""
|
||||
...
|
||||
|
||||
def _from_scalar(self, value: Any) -> Self:
|
||||
return self.from_iterable([value], name=self.name, context=self)
|
||||
|
||||
def _with_native(
|
||||
self, series: NativeSeriesT, *, preserve_broadcast: bool = False
|
||||
) -> Self:
|
||||
"""Return a new `CompliantSeries`, wrapping the native `series`.
|
||||
|
||||
In cases when operations are known to not affect whether a result should
|
||||
be broadcast, we can pass `preserve_broadcast=True`.
|
||||
Set this with care - it should only be set for unary expressions which don't
|
||||
change length or order, such as `.alias` or `.fill_null`. If in doubt, don't
|
||||
set it, you probably don't need it.
|
||||
"""
|
||||
...
|
||||
|
||||
def __narwhals_namespace__(
|
||||
self,
|
||||
) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ...
|
||||
|
||||
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
|
||||
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
|
||||
def __getitem__(self, item: MultiIndexSelector[Self]) -> Self:
|
||||
if isinstance(item, (slice, range)):
|
||||
return self._gather_slice(item)
|
||||
if is_compliant_series(item):
|
||||
return self._gather(item.native)
|
||||
elif is_sized_multi_index_selector(item): # noqa: RET505
|
||||
return self._gather(item)
|
||||
assert_never(item)
|
||||
|
||||
@property
|
||||
def str(self) -> EagerSeriesStringNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def dt(self) -> EagerSeriesDateTimeNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def cat(self) -> EagerSeriesCatNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def list(self) -> EagerSeriesListNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def struct(self) -> EagerSeriesStructNamespace[Self, NativeSeriesT]: ...
|
||||
|
||||
|
||||
class _SeriesNamespace( # type: ignore[misc]
|
||||
_StoresCompliant[CompliantSeriesT_co],
|
||||
_StoresNative[NativeSeriesT_co],
|
||||
Protocol[CompliantSeriesT_co, NativeSeriesT_co],
|
||||
):
|
||||
_compliant_series: CompliantSeriesT_co
|
||||
|
||||
@property
|
||||
def compliant(self) -> CompliantSeriesT_co:
|
||||
return self._compliant_series
|
||||
|
||||
@property
|
||||
def implementation(self) -> Implementation:
|
||||
return self.compliant._implementation
|
||||
|
||||
@property
|
||||
def backend_version(self) -> tuple[int, ...]:
|
||||
return self.implementation._backend_version()
|
||||
|
||||
@property
|
||||
def version(self) -> Version:
|
||||
return self.compliant._version
|
||||
|
||||
@property
|
||||
def native(self) -> NativeSeriesT_co:
|
||||
return self._compliant_series.native # type: ignore[no-any-return]
|
||||
|
||||
def with_native(self, series: Any, /) -> CompliantSeriesT_co:
|
||||
return self.compliant._with_native(series)
|
||||
|
||||
|
||||
class EagerSeriesNamespace(
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
Generic[EagerSeriesT_co, NativeSeriesT_co],
|
||||
):
|
||||
_compliant_series: EagerSeriesT_co
|
||||
|
||||
def __init__(self, series: EagerSeriesT_co, /) -> None:
|
||||
self._compliant_series = series
|
||||
|
||||
|
||||
class EagerSeriesCatNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
CatNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesDateTimeNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
DateTimeNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesListNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
ListNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesStringNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
StringNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesStructNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
StructNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesHist(Protocol[NativeSeriesT, _CountsT_co]):
|
||||
_series: EagerSeries[NativeSeriesT]
|
||||
_breakpoint: bool
|
||||
_data: HistData[NativeSeriesT, _CountsT_co]
|
||||
|
||||
@property
|
||||
def native(self) -> NativeSeriesT:
|
||||
return self._series.native
|
||||
|
||||
@classmethod
|
||||
def from_series(
|
||||
cls, series: EagerSeries[NativeSeriesT], *, include_breakpoint: bool
|
||||
) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._series = series
|
||||
obj._breakpoint = include_breakpoint
|
||||
return obj
|
||||
|
||||
def to_frame(self) -> EagerDataFrameAny: ...
|
||||
def _linear_space( # NOTE: Roughly `pl.linear_space`
|
||||
self,
|
||||
start: float,
|
||||
end: float,
|
||||
num_samples: int,
|
||||
*,
|
||||
closed: Literal["both", "none"] = "both",
|
||||
) -> _1DArray: ...
|
||||
|
||||
# NOTE: *Could* be handled at narwhals-level
|
||||
def is_empty_series(self) -> bool: ...
|
||||
|
||||
# NOTE: **Should** be handled at narwhals-level
|
||||
def data_empty(self) -> HistData[NativeSeriesT, _CountsT_co]:
|
||||
return {"breakpoint": [], "count": []} if self._breakpoint else {"count": []}
|
||||
|
||||
# NOTE: *Could* be handled at narwhals-level, **iff** we add `nw.repeat`, `nw.linear_space`
|
||||
# See https://github.com/narwhals-dev/narwhals/pull/2839#discussion_r2215630696
|
||||
def series_empty(
|
||||
self, arg: int | list[float], /
|
||||
) -> HistData[NativeSeriesT, _CountsT_co]: ...
|
||||
|
||||
def with_bins(self, bins: list[float], /) -> Self:
|
||||
if len(bins) <= 1:
|
||||
self._data = self.data_empty()
|
||||
elif self.is_empty_series():
|
||||
self._data = self.series_empty(bins)
|
||||
else:
|
||||
self._data = self._calculate_hist(bins)
|
||||
return self
|
||||
|
||||
def with_bin_count(self, bin_count: int, /) -> Self:
|
||||
if bin_count == 0:
|
||||
self._data = self.data_empty()
|
||||
elif self.is_empty_series():
|
||||
self._data = self.series_empty(bin_count)
|
||||
else:
|
||||
self._data = self._calculate_hist(self._calculate_bins(bin_count))
|
||||
return self
|
||||
|
||||
def _calculate_breakpoint(self, arg: int | list[float], /) -> list[float] | _1DArray:
|
||||
bins = self._linear_space(0, 1, arg + 1) if isinstance(arg, int) else arg
|
||||
return bins[1:]
|
||||
|
||||
def _calculate_bins(self, bin_count: int) -> _1DArray: ...
|
||||
def _calculate_hist(
|
||||
self, bins: list[float] | _1DArray
|
||||
) -> HistData[NativeSeriesT, _CountsT_co]: ...
|
206
lib/python3.11/site-packages/narwhals/_compliant/typing.py
Normal file
206
lib/python3.11/site-packages/narwhals/_compliant/typing.py
Normal file
@ -0,0 +1,206 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.dataframe import (
|
||||
CompliantDataFrame,
|
||||
CompliantFrame,
|
||||
CompliantLazyFrame,
|
||||
EagerDataFrame,
|
||||
)
|
||||
from narwhals._compliant.expr import (
|
||||
CompliantExpr,
|
||||
DepthTrackingExpr,
|
||||
EagerExpr,
|
||||
ImplExpr,
|
||||
LazyExpr,
|
||||
NativeExpr,
|
||||
)
|
||||
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
|
||||
from narwhals._compliant.series import CompliantSeries, EagerSeries
|
||||
from narwhals._compliant.window import WindowInputs
|
||||
from narwhals.typing import (
|
||||
FillNullStrategy,
|
||||
IntoLazyFrame,
|
||||
ModeKeepStrategy,
|
||||
NativeDataFrame,
|
||||
NativeFrame,
|
||||
NativeSeries,
|
||||
RankMethod,
|
||||
RollingInterpolationMethod,
|
||||
)
|
||||
|
||||
class ScalarKwargs(TypedDict, total=False):
|
||||
"""Non-expressifiable args which we may need to reuse in `agg` or `over`."""
|
||||
|
||||
adjust: bool
|
||||
alpha: float | None
|
||||
center: int
|
||||
com: float | None
|
||||
ddof: int
|
||||
descending: bool
|
||||
half_life: float | None
|
||||
ignore_nulls: bool
|
||||
interpolation: RollingInterpolationMethod
|
||||
keep: ModeKeepStrategy
|
||||
limit: int | None
|
||||
method: RankMethod
|
||||
min_samples: int
|
||||
n: int
|
||||
quantile: float
|
||||
reverse: bool
|
||||
span: float | None
|
||||
strategy: FillNullStrategy | None
|
||||
window_size: int
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AliasName",
|
||||
"AliasNames",
|
||||
"CompliantDataFrameT",
|
||||
"CompliantFrameT",
|
||||
"CompliantLazyFrameT",
|
||||
"CompliantSeriesT",
|
||||
"EvalNames",
|
||||
"EvalSeries",
|
||||
"NarwhalsAggregation",
|
||||
"NativeFrameT_co",
|
||||
"NativeSeriesT_co",
|
||||
]
|
||||
CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]"
|
||||
CompliantSeriesAny: TypeAlias = "CompliantSeries[Any]"
|
||||
CompliantSeriesOrNativeExprAny: TypeAlias = "CompliantSeriesAny | NativeExpr"
|
||||
CompliantDataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any, Any]"
|
||||
CompliantLazyFrameAny: TypeAlias = "CompliantLazyFrame[Any, Any, Any]"
|
||||
CompliantFrameAny: TypeAlias = "CompliantFrame[Any, Any, Any]"
|
||||
CompliantNamespaceAny: TypeAlias = "CompliantNamespace[Any, Any]"
|
||||
|
||||
ImplExprAny: TypeAlias = "ImplExpr[Any, Any]"
|
||||
|
||||
DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]"
|
||||
|
||||
EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any, Any]"
|
||||
EagerSeriesAny: TypeAlias = "EagerSeries[Any]"
|
||||
EagerExprAny: TypeAlias = "EagerExpr[Any, Any]"
|
||||
EagerNamespaceAny: TypeAlias = "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame, NativeSeries]"
|
||||
|
||||
LazyExprAny: TypeAlias = "LazyExpr[Any, Any]"
|
||||
|
||||
NativeExprT = TypeVar("NativeExprT", bound="NativeExpr")
|
||||
NativeExprT_co = TypeVar("NativeExprT_co", bound="NativeExpr", covariant=True)
|
||||
NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries")
|
||||
NativeSeriesT_co = TypeVar("NativeSeriesT_co", bound="NativeSeries", covariant=True)
|
||||
NativeSeriesT_contra = TypeVar(
|
||||
"NativeSeriesT_contra", bound="NativeSeries", contravariant=True
|
||||
)
|
||||
NativeDataFrameT = TypeVar("NativeDataFrameT", bound="NativeDataFrame")
|
||||
NativeLazyFrameT = TypeVar("NativeLazyFrameT", bound="IntoLazyFrame")
|
||||
NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame")
|
||||
NativeFrameT_co = TypeVar("NativeFrameT_co", bound="NativeFrame", covariant=True)
|
||||
NativeFrameT_contra = TypeVar(
|
||||
"NativeFrameT_contra", bound="NativeFrame", contravariant=True
|
||||
)
|
||||
|
||||
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny)
|
||||
CompliantExprT_co = TypeVar("CompliantExprT_co", bound=CompliantExprAny, covariant=True)
|
||||
CompliantExprT_contra = TypeVar(
|
||||
"CompliantExprT_contra", bound=CompliantExprAny, contravariant=True
|
||||
)
|
||||
CompliantSeriesT = TypeVar("CompliantSeriesT", bound=CompliantSeriesAny)
|
||||
CompliantSeriesT_co = TypeVar(
|
||||
"CompliantSeriesT_co", bound=CompliantSeriesAny, covariant=True
|
||||
)
|
||||
CompliantSeriesOrNativeExprT = TypeVar(
|
||||
"CompliantSeriesOrNativeExprT", bound=CompliantSeriesOrNativeExprAny
|
||||
)
|
||||
CompliantSeriesOrNativeExprT_co = TypeVar(
|
||||
"CompliantSeriesOrNativeExprT_co",
|
||||
bound=CompliantSeriesOrNativeExprAny,
|
||||
covariant=True,
|
||||
)
|
||||
CompliantFrameT = TypeVar("CompliantFrameT", bound=CompliantFrameAny)
|
||||
CompliantFrameT_co = TypeVar(
|
||||
"CompliantFrameT_co", bound=CompliantFrameAny, covariant=True
|
||||
)
|
||||
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound=CompliantDataFrameAny)
|
||||
CompliantDataFrameT_co = TypeVar(
|
||||
"CompliantDataFrameT_co", bound=CompliantDataFrameAny, covariant=True
|
||||
)
|
||||
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound=CompliantLazyFrameAny)
|
||||
CompliantLazyFrameT_co = TypeVar(
|
||||
"CompliantLazyFrameT_co", bound=CompliantLazyFrameAny, covariant=True
|
||||
)
|
||||
CompliantNamespaceT = TypeVar("CompliantNamespaceT", bound=CompliantNamespaceAny)
|
||||
CompliantNamespaceT_co = TypeVar(
|
||||
"CompliantNamespaceT_co", bound=CompliantNamespaceAny, covariant=True
|
||||
)
|
||||
|
||||
ImplExprT_contra = TypeVar("ImplExprT_contra", bound=ImplExprAny, contravariant=True)
|
||||
|
||||
DepthTrackingExprT = TypeVar("DepthTrackingExprT", bound=DepthTrackingExprAny)
|
||||
DepthTrackingExprT_contra = TypeVar(
|
||||
"DepthTrackingExprT_contra", bound=DepthTrackingExprAny, contravariant=True
|
||||
)
|
||||
|
||||
EagerExprT = TypeVar("EagerExprT", bound=EagerExprAny)
|
||||
EagerExprT_contra = TypeVar("EagerExprT_contra", bound=EagerExprAny, contravariant=True)
|
||||
EagerSeriesT = TypeVar("EagerSeriesT", bound=EagerSeriesAny)
|
||||
EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True)
|
||||
|
||||
# NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`?
|
||||
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any, Any]")
|
||||
|
||||
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
|
||||
LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True)
|
||||
|
||||
AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]]
|
||||
"""A function aliasing a *sequence* of column names."""
|
||||
|
||||
AliasName: TypeAlias = Callable[[str], str]
|
||||
"""A function aliasing a *single* column name."""
|
||||
|
||||
EvalSeries: TypeAlias = Callable[
|
||||
[CompliantFrameT], Sequence[CompliantSeriesOrNativeExprT]
|
||||
]
|
||||
"""A function from a `Frame` to a sequence of `Series`*.
|
||||
|
||||
See [underwater unicorn magic](https://narwhals-dev.github.io/narwhals/how_it_works/).
|
||||
"""
|
||||
|
||||
EvalNames: TypeAlias = Callable[[CompliantFrameT], Sequence[str]]
|
||||
"""A function from a `Frame` to a sequence of columns names *before* any aliasing takes place."""
|
||||
|
||||
WindowFunction: TypeAlias = (
|
||||
"Callable[[CompliantFrameT, WindowInputs[NativeExprT]], Sequence[NativeExprT]]"
|
||||
)
|
||||
"""A function evaluated with `over(partition_by=..., order_by=...)`."""
|
||||
|
||||
NarwhalsAggregation: TypeAlias = Literal[
|
||||
"sum",
|
||||
"mean",
|
||||
"median",
|
||||
"max",
|
||||
"min",
|
||||
"mode",
|
||||
"std",
|
||||
"var",
|
||||
"len",
|
||||
"n_unique",
|
||||
"count",
|
||||
"quantile",
|
||||
"all",
|
||||
"any",
|
||||
]
|
||||
"""`Expr` methods we aim to support in `DepthTrackingGroupBy`.
|
||||
|
||||
Be sure to update me if you're working on one of these:
|
||||
- https://github.com/narwhals-dev/narwhals/issues/981
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2385
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2484
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2526
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2660
|
||||
"""
|
130
lib/python3.11/site-packages/narwhals/_compliant/when_then.py
Normal file
130
lib/python3.11/site-packages/narwhals/_compliant/when_then.py
Normal file
@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
|
||||
|
||||
from narwhals._compliant.expr import CompliantExpr
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantExprAny,
|
||||
CompliantFrameAny,
|
||||
CompliantSeriesOrNativeExprAny,
|
||||
EagerDataFrameT,
|
||||
EagerExprT,
|
||||
EagerSeriesT,
|
||||
LazyExprAny,
|
||||
NativeSeriesT,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
from narwhals._compliant.typing import EvalSeries
|
||||
from narwhals._utils import Implementation, Version, _LimitedContext
|
||||
from narwhals.typing import NonNestedLiteral
|
||||
|
||||
|
||||
__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen"]
|
||||
|
||||
ExprT = TypeVar("ExprT", bound=CompliantExprAny)
|
||||
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
|
||||
SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny)
|
||||
FrameT = TypeVar("FrameT", bound=CompliantFrameAny)
|
||||
|
||||
Scalar: TypeAlias = Any
|
||||
"""A native literal value."""
|
||||
|
||||
IntoExpr: TypeAlias = "SeriesT | ExprT | NonNestedLiteral | Scalar"
|
||||
"""Anything that is convertible into a `CompliantExpr`."""
|
||||
|
||||
|
||||
class CompliantWhen(Protocol[FrameT, SeriesT, ExprT]):
|
||||
_condition: ExprT
|
||||
_then_value: IntoExpr[SeriesT, ExprT]
|
||||
_otherwise_value: IntoExpr[SeriesT, ExprT] | None
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@property
|
||||
def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT, Self]]: ...
|
||||
def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ...
|
||||
def then(
|
||||
self, value: IntoExpr[SeriesT, ExprT], /
|
||||
) -> CompliantThen[FrameT, SeriesT, ExprT, Self]:
|
||||
return self._then.from_when(self, value)
|
||||
|
||||
@classmethod
|
||||
def from_expr(cls, condition: ExprT, /, *, context: _LimitedContext) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._condition = condition
|
||||
obj._then_value = None
|
||||
obj._otherwise_value = None
|
||||
obj._implementation = context._implementation
|
||||
obj._version = context._version
|
||||
return obj
|
||||
|
||||
|
||||
WhenT_contra = TypeVar(
|
||||
"WhenT_contra", bound=CompliantWhen[Any, Any, Any], contravariant=True
|
||||
)
|
||||
|
||||
|
||||
class CompliantThen(
|
||||
CompliantExpr[FrameT, SeriesT], Protocol[FrameT, SeriesT, ExprT, WhenT_contra]
|
||||
):
|
||||
_call: EvalSeries[FrameT, SeriesT]
|
||||
_when_value: CompliantWhen[FrameT, SeriesT, ExprT]
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@classmethod
|
||||
def from_when(cls, when: WhenT_contra, then: IntoExpr[SeriesT, ExprT], /) -> Self:
|
||||
when._then_value = then
|
||||
obj = cls.__new__(cls)
|
||||
obj._call = when
|
||||
obj._when_value = when
|
||||
obj._evaluate_output_names = getattr(
|
||||
then, "_evaluate_output_names", lambda _df: ["literal"]
|
||||
)
|
||||
obj._alias_output_names = getattr(then, "_alias_output_names", None)
|
||||
obj._implementation = when._implementation
|
||||
obj._version = when._version
|
||||
return obj
|
||||
|
||||
def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT:
|
||||
self._when_value._otherwise_value = otherwise
|
||||
return cast("ExprT", self)
|
||||
|
||||
|
||||
class EagerWhen(
|
||||
CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT],
|
||||
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT],
|
||||
):
|
||||
def _if_then_else(
|
||||
self,
|
||||
when: NativeSeriesT,
|
||||
then: NativeSeriesT,
|
||||
otherwise: NativeSeriesT | NonNestedLiteral | Scalar,
|
||||
/,
|
||||
) -> NativeSeriesT: ...
|
||||
|
||||
def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
|
||||
is_expr = self._condition._is_expr
|
||||
when: EagerSeriesT = self._condition(df)[0]
|
||||
then: EagerSeriesT
|
||||
align = when._align_full_broadcast
|
||||
|
||||
if is_expr(self._then_value):
|
||||
then = self._then_value(df)[0]
|
||||
else:
|
||||
then = when.alias("literal")._from_scalar(self._then_value)
|
||||
then._broadcast = True
|
||||
|
||||
if is_expr(self._otherwise_value):
|
||||
otherwise = self._otherwise_value(df)[0]
|
||||
when, then, otherwise = align(when, then, otherwise)
|
||||
result = self._if_then_else(when.native, then.native, otherwise.native)
|
||||
else:
|
||||
when, then = align(when, then)
|
||||
result = self._if_then_else(when.native, then.native, self._otherwise_value)
|
||||
return [then._with_native(result)]
|
20
lib/python3.11/site-packages/narwhals/_compliant/window.py
Normal file
20
lib/python3.11/site-packages/narwhals/_compliant/window.py
Normal file
@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generic
|
||||
|
||||
from narwhals._compliant.typing import NativeExprT_co
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
__all__ = ["WindowInputs"]
|
||||
|
||||
|
||||
class WindowInputs(Generic[NativeExprT_co]):
|
||||
__slots__ = ("order_by", "partition_by")
|
||||
|
||||
def __init__(
|
||||
self, partition_by: Sequence[str | NativeExprT_co], order_by: Sequence[str]
|
||||
) -> None:
|
||||
self.partition_by = partition_by
|
||||
self.order_by = order_by
|
Reference in New Issue
Block a user