done
This commit is contained in:
678
lib/python3.11/site-packages/narwhals/_polars/dataframe.py
Normal file
678
lib/python3.11/site-packages/narwhals/_polars/dataframe.py
Normal file
@ -0,0 +1,678 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping, Sequence, Sized
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
from narwhals._polars.utils import (
|
||||
catch_polars_exception,
|
||||
extract_args_kwargs,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
_into_arrow_table,
|
||||
convert_str_slice_to_int_slice,
|
||||
is_compliant_series,
|
||||
is_index_selector,
|
||||
is_range,
|
||||
is_sequence_like,
|
||||
is_slice_index,
|
||||
is_slice_none,
|
||||
parse_columns_to_drop,
|
||||
requires,
|
||||
)
|
||||
from narwhals.dependencies import is_numpy_array_1d
|
||||
from narwhals.exceptions import ColumnNotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
from types import ModuleType
|
||||
from typing import Callable
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.group_by import PolarsGroupBy, PolarsLazyGroupBy
|
||||
from narwhals._spark_like.utils import SparkSession
|
||||
from narwhals._translate import IntoArrowTable
|
||||
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dataframe import DataFrame, LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import (
|
||||
IntoSchema,
|
||||
JoinStrategy,
|
||||
MultiColSelector,
|
||||
MultiIndexSelector,
|
||||
PivotAgg,
|
||||
SingleIndexSelector,
|
||||
_2DArray,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
Method: TypeAlias = "Callable[..., R]"
|
||||
"""Generic alias representing all methods implemented via `__getattr__`.
|
||||
|
||||
Where `R` is the return type.
|
||||
"""
|
||||
|
||||
# DataFrame methods where PolarsDataFrame just defers to Polars.DataFrame directly.
|
||||
INHERITED_METHODS = frozenset(
|
||||
[
|
||||
"clone",
|
||||
"drop_nulls",
|
||||
"estimated_size",
|
||||
"explode",
|
||||
"filter",
|
||||
"gather_every",
|
||||
"head",
|
||||
"is_unique",
|
||||
"item",
|
||||
"iter_rows",
|
||||
"join_asof",
|
||||
"rename",
|
||||
"row",
|
||||
"rows",
|
||||
"sample",
|
||||
"select",
|
||||
"sink_parquet",
|
||||
"sort",
|
||||
"tail",
|
||||
"to_arrow",
|
||||
"to_pandas",
|
||||
"unique",
|
||||
"with_columns",
|
||||
"write_csv",
|
||||
"write_parquet",
|
||||
]
|
||||
)
|
||||
|
||||
NativePolarsFrame = TypeVar("NativePolarsFrame", pl.DataFrame, pl.LazyFrame)
|
||||
|
||||
|
||||
class PolarsBaseFrame(Generic[NativePolarsFrame]):
|
||||
drop_nulls: Method[Self]
|
||||
explode: Method[Self]
|
||||
filter: Method[Self]
|
||||
gather_every: Method[Self]
|
||||
head: Method[Self]
|
||||
join_asof: Method[Self]
|
||||
rename: Method[Self]
|
||||
select: Method[Self]
|
||||
sort: Method[Self]
|
||||
tail: Method[Self]
|
||||
unique: Method[Self]
|
||||
with_columns: Method[Self]
|
||||
|
||||
_native_frame: NativePolarsFrame
|
||||
_implementation = Implementation.POLARS
|
||||
_version: Version
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: NativePolarsFrame,
|
||||
*,
|
||||
version: Version,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
self._native_frame = df
|
||||
self._version = version
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
|
||||
def _validate_backend_version(self) -> None:
|
||||
"""Raise if installed version below `nw._utils.MIN_VERSIONS`.
|
||||
|
||||
**Only use this when moving between backends.**
|
||||
Otherwise, the validation will have taken place already.
|
||||
"""
|
||||
_ = self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def native(self) -> NativePolarsFrame:
|
||||
return self._native_frame
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
return self.native.columns
|
||||
|
||||
def __narwhals_namespace__(self) -> PolarsNamespace:
|
||||
return PolarsNamespace(version=self._version)
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.POLARS:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def _with_native(self, df: NativePolarsFrame) -> Self:
|
||||
return self.__class__(df, version=self._version)
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: NativePolarsFrame, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
return self._with_native(self.native.select(*column_names))
|
||||
|
||||
def aggregate(self, *exprs: Any) -> Self:
|
||||
return self.select(*exprs)
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
return self.collect_schema()
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
how_native = (
|
||||
"outer" if (self._backend_version < (0, 20, 29) and how == "full") else how
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.join(
|
||||
other=other.native,
|
||||
how=how_native, # type: ignore[arg-type]
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
suffix=suffix,
|
||||
)
|
||||
)
|
||||
|
||||
def top_k(
|
||||
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool]
|
||||
) -> Self:
|
||||
if self._backend_version < (1, 0, 0):
|
||||
return self._with_native(
|
||||
self.native.top_k(
|
||||
k=k,
|
||||
by=by,
|
||||
descending=reverse, # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
return self._with_native(self.native.top_k(k=k, by=by, reverse=reverse))
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
if self._backend_version < (1, 0, 0):
|
||||
return self._with_native(
|
||||
self.native.melt(
|
||||
id_vars=index,
|
||||
value_vars=on,
|
||||
variable_name=variable_name,
|
||||
value_name=value_name,
|
||||
)
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.unpivot(
|
||||
on=on, index=index, variable_name=variable_name, value_name=value_name
|
||||
)
|
||||
)
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
df = self.native
|
||||
schema = df.schema if self._backend_version < (1,) else df.collect_schema()
|
||||
return {
|
||||
name: native_to_narwhals_dtype(dtype, self._version)
|
||||
for name, dtype in schema.items()
|
||||
}
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
|
||||
frame = self.native
|
||||
if order_by is None:
|
||||
result = frame.with_row_index(name)
|
||||
else:
|
||||
end = pl.count() if self._backend_version < (0, 20, 5) else pl.len()
|
||||
result = frame.select(
|
||||
pl.int_range(start=0, end=end).sort_by(order_by).alias(name), pl.all()
|
||||
)
|
||||
|
||||
return self._with_native(result)
|
||||
|
||||
|
||||
class PolarsDataFrame(PolarsBaseFrame[pl.DataFrame]):
|
||||
clone: Method[Self]
|
||||
collect: Method[CompliantDataFrameAny]
|
||||
estimated_size: Method[int | float]
|
||||
gather_every: Method[Self]
|
||||
item: Method[Any]
|
||||
iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]]
|
||||
is_unique: Method[PolarsSeries]
|
||||
row: Method[tuple[Any, ...]]
|
||||
rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]]
|
||||
sample: Method[Self]
|
||||
to_arrow: Method[pa.Table]
|
||||
to_pandas: Method[pd.DataFrame]
|
||||
# NOTE: `write_csv` requires an `@overload` for `str | None`
|
||||
# Can't do that here 😟
|
||||
write_csv: Method[Any]
|
||||
write_parquet: Method[None]
|
||||
|
||||
@classmethod
|
||||
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
|
||||
if context._implementation._backend_version() >= (1, 3):
|
||||
native = pl.DataFrame(data)
|
||||
else: # pragma: no cover
|
||||
native = cast("pl.DataFrame", pl.from_arrow(_into_arrow_table(data, context)))
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: IntoSchema | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
pl_schema = Schema(schema).to_polars() if schema is not None else schema
|
||||
return cls.from_native(pl.from_dict(data, pl_schema), context=context)
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pl.DataFrame | Any) -> TypeIs[pl.DataFrame]:
|
||||
return isinstance(obj, pl.DataFrame)
|
||||
|
||||
@classmethod
|
||||
def from_numpy(
|
||||
cls,
|
||||
data: _2DArray,
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext, # NOTE: Maybe only `Implementation`?
|
||||
schema: IntoSchema | Sequence[str] | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
pl_schema = (
|
||||
Schema(schema).to_polars()
|
||||
if isinstance(schema, (Mapping, Schema))
|
||||
else schema
|
||||
)
|
||||
return cls.from_native(pl.from_numpy(data, pl_schema), context=context)
|
||||
|
||||
def to_narwhals(self) -> DataFrame[pl.DataFrame]:
|
||||
return self._version.dataframe(self, level="full")
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsDataFrame"
|
||||
|
||||
def __narwhals_dataframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, obj: pl.Series) -> PolarsSeries: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, obj: pl.DataFrame) -> Self: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, obj: T) -> T: ...
|
||||
|
||||
def _from_native_object(
|
||||
self, obj: pl.Series | pl.DataFrame | T
|
||||
) -> Self | PolarsSeries | T:
|
||||
if isinstance(obj, pl.Series):
|
||||
return PolarsSeries.from_native(obj, context=self)
|
||||
if self._is_native(obj):
|
||||
return self._with_native(obj)
|
||||
# scalar
|
||||
return obj
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr not in INHERITED_METHODS: # pragma: no cover
|
||||
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
try:
|
||||
return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
|
||||
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
|
||||
msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
|
||||
raise ColumnNotFoundError(msg) from e
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
return func
|
||||
|
||||
def __array__(
|
||||
self, dtype: Any | None = None, *, copy: bool | None = None
|
||||
) -> _2DArray:
|
||||
if self._backend_version < (0, 20, 28) and copy is not None:
|
||||
msg = "`copy` in `__array__` is only supported for 'polars>=0.20.28'"
|
||||
raise NotImplementedError(msg)
|
||||
if self._backend_version < (0, 20, 28):
|
||||
return self.native.__array__(dtype)
|
||||
return self.native.__array__(dtype)
|
||||
|
||||
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
|
||||
return self.native.to_numpy()
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, int]:
|
||||
return self.native.shape
|
||||
|
||||
def __getitem__( # noqa: C901, PLR0912
|
||||
self,
|
||||
item: tuple[
|
||||
SingleIndexSelector | MultiIndexSelector[PolarsSeries],
|
||||
MultiColSelector[PolarsSeries],
|
||||
],
|
||||
) -> Any:
|
||||
rows, columns = item
|
||||
if self._backend_version > (0, 20, 30):
|
||||
rows_native = rows.native if is_compliant_series(rows) else rows
|
||||
columns_native = columns.native if is_compliant_series(columns) else columns
|
||||
selector = rows_native, columns_native
|
||||
selected = self.native.__getitem__(selector) # type: ignore[index]
|
||||
return self._from_native_object(selected)
|
||||
else: # pragma: no cover # noqa: RET505
|
||||
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
|
||||
# Polars version we support
|
||||
# This mostly mirrors the logic in `EagerDataFrame.__getitem__`.
|
||||
rows = list(rows) if isinstance(rows, tuple) else rows
|
||||
columns = list(columns) if isinstance(columns, tuple) else columns
|
||||
if is_numpy_array_1d(columns):
|
||||
columns = columns.tolist()
|
||||
|
||||
native = self.native
|
||||
if not is_slice_none(columns):
|
||||
if isinstance(columns, Sized) and len(columns) == 0:
|
||||
return self.select()
|
||||
if is_index_selector(columns):
|
||||
if is_slice_index(columns) or is_range(columns):
|
||||
native = native.select(
|
||||
self.columns[slice(columns.start, columns.stop, columns.step)]
|
||||
)
|
||||
# NOTE: `mypy` loses track of `PolarsSeries` when `is_compliant_series` is used here
|
||||
# `pyright` is fine
|
||||
elif isinstance(columns, PolarsSeries):
|
||||
native = native[:, columns.native.to_list()]
|
||||
else:
|
||||
native = native[:, columns]
|
||||
elif isinstance(columns, slice):
|
||||
native = native.select(
|
||||
self.columns[
|
||||
slice(*convert_str_slice_to_int_slice(columns, self.columns))
|
||||
]
|
||||
)
|
||||
elif is_compliant_series(columns):
|
||||
native = native.select(columns.native.to_list())
|
||||
elif is_sequence_like(columns):
|
||||
native = native.select(columns)
|
||||
else:
|
||||
msg = f"Unreachable code, got unexpected type: {type(columns)}"
|
||||
raise AssertionError(msg)
|
||||
|
||||
if not is_slice_none(rows):
|
||||
if isinstance(rows, int):
|
||||
native = native[[rows], :]
|
||||
elif isinstance(rows, (slice, range)):
|
||||
native = native[rows, :]
|
||||
elif is_compliant_series(rows):
|
||||
native = native[rows.native, :]
|
||||
elif is_sequence_like(rows):
|
||||
native = native[rows, :]
|
||||
else:
|
||||
msg = f"Unreachable code, got unexpected type: {type(rows)}"
|
||||
raise AssertionError(msg)
|
||||
|
||||
return self._with_native(native)
|
||||
|
||||
def get_column(self, name: str) -> PolarsSeries:
|
||||
return PolarsSeries.from_native(self.native.get_column(name), context=self)
|
||||
|
||||
def iter_columns(self) -> Iterator[PolarsSeries]:
|
||||
for series in self.native.iter_columns():
|
||||
yield PolarsSeries.from_native(series, context=self)
|
||||
|
||||
def lazy(
|
||||
self,
|
||||
backend: _LazyAllowedImpl | None = None,
|
||||
*,
|
||||
session: SparkSession | None = None,
|
||||
) -> CompliantLazyFrameAny:
|
||||
if backend is None or backend is Implementation.POLARS:
|
||||
return PolarsLazyFrame.from_native(self.native.lazy(), context=self)
|
||||
if backend is Implementation.DUCKDB:
|
||||
import duckdb # ignore-banned-import
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
|
||||
_df = self.native
|
||||
return DuckDBLazyFrame(
|
||||
duckdb.table("_df"), validate_backend_version=True, version=self._version
|
||||
)
|
||||
if backend is Implementation.DASK:
|
||||
import dask.dataframe as dd # ignore-banned-import
|
||||
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
|
||||
return DaskLazyFrame(
|
||||
dd.from_pandas(self.native.to_pandas()),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
if backend is Implementation.IBIS:
|
||||
import ibis # ignore-banned-import
|
||||
|
||||
from narwhals._ibis.dataframe import IbisLazyFrame
|
||||
|
||||
return IbisLazyFrame(
|
||||
ibis.memtable(self.native, columns=self.columns),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
if backend.is_spark_like():
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
|
||||
if session is None:
|
||||
msg = "Spark like backends require `session` to be not None."
|
||||
raise ValueError(msg)
|
||||
|
||||
return SparkLikeLazyFrame._from_compliant_dataframe(
|
||||
self, # pyright: ignore[reportArgumentType]
|
||||
session=session,
|
||||
implementation=backend,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
raise AssertionError # pragma: no cover
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ...
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
|
||||
|
||||
def to_dict(
|
||||
self, *, as_series: bool
|
||||
) -> dict[str, PolarsSeries] | dict[str, list[Any]]:
|
||||
if as_series:
|
||||
return {
|
||||
name: PolarsSeries.from_native(col, context=self)
|
||||
for name, col in self.native.to_dict().items()
|
||||
}
|
||||
return self.native.to_dict(as_series=False)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
|
||||
) -> PolarsGroupBy:
|
||||
from narwhals._polars.group_by import PolarsGroupBy
|
||||
|
||||
return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
return self._with_native(self.native.drop(to_drop))
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def pivot(
|
||||
self,
|
||||
on: Sequence[str],
|
||||
*,
|
||||
index: Sequence[str] | None,
|
||||
values: Sequence[str] | None,
|
||||
aggregate_function: PivotAgg | None,
|
||||
sort_columns: bool,
|
||||
separator: str,
|
||||
) -> Self:
|
||||
try:
|
||||
result = self.native.pivot(
|
||||
on,
|
||||
index=index,
|
||||
values=values,
|
||||
aggregate_function=aggregate_function,
|
||||
sort_columns=sort_columns,
|
||||
separator=separator,
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
return self._from_native_object(result)
|
||||
|
||||
def to_polars(self) -> pl.DataFrame:
|
||||
return self.native
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
try:
|
||||
return super().join(
|
||||
other=other, how=how, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
def top_k(
|
||||
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool]
|
||||
) -> Self:
|
||||
try:
|
||||
return super().top_k(k=k, by=by, reverse=reverse)
|
||||
except Exception as e: # noqa: BLE001 # pragma: no cover
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
|
||||
class PolarsLazyFrame(PolarsBaseFrame[pl.LazyFrame]):
|
||||
sink_parquet: Method[None]
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pl.LazyFrame | Any) -> TypeIs[pl.LazyFrame]:
|
||||
return isinstance(obj, pl.LazyFrame)
|
||||
|
||||
def to_narwhals(self) -> LazyFrame[pl.LazyFrame]:
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsLazyFrame"
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr not in INHERITED_METHODS: # pragma: no cover
|
||||
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
try:
|
||||
return self._with_native(getattr(self.native, attr)(*pos, **kwds))
|
||||
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
|
||||
raise ColumnNotFoundError(str(e)) from e
|
||||
|
||||
return func
|
||||
|
||||
def _iter_columns(self) -> Iterator[PolarsSeries]: # pragma: no cover
|
||||
yield from self.collect(Implementation.POLARS).iter_columns()
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
try:
|
||||
return super().collect_schema()
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
try:
|
||||
result = self.native.collect(**kwargs)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
if backend is None or backend is Implementation.POLARS:
|
||||
return PolarsDataFrame.from_native(result, context=self)
|
||||
|
||||
if backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
result.to_pandas(),
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
if backend is Implementation.PYARROW:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
result.to_arrow(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
|
||||
) -> PolarsLazyGroupBy:
|
||||
from narwhals._polars.group_by import PolarsLazyGroupBy
|
||||
|
||||
return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
if self._backend_version < (1, 0, 0):
|
||||
return self._with_native(self.native.drop(columns))
|
||||
return self._with_native(self.native.drop(columns, strict=strict))
|
479
lib/python3.11/site-packages/narwhals/_polars/expr.py
Normal file
479
lib/python3.11/site-packages/narwhals/_polars/expr.py
Normal file
@ -0,0 +1,479 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._polars.utils import (
|
||||
PolarsAnyNamespace,
|
||||
PolarsCatNamespace,
|
||||
PolarsDateTimeNamespace,
|
||||
PolarsListNamespace,
|
||||
PolarsStringNamespace,
|
||||
PolarsStructNamespace,
|
||||
extract_args_kwargs,
|
||||
extract_native,
|
||||
narwhals_to_native_dtype,
|
||||
)
|
||||
from narwhals._utils import Implementation, requires
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._polars.dataframe import Method
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import IntoDType, ModeKeepStrategy, NumericLiteral
|
||||
|
||||
|
||||
class PolarsExpr:
|
||||
# CompliantExpr
|
||||
_implementation: Implementation = Implementation.POLARS
|
||||
_version: Version
|
||||
_native_expr: pl.Expr
|
||||
_metadata: ExprMetadata | None = None
|
||||
_evaluate_output_names: Any
|
||||
_alias_output_names: Any
|
||||
__call__: Any
|
||||
|
||||
# CompliantExpr + builtin descriptor
|
||||
# TODO @dangotbanned: Remove in #2713
|
||||
@classmethod
|
||||
def from_column_names(cls, *_: Any, **__: Any) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *_: Any, **__: Any) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _eval_names_indices(*_: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def __narwhals_expr__(self) -> Self: # pragma: no cover
|
||||
return self
|
||||
|
||||
def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
|
||||
return PolarsNamespace(version=self._version)
|
||||
|
||||
def __init__(self, expr: pl.Expr, version: Version) -> None:
|
||||
self._native_expr = expr
|
||||
self._version = version
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def native(self) -> pl.Expr:
|
||||
return self._native_expr
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsExpr"
|
||||
|
||||
def _with_native(self, expr: pl.Expr) -> Self:
|
||||
return self.__class__(expr, self._version)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
# Let Polars do its thing.
|
||||
return self
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._with_native(getattr(self.native, attr)(*pos, **kwds))
|
||||
|
||||
return func
|
||||
|
||||
def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]:
|
||||
name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples"
|
||||
return {name: min_samples}
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
|
||||
return self._with_native(self.native.cast(dtype_pl))
|
||||
|
||||
def ewm_mean(
|
||||
self,
|
||||
*,
|
||||
com: float | None,
|
||||
span: float | None,
|
||||
half_life: float | None,
|
||||
alpha: float | None,
|
||||
adjust: bool,
|
||||
min_samples: int,
|
||||
ignore_nulls: bool,
|
||||
) -> Self:
|
||||
native = self.native.ewm_mean(
|
||||
com=com,
|
||||
span=span,
|
||||
half_life=half_life,
|
||||
alpha=alpha,
|
||||
adjust=adjust,
|
||||
ignore_nulls=ignore_nulls,
|
||||
**self._renamed_min_periods(min_samples),
|
||||
)
|
||||
if self._backend_version < (1,): # pragma: no cover
|
||||
native = pl.when(~self.native.is_null()).then(native).otherwise(None)
|
||||
return self._with_native(native)
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
if self._backend_version >= (1, 18):
|
||||
native = self.native.is_nan()
|
||||
else: # pragma: no cover
|
||||
native = pl.when(self.native.is_not_null()).then(self.native.is_nan())
|
||||
return self._with_native(native)
|
||||
|
||||
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
|
||||
if self._backend_version < (1, 9):
|
||||
if order_by:
|
||||
msg = "`order_by` in Polars requires version 1.10 or greater"
|
||||
raise NotImplementedError(msg)
|
||||
native = self.native.over(partition_by or pl.lit(1))
|
||||
else:
|
||||
native = self.native.over(
|
||||
partition_by or pl.lit(1), order_by=order_by or None
|
||||
)
|
||||
return self._with_native(native)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_var(
|
||||
window_size=window_size, center=center, ddof=ddof, **kwds
|
||||
)
|
||||
return self._with_native(native)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_std(
|
||||
window_size=window_size, center=center, ddof=ddof, **kwds
|
||||
)
|
||||
return self._with_native(native)
|
||||
|
||||
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_sum(window_size=window_size, center=center, **kwds)
|
||||
return self._with_native(native)
|
||||
|
||||
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_mean(window_size=window_size, center=center, **kwds)
|
||||
return self._with_native(native)
|
||||
|
||||
def map_batches(
|
||||
self,
|
||||
function: Callable[[Any], Any],
|
||||
return_dtype: IntoDType | None,
|
||||
*,
|
||||
returns_scalar: bool,
|
||||
) -> Self:
|
||||
pl_version = self._backend_version
|
||||
return_dtype_pl = (
|
||||
narwhals_to_native_dtype(return_dtype, self._version)
|
||||
if return_dtype is not None
|
||||
else None
|
||||
if pl_version < (1, 32)
|
||||
else pl.self_dtype()
|
||||
)
|
||||
kwargs = {} if pl_version < (0, 20, 31) else {"returns_scalar": returns_scalar}
|
||||
native = self.native.map_batches(function, return_dtype_pl, **kwargs)
|
||||
return self._with_native(native)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def replace_strict(
|
||||
self,
|
||||
old: Sequence[Any] | Mapping[Any, Any],
|
||||
new: Sequence[Any],
|
||||
*,
|
||||
return_dtype: IntoDType | None,
|
||||
) -> Self:
|
||||
return_dtype_pl = (
|
||||
narwhals_to_native_dtype(return_dtype, self._version)
|
||||
if return_dtype
|
||||
else None
|
||||
)
|
||||
native = self.native.replace_strict(old, new, return_dtype=return_dtype_pl)
|
||||
return self._with_native(native)
|
||||
|
||||
def __eq__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __ne__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__ne__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __ge__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__ge__(extract_native(other)))
|
||||
|
||||
def __gt__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__gt__(extract_native(other)))
|
||||
|
||||
def __le__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__le__(extract_native(other)))
|
||||
|
||||
def __lt__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__lt__(extract_native(other)))
|
||||
|
||||
def __and__(self, other: PolarsExpr | bool | Any) -> Self:
|
||||
return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __or__(self, other: PolarsExpr | bool | Any) -> Self:
|
||||
return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __add__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__add__(extract_native(other)))
|
||||
|
||||
def __sub__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__sub__(extract_native(other)))
|
||||
|
||||
def __mul__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__mul__(extract_native(other)))
|
||||
|
||||
def __pow__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__pow__(extract_native(other)))
|
||||
|
||||
def __truediv__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__truediv__(extract_native(other)))
|
||||
|
||||
def __floordiv__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__floordiv__(extract_native(other)))
|
||||
|
||||
def __mod__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__mod__(extract_native(other)))
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
return self._with_native(self.native.__invert__())
|
||||
|
||||
def cum_count(self, *, reverse: bool) -> Self:
|
||||
return self._with_native(self.native.cum_count(reverse=reverse))
|
||||
|
||||
def is_close(
|
||||
self,
|
||||
other: Self | NumericLiteral,
|
||||
*,
|
||||
abs_tol: float,
|
||||
rel_tol: float,
|
||||
nans_equal: bool,
|
||||
) -> Self:
|
||||
left = self.native
|
||||
right = other.native if isinstance(other, PolarsExpr) else pl.lit(other)
|
||||
|
||||
if self._backend_version < (1, 32, 0):
|
||||
lower_bound = right.abs()
|
||||
tolerance = (left.abs().clip(lower_bound) * rel_tol).clip(abs_tol)
|
||||
|
||||
# Values are close if abs_diff <= tolerance, and both finite
|
||||
abs_diff = (left - right).abs()
|
||||
all_ = pl.all_horizontal
|
||||
is_close = all_((abs_diff <= tolerance), left.is_finite(), right.is_finite())
|
||||
|
||||
# Handle infinity cases: infinities are "close" only if they have the same sign
|
||||
is_same_inf = all_(
|
||||
left.is_infinite(), right.is_infinite(), (left.sign() == right.sign())
|
||||
)
|
||||
|
||||
# Handle nan cases:
|
||||
# * nans_equals = True => if both values are NaN, then True
|
||||
# * nans_equals = False => if any value is NaN, then False
|
||||
left_is_nan, right_is_nan = left.is_nan(), right.is_nan()
|
||||
either_nan = left_is_nan | right_is_nan
|
||||
result = (is_close | is_same_inf) & either_nan.not_()
|
||||
|
||||
if nans_equal:
|
||||
result = result | (left_is_nan & right_is_nan)
|
||||
else:
|
||||
result = left.is_close(
|
||||
right, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
|
||||
)
|
||||
return self._with_native(result)
|
||||
|
||||
def mode(self, *, keep: ModeKeepStrategy) -> Self:
|
||||
result = self.native.mode()
|
||||
return self._with_native(result.first() if keep == "any" else result)
|
||||
|
||||
@property
|
||||
def dt(self) -> PolarsExprDateTimeNamespace:
|
||||
return PolarsExprDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def str(self) -> PolarsExprStringNamespace:
|
||||
return PolarsExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def cat(self) -> PolarsExprCatNamespace:
|
||||
return PolarsExprCatNamespace(self)
|
||||
|
||||
@property
|
||||
def name(self) -> PolarsExprNameNamespace:
|
||||
return PolarsExprNameNamespace(self)
|
||||
|
||||
@property
|
||||
def list(self) -> PolarsExprListNamespace:
|
||||
return PolarsExprListNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> PolarsExprStructNamespace:
|
||||
return PolarsExprStructNamespace(self)
|
||||
|
||||
# Polars
|
||||
abs: Method[Self]
|
||||
all: Method[Self]
|
||||
any: Method[Self]
|
||||
alias: Method[Self]
|
||||
arg_max: Method[Self]
|
||||
arg_min: Method[Self]
|
||||
arg_true: Method[Self]
|
||||
clip: Method[Self]
|
||||
count: Method[Self]
|
||||
cum_max: Method[Self]
|
||||
cum_min: Method[Self]
|
||||
cum_prod: Method[Self]
|
||||
cum_sum: Method[Self]
|
||||
diff: Method[Self]
|
||||
drop_nulls: Method[Self]
|
||||
exp: Method[Self]
|
||||
fill_null: Method[Self]
|
||||
fill_nan: Method[Self]
|
||||
gather_every: Method[Self]
|
||||
head: Method[Self]
|
||||
is_between: Method[Self]
|
||||
is_duplicated: Method[Self]
|
||||
is_finite: Method[Self]
|
||||
is_first_distinct: Method[Self]
|
||||
is_in: Method[Self]
|
||||
is_last_distinct: Method[Self]
|
||||
is_null: Method[Self]
|
||||
is_unique: Method[Self]
|
||||
kurtosis: Method[Self]
|
||||
len: Method[Self]
|
||||
log: Method[Self]
|
||||
max: Method[Self]
|
||||
mean: Method[Self]
|
||||
median: Method[Self]
|
||||
min: Method[Self]
|
||||
n_unique: Method[Self]
|
||||
null_count: Method[Self]
|
||||
quantile: Method[Self]
|
||||
rank: Method[Self]
|
||||
round: Method[Self]
|
||||
sample: Method[Self]
|
||||
shift: Method[Self]
|
||||
skew: Method[Self]
|
||||
sqrt: Method[Self]
|
||||
std: Method[Self]
|
||||
sum: Method[Self]
|
||||
sort: Method[Self]
|
||||
tail: Method[Self]
|
||||
unique: Method[Self]
|
||||
var: Method[Self]
|
||||
__rfloordiv__: Method[Self]
|
||||
__rsub__: Method[Self]
|
||||
__rmod__: Method[Self]
|
||||
__rpow__: Method[Self]
|
||||
__rtruediv__: Method[Self]
|
||||
|
||||
|
||||
class PolarsExprNamespace(PolarsAnyNamespace[PolarsExpr, pl.Expr]):
|
||||
def __init__(self, expr: PolarsExpr) -> None:
|
||||
self._expr = expr
|
||||
|
||||
@property
|
||||
def compliant(self) -> PolarsExpr:
|
||||
return self._expr
|
||||
|
||||
@property
|
||||
def native(self) -> pl.Expr:
|
||||
return self._expr.native
|
||||
|
||||
|
||||
class PolarsExprDateTimeNamespace(
|
||||
PolarsExprNamespace, PolarsDateTimeNamespace[PolarsExpr, pl.Expr]
|
||||
): ...
|
||||
|
||||
|
||||
class PolarsExprStringNamespace(
|
||||
PolarsExprNamespace, PolarsStringNamespace[PolarsExpr, pl.Expr]
|
||||
):
|
||||
def zfill(self, width: int) -> PolarsExpr:
|
||||
backend_version = self.compliant._backend_version
|
||||
native_result = self.native.str.zfill(width)
|
||||
|
||||
if backend_version < (0, 20, 5): # pragma: no cover
|
||||
# Reason:
|
||||
# `TypeError: argument 'length': 'Expr' object cannot be interpreted as an integer`
|
||||
# in `native_expr.str.slice(1, length)`
|
||||
msg = "`zfill` is only available in 'polars>=0.20.5', found version '0.20.4'."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if backend_version <= (1, 30, 0):
|
||||
length = self.native.str.len_chars()
|
||||
less_than_width = length < width
|
||||
plus = "+"
|
||||
starts_with_plus = self.native.str.starts_with(plus)
|
||||
native_result = (
|
||||
pl.when(starts_with_plus & less_than_width)
|
||||
.then(
|
||||
self.native.str.slice(1, length)
|
||||
.str.zfill(width - 1)
|
||||
.str.pad_start(width, plus)
|
||||
)
|
||||
.otherwise(native_result)
|
||||
)
|
||||
|
||||
return self.compliant._with_native(native_result)
|
||||
|
||||
|
||||
class PolarsExprCatNamespace(
|
||||
PolarsExprNamespace, PolarsCatNamespace[PolarsExpr, pl.Expr]
|
||||
): ...
|
||||
|
||||
|
||||
class PolarsExprNameNamespace(PolarsExprNamespace):
|
||||
_accessor = "name"
|
||||
keep: Method[PolarsExpr]
|
||||
map: Method[PolarsExpr]
|
||||
prefix: Method[PolarsExpr]
|
||||
suffix: Method[PolarsExpr]
|
||||
to_lowercase: Method[PolarsExpr]
|
||||
to_uppercase: Method[PolarsExpr]
|
||||
|
||||
|
||||
class PolarsExprListNamespace(
|
||||
PolarsExprNamespace, PolarsListNamespace[PolarsExpr, pl.Expr]
|
||||
):
|
||||
def len(self) -> PolarsExpr:
|
||||
native_expr = self.native
|
||||
native_result = native_expr.list.len()
|
||||
|
||||
if self.compliant._backend_version < (1, 16): # pragma: no cover
|
||||
native_result = (
|
||||
pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32())
|
||||
)
|
||||
elif self.compliant._backend_version < (1, 17): # pragma: no cover
|
||||
native_result = native_result.cast(pl.UInt32())
|
||||
|
||||
return self.compliant._with_native(native_result)
|
||||
|
||||
def contains(self, item: Any) -> PolarsExpr:
|
||||
if self.compliant._backend_version < (1, 28):
|
||||
result: pl.Expr = pl.when(self.native.is_not_null()).then(
|
||||
self.native.list.contains(item)
|
||||
)
|
||||
else:
|
||||
result = self.native.list.contains(item)
|
||||
return self.compliant._with_native(result)
|
||||
|
||||
|
||||
class PolarsExprStructNamespace(
|
||||
PolarsExprNamespace, PolarsStructNamespace[PolarsExpr, pl.Expr]
|
||||
): ...
|
76
lib/python3.11/site-packages/narwhals/_polars/group_by.py
Normal file
76
lib/python3.11/site-packages/narwhals/_polars/group_by.py
Normal file
@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from narwhals._utils import is_sequence_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
from polars.dataframe.group_by import GroupBy as NativeGroupBy
|
||||
from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
|
||||
|
||||
class PolarsGroupBy:
|
||||
_compliant_frame: PolarsDataFrame
|
||||
_grouped: NativeGroupBy
|
||||
|
||||
@property
|
||||
def compliant(self) -> PolarsDataFrame:
|
||||
return self._compliant_frame
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: PolarsDataFrame,
|
||||
keys: Sequence[PolarsExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._keys = list(keys)
|
||||
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
|
||||
self._grouped = (
|
||||
self.compliant.native.group_by(keys)
|
||||
if is_sequence_of(keys, str)
|
||||
else self.compliant.native.group_by(arg.native for arg in keys)
|
||||
)
|
||||
|
||||
def agg(self, *aggs: PolarsExpr) -> PolarsDataFrame:
|
||||
agg_result = self._grouped.agg(arg.native for arg in aggs)
|
||||
return self.compliant._with_native(agg_result)
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]:
|
||||
for key, df in self._grouped:
|
||||
yield tuple(cast("str", key)), self.compliant._with_native(df)
|
||||
|
||||
|
||||
class PolarsLazyGroupBy:
|
||||
_compliant_frame: PolarsLazyFrame
|
||||
_grouped: NativeLazyGroupBy
|
||||
|
||||
@property
|
||||
def compliant(self) -> PolarsLazyFrame:
|
||||
return self._compliant_frame
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: PolarsLazyFrame,
|
||||
keys: Sequence[PolarsExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._keys = list(keys)
|
||||
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
|
||||
self._grouped = (
|
||||
self.compliant.native.group_by(keys)
|
||||
if is_sequence_of(keys, str)
|
||||
else self.compliant.native.group_by(arg.native for arg in keys)
|
||||
)
|
||||
|
||||
def agg(self, *aggs: PolarsExpr) -> PolarsLazyFrame:
|
||||
agg_result = self._grouped.agg(arg.native for arg in aggs)
|
||||
return self.compliant._with_native(agg_result)
|
281
lib/python3.11/site-packages/narwhals/_polars/namespace.py
Normal file
281
lib/python3.11/site-packages/narwhals/_polars/namespace.py
Normal file
@ -0,0 +1,281 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._expression_parsing import is_expr, is_series
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype
|
||||
from narwhals._utils import Implementation, requires, zip_strict
|
||||
from narwhals.dependencies import is_numpy_array_2d
|
||||
from narwhals.dtypes import DType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
from datetime import timezone
|
||||
|
||||
from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen
|
||||
from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.typing import FrameT
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.expr import Expr
|
||||
from narwhals.series import Series
|
||||
from narwhals.typing import (
|
||||
Into1DArray,
|
||||
IntoDType,
|
||||
IntoSchema,
|
||||
NonNestedLiteral,
|
||||
TimeUnit,
|
||||
_1DArray,
|
||||
_2DArray,
|
||||
)
|
||||
|
||||
|
||||
class PolarsNamespace:
|
||||
all: Method[PolarsExpr]
|
||||
coalesce: Method[PolarsExpr]
|
||||
col: Method[PolarsExpr]
|
||||
exclude: Method[PolarsExpr]
|
||||
sum_horizontal: Method[PolarsExpr]
|
||||
min_horizontal: Method[PolarsExpr]
|
||||
max_horizontal: Method[PolarsExpr]
|
||||
|
||||
when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]]
|
||||
|
||||
_implementation: Implementation = Implementation.POLARS
|
||||
_version: Version
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._expr(getattr(pl, attr)(*pos, **kwds), version=self._version)
|
||||
|
||||
return func
|
||||
|
||||
@property
|
||||
def _dataframe(self) -> type[PolarsDataFrame]:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[PolarsLazyFrame]:
|
||||
from narwhals._polars.dataframe import PolarsLazyFrame
|
||||
|
||||
return PolarsLazyFrame
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[PolarsExpr]:
|
||||
return PolarsExpr
|
||||
|
||||
@property
|
||||
def _series(self) -> type[PolarsSeries]:
|
||||
return PolarsSeries
|
||||
|
||||
def parse_into_expr(
|
||||
self,
|
||||
data: Expr | NonNestedLiteral | Series[pl.Series] | _1DArray,
|
||||
/,
|
||||
*,
|
||||
str_as_lit: bool,
|
||||
) -> PolarsExpr | None:
|
||||
if data is None:
|
||||
# NOTE: To avoid `pl.lit(None)` failing this `None` check
|
||||
# https://github.com/pola-rs/polars/blob/58dd8e5770f16a9bef9009a1c05f00e15a5263c7/py-polars/polars/expr/expr.py#L2870-L2872
|
||||
return data
|
||||
if is_expr(data):
|
||||
expr = data._to_compliant_expr(self)
|
||||
assert isinstance(expr, self._expr) # noqa: S101
|
||||
return expr
|
||||
if isinstance(data, str) and not str_as_lit:
|
||||
return self.col(data)
|
||||
return self.lit(data.to_native() if is_series(data) else data, None)
|
||||
|
||||
@overload
|
||||
def from_native(self, data: pl.DataFrame, /) -> PolarsDataFrame: ...
|
||||
@overload
|
||||
def from_native(self, data: pl.LazyFrame, /) -> PolarsLazyFrame: ...
|
||||
@overload
|
||||
def from_native(self, data: pl.Series, /) -> PolarsSeries: ...
|
||||
def from_native(
|
||||
self, data: pl.DataFrame | pl.LazyFrame | pl.Series | Any, /
|
||||
) -> PolarsDataFrame | PolarsLazyFrame | PolarsSeries:
|
||||
if self._dataframe._is_native(data):
|
||||
return self._dataframe.from_native(data, context=self)
|
||||
if self._series._is_native(data):
|
||||
return self._series.from_native(data, context=self)
|
||||
if self._lazyframe._is_native(data):
|
||||
return self._lazyframe.from_native(data, context=self)
|
||||
msg = f"Unsupported type: {type(data).__name__!r}" # pragma: no cover
|
||||
raise TypeError(msg) # pragma: no cover
|
||||
|
||||
@overload
|
||||
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> PolarsSeries: ...
|
||||
|
||||
@overload
|
||||
def from_numpy(
|
||||
self, data: _2DArray, /, schema: IntoSchema | Sequence[str] | None
|
||||
) -> PolarsDataFrame: ...
|
||||
|
||||
def from_numpy(
|
||||
self,
|
||||
data: Into1DArray | _2DArray,
|
||||
/,
|
||||
schema: IntoSchema | Sequence[str] | None = None,
|
||||
) -> PolarsDataFrame | PolarsSeries:
|
||||
if is_numpy_array_2d(data):
|
||||
return self._dataframe.from_numpy(data, schema=schema, context=self)
|
||||
return self._series.from_numpy(data, context=self) # pragma: no cover
|
||||
|
||||
@requires.backend_version(
|
||||
(1, 0, 0), "Please use `col` for columns selection instead."
|
||||
)
|
||||
def nth(self, *indices: int) -> PolarsExpr:
|
||||
return self._expr(pl.nth(*indices), version=self._version)
|
||||
|
||||
def len(self) -> PolarsExpr:
|
||||
if self._backend_version < (0, 20, 5):
|
||||
return self._expr(pl.count().alias("len"), self._version)
|
||||
return self._expr(pl.len(), self._version)
|
||||
|
||||
def all_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr:
|
||||
it = (expr.fill_null(True) for expr in exprs) if ignore_nulls else iter(exprs)
|
||||
return self._expr(pl.all_horizontal(*(expr.native for expr in it)), self._version)
|
||||
|
||||
def any_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr:
|
||||
it = (expr.fill_null(False) for expr in exprs) if ignore_nulls else iter(exprs)
|
||||
return self._expr(pl.any_horizontal(*(expr.native for expr in it)), self._version)
|
||||
|
||||
def concat(
|
||||
self,
|
||||
items: Iterable[FrameT],
|
||||
*,
|
||||
how: Literal["vertical", "horizontal", "diagonal"],
|
||||
) -> PolarsDataFrame | PolarsLazyFrame:
|
||||
result = pl.concat((item.native for item in items), how=how)
|
||||
if isinstance(result, pl.DataFrame):
|
||||
return self._dataframe(result, version=self._version)
|
||||
return self._lazyframe.from_native(result, context=self)
|
||||
|
||||
def lit(self, value: Any, dtype: IntoDType | None) -> PolarsExpr:
|
||||
if dtype is not None:
|
||||
return self._expr(
|
||||
pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._version)),
|
||||
version=self._version,
|
||||
)
|
||||
return self._expr(pl.lit(value), version=self._version)
|
||||
|
||||
def mean_horizontal(self, *exprs: PolarsExpr) -> PolarsExpr:
|
||||
if self._backend_version < (0, 20, 8):
|
||||
return self._expr(
|
||||
pl.sum_horizontal(e._native_expr for e in exprs)
|
||||
/ pl.sum_horizontal(1 - e.is_null()._native_expr for e in exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
return self._expr(
|
||||
pl.mean_horizontal(e._native_expr for e in exprs), version=self._version
|
||||
)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: PolarsExpr, separator: str, ignore_nulls: bool
|
||||
) -> PolarsExpr:
|
||||
pl_exprs: list[pl.Expr] = [expr._native_expr for expr in exprs]
|
||||
|
||||
if self._backend_version < (0, 20, 6):
|
||||
null_mask = [expr.is_null() for expr in pl_exprs]
|
||||
sep = pl.lit(separator)
|
||||
|
||||
if not ignore_nulls:
|
||||
null_mask_result = pl.any_horizontal(*null_mask)
|
||||
output_expr = pl.reduce(
|
||||
lambda x, y: x.cast(pl.String()) + sep + y.cast(pl.String()), # type: ignore[arg-type,return-value]
|
||||
pl_exprs,
|
||||
)
|
||||
result = pl.when(~null_mask_result).then(output_expr)
|
||||
else:
|
||||
init_value, *values = [
|
||||
pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String()))
|
||||
for expr, nm in zip_strict(pl_exprs, null_mask)
|
||||
]
|
||||
separators = [
|
||||
pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1]
|
||||
]
|
||||
|
||||
result = pl.fold( # type: ignore[assignment]
|
||||
acc=init_value,
|
||||
function=operator.add,
|
||||
exprs=[s + v for s, v in zip_strict(separators, values)],
|
||||
)
|
||||
|
||||
return self._expr(result, version=self._version)
|
||||
|
||||
return self._expr(
|
||||
pl.concat_str(pl_exprs, separator=separator, ignore_nulls=ignore_nulls),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
# NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`)
|
||||
# 1. Others have lots of private stuff for code reuse
|
||||
# i. None of that is useful here
|
||||
# 2. We don't have a `PolarsSelector` abstraction, and just use `PolarsExpr`
|
||||
@property
|
||||
def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]:
|
||||
return cast(
|
||||
"CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]",
|
||||
PolarsSelectorNamespace(self),
|
||||
)
|
||||
|
||||
|
||||
class PolarsSelectorNamespace:
|
||||
_implementation = Implementation.POLARS
|
||||
|
||||
def __init__(self, context: _LimitedContext, /) -> None:
|
||||
self._version = context._version
|
||||
|
||||
def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr:
|
||||
native_dtypes = [
|
||||
narwhals_to_native_dtype(dtype, self._version).__class__
|
||||
if isinstance(dtype, type) and issubclass(dtype, DType)
|
||||
else narwhals_to_native_dtype(dtype, self._version)
|
||||
for dtype in dtypes
|
||||
]
|
||||
return PolarsExpr(pl.selectors.by_dtype(native_dtypes), version=self._version)
|
||||
|
||||
def matches(self, pattern: str) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.matches(pattern=pattern), version=self._version)
|
||||
|
||||
def numeric(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.numeric(), version=self._version)
|
||||
|
||||
def boolean(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.boolean(), version=self._version)
|
||||
|
||||
def string(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.string(), version=self._version)
|
||||
|
||||
def categorical(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.categorical(), version=self._version)
|
||||
|
||||
def all(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.all(), version=self._version)
|
||||
|
||||
def datetime(
|
||||
self,
|
||||
time_unit: TimeUnit | Iterable[TimeUnit] | None,
|
||||
time_zone: str | timezone | Iterable[str | timezone | None] | None,
|
||||
) -> PolarsExpr:
|
||||
return PolarsExpr(
|
||||
pl.selectors.datetime(time_unit=time_unit, time_zone=time_zone), # type: ignore[arg-type]
|
||||
version=self._version,
|
||||
)
|
795
lib/python3.11/site-packages/narwhals/_polars/series.py
Normal file
795
lib/python3.11/site-packages/narwhals/_polars/series.py
Normal file
@ -0,0 +1,795 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._polars.utils import (
|
||||
BACKEND_VERSION,
|
||||
SERIES_ACCEPTS_PD_INDEX,
|
||||
SERIES_RESPECTS_DTYPE,
|
||||
PolarsAnyNamespace,
|
||||
PolarsCatNamespace,
|
||||
PolarsDateTimeNamespace,
|
||||
PolarsListNamespace,
|
||||
PolarsStringNamespace,
|
||||
PolarsStructNamespace,
|
||||
catch_polars_exception,
|
||||
extract_args_kwargs,
|
||||
extract_native,
|
||||
narwhals_to_native_dtype,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._utils import Implementation, requires
|
||||
from narwhals.dependencies import is_numpy_array_1d, is_pandas_index
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from types import ModuleType
|
||||
from typing import Literal, TypeVar
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._polars.dataframe import Method, PolarsDataFrame
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.series import Series
|
||||
from narwhals.typing import (
|
||||
Into1DArray,
|
||||
IntoDType,
|
||||
ModeKeepStrategy,
|
||||
MultiIndexSelector,
|
||||
NonNestedLiteral,
|
||||
NumericLiteral,
|
||||
_1DArray,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
IncludeBreakpoint: TypeAlias = Literal[False, True]
|
||||
|
||||
Incomplete: TypeAlias = Any
|
||||
|
||||
# Series methods where PolarsSeries just defers to Polars.Series directly.
|
||||
INHERITED_METHODS = frozenset(
|
||||
[
|
||||
"__add__",
|
||||
"__and__",
|
||||
"__floordiv__",
|
||||
"__invert__",
|
||||
"__iter__",
|
||||
"__mod__",
|
||||
"__mul__",
|
||||
"__or__",
|
||||
"__pow__",
|
||||
"__radd__",
|
||||
"__rand__",
|
||||
"__rfloordiv__",
|
||||
"__rmod__",
|
||||
"__rmul__",
|
||||
"__ror__",
|
||||
"__rsub__",
|
||||
"__rtruediv__",
|
||||
"__sub__",
|
||||
"__truediv__",
|
||||
"abs",
|
||||
"all",
|
||||
"any",
|
||||
"arg_max",
|
||||
"arg_min",
|
||||
"arg_true",
|
||||
"clip",
|
||||
"count",
|
||||
"cum_max",
|
||||
"cum_min",
|
||||
"cum_prod",
|
||||
"cum_sum",
|
||||
"diff",
|
||||
"drop_nulls",
|
||||
"exp",
|
||||
"fill_null",
|
||||
"fill_nan",
|
||||
"filter",
|
||||
"gather_every",
|
||||
"head",
|
||||
"is_between",
|
||||
"is_close",
|
||||
"is_duplicated",
|
||||
"is_empty",
|
||||
"is_finite",
|
||||
"is_first_distinct",
|
||||
"is_in",
|
||||
"is_last_distinct",
|
||||
"is_null",
|
||||
"is_sorted",
|
||||
"is_unique",
|
||||
"item",
|
||||
"kurtosis",
|
||||
"len",
|
||||
"log",
|
||||
"max",
|
||||
"mean",
|
||||
"min",
|
||||
"mode",
|
||||
"n_unique",
|
||||
"null_count",
|
||||
"quantile",
|
||||
"rank",
|
||||
"round",
|
||||
"sample",
|
||||
"shift",
|
||||
"skew",
|
||||
"sqrt",
|
||||
"std",
|
||||
"sum",
|
||||
"tail",
|
||||
"to_arrow",
|
||||
"to_frame",
|
||||
"to_list",
|
||||
"to_pandas",
|
||||
"unique",
|
||||
"var",
|
||||
"zip_with",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class PolarsSeries:
|
||||
_implementation: Implementation = Implementation.POLARS
|
||||
_native_series: pl.Series
|
||||
_version: Version
|
||||
|
||||
_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
|
||||
True: ["breakpoint", "count"],
|
||||
False: ["count"],
|
||||
}
|
||||
|
||||
def __init__(self, series: pl.Series, *, version: Version) -> None:
|
||||
self._native_series = series
|
||||
self._version = version
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsSeries"
|
||||
|
||||
def __narwhals_namespace__(self) -> PolarsNamespace:
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
|
||||
return PolarsNamespace(version=self._version)
|
||||
|
||||
def __narwhals_series__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.POLARS:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
@classmethod
|
||||
def from_iterable(
|
||||
cls,
|
||||
data: Iterable[Any],
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
name: str = "",
|
||||
dtype: IntoDType | None = None,
|
||||
) -> Self:
|
||||
version = context._version
|
||||
dtype_pl = narwhals_to_native_dtype(dtype, version) if dtype else None
|
||||
values: Incomplete = data
|
||||
if SERIES_RESPECTS_DTYPE:
|
||||
native = pl.Series(name, values, dtype=dtype_pl)
|
||||
else: # pragma: no cover
|
||||
if (not SERIES_ACCEPTS_PD_INDEX) and is_pandas_index(values):
|
||||
values = values.to_series()
|
||||
native = pl.Series(name, values)
|
||||
if dtype_pl:
|
||||
native = native.cast(dtype_pl)
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pl.Series | Any) -> TypeIs[pl.Series]:
|
||||
return isinstance(obj, pl.Series)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: pl.Series, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
@classmethod
|
||||
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self:
|
||||
native = pl.Series(data if is_numpy_array_1d(data) else [data])
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
def to_narwhals(self) -> Series[pl.Series]:
|
||||
return self._version.series(self, level="full")
|
||||
|
||||
def _with_native(self, series: pl.Series) -> Self:
|
||||
return self.__class__(series, version=self._version)
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, series: pl.Series) -> Self: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, series: pl.DataFrame) -> PolarsDataFrame: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, series: T) -> T: ...
|
||||
|
||||
def _from_native_object(
|
||||
self, series: pl.Series | pl.DataFrame | T
|
||||
) -> Self | PolarsDataFrame | T:
|
||||
if self._is_native(series):
|
||||
return self._with_native(series)
|
||||
if isinstance(series, pl.DataFrame):
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame.from_native(series, context=self)
|
||||
# scalar
|
||||
return series
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr not in INHERITED_METHODS:
|
||||
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
|
||||
|
||||
return func
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.native.name
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
return native_to_narwhals_dtype(self.native.dtype, self._version)
|
||||
|
||||
@property
|
||||
def native(self) -> pl.Series:
|
||||
return self._native_series
|
||||
|
||||
def alias(self, name: str) -> Self:
|
||||
return self._from_native_object(self.native.alias(name))
|
||||
|
||||
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self:
|
||||
if isinstance(item, PolarsSeries):
|
||||
return self._from_native_object(self.native.__getitem__(item.native))
|
||||
return self._from_native_object(self.native.__getitem__(item))
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
|
||||
return self._with_native(self.native.cast(dtype_pl))
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def replace_strict(
|
||||
self,
|
||||
old: Sequence[Any] | Mapping[Any, Any],
|
||||
new: Sequence[Any],
|
||||
*,
|
||||
return_dtype: IntoDType | None,
|
||||
) -> Self:
|
||||
ser = self.native
|
||||
dtype = (
|
||||
narwhals_to_native_dtype(return_dtype, self._version)
|
||||
if return_dtype
|
||||
else None
|
||||
)
|
||||
return self._with_native(ser.replace_strict(old, new, return_dtype=dtype))
|
||||
|
||||
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray:
|
||||
return self.__array__(dtype, copy=copy)
|
||||
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray:
|
||||
if self._backend_version < (0, 20, 29):
|
||||
return self.native.__array__(dtype=dtype)
|
||||
return self.native.__array__(dtype=dtype, copy=copy)
|
||||
|
||||
def __eq__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__eq__(extract_native(other)))
|
||||
|
||||
def __ne__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__ne__(extract_native(other)))
|
||||
|
||||
# NOTE: These need to be anything that can't match `PolarsExpr`, due to overload order
|
||||
def __ge__(self, other: Self) -> Self:
|
||||
return self._with_native(self.native.__ge__(extract_native(other)))
|
||||
|
||||
def __gt__(self, other: Self) -> Self:
|
||||
return self._with_native(self.native.__gt__(extract_native(other)))
|
||||
|
||||
def __le__(self, other: Self) -> Self:
|
||||
return self._with_native(self.native.__le__(extract_native(other)))
|
||||
|
||||
def __lt__(self, other: Self) -> Self:
|
||||
return self._with_native(self.native.__lt__(extract_native(other)))
|
||||
|
||||
def __rpow__(self, other: PolarsSeries | Any) -> Self:
|
||||
result = self.native.__rpow__(extract_native(other))
|
||||
if self._backend_version < (1, 16, 1):
|
||||
# Explicitly set alias to work around https://github.com/pola-rs/polars/issues/20071
|
||||
result = result.alias(self.name)
|
||||
return self._with_native(result)
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
try:
|
||||
native_is_nan = self.native.is_nan()
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
if self._backend_version < (1, 18): # pragma: no cover
|
||||
select = pl.when(self.native.is_not_null()).then(native_is_nan)
|
||||
return self._with_native(pl.select(select)[self.name])
|
||||
return self._with_native(native_is_nan)
|
||||
|
||||
def median(self) -> Any:
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
if not self.dtype.is_numeric():
|
||||
msg = "`median` operation not supported for non-numeric input type."
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
return self.native.median()
|
||||
|
||||
def to_dummies(self, *, separator: str, drop_first: bool) -> PolarsDataFrame:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
if self._backend_version < (0, 20, 15):
|
||||
has_nulls = self.native.is_null().any()
|
||||
result = self.native.to_dummies(separator=separator)
|
||||
output_columns = result.columns
|
||||
if drop_first:
|
||||
_ = output_columns.pop(int(has_nulls))
|
||||
|
||||
result = result.select(output_columns)
|
||||
else:
|
||||
result = self.native.to_dummies(separator=separator, drop_first=drop_first)
|
||||
result = result.with_columns(pl.all().cast(pl.Int8))
|
||||
return PolarsDataFrame.from_native(result, context=self)
|
||||
|
||||
def ewm_mean(
|
||||
self,
|
||||
*,
|
||||
com: float | None,
|
||||
span: float | None,
|
||||
half_life: float | None,
|
||||
alpha: float | None,
|
||||
adjust: bool,
|
||||
min_samples: int,
|
||||
ignore_nulls: bool,
|
||||
) -> Self:
|
||||
extra_kwargs = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
|
||||
native_result = self.native.ewm_mean(
|
||||
com=com,
|
||||
span=span,
|
||||
half_life=half_life,
|
||||
alpha=alpha,
|
||||
adjust=adjust,
|
||||
ignore_nulls=ignore_nulls,
|
||||
**extra_kwargs,
|
||||
)
|
||||
if self._backend_version < (1,): # pragma: no cover
|
||||
return self._with_native(
|
||||
pl.select(
|
||||
pl.when(~self.native.is_null()).then(native_result).otherwise(None)
|
||||
)[self.native.name]
|
||||
)
|
||||
|
||||
return self._with_native(native_result)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_var(
|
||||
window_size=window_size, center=center, ddof=ddof, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_std(
|
||||
window_size=window_size, center=center, ddof=ddof, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_sum(
|
||||
window_size=window_size, center=center, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_mean(
|
||||
window_size=window_size, center=center, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def sort(self, *, descending: bool, nulls_last: bool) -> Self:
|
||||
if self._backend_version < (0, 20, 6):
|
||||
result = self.native.sort(descending=descending)
|
||||
|
||||
if nulls_last:
|
||||
is_null = result.is_null()
|
||||
result = pl.concat([result.filter(~is_null), result.filter(is_null)])
|
||||
else:
|
||||
result = self.native.sort(descending=descending, nulls_last=nulls_last)
|
||||
|
||||
return self._with_native(result)
|
||||
|
||||
def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
|
||||
s = self.native.clone().scatter(indices, extract_native(values))
|
||||
return self._with_native(s)
|
||||
|
||||
def value_counts(
|
||||
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
|
||||
) -> PolarsDataFrame:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
if self._backend_version < (1, 0, 0):
|
||||
value_name_ = name or ("proportion" if normalize else "count")
|
||||
|
||||
result = self.native.value_counts(sort=sort, parallel=parallel).select(
|
||||
**{
|
||||
(self.native.name): pl.col(self.native.name),
|
||||
value_name_: pl.col("count") / pl.sum("count")
|
||||
if normalize
|
||||
else pl.col("count"),
|
||||
}
|
||||
)
|
||||
else:
|
||||
result = self.native.value_counts(
|
||||
sort=sort, parallel=parallel, name=name, normalize=normalize
|
||||
)
|
||||
return PolarsDataFrame.from_native(result, context=self)
|
||||
|
||||
def cum_count(self, *, reverse: bool) -> Self:
|
||||
return self._with_native(self.native.cum_count(reverse=reverse))
|
||||
|
||||
def __contains__(self, other: Any) -> bool:
|
||||
try:
|
||||
return self.native.__contains__(other)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
def is_close(
|
||||
self,
|
||||
other: Self | NumericLiteral,
|
||||
*,
|
||||
abs_tol: float,
|
||||
rel_tol: float,
|
||||
nans_equal: bool,
|
||||
) -> PolarsSeries:
|
||||
if self._backend_version < (1, 32, 0):
|
||||
name = self.name
|
||||
ns = self.__narwhals_namespace__()
|
||||
other_expr = (
|
||||
ns.lit(other.native, None) if isinstance(other, PolarsSeries) else other
|
||||
)
|
||||
expr = ns.col(name).is_close(
|
||||
other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
|
||||
)
|
||||
return self.to_frame().select(expr).get_column(name)
|
||||
other_series = other.native if isinstance(other, PolarsSeries) else other
|
||||
result = self.native.is_close(
|
||||
other_series, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
|
||||
)
|
||||
return self._with_native(result)
|
||||
|
||||
def mode(self, *, keep: ModeKeepStrategy) -> Self:
|
||||
result = self.native.mode()
|
||||
return self._with_native(result.head(1) if keep == "any" else result)
|
||||
|
||||
def hist_from_bins(
|
||||
self, bins: list[float], *, include_breakpoint: bool
|
||||
) -> PolarsDataFrame:
|
||||
if len(bins) <= 1:
|
||||
native = pl.DataFrame(schema=self._HIST_EMPTY_SCHEMA[include_breakpoint])
|
||||
elif self.native.is_empty():
|
||||
if include_breakpoint:
|
||||
native = (
|
||||
pl.Series(bins[1:])
|
||||
.to_frame("breakpoint")
|
||||
.with_columns(count=pl.lit(0, pl.Int64))
|
||||
)
|
||||
else:
|
||||
native = pl.select(count=pl.zeros(len(bins) - 1, pl.Int64))
|
||||
else:
|
||||
return self._hist_from_data(
|
||||
bins=bins, bin_count=None, include_breakpoint=include_breakpoint
|
||||
)
|
||||
return self.__narwhals_namespace__()._dataframe.from_native(native, context=self)
|
||||
|
||||
def hist_from_bin_count(
|
||||
self, bin_count: int, *, include_breakpoint: bool
|
||||
) -> PolarsDataFrame:
|
||||
if bin_count == 0:
|
||||
native = pl.DataFrame(schema=self._HIST_EMPTY_SCHEMA[include_breakpoint])
|
||||
elif self.native.is_empty():
|
||||
if include_breakpoint:
|
||||
native = pl.select(
|
||||
breakpoint=pl.int_range(1, bin_count + 1) / bin_count,
|
||||
count=pl.lit(0, pl.Int64),
|
||||
)
|
||||
else:
|
||||
native = pl.select(count=pl.zeros(bin_count, pl.Int64))
|
||||
else:
|
||||
count: int | None
|
||||
if BACKEND_VERSION < (1, 15): # pragma: no cover
|
||||
count = None
|
||||
bins = self._bins_from_bin_count(bin_count=bin_count)
|
||||
else:
|
||||
count = bin_count
|
||||
bins = None
|
||||
return self._hist_from_data(
|
||||
bins=bins, # type: ignore[arg-type]
|
||||
bin_count=count,
|
||||
include_breakpoint=include_breakpoint,
|
||||
)
|
||||
return self.__narwhals_namespace__()._dataframe.from_native(native, context=self)
|
||||
|
||||
def _bins_from_bin_count(self, bin_count: int) -> pl.Series: # pragma: no cover
|
||||
"""Prepare bins based on backend version compatibility.
|
||||
|
||||
polars <1.15 does not adjust the bins when they have equivalent min/max
|
||||
polars <1.5 with bin_count=...
|
||||
returns bins that range from -inf to +inf and has bin_count + 1 bins.
|
||||
for compat: convert `bin_count=` call to `bins=`
|
||||
"""
|
||||
lower = cast("float", self.native.min())
|
||||
upper = cast("float", self.native.max())
|
||||
|
||||
if lower == upper:
|
||||
lower -= 0.5
|
||||
upper += 0.5
|
||||
|
||||
width = (upper - lower) / bin_count
|
||||
return pl.int_range(0, bin_count + 1, eager=True) * width + lower
|
||||
|
||||
def _hist_from_data(
|
||||
self, bins: list[float] | None, bin_count: int | None, *, include_breakpoint: bool
|
||||
) -> PolarsDataFrame:
|
||||
"""Calculate histogram from non-empty data and post-process the results based on the backend version."""
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
series = self.native
|
||||
|
||||
# Polars inconsistently handles NaN values when computing histograms
|
||||
# against predefined bins: https://github.com/pola-rs/polars/issues/21082
|
||||
if BACKEND_VERSION < (1, 15) or bins is not None:
|
||||
series = series.fill_nan(None)
|
||||
|
||||
df = series.hist(
|
||||
bins,
|
||||
bin_count=bin_count,
|
||||
include_category=False,
|
||||
include_breakpoint=include_breakpoint,
|
||||
)
|
||||
|
||||
# Apply post-processing corrections
|
||||
|
||||
# Handle column naming
|
||||
if not include_breakpoint:
|
||||
col_name = df.columns[0]
|
||||
df = df.select(pl.col(col_name).alias("count"))
|
||||
elif BACKEND_VERSION < (1, 0): # pragma: no cover
|
||||
df = df.rename({"break_point": "breakpoint"})
|
||||
|
||||
if bins is not None: # pragma: no cover
|
||||
# polars<1.6 implicitly adds -inf and inf to either end of bins
|
||||
if BACKEND_VERSION < (1, 6):
|
||||
r = pl.int_range(0, len(df))
|
||||
df = df.filter((r > 0) & (r < len(df) - 1))
|
||||
# polars<1.27 makes the lowest bin a left/right closed interval
|
||||
if BACKEND_VERSION < (1, 27):
|
||||
df = (
|
||||
df.slice(0, 1)
|
||||
.with_columns(pl.col("count") + ((pl.lit(series) == bins[0]).sum()))
|
||||
.vstack(df.slice(1))
|
||||
)
|
||||
|
||||
return PolarsDataFrame.from_native(df, context=self)
|
||||
|
||||
def to_polars(self) -> pl.Series:
|
||||
return self.native
|
||||
|
||||
@property
|
||||
def dt(self) -> PolarsSeriesDateTimeNamespace:
|
||||
return PolarsSeriesDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def str(self) -> PolarsSeriesStringNamespace:
|
||||
return PolarsSeriesStringNamespace(self)
|
||||
|
||||
@property
|
||||
def cat(self) -> PolarsSeriesCatNamespace:
|
||||
return PolarsSeriesCatNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> PolarsSeriesStructNamespace:
|
||||
return PolarsSeriesStructNamespace(self)
|
||||
|
||||
__add__: Method[Self]
|
||||
__and__: Method[Self]
|
||||
__floordiv__: Method[Self]
|
||||
__invert__: Method[Self]
|
||||
__iter__: Method[Iterator[Any]]
|
||||
__mod__: Method[Self]
|
||||
__mul__: Method[Self]
|
||||
__or__: Method[Self]
|
||||
__pow__: Method[Self]
|
||||
__radd__: Method[Self]
|
||||
__rand__: Method[Self]
|
||||
__rfloordiv__: Method[Self]
|
||||
__rmod__: Method[Self]
|
||||
__rmul__: Method[Self]
|
||||
__ror__: Method[Self]
|
||||
__rsub__: Method[Self]
|
||||
__rtruediv__: Method[Self]
|
||||
__sub__: Method[Self]
|
||||
__truediv__: Method[Self]
|
||||
abs: Method[Self]
|
||||
all: Method[bool]
|
||||
any: Method[bool]
|
||||
arg_max: Method[int]
|
||||
arg_min: Method[int]
|
||||
arg_true: Method[Self]
|
||||
clip: Method[Self]
|
||||
count: Method[int]
|
||||
cum_max: Method[Self]
|
||||
cum_min: Method[Self]
|
||||
cum_prod: Method[Self]
|
||||
cum_sum: Method[Self]
|
||||
diff: Method[Self]
|
||||
drop_nulls: Method[Self]
|
||||
exp: Method[Self]
|
||||
fill_null: Method[Self]
|
||||
fill_nan: Method[Self]
|
||||
filter: Method[Self]
|
||||
gather_every: Method[Self]
|
||||
head: Method[Self]
|
||||
is_between: Method[Self]
|
||||
is_duplicated: Method[Self]
|
||||
is_empty: Method[bool]
|
||||
is_finite: Method[Self]
|
||||
is_first_distinct: Method[Self]
|
||||
is_in: Method[Self]
|
||||
is_last_distinct: Method[Self]
|
||||
is_null: Method[Self]
|
||||
is_sorted: Method[bool]
|
||||
is_unique: Method[Self]
|
||||
item: Method[Any]
|
||||
kurtosis: Method[float | None]
|
||||
len: Method[int]
|
||||
log: Method[Self]
|
||||
max: Method[Any]
|
||||
mean: Method[float]
|
||||
min: Method[Any]
|
||||
n_unique: Method[int]
|
||||
null_count: Method[int]
|
||||
quantile: Method[float]
|
||||
rank: Method[Self]
|
||||
round: Method[Self]
|
||||
sample: Method[Self]
|
||||
shift: Method[Self]
|
||||
skew: Method[float | None]
|
||||
sqrt: Method[Self]
|
||||
std: Method[float]
|
||||
sum: Method[float]
|
||||
tail: Method[Self]
|
||||
to_arrow: Method[pa.Array[Any]]
|
||||
to_frame: Method[PolarsDataFrame]
|
||||
to_list: Method[list[Any]]
|
||||
to_pandas: Method[pd.Series[Any]]
|
||||
unique: Method[Self]
|
||||
var: Method[float]
|
||||
zip_with: Method[Self]
|
||||
|
||||
@property
|
||||
def list(self) -> PolarsSeriesListNamespace:
|
||||
return PolarsSeriesListNamespace(self)
|
||||
|
||||
|
||||
class PolarsSeriesNamespace(PolarsAnyNamespace[PolarsSeries, pl.Series]):
|
||||
def __init__(self, series: PolarsSeries) -> None:
|
||||
self._series = series
|
||||
|
||||
@property
|
||||
def compliant(self) -> PolarsSeries:
|
||||
return self._series
|
||||
|
||||
@property
|
||||
def native(self) -> pl.Series:
|
||||
return self._series.native
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.compliant.name
|
||||
|
||||
def __narwhals_namespace__(self) -> PolarsNamespace:
|
||||
return self.compliant.__narwhals_namespace__()
|
||||
|
||||
def to_frame(self) -> PolarsDataFrame:
|
||||
return self.compliant.to_frame()
|
||||
|
||||
|
||||
class PolarsSeriesDateTimeNamespace(
|
||||
PolarsSeriesNamespace, PolarsDateTimeNamespace[PolarsSeries, pl.Series]
|
||||
): ...
|
||||
|
||||
|
||||
class PolarsSeriesStringNamespace(
|
||||
PolarsSeriesNamespace, PolarsStringNamespace[PolarsSeries, pl.Series]
|
||||
):
|
||||
def zfill(self, width: int) -> PolarsSeries:
|
||||
name = self.name
|
||||
ns = self.__narwhals_namespace__()
|
||||
return self.to_frame().select(ns.col(name).str.zfill(width)).get_column(name)
|
||||
|
||||
|
||||
class PolarsSeriesCatNamespace(
|
||||
PolarsSeriesNamespace, PolarsCatNamespace[PolarsSeries, pl.Series]
|
||||
): ...
|
||||
|
||||
|
||||
class PolarsSeriesListNamespace(
|
||||
PolarsSeriesNamespace, PolarsListNamespace[PolarsSeries, pl.Series]
|
||||
):
|
||||
def len(self) -> PolarsSeries:
|
||||
name = self.name
|
||||
ns = self.__narwhals_namespace__()
|
||||
return self.to_frame().select(ns.col(name).list.len()).get_column(name)
|
||||
|
||||
def contains(self, item: NonNestedLiteral) -> PolarsSeries:
|
||||
name = self.name
|
||||
ns = self.__narwhals_namespace__()
|
||||
return self.to_frame().select(ns.col(name).list.contains(item)).get_column(name)
|
||||
|
||||
|
||||
class PolarsSeriesStructNamespace(
|
||||
PolarsSeriesNamespace, PolarsStructNamespace[PolarsSeries, pl.Series]
|
||||
): ...
|
25
lib/python3.11/site-packages/narwhals/_polars/typing.py
Normal file
25
lib/python3.11/site-packages/narwhals/_polars/typing.py
Normal file
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations # pragma: no cover
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING, # pragma: no cover
|
||||
Union, # pragma: no cover
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
from typing import Literal, TypeVar
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
|
||||
IntoPolarsExpr: TypeAlias = Union[PolarsExpr, PolarsSeries]
|
||||
FrameT = TypeVar("FrameT", PolarsDataFrame, PolarsLazyFrame)
|
||||
NativeAccessor: TypeAlias = Literal[
|
||||
"arr", "cat", "dt", "list", "meta", "name", "str", "bin", "struct"
|
||||
]
|
351
lib/python3.11/site-packages/narwhals/_polars/utils.py
Normal file
351
lib/python3.11/site-packages/narwhals/_polars/utils.py
Normal file
@ -0,0 +1,351 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Final, Protocol, TypeVar, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
Version,
|
||||
_DeferredIterable,
|
||||
_StoresCompliant,
|
||||
_StoresNative,
|
||||
deep_getattr,
|
||||
isinstance_or_issubclass,
|
||||
)
|
||||
from narwhals.exceptions import (
|
||||
ColumnNotFoundError,
|
||||
ComputeError,
|
||||
DuplicateError,
|
||||
InvalidOperationError,
|
||||
NarwhalsError,
|
||||
ShapeError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterable, Iterator, Mapping
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from narwhals._polars.dataframe import Method
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
from narwhals._polars.typing import NativeAccessor
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType
|
||||
|
||||
T = TypeVar("T")
|
||||
NativeT = TypeVar(
|
||||
"NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr"
|
||||
)
|
||||
|
||||
NativeT_co = TypeVar("NativeT_co", "pl.Series", "pl.Expr", covariant=True)
|
||||
CompliantT_co = TypeVar("CompliantT_co", "PolarsSeries", "PolarsExpr", covariant=True)
|
||||
CompliantT = TypeVar("CompliantT", "PolarsSeries", "PolarsExpr")
|
||||
|
||||
BACKEND_VERSION = Implementation.POLARS._backend_version()
|
||||
"""Static backend version for `polars`."""
|
||||
|
||||
SERIES_RESPECTS_DTYPE: Final[bool] = BACKEND_VERSION >= (0, 20, 26)
|
||||
"""`pl.Series(dtype=...)` fixed in https://github.com/pola-rs/polars/pull/15962
|
||||
|
||||
Includes `SERIES_ACCEPTS_PD_INDEX`.
|
||||
"""
|
||||
|
||||
SERIES_ACCEPTS_PD_INDEX: Final[bool] = BACKEND_VERSION >= (0, 20, 7)
|
||||
"""`pl.Series(values: pd.Index)` fixed in https://github.com/pola-rs/polars/pull/14087"""
|
||||
|
||||
|
||||
@overload
|
||||
def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ...
|
||||
@overload
|
||||
def extract_native(obj: T) -> T: ...
|
||||
def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T:
|
||||
return obj.native if _is_compliant_polars(obj) else obj
|
||||
|
||||
|
||||
def _is_compliant_polars(
|
||||
obj: _StoresNative[NativeT] | Any,
|
||||
) -> TypeIs[_StoresNative[NativeT]]:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
|
||||
return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr))
|
||||
|
||||
|
||||
def extract_args_kwargs(
|
||||
args: Iterable[Any], kwds: Mapping[str, Any], /
|
||||
) -> tuple[Iterator[Any], dict[str, Any]]:
|
||||
it_args = (extract_native(arg) for arg in args)
|
||||
return it_args, {k: extract_native(v) for k, v in kwds.items()}
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def native_to_narwhals_dtype( # noqa: C901, PLR0912
|
||||
dtype: pl.DataType, version: Version
|
||||
) -> DType:
|
||||
dtypes = version.dtypes
|
||||
if dtype == pl.Float64:
|
||||
return dtypes.Float64()
|
||||
if dtype == pl.Float32:
|
||||
return dtypes.Float32()
|
||||
if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover
|
||||
# Not available for Polars pre 1.8.0
|
||||
return dtypes.Int128()
|
||||
if dtype == pl.Int64:
|
||||
return dtypes.Int64()
|
||||
if dtype == pl.Int32:
|
||||
return dtypes.Int32()
|
||||
if dtype == pl.Int16:
|
||||
return dtypes.Int16()
|
||||
if dtype == pl.Int8:
|
||||
return dtypes.Int8()
|
||||
if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover
|
||||
# Not available for Polars pre 1.8.0
|
||||
return dtypes.UInt128()
|
||||
if dtype == pl.UInt64:
|
||||
return dtypes.UInt64()
|
||||
if dtype == pl.UInt32:
|
||||
return dtypes.UInt32()
|
||||
if dtype == pl.UInt16:
|
||||
return dtypes.UInt16()
|
||||
if dtype == pl.UInt8:
|
||||
return dtypes.UInt8()
|
||||
if dtype == pl.String:
|
||||
return dtypes.String()
|
||||
if dtype == pl.Boolean:
|
||||
return dtypes.Boolean()
|
||||
if dtype == pl.Object:
|
||||
return dtypes.Object()
|
||||
if dtype == pl.Categorical:
|
||||
return dtypes.Categorical()
|
||||
if isinstance_or_issubclass(dtype, pl.Enum):
|
||||
if version is Version.V1:
|
||||
return dtypes.Enum() # type: ignore[call-arg]
|
||||
categories = _DeferredIterable(dtype.categories.to_list)
|
||||
return dtypes.Enum(categories)
|
||||
if dtype == pl.Date:
|
||||
return dtypes.Date()
|
||||
if isinstance_or_issubclass(dtype, pl.Datetime):
|
||||
return (
|
||||
dtypes.Datetime()
|
||||
if dtype is pl.Datetime
|
||||
else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, pl.Duration):
|
||||
return (
|
||||
dtypes.Duration()
|
||||
if dtype is pl.Duration
|
||||
else dtypes.Duration(dtype.time_unit)
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, pl.Struct):
|
||||
fields = [
|
||||
dtypes.Field(name, native_to_narwhals_dtype(tp, version))
|
||||
for name, tp in dtype
|
||||
]
|
||||
return dtypes.Struct(fields)
|
||||
if isinstance_or_issubclass(dtype, pl.List):
|
||||
return dtypes.List(native_to_narwhals_dtype(dtype.inner, version))
|
||||
if isinstance_or_issubclass(dtype, pl.Array):
|
||||
outer_shape = dtype.width if BACKEND_VERSION < (0, 20, 30) else dtype.size
|
||||
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, version), outer_shape)
|
||||
if dtype == pl.Decimal:
|
||||
return dtypes.Decimal()
|
||||
if dtype == pl.Time:
|
||||
return dtypes.Time()
|
||||
if dtype == pl.Binary:
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown()
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_PL_DTYPES: Mapping[type[DType], pl.DataType] = {
|
||||
dtypes.Float64: pl.Float64(),
|
||||
dtypes.Float32: pl.Float32(),
|
||||
dtypes.Binary: pl.Binary(),
|
||||
dtypes.String: pl.String(),
|
||||
dtypes.Boolean: pl.Boolean(),
|
||||
dtypes.Categorical: pl.Categorical(),
|
||||
dtypes.Date: pl.Date(),
|
||||
dtypes.Time: pl.Time(),
|
||||
dtypes.Int8: pl.Int8(),
|
||||
dtypes.Int16: pl.Int16(),
|
||||
dtypes.Int32: pl.Int32(),
|
||||
dtypes.Int64: pl.Int64(),
|
||||
dtypes.UInt8: pl.UInt8(),
|
||||
dtypes.UInt16: pl.UInt16(),
|
||||
dtypes.UInt32: pl.UInt32(),
|
||||
dtypes.UInt64: pl.UInt64(),
|
||||
dtypes.Object: pl.Object(),
|
||||
dtypes.Unknown: pl.Unknown(),
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (dtypes.Decimal,)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype( # noqa: C901
|
||||
dtype: IntoDType, version: Version
|
||||
) -> pl.DataType:
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if pl_type := NW_TO_PL_DTYPES.get(base_type):
|
||||
return pl_type
|
||||
if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
|
||||
# Not available for Polars pre 1.8.0
|
||||
return pl.Int128()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Enum):
|
||||
if version is Version.V1:
|
||||
msg = "Converting to Enum is not supported in narwhals.stable.v1"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance(dtype, dtypes.Enum):
|
||||
return pl.Enum(dtype.categories)
|
||||
msg = "Can not cast / initialize Enum without categories present"
|
||||
raise ValueError(msg)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
|
||||
if isinstance_or_issubclass(dtype, dtypes.List):
|
||||
return pl.List(narwhals_to_native_dtype(dtype.inner, version))
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct):
|
||||
fields = [
|
||||
pl.Field(field.name, narwhals_to_native_dtype(field.dtype, version))
|
||||
for field in dtype.fields
|
||||
]
|
||||
return pl.Struct(fields)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
|
||||
size = dtype.size
|
||||
kwargs = {"width": size} if BACKEND_VERSION < (0, 20, 30) else {"shape": size}
|
||||
return pl.Array(narwhals_to_native_dtype(dtype.inner, version), **kwargs)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES):
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for Polars."
|
||||
raise NotImplementedError(msg)
|
||||
return pl.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
def _is_polars_exception(exception: Exception) -> bool:
|
||||
if BACKEND_VERSION >= (1,):
|
||||
# Old versions of Polars didn't have PolarsError.
|
||||
return isinstance(exception, pl.exceptions.PolarsError)
|
||||
# Last attempt, for old Polars versions.
|
||||
return "polars.exceptions" in str(type(exception)) # pragma: no cover
|
||||
|
||||
|
||||
def _is_cudf_exception(exception: Exception) -> bool:
|
||||
# These exceptions are raised when running polars on GPUs via cuDF
|
||||
return str(exception).startswith("CUDF failure")
|
||||
|
||||
|
||||
def catch_polars_exception(exception: Exception) -> NarwhalsError | Exception:
|
||||
if isinstance(exception, pl.exceptions.ColumnNotFoundError):
|
||||
return ColumnNotFoundError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.ShapeError):
|
||||
return ShapeError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.InvalidOperationError):
|
||||
return InvalidOperationError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.DuplicateError):
|
||||
return DuplicateError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.ComputeError):
|
||||
return ComputeError(str(exception))
|
||||
if _is_polars_exception(exception) or _is_cudf_exception(exception):
|
||||
return NarwhalsError(str(exception)) # pragma: no cover
|
||||
# Just return exception as-is.
|
||||
return exception
|
||||
|
||||
|
||||
class PolarsAnyNamespace(
|
||||
_StoresCompliant[CompliantT_co],
|
||||
_StoresNative[NativeT_co],
|
||||
Protocol[CompliantT_co, NativeT_co],
|
||||
):
|
||||
_accessor: ClassVar[NativeAccessor]
|
||||
|
||||
def __getattr__(self, attr: str) -> Callable[..., CompliantT_co]:
|
||||
def func(*args: Any, **kwargs: Any) -> CompliantT_co:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
method = deep_getattr(self.native, self._accessor, attr)
|
||||
return self.compliant._with_native(method(*pos, **kwds))
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class PolarsDateTimeNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "dt"
|
||||
|
||||
def truncate(self, every: str) -> CompliantT:
|
||||
# Ensure consistent error message is raised.
|
||||
Interval.parse(every)
|
||||
return self.__getattr__("truncate")(every)
|
||||
|
||||
def offset_by(self, by: str) -> CompliantT:
|
||||
# Ensure consistent error message is raised.
|
||||
Interval.parse_no_constraints(by)
|
||||
return self.__getattr__("offset_by")(by)
|
||||
|
||||
to_string: Method[CompliantT]
|
||||
replace_time_zone: Method[CompliantT]
|
||||
convert_time_zone: Method[CompliantT]
|
||||
timestamp: Method[CompliantT]
|
||||
date: Method[CompliantT]
|
||||
year: Method[CompliantT]
|
||||
month: Method[CompliantT]
|
||||
day: Method[CompliantT]
|
||||
hour: Method[CompliantT]
|
||||
minute: Method[CompliantT]
|
||||
second: Method[CompliantT]
|
||||
millisecond: Method[CompliantT]
|
||||
microsecond: Method[CompliantT]
|
||||
nanosecond: Method[CompliantT]
|
||||
ordinal_day: Method[CompliantT]
|
||||
weekday: Method[CompliantT]
|
||||
total_minutes: Method[CompliantT]
|
||||
total_seconds: Method[CompliantT]
|
||||
total_milliseconds: Method[CompliantT]
|
||||
total_microseconds: Method[CompliantT]
|
||||
total_nanoseconds: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsStringNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "str"
|
||||
|
||||
# NOTE: Use `abstractmethod` if we have defs to implement, but also `Method` usage
|
||||
@abc.abstractmethod
|
||||
def zfill(self, width: int) -> CompliantT: ...
|
||||
|
||||
len_chars: Method[CompliantT]
|
||||
replace: Method[CompliantT]
|
||||
replace_all: Method[CompliantT]
|
||||
strip_chars: Method[CompliantT]
|
||||
starts_with: Method[CompliantT]
|
||||
ends_with: Method[CompliantT]
|
||||
contains: Method[CompliantT]
|
||||
slice: Method[CompliantT]
|
||||
split: Method[CompliantT]
|
||||
to_date: Method[CompliantT]
|
||||
to_datetime: Method[CompliantT]
|
||||
to_lowercase: Method[CompliantT]
|
||||
to_uppercase: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsCatNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "cat"
|
||||
get_categories: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsListNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "list"
|
||||
|
||||
@abc.abstractmethod
|
||||
def len(self) -> CompliantT: ...
|
||||
|
||||
get: Method[CompliantT]
|
||||
|
||||
unique: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsStructNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "struct"
|
||||
field: Method[CompliantT]
|
Reference in New Issue
Block a user