This commit is contained in:
2025-09-07 22:09:54 +02:00
parent e1b817252c
commit 2fc0d000b6
7796 changed files with 2159515 additions and 933 deletions

View File

@ -0,0 +1,678 @@
from __future__ import annotations
from collections.abc import Iterator, Mapping, Sequence, Sized
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast, overload
import polars as pl
from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.series import PolarsSeries
from narwhals._polars.utils import (
catch_polars_exception,
extract_args_kwargs,
native_to_narwhals_dtype,
)
from narwhals._utils import (
Implementation,
_into_arrow_table,
convert_str_slice_to_int_slice,
is_compliant_series,
is_index_selector,
is_range,
is_sequence_like,
is_slice_index,
is_slice_none,
parse_columns_to_drop,
requires,
)
from narwhals.dependencies import is_numpy_array_1d
from narwhals.exceptions import ColumnNotFoundError
if TYPE_CHECKING:
from collections.abc import Iterable
from types import ModuleType
from typing import Callable
import pandas as pd
import pyarrow as pa
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.group_by import PolarsGroupBy, PolarsLazyGroupBy
from narwhals._spark_like.utils import SparkSession
from narwhals._translate import IntoArrowTable
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
from narwhals._utils import Version, _LimitedContext
from narwhals.dataframe import DataFrame, LazyFrame
from narwhals.dtypes import DType
from narwhals.typing import (
IntoSchema,
JoinStrategy,
MultiColSelector,
MultiIndexSelector,
PivotAgg,
SingleIndexSelector,
_2DArray,
)
T = TypeVar("T")
R = TypeVar("R")
Method: TypeAlias = "Callable[..., R]"
"""Generic alias representing all methods implemented via `__getattr__`.
Where `R` is the return type.
"""
# DataFrame methods where PolarsDataFrame just defers to Polars.DataFrame directly.
INHERITED_METHODS = frozenset(
[
"clone",
"drop_nulls",
"estimated_size",
"explode",
"filter",
"gather_every",
"head",
"is_unique",
"item",
"iter_rows",
"join_asof",
"rename",
"row",
"rows",
"sample",
"select",
"sink_parquet",
"sort",
"tail",
"to_arrow",
"to_pandas",
"unique",
"with_columns",
"write_csv",
"write_parquet",
]
)
NativePolarsFrame = TypeVar("NativePolarsFrame", pl.DataFrame, pl.LazyFrame)
class PolarsBaseFrame(Generic[NativePolarsFrame]):
drop_nulls: Method[Self]
explode: Method[Self]
filter: Method[Self]
gather_every: Method[Self]
head: Method[Self]
join_asof: Method[Self]
rename: Method[Self]
select: Method[Self]
sort: Method[Self]
tail: Method[Self]
unique: Method[Self]
with_columns: Method[Self]
_native_frame: NativePolarsFrame
_implementation = Implementation.POLARS
_version: Version
def __init__(
self,
df: NativePolarsFrame,
*,
version: Version,
validate_backend_version: bool = False,
) -> None:
self._native_frame = df
self._version = version
if validate_backend_version:
self._validate_backend_version()
def _validate_backend_version(self) -> None:
"""Raise if installed version below `nw._utils.MIN_VERSIONS`.
**Only use this when moving between backends.**
Otherwise, the validation will have taken place already.
"""
_ = self._implementation._backend_version()
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def native(self) -> NativePolarsFrame:
return self._native_frame
@property
def columns(self) -> list[str]:
return self.native.columns
def __narwhals_namespace__(self) -> PolarsNamespace:
return PolarsNamespace(version=self._version)
def __native_namespace__(self) -> ModuleType:
if self._implementation is Implementation.POLARS:
return self._implementation.to_native_namespace()
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def _with_native(self, df: NativePolarsFrame) -> Self:
return self.__class__(df, version=self._version)
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version)
@classmethod
def from_native(cls, data: NativePolarsFrame, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version)
def simple_select(self, *column_names: str) -> Self:
return self._with_native(self.native.select(*column_names))
def aggregate(self, *exprs: Any) -> Self:
return self.select(*exprs)
@property
def schema(self) -> dict[str, DType]:
return self.collect_schema()
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
how_native = (
"outer" if (self._backend_version < (0, 20, 29) and how == "full") else how
)
return self._with_native(
self.native.join(
other=other.native,
how=how_native, # type: ignore[arg-type]
left_on=left_on,
right_on=right_on,
suffix=suffix,
)
)
def top_k(
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool]
) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(
self.native.top_k(
k=k,
by=by,
descending=reverse, # type: ignore[call-arg]
)
)
return self._with_native(self.native.top_k(k=k, by=by, reverse=reverse))
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(
self.native.melt(
id_vars=index,
value_vars=on,
variable_name=variable_name,
value_name=value_name,
)
)
return self._with_native(
self.native.unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)
)
def collect_schema(self) -> dict[str, DType]:
df = self.native
schema = df.schema if self._backend_version < (1,) else df.collect_schema()
return {
name: native_to_narwhals_dtype(dtype, self._version)
for name, dtype in schema.items()
}
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
frame = self.native
if order_by is None:
result = frame.with_row_index(name)
else:
end = pl.count() if self._backend_version < (0, 20, 5) else pl.len()
result = frame.select(
pl.int_range(start=0, end=end).sort_by(order_by).alias(name), pl.all()
)
return self._with_native(result)
class PolarsDataFrame(PolarsBaseFrame[pl.DataFrame]):
clone: Method[Self]
collect: Method[CompliantDataFrameAny]
estimated_size: Method[int | float]
gather_every: Method[Self]
item: Method[Any]
iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]]
is_unique: Method[PolarsSeries]
row: Method[tuple[Any, ...]]
rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]]
sample: Method[Self]
to_arrow: Method[pa.Table]
to_pandas: Method[pd.DataFrame]
# NOTE: `write_csv` requires an `@overload` for `str | None`
# Can't do that here 😟
write_csv: Method[Any]
write_parquet: Method[None]
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
if context._implementation._backend_version() >= (1, 3):
native = pl.DataFrame(data)
else: # pragma: no cover
native = cast("pl.DataFrame", pl.from_arrow(_into_arrow_table(data, context)))
return cls.from_native(native, context=context)
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _LimitedContext,
schema: IntoSchema | None,
) -> Self:
from narwhals.schema import Schema
pl_schema = Schema(schema).to_polars() if schema is not None else schema
return cls.from_native(pl.from_dict(data, pl_schema), context=context)
@staticmethod
def _is_native(obj: pl.DataFrame | Any) -> TypeIs[pl.DataFrame]:
return isinstance(obj, pl.DataFrame)
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _LimitedContext, # NOTE: Maybe only `Implementation`?
schema: IntoSchema | Sequence[str] | None,
) -> Self:
from narwhals.schema import Schema
pl_schema = (
Schema(schema).to_polars()
if isinstance(schema, (Mapping, Schema))
else schema
)
return cls.from_native(pl.from_numpy(data, pl_schema), context=context)
def to_narwhals(self) -> DataFrame[pl.DataFrame]:
return self._version.dataframe(self, level="full")
def __repr__(self) -> str: # pragma: no cover
return "PolarsDataFrame"
def __narwhals_dataframe__(self) -> Self:
return self
@overload
def _from_native_object(self, obj: pl.Series) -> PolarsSeries: ...
@overload
def _from_native_object(self, obj: pl.DataFrame) -> Self: ...
@overload
def _from_native_object(self, obj: T) -> T: ...
def _from_native_object(
self, obj: pl.Series | pl.DataFrame | T
) -> Self | PolarsSeries | T:
if isinstance(obj, pl.Series):
return PolarsSeries.from_native(obj, context=self)
if self._is_native(obj):
return self._with_native(obj)
# scalar
return obj
def __len__(self) -> int:
return len(self.native)
def __getattr__(self, attr: str) -> Any:
if attr not in INHERITED_METHODS: # pragma: no cover
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
raise AttributeError(msg)
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
try:
return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
raise ColumnNotFoundError(msg) from e
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
return func
def __array__(
self, dtype: Any | None = None, *, copy: bool | None = None
) -> _2DArray:
if self._backend_version < (0, 20, 28) and copy is not None:
msg = "`copy` in `__array__` is only supported for 'polars>=0.20.28'"
raise NotImplementedError(msg)
if self._backend_version < (0, 20, 28):
return self.native.__array__(dtype)
return self.native.__array__(dtype)
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
return self.native.to_numpy()
@property
def shape(self) -> tuple[int, int]:
return self.native.shape
def __getitem__( # noqa: C901, PLR0912
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[PolarsSeries],
MultiColSelector[PolarsSeries],
],
) -> Any:
rows, columns = item
if self._backend_version > (0, 20, 30):
rows_native = rows.native if is_compliant_series(rows) else rows
columns_native = columns.native if is_compliant_series(columns) else columns
selector = rows_native, columns_native
selected = self.native.__getitem__(selector) # type: ignore[index]
return self._from_native_object(selected)
else: # pragma: no cover # noqa: RET505
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
# Polars version we support
# This mostly mirrors the logic in `EagerDataFrame.__getitem__`.
rows = list(rows) if isinstance(rows, tuple) else rows
columns = list(columns) if isinstance(columns, tuple) else columns
if is_numpy_array_1d(columns):
columns = columns.tolist()
native = self.native
if not is_slice_none(columns):
if isinstance(columns, Sized) and len(columns) == 0:
return self.select()
if is_index_selector(columns):
if is_slice_index(columns) or is_range(columns):
native = native.select(
self.columns[slice(columns.start, columns.stop, columns.step)]
)
# NOTE: `mypy` loses track of `PolarsSeries` when `is_compliant_series` is used here
# `pyright` is fine
elif isinstance(columns, PolarsSeries):
native = native[:, columns.native.to_list()]
else:
native = native[:, columns]
elif isinstance(columns, slice):
native = native.select(
self.columns[
slice(*convert_str_slice_to_int_slice(columns, self.columns))
]
)
elif is_compliant_series(columns):
native = native.select(columns.native.to_list())
elif is_sequence_like(columns):
native = native.select(columns)
else:
msg = f"Unreachable code, got unexpected type: {type(columns)}"
raise AssertionError(msg)
if not is_slice_none(rows):
if isinstance(rows, int):
native = native[[rows], :]
elif isinstance(rows, (slice, range)):
native = native[rows, :]
elif is_compliant_series(rows):
native = native[rows.native, :]
elif is_sequence_like(rows):
native = native[rows, :]
else:
msg = f"Unreachable code, got unexpected type: {type(rows)}"
raise AssertionError(msg)
return self._with_native(native)
def get_column(self, name: str) -> PolarsSeries:
return PolarsSeries.from_native(self.native.get_column(name), context=self)
def iter_columns(self) -> Iterator[PolarsSeries]:
for series in self.native.iter_columns():
yield PolarsSeries.from_native(series, context=self)
def lazy(
self,
backend: _LazyAllowedImpl | None = None,
*,
session: SparkSession | None = None,
) -> CompliantLazyFrameAny:
if backend is None or backend is Implementation.POLARS:
return PolarsLazyFrame.from_native(self.native.lazy(), context=self)
if backend is Implementation.DUCKDB:
import duckdb # ignore-banned-import
from narwhals._duckdb.dataframe import DuckDBLazyFrame
_df = self.native
return DuckDBLazyFrame(
duckdb.table("_df"), validate_backend_version=True, version=self._version
)
if backend is Implementation.DASK:
import dask.dataframe as dd # ignore-banned-import
from narwhals._dask.dataframe import DaskLazyFrame
return DaskLazyFrame(
dd.from_pandas(self.native.to_pandas()),
validate_backend_version=True,
version=self._version,
)
if backend is Implementation.IBIS:
import ibis # ignore-banned-import
from narwhals._ibis.dataframe import IbisLazyFrame
return IbisLazyFrame(
ibis.memtable(self.native, columns=self.columns),
validate_backend_version=True,
version=self._version,
)
if backend.is_spark_like():
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
if session is None:
msg = "Spark like backends require `session` to be not None."
raise ValueError(msg)
return SparkLikeLazyFrame._from_compliant_dataframe(
self, # pyright: ignore[reportArgumentType]
session=session,
implementation=backend,
version=self._version,
)
raise AssertionError # pragma: no cover
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, PolarsSeries] | dict[str, list[Any]]:
if as_series:
return {
name: PolarsSeries.from_native(col, context=self)
for name, col in self.native.to_dict().items()
}
return self.native.to_dict(as_series=False)
def group_by(
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
) -> PolarsGroupBy:
from narwhals._polars.group_by import PolarsGroupBy
return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys)
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
to_drop = parse_columns_to_drop(self, columns, strict=strict)
return self._with_native(self.native.drop(to_drop))
@requires.backend_version((1,))
def pivot(
self,
on: Sequence[str],
*,
index: Sequence[str] | None,
values: Sequence[str] | None,
aggregate_function: PivotAgg | None,
sort_columns: bool,
separator: str,
) -> Self:
try:
result = self.native.pivot(
on,
index=index,
values=values,
aggregate_function=aggregate_function,
sort_columns=sort_columns,
separator=separator,
)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
return self._from_native_object(result)
def to_polars(self) -> pl.DataFrame:
return self.native
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
try:
return super().join(
other=other, how=how, left_on=left_on, right_on=right_on, suffix=suffix
)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
def top_k(
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool]
) -> Self:
try:
return super().top_k(k=k, by=by, reverse=reverse)
except Exception as e: # noqa: BLE001 # pragma: no cover
raise catch_polars_exception(e) from None
class PolarsLazyFrame(PolarsBaseFrame[pl.LazyFrame]):
sink_parquet: Method[None]
@staticmethod
def _is_native(obj: pl.LazyFrame | Any) -> TypeIs[pl.LazyFrame]:
return isinstance(obj, pl.LazyFrame)
def to_narwhals(self) -> LazyFrame[pl.LazyFrame]:
return self._version.lazyframe(self, level="lazy")
def __repr__(self) -> str: # pragma: no cover
return "PolarsLazyFrame"
def __narwhals_lazyframe__(self) -> Self:
return self
def __getattr__(self, attr: str) -> Any:
if attr not in INHERITED_METHODS: # pragma: no cover
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
raise AttributeError(msg)
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
try:
return self._with_native(getattr(self.native, attr)(*pos, **kwds))
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
raise ColumnNotFoundError(str(e)) from e
return func
def _iter_columns(self) -> Iterator[PolarsSeries]: # pragma: no cover
yield from self.collect(Implementation.POLARS).iter_columns()
def collect_schema(self) -> dict[str, DType]:
try:
return super().collect_schema()
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny:
try:
result = self.native.collect(**kwargs)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
if backend is None or backend is Implementation.POLARS:
return PolarsDataFrame.from_native(result, context=self)
if backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
result.to_pandas(),
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=False,
)
if backend is Implementation.PYARROW:
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
result.to_arrow(),
validate_backend_version=True,
version=self._version,
validate_column_names=False,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
def group_by(
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
) -> PolarsLazyGroupBy:
from narwhals._polars.group_by import PolarsLazyGroupBy
return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(self.native.drop(columns))
return self._with_native(self.native.drop(columns, strict=strict))

View File

@ -0,0 +1,479 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Literal
import polars as pl
from narwhals._polars.utils import (
PolarsAnyNamespace,
PolarsCatNamespace,
PolarsDateTimeNamespace,
PolarsListNamespace,
PolarsStringNamespace,
PolarsStructNamespace,
extract_args_kwargs,
extract_native,
narwhals_to_native_dtype,
)
from narwhals._utils import Implementation, requires
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from typing_extensions import Self
from narwhals._expression_parsing import ExprKind, ExprMetadata
from narwhals._polars.dataframe import Method
from narwhals._polars.namespace import PolarsNamespace
from narwhals._utils import Version
from narwhals.typing import IntoDType, ModeKeepStrategy, NumericLiteral
class PolarsExpr:
# CompliantExpr
_implementation: Implementation = Implementation.POLARS
_version: Version
_native_expr: pl.Expr
_metadata: ExprMetadata | None = None
_evaluate_output_names: Any
_alias_output_names: Any
__call__: Any
# CompliantExpr + builtin descriptor
# TODO @dangotbanned: Remove in #2713
@classmethod
def from_column_names(cls, *_: Any, **__: Any) -> Self:
raise NotImplementedError
@classmethod
def from_column_indices(cls, *_: Any, **__: Any) -> Self:
raise NotImplementedError
@staticmethod
def _eval_names_indices(*_: Any) -> Any:
raise NotImplementedError
def __narwhals_expr__(self) -> Self: # pragma: no cover
return self
def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover
from narwhals._polars.namespace import PolarsNamespace
return PolarsNamespace(version=self._version)
def __init__(self, expr: pl.Expr, version: Version) -> None:
self._native_expr = expr
self._version = version
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def native(self) -> pl.Expr:
return self._native_expr
def __repr__(self) -> str: # pragma: no cover
return "PolarsExpr"
def _with_native(self, expr: pl.Expr) -> Self:
return self.__class__(expr, self._version)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
# Let Polars do its thing.
return self
def __getattr__(self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
return self._with_native(getattr(self.native, attr)(*pos, **kwds))
return func
def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]:
name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples"
return {name: min_samples}
def cast(self, dtype: IntoDType) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
return self._with_native(self.native.cast(dtype_pl))
def ewm_mean(
self,
*,
com: float | None,
span: float | None,
half_life: float | None,
alpha: float | None,
adjust: bool,
min_samples: int,
ignore_nulls: bool,
) -> Self:
native = self.native.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
ignore_nulls=ignore_nulls,
**self._renamed_min_periods(min_samples),
)
if self._backend_version < (1,): # pragma: no cover
native = pl.when(~self.native.is_null()).then(native).otherwise(None)
return self._with_native(native)
def is_nan(self) -> Self:
if self._backend_version >= (1, 18):
native = self.native.is_nan()
else: # pragma: no cover
native = pl.when(self.native.is_not_null()).then(self.native.is_nan())
return self._with_native(native)
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
if self._backend_version < (1, 9):
if order_by:
msg = "`order_by` in Polars requires version 1.10 or greater"
raise NotImplementedError(msg)
native = self.native.over(partition_by or pl.lit(1))
else:
native = self.native.over(
partition_by or pl.lit(1), order_by=order_by or None
)
return self._with_native(native)
@requires.backend_version((1,))
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
kwds = self._renamed_min_periods(min_samples)
native = self.native.rolling_var(
window_size=window_size, center=center, ddof=ddof, **kwds
)
return self._with_native(native)
@requires.backend_version((1,))
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
kwds = self._renamed_min_periods(min_samples)
native = self.native.rolling_std(
window_size=window_size, center=center, ddof=ddof, **kwds
)
return self._with_native(native)
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
kwds = self._renamed_min_periods(min_samples)
native = self.native.rolling_sum(window_size=window_size, center=center, **kwds)
return self._with_native(native)
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
kwds = self._renamed_min_periods(min_samples)
native = self.native.rolling_mean(window_size=window_size, center=center, **kwds)
return self._with_native(native)
def map_batches(
self,
function: Callable[[Any], Any],
return_dtype: IntoDType | None,
*,
returns_scalar: bool,
) -> Self:
pl_version = self._backend_version
return_dtype_pl = (
narwhals_to_native_dtype(return_dtype, self._version)
if return_dtype is not None
else None
if pl_version < (1, 32)
else pl.self_dtype()
)
kwargs = {} if pl_version < (0, 20, 31) else {"returns_scalar": returns_scalar}
native = self.native.map_batches(function, return_dtype_pl, **kwargs)
return self._with_native(native)
@requires.backend_version((1,))
def replace_strict(
self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any],
*,
return_dtype: IntoDType | None,
) -> Self:
return_dtype_pl = (
narwhals_to_native_dtype(return_dtype, self._version)
if return_dtype
else None
)
native = self.native.replace_strict(old, new, return_dtype=return_dtype_pl)
return self._with_native(native)
def __eq__(self, other: object) -> Self: # type: ignore[override]
return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator]
def __ne__(self, other: object) -> Self: # type: ignore[override]
return self._with_native(self.native.__ne__(extract_native(other))) # type: ignore[operator]
def __ge__(self, other: Any) -> Self:
return self._with_native(self.native.__ge__(extract_native(other)))
def __gt__(self, other: Any) -> Self:
return self._with_native(self.native.__gt__(extract_native(other)))
def __le__(self, other: Any) -> Self:
return self._with_native(self.native.__le__(extract_native(other)))
def __lt__(self, other: Any) -> Self:
return self._with_native(self.native.__lt__(extract_native(other)))
def __and__(self, other: PolarsExpr | bool | Any) -> Self:
return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator]
def __or__(self, other: PolarsExpr | bool | Any) -> Self:
return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator]
def __add__(self, other: Any) -> Self:
return self._with_native(self.native.__add__(extract_native(other)))
def __sub__(self, other: Any) -> Self:
return self._with_native(self.native.__sub__(extract_native(other)))
def __mul__(self, other: Any) -> Self:
return self._with_native(self.native.__mul__(extract_native(other)))
def __pow__(self, other: Any) -> Self:
return self._with_native(self.native.__pow__(extract_native(other)))
def __truediv__(self, other: Any) -> Self:
return self._with_native(self.native.__truediv__(extract_native(other)))
def __floordiv__(self, other: Any) -> Self:
return self._with_native(self.native.__floordiv__(extract_native(other)))
def __mod__(self, other: Any) -> Self:
return self._with_native(self.native.__mod__(extract_native(other)))
def __invert__(self) -> Self:
return self._with_native(self.native.__invert__())
def cum_count(self, *, reverse: bool) -> Self:
return self._with_native(self.native.cum_count(reverse=reverse))
def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
left = self.native
right = other.native if isinstance(other, PolarsExpr) else pl.lit(other)
if self._backend_version < (1, 32, 0):
lower_bound = right.abs()
tolerance = (left.abs().clip(lower_bound) * rel_tol).clip(abs_tol)
# Values are close if abs_diff <= tolerance, and both finite
abs_diff = (left - right).abs()
all_ = pl.all_horizontal
is_close = all_((abs_diff <= tolerance), left.is_finite(), right.is_finite())
# Handle infinity cases: infinities are "close" only if they have the same sign
is_same_inf = all_(
left.is_infinite(), right.is_infinite(), (left.sign() == right.sign())
)
# Handle nan cases:
# * nans_equals = True => if both values are NaN, then True
# * nans_equals = False => if any value is NaN, then False
left_is_nan, right_is_nan = left.is_nan(), right.is_nan()
either_nan = left_is_nan | right_is_nan
result = (is_close | is_same_inf) & either_nan.not_()
if nans_equal:
result = result | (left_is_nan & right_is_nan)
else:
result = left.is_close(
right, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)
return self._with_native(result)
def mode(self, *, keep: ModeKeepStrategy) -> Self:
result = self.native.mode()
return self._with_native(result.first() if keep == "any" else result)
@property
def dt(self) -> PolarsExprDateTimeNamespace:
return PolarsExprDateTimeNamespace(self)
@property
def str(self) -> PolarsExprStringNamespace:
return PolarsExprStringNamespace(self)
@property
def cat(self) -> PolarsExprCatNamespace:
return PolarsExprCatNamespace(self)
@property
def name(self) -> PolarsExprNameNamespace:
return PolarsExprNameNamespace(self)
@property
def list(self) -> PolarsExprListNamespace:
return PolarsExprListNamespace(self)
@property
def struct(self) -> PolarsExprStructNamespace:
return PolarsExprStructNamespace(self)
# Polars
abs: Method[Self]
all: Method[Self]
any: Method[Self]
alias: Method[Self]
arg_max: Method[Self]
arg_min: Method[Self]
arg_true: Method[Self]
clip: Method[Self]
count: Method[Self]
cum_max: Method[Self]
cum_min: Method[Self]
cum_prod: Method[Self]
cum_sum: Method[Self]
diff: Method[Self]
drop_nulls: Method[Self]
exp: Method[Self]
fill_null: Method[Self]
fill_nan: Method[Self]
gather_every: Method[Self]
head: Method[Self]
is_between: Method[Self]
is_duplicated: Method[Self]
is_finite: Method[Self]
is_first_distinct: Method[Self]
is_in: Method[Self]
is_last_distinct: Method[Self]
is_null: Method[Self]
is_unique: Method[Self]
kurtosis: Method[Self]
len: Method[Self]
log: Method[Self]
max: Method[Self]
mean: Method[Self]
median: Method[Self]
min: Method[Self]
n_unique: Method[Self]
null_count: Method[Self]
quantile: Method[Self]
rank: Method[Self]
round: Method[Self]
sample: Method[Self]
shift: Method[Self]
skew: Method[Self]
sqrt: Method[Self]
std: Method[Self]
sum: Method[Self]
sort: Method[Self]
tail: Method[Self]
unique: Method[Self]
var: Method[Self]
__rfloordiv__: Method[Self]
__rsub__: Method[Self]
__rmod__: Method[Self]
__rpow__: Method[Self]
__rtruediv__: Method[Self]
class PolarsExprNamespace(PolarsAnyNamespace[PolarsExpr, pl.Expr]):
def __init__(self, expr: PolarsExpr) -> None:
self._expr = expr
@property
def compliant(self) -> PolarsExpr:
return self._expr
@property
def native(self) -> pl.Expr:
return self._expr.native
class PolarsExprDateTimeNamespace(
PolarsExprNamespace, PolarsDateTimeNamespace[PolarsExpr, pl.Expr]
): ...
class PolarsExprStringNamespace(
PolarsExprNamespace, PolarsStringNamespace[PolarsExpr, pl.Expr]
):
def zfill(self, width: int) -> PolarsExpr:
backend_version = self.compliant._backend_version
native_result = self.native.str.zfill(width)
if backend_version < (0, 20, 5): # pragma: no cover
# Reason:
# `TypeError: argument 'length': 'Expr' object cannot be interpreted as an integer`
# in `native_expr.str.slice(1, length)`
msg = "`zfill` is only available in 'polars>=0.20.5', found version '0.20.4'."
raise NotImplementedError(msg)
if backend_version <= (1, 30, 0):
length = self.native.str.len_chars()
less_than_width = length < width
plus = "+"
starts_with_plus = self.native.str.starts_with(plus)
native_result = (
pl.when(starts_with_plus & less_than_width)
.then(
self.native.str.slice(1, length)
.str.zfill(width - 1)
.str.pad_start(width, plus)
)
.otherwise(native_result)
)
return self.compliant._with_native(native_result)
class PolarsExprCatNamespace(
PolarsExprNamespace, PolarsCatNamespace[PolarsExpr, pl.Expr]
): ...
class PolarsExprNameNamespace(PolarsExprNamespace):
_accessor = "name"
keep: Method[PolarsExpr]
map: Method[PolarsExpr]
prefix: Method[PolarsExpr]
suffix: Method[PolarsExpr]
to_lowercase: Method[PolarsExpr]
to_uppercase: Method[PolarsExpr]
class PolarsExprListNamespace(
PolarsExprNamespace, PolarsListNamespace[PolarsExpr, pl.Expr]
):
def len(self) -> PolarsExpr:
native_expr = self.native
native_result = native_expr.list.len()
if self.compliant._backend_version < (1, 16): # pragma: no cover
native_result = (
pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32())
)
elif self.compliant._backend_version < (1, 17): # pragma: no cover
native_result = native_result.cast(pl.UInt32())
return self.compliant._with_native(native_result)
def contains(self, item: Any) -> PolarsExpr:
if self.compliant._backend_version < (1, 28):
result: pl.Expr = pl.when(self.native.is_not_null()).then(
self.native.list.contains(item)
)
else:
result = self.native.list.contains(item)
return self.compliant._with_native(result)
class PolarsExprStructNamespace(
PolarsExprNamespace, PolarsStructNamespace[PolarsExpr, pl.Expr]
): ...

View File

@ -0,0 +1,76 @@
from __future__ import annotations
from typing import TYPE_CHECKING, cast
from narwhals._utils import is_sequence_of
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from polars.dataframe.group_by import GroupBy as NativeGroupBy
from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
class PolarsGroupBy:
_compliant_frame: PolarsDataFrame
_grouped: NativeGroupBy
@property
def compliant(self) -> PolarsDataFrame:
return self._compliant_frame
def __init__(
self,
df: PolarsDataFrame,
keys: Sequence[PolarsExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
self._keys = list(keys)
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
self._grouped = (
self.compliant.native.group_by(keys)
if is_sequence_of(keys, str)
else self.compliant.native.group_by(arg.native for arg in keys)
)
def agg(self, *aggs: PolarsExpr) -> PolarsDataFrame:
agg_result = self._grouped.agg(arg.native for arg in aggs)
return self.compliant._with_native(agg_result)
def __iter__(self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]:
for key, df in self._grouped:
yield tuple(cast("str", key)), self.compliant._with_native(df)
class PolarsLazyGroupBy:
_compliant_frame: PolarsLazyFrame
_grouped: NativeLazyGroupBy
@property
def compliant(self) -> PolarsLazyFrame:
return self._compliant_frame
def __init__(
self,
df: PolarsLazyFrame,
keys: Sequence[PolarsExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
self._keys = list(keys)
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
self._grouped = (
self.compliant.native.group_by(keys)
if is_sequence_of(keys, str)
else self.compliant.native.group_by(arg.native for arg in keys)
)
def agg(self, *aggs: PolarsExpr) -> PolarsLazyFrame:
agg_result = self._grouped.agg(arg.native for arg in aggs)
return self.compliant._with_native(agg_result)

View File

@ -0,0 +1,281 @@
from __future__ import annotations
import operator
from typing import TYPE_CHECKING, Any, Literal, cast, overload
import polars as pl
from narwhals._expression_parsing import is_expr, is_series
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype
from narwhals._utils import Implementation, requires, zip_strict
from narwhals.dependencies import is_numpy_array_2d
from narwhals.dtypes import DType
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from datetime import timezone
from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen
from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame
from narwhals._polars.typing import FrameT
from narwhals._utils import Version, _LimitedContext
from narwhals.expr import Expr
from narwhals.series import Series
from narwhals.typing import (
Into1DArray,
IntoDType,
IntoSchema,
NonNestedLiteral,
TimeUnit,
_1DArray,
_2DArray,
)
class PolarsNamespace:
all: Method[PolarsExpr]
coalesce: Method[PolarsExpr]
col: Method[PolarsExpr]
exclude: Method[PolarsExpr]
sum_horizontal: Method[PolarsExpr]
min_horizontal: Method[PolarsExpr]
max_horizontal: Method[PolarsExpr]
when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]]
_implementation: Implementation = Implementation.POLARS
_version: Version
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
def __init__(self, *, version: Version) -> None:
self._version = version
def __getattr__(self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
return self._expr(getattr(pl, attr)(*pos, **kwds), version=self._version)
return func
@property
def _dataframe(self) -> type[PolarsDataFrame]:
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame
@property
def _lazyframe(self) -> type[PolarsLazyFrame]:
from narwhals._polars.dataframe import PolarsLazyFrame
return PolarsLazyFrame
@property
def _expr(self) -> type[PolarsExpr]:
return PolarsExpr
@property
def _series(self) -> type[PolarsSeries]:
return PolarsSeries
def parse_into_expr(
self,
data: Expr | NonNestedLiteral | Series[pl.Series] | _1DArray,
/,
*,
str_as_lit: bool,
) -> PolarsExpr | None:
if data is None:
# NOTE: To avoid `pl.lit(None)` failing this `None` check
# https://github.com/pola-rs/polars/blob/58dd8e5770f16a9bef9009a1c05f00e15a5263c7/py-polars/polars/expr/expr.py#L2870-L2872
return data
if is_expr(data):
expr = data._to_compliant_expr(self)
assert isinstance(expr, self._expr) # noqa: S101
return expr
if isinstance(data, str) and not str_as_lit:
return self.col(data)
return self.lit(data.to_native() if is_series(data) else data, None)
@overload
def from_native(self, data: pl.DataFrame, /) -> PolarsDataFrame: ...
@overload
def from_native(self, data: pl.LazyFrame, /) -> PolarsLazyFrame: ...
@overload
def from_native(self, data: pl.Series, /) -> PolarsSeries: ...
def from_native(
self, data: pl.DataFrame | pl.LazyFrame | pl.Series | Any, /
) -> PolarsDataFrame | PolarsLazyFrame | PolarsSeries:
if self._dataframe._is_native(data):
return self._dataframe.from_native(data, context=self)
if self._series._is_native(data):
return self._series.from_native(data, context=self)
if self._lazyframe._is_native(data):
return self._lazyframe.from_native(data, context=self)
msg = f"Unsupported type: {type(data).__name__!r}" # pragma: no cover
raise TypeError(msg) # pragma: no cover
@overload
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> PolarsSeries: ...
@overload
def from_numpy(
self, data: _2DArray, /, schema: IntoSchema | Sequence[str] | None
) -> PolarsDataFrame: ...
def from_numpy(
self,
data: Into1DArray | _2DArray,
/,
schema: IntoSchema | Sequence[str] | None = None,
) -> PolarsDataFrame | PolarsSeries:
if is_numpy_array_2d(data):
return self._dataframe.from_numpy(data, schema=schema, context=self)
return self._series.from_numpy(data, context=self) # pragma: no cover
@requires.backend_version(
(1, 0, 0), "Please use `col` for columns selection instead."
)
def nth(self, *indices: int) -> PolarsExpr:
return self._expr(pl.nth(*indices), version=self._version)
def len(self) -> PolarsExpr:
if self._backend_version < (0, 20, 5):
return self._expr(pl.count().alias("len"), self._version)
return self._expr(pl.len(), self._version)
def all_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr:
it = (expr.fill_null(True) for expr in exprs) if ignore_nulls else iter(exprs)
return self._expr(pl.all_horizontal(*(expr.native for expr in it)), self._version)
def any_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr:
it = (expr.fill_null(False) for expr in exprs) if ignore_nulls else iter(exprs)
return self._expr(pl.any_horizontal(*(expr.native for expr in it)), self._version)
def concat(
self,
items: Iterable[FrameT],
*,
how: Literal["vertical", "horizontal", "diagonal"],
) -> PolarsDataFrame | PolarsLazyFrame:
result = pl.concat((item.native for item in items), how=how)
if isinstance(result, pl.DataFrame):
return self._dataframe(result, version=self._version)
return self._lazyframe.from_native(result, context=self)
def lit(self, value: Any, dtype: IntoDType | None) -> PolarsExpr:
if dtype is not None:
return self._expr(
pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._version)),
version=self._version,
)
return self._expr(pl.lit(value), version=self._version)
def mean_horizontal(self, *exprs: PolarsExpr) -> PolarsExpr:
if self._backend_version < (0, 20, 8):
return self._expr(
pl.sum_horizontal(e._native_expr for e in exprs)
/ pl.sum_horizontal(1 - e.is_null()._native_expr for e in exprs),
version=self._version,
)
return self._expr(
pl.mean_horizontal(e._native_expr for e in exprs), version=self._version
)
def concat_str(
self, *exprs: PolarsExpr, separator: str, ignore_nulls: bool
) -> PolarsExpr:
pl_exprs: list[pl.Expr] = [expr._native_expr for expr in exprs]
if self._backend_version < (0, 20, 6):
null_mask = [expr.is_null() for expr in pl_exprs]
sep = pl.lit(separator)
if not ignore_nulls:
null_mask_result = pl.any_horizontal(*null_mask)
output_expr = pl.reduce(
lambda x, y: x.cast(pl.String()) + sep + y.cast(pl.String()), # type: ignore[arg-type,return-value]
pl_exprs,
)
result = pl.when(~null_mask_result).then(output_expr)
else:
init_value, *values = [
pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String()))
for expr, nm in zip_strict(pl_exprs, null_mask)
]
separators = [
pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1]
]
result = pl.fold( # type: ignore[assignment]
acc=init_value,
function=operator.add,
exprs=[s + v for s, v in zip_strict(separators, values)],
)
return self._expr(result, version=self._version)
return self._expr(
pl.concat_str(pl_exprs, separator=separator, ignore_nulls=ignore_nulls),
version=self._version,
)
# NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`)
# 1. Others have lots of private stuff for code reuse
# i. None of that is useful here
# 2. We don't have a `PolarsSelector` abstraction, and just use `PolarsExpr`
@property
def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]:
return cast(
"CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]",
PolarsSelectorNamespace(self),
)
class PolarsSelectorNamespace:
_implementation = Implementation.POLARS
def __init__(self, context: _LimitedContext, /) -> None:
self._version = context._version
def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr:
native_dtypes = [
narwhals_to_native_dtype(dtype, self._version).__class__
if isinstance(dtype, type) and issubclass(dtype, DType)
else narwhals_to_native_dtype(dtype, self._version)
for dtype in dtypes
]
return PolarsExpr(pl.selectors.by_dtype(native_dtypes), version=self._version)
def matches(self, pattern: str) -> PolarsExpr:
return PolarsExpr(pl.selectors.matches(pattern=pattern), version=self._version)
def numeric(self) -> PolarsExpr:
return PolarsExpr(pl.selectors.numeric(), version=self._version)
def boolean(self) -> PolarsExpr:
return PolarsExpr(pl.selectors.boolean(), version=self._version)
def string(self) -> PolarsExpr:
return PolarsExpr(pl.selectors.string(), version=self._version)
def categorical(self) -> PolarsExpr:
return PolarsExpr(pl.selectors.categorical(), version=self._version)
def all(self) -> PolarsExpr:
return PolarsExpr(pl.selectors.all(), version=self._version)
def datetime(
self,
time_unit: TimeUnit | Iterable[TimeUnit] | None,
time_zone: str | timezone | Iterable[str | timezone | None] | None,
) -> PolarsExpr:
return PolarsExpr(
pl.selectors.datetime(time_unit=time_unit, time_zone=time_zone), # type: ignore[arg-type]
version=self._version,
)

View File

@ -0,0 +1,795 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, cast, overload
import polars as pl
from narwhals._polars.utils import (
BACKEND_VERSION,
SERIES_ACCEPTS_PD_INDEX,
SERIES_RESPECTS_DTYPE,
PolarsAnyNamespace,
PolarsCatNamespace,
PolarsDateTimeNamespace,
PolarsListNamespace,
PolarsStringNamespace,
PolarsStructNamespace,
catch_polars_exception,
extract_args_kwargs,
extract_native,
narwhals_to_native_dtype,
native_to_narwhals_dtype,
)
from narwhals._utils import Implementation, requires
from narwhals.dependencies import is_numpy_array_1d, is_pandas_index
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from types import ModuleType
from typing import Literal, TypeVar
import pandas as pd
import pyarrow as pa
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._polars.dataframe import Method, PolarsDataFrame
from narwhals._polars.namespace import PolarsNamespace
from narwhals._utils import Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.series import Series
from narwhals.typing import (
Into1DArray,
IntoDType,
ModeKeepStrategy,
MultiIndexSelector,
NonNestedLiteral,
NumericLiteral,
_1DArray,
)
T = TypeVar("T")
IncludeBreakpoint: TypeAlias = Literal[False, True]
Incomplete: TypeAlias = Any
# Series methods where PolarsSeries just defers to Polars.Series directly.
INHERITED_METHODS = frozenset(
[
"__add__",
"__and__",
"__floordiv__",
"__invert__",
"__iter__",
"__mod__",
"__mul__",
"__or__",
"__pow__",
"__radd__",
"__rand__",
"__rfloordiv__",
"__rmod__",
"__rmul__",
"__ror__",
"__rsub__",
"__rtruediv__",
"__sub__",
"__truediv__",
"abs",
"all",
"any",
"arg_max",
"arg_min",
"arg_true",
"clip",
"count",
"cum_max",
"cum_min",
"cum_prod",
"cum_sum",
"diff",
"drop_nulls",
"exp",
"fill_null",
"fill_nan",
"filter",
"gather_every",
"head",
"is_between",
"is_close",
"is_duplicated",
"is_empty",
"is_finite",
"is_first_distinct",
"is_in",
"is_last_distinct",
"is_null",
"is_sorted",
"is_unique",
"item",
"kurtosis",
"len",
"log",
"max",
"mean",
"min",
"mode",
"n_unique",
"null_count",
"quantile",
"rank",
"round",
"sample",
"shift",
"skew",
"sqrt",
"std",
"sum",
"tail",
"to_arrow",
"to_frame",
"to_list",
"to_pandas",
"unique",
"var",
"zip_with",
]
)
class PolarsSeries:
_implementation: Implementation = Implementation.POLARS
_native_series: pl.Series
_version: Version
_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
True: ["breakpoint", "count"],
False: ["count"],
}
def __init__(self, series: pl.Series, *, version: Version) -> None:
self._native_series = series
self._version = version
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
def __repr__(self) -> str: # pragma: no cover
return "PolarsSeries"
def __narwhals_namespace__(self) -> PolarsNamespace:
from narwhals._polars.namespace import PolarsNamespace
return PolarsNamespace(version=self._version)
def __narwhals_series__(self) -> Self:
return self
def __native_namespace__(self) -> ModuleType:
if self._implementation is Implementation.POLARS:
return self._implementation.to_native_namespace()
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version)
@classmethod
def from_iterable(
cls,
data: Iterable[Any],
*,
context: _LimitedContext,
name: str = "",
dtype: IntoDType | None = None,
) -> Self:
version = context._version
dtype_pl = narwhals_to_native_dtype(dtype, version) if dtype else None
values: Incomplete = data
if SERIES_RESPECTS_DTYPE:
native = pl.Series(name, values, dtype=dtype_pl)
else: # pragma: no cover
if (not SERIES_ACCEPTS_PD_INDEX) and is_pandas_index(values):
values = values.to_series()
native = pl.Series(name, values)
if dtype_pl:
native = native.cast(dtype_pl)
return cls.from_native(native, context=context)
@staticmethod
def _is_native(obj: pl.Series | Any) -> TypeIs[pl.Series]:
return isinstance(obj, pl.Series)
@classmethod
def from_native(cls, data: pl.Series, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version)
@classmethod
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self:
native = pl.Series(data if is_numpy_array_1d(data) else [data])
return cls.from_native(native, context=context)
def to_narwhals(self) -> Series[pl.Series]:
return self._version.series(self, level="full")
def _with_native(self, series: pl.Series) -> Self:
return self.__class__(series, version=self._version)
@overload
def _from_native_object(self, series: pl.Series) -> Self: ...
@overload
def _from_native_object(self, series: pl.DataFrame) -> PolarsDataFrame: ...
@overload
def _from_native_object(self, series: T) -> T: ...
def _from_native_object(
self, series: pl.Series | pl.DataFrame | T
) -> Self | PolarsDataFrame | T:
if self._is_native(series):
return self._with_native(series)
if isinstance(series, pl.DataFrame):
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame.from_native(series, context=self)
# scalar
return series
def __getattr__(self, attr: str) -> Any:
if attr not in INHERITED_METHODS:
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
raise AttributeError(msg)
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
return func
def __len__(self) -> int:
return len(self.native)
@property
def name(self) -> str:
return self.native.name
@property
def dtype(self) -> DType:
return native_to_narwhals_dtype(self.native.dtype, self._version)
@property
def native(self) -> pl.Series:
return self._native_series
def alias(self, name: str) -> Self:
return self._from_native_object(self.native.alias(name))
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self:
if isinstance(item, PolarsSeries):
return self._from_native_object(self.native.__getitem__(item.native))
return self._from_native_object(self.native.__getitem__(item))
def cast(self, dtype: IntoDType) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
return self._with_native(self.native.cast(dtype_pl))
@requires.backend_version((1,))
def replace_strict(
self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any],
*,
return_dtype: IntoDType | None,
) -> Self:
ser = self.native
dtype = (
narwhals_to_native_dtype(return_dtype, self._version)
if return_dtype
else None
)
return self._with_native(ser.replace_strict(old, new, return_dtype=dtype))
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray:
return self.__array__(dtype, copy=copy)
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray:
if self._backend_version < (0, 20, 29):
return self.native.__array__(dtype=dtype)
return self.native.__array__(dtype=dtype, copy=copy)
def __eq__(self, other: object) -> Self: # type: ignore[override]
return self._with_native(self.native.__eq__(extract_native(other)))
def __ne__(self, other: object) -> Self: # type: ignore[override]
return self._with_native(self.native.__ne__(extract_native(other)))
# NOTE: These need to be anything that can't match `PolarsExpr`, due to overload order
def __ge__(self, other: Self) -> Self:
return self._with_native(self.native.__ge__(extract_native(other)))
def __gt__(self, other: Self) -> Self:
return self._with_native(self.native.__gt__(extract_native(other)))
def __le__(self, other: Self) -> Self:
return self._with_native(self.native.__le__(extract_native(other)))
def __lt__(self, other: Self) -> Self:
return self._with_native(self.native.__lt__(extract_native(other)))
def __rpow__(self, other: PolarsSeries | Any) -> Self:
result = self.native.__rpow__(extract_native(other))
if self._backend_version < (1, 16, 1):
# Explicitly set alias to work around https://github.com/pola-rs/polars/issues/20071
result = result.alias(self.name)
return self._with_native(result)
def is_nan(self) -> Self:
try:
native_is_nan = self.native.is_nan()
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
if self._backend_version < (1, 18): # pragma: no cover
select = pl.when(self.native.is_not_null()).then(native_is_nan)
return self._with_native(pl.select(select)[self.name])
return self._with_native(native_is_nan)
def median(self) -> Any:
from narwhals.exceptions import InvalidOperationError
if not self.dtype.is_numeric():
msg = "`median` operation not supported for non-numeric input type."
raise InvalidOperationError(msg)
return self.native.median()
def to_dummies(self, *, separator: str, drop_first: bool) -> PolarsDataFrame:
from narwhals._polars.dataframe import PolarsDataFrame
if self._backend_version < (0, 20, 15):
has_nulls = self.native.is_null().any()
result = self.native.to_dummies(separator=separator)
output_columns = result.columns
if drop_first:
_ = output_columns.pop(int(has_nulls))
result = result.select(output_columns)
else:
result = self.native.to_dummies(separator=separator, drop_first=drop_first)
result = result.with_columns(pl.all().cast(pl.Int8))
return PolarsDataFrame.from_native(result, context=self)
def ewm_mean(
self,
*,
com: float | None,
span: float | None,
half_life: float | None,
alpha: float | None,
adjust: bool,
min_samples: int,
ignore_nulls: bool,
) -> Self:
extra_kwargs = (
{"min_periods": min_samples}
if self._backend_version < (1, 21, 0)
else {"min_samples": min_samples}
)
native_result = self.native.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
ignore_nulls=ignore_nulls,
**extra_kwargs,
)
if self._backend_version < (1,): # pragma: no cover
return self._with_native(
pl.select(
pl.when(~self.native.is_null()).then(native_result).otherwise(None)
)[self.native.name]
)
return self._with_native(native_result)
@requires.backend_version((1,))
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
extra_kwargs: dict[str, Any] = (
{"min_periods": min_samples}
if self._backend_version < (1, 21, 0)
else {"min_samples": min_samples}
)
return self._with_native(
self.native.rolling_var(
window_size=window_size, center=center, ddof=ddof, **extra_kwargs
)
)
@requires.backend_version((1,))
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
extra_kwargs: dict[str, Any] = (
{"min_periods": min_samples}
if self._backend_version < (1, 21, 0)
else {"min_samples": min_samples}
)
return self._with_native(
self.native.rolling_std(
window_size=window_size, center=center, ddof=ddof, **extra_kwargs
)
)
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
extra_kwargs: dict[str, Any] = (
{"min_periods": min_samples}
if self._backend_version < (1, 21, 0)
else {"min_samples": min_samples}
)
return self._with_native(
self.native.rolling_sum(
window_size=window_size, center=center, **extra_kwargs
)
)
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
extra_kwargs: dict[str, Any] = (
{"min_periods": min_samples}
if self._backend_version < (1, 21, 0)
else {"min_samples": min_samples}
)
return self._with_native(
self.native.rolling_mean(
window_size=window_size, center=center, **extra_kwargs
)
)
def sort(self, *, descending: bool, nulls_last: bool) -> Self:
if self._backend_version < (0, 20, 6):
result = self.native.sort(descending=descending)
if nulls_last:
is_null = result.is_null()
result = pl.concat([result.filter(~is_null), result.filter(is_null)])
else:
result = self.native.sort(descending=descending, nulls_last=nulls_last)
return self._with_native(result)
def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
s = self.native.clone().scatter(indices, extract_native(values))
return self._with_native(s)
def value_counts(
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
) -> PolarsDataFrame:
from narwhals._polars.dataframe import PolarsDataFrame
if self._backend_version < (1, 0, 0):
value_name_ = name or ("proportion" if normalize else "count")
result = self.native.value_counts(sort=sort, parallel=parallel).select(
**{
(self.native.name): pl.col(self.native.name),
value_name_: pl.col("count") / pl.sum("count")
if normalize
else pl.col("count"),
}
)
else:
result = self.native.value_counts(
sort=sort, parallel=parallel, name=name, normalize=normalize
)
return PolarsDataFrame.from_native(result, context=self)
def cum_count(self, *, reverse: bool) -> Self:
return self._with_native(self.native.cum_count(reverse=reverse))
def __contains__(self, other: Any) -> bool:
try:
return self.native.__contains__(other)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> PolarsSeries:
if self._backend_version < (1, 32, 0):
name = self.name
ns = self.__narwhals_namespace__()
other_expr = (
ns.lit(other.native, None) if isinstance(other, PolarsSeries) else other
)
expr = ns.col(name).is_close(
other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)
return self.to_frame().select(expr).get_column(name)
other_series = other.native if isinstance(other, PolarsSeries) else other
result = self.native.is_close(
other_series, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)
return self._with_native(result)
def mode(self, *, keep: ModeKeepStrategy) -> Self:
result = self.native.mode()
return self._with_native(result.head(1) if keep == "any" else result)
def hist_from_bins(
self, bins: list[float], *, include_breakpoint: bool
) -> PolarsDataFrame:
if len(bins) <= 1:
native = pl.DataFrame(schema=self._HIST_EMPTY_SCHEMA[include_breakpoint])
elif self.native.is_empty():
if include_breakpoint:
native = (
pl.Series(bins[1:])
.to_frame("breakpoint")
.with_columns(count=pl.lit(0, pl.Int64))
)
else:
native = pl.select(count=pl.zeros(len(bins) - 1, pl.Int64))
else:
return self._hist_from_data(
bins=bins, bin_count=None, include_breakpoint=include_breakpoint
)
return self.__narwhals_namespace__()._dataframe.from_native(native, context=self)
def hist_from_bin_count(
self, bin_count: int, *, include_breakpoint: bool
) -> PolarsDataFrame:
if bin_count == 0:
native = pl.DataFrame(schema=self._HIST_EMPTY_SCHEMA[include_breakpoint])
elif self.native.is_empty():
if include_breakpoint:
native = pl.select(
breakpoint=pl.int_range(1, bin_count + 1) / bin_count,
count=pl.lit(0, pl.Int64),
)
else:
native = pl.select(count=pl.zeros(bin_count, pl.Int64))
else:
count: int | None
if BACKEND_VERSION < (1, 15): # pragma: no cover
count = None
bins = self._bins_from_bin_count(bin_count=bin_count)
else:
count = bin_count
bins = None
return self._hist_from_data(
bins=bins, # type: ignore[arg-type]
bin_count=count,
include_breakpoint=include_breakpoint,
)
return self.__narwhals_namespace__()._dataframe.from_native(native, context=self)
def _bins_from_bin_count(self, bin_count: int) -> pl.Series: # pragma: no cover
"""Prepare bins based on backend version compatibility.
polars <1.15 does not adjust the bins when they have equivalent min/max
polars <1.5 with bin_count=...
returns bins that range from -inf to +inf and has bin_count + 1 bins.
for compat: convert `bin_count=` call to `bins=`
"""
lower = cast("float", self.native.min())
upper = cast("float", self.native.max())
if lower == upper:
lower -= 0.5
upper += 0.5
width = (upper - lower) / bin_count
return pl.int_range(0, bin_count + 1, eager=True) * width + lower
def _hist_from_data(
self, bins: list[float] | None, bin_count: int | None, *, include_breakpoint: bool
) -> PolarsDataFrame:
"""Calculate histogram from non-empty data and post-process the results based on the backend version."""
from narwhals._polars.dataframe import PolarsDataFrame
series = self.native
# Polars inconsistently handles NaN values when computing histograms
# against predefined bins: https://github.com/pola-rs/polars/issues/21082
if BACKEND_VERSION < (1, 15) or bins is not None:
series = series.fill_nan(None)
df = series.hist(
bins,
bin_count=bin_count,
include_category=False,
include_breakpoint=include_breakpoint,
)
# Apply post-processing corrections
# Handle column naming
if not include_breakpoint:
col_name = df.columns[0]
df = df.select(pl.col(col_name).alias("count"))
elif BACKEND_VERSION < (1, 0): # pragma: no cover
df = df.rename({"break_point": "breakpoint"})
if bins is not None: # pragma: no cover
# polars<1.6 implicitly adds -inf and inf to either end of bins
if BACKEND_VERSION < (1, 6):
r = pl.int_range(0, len(df))
df = df.filter((r > 0) & (r < len(df) - 1))
# polars<1.27 makes the lowest bin a left/right closed interval
if BACKEND_VERSION < (1, 27):
df = (
df.slice(0, 1)
.with_columns(pl.col("count") + ((pl.lit(series) == bins[0]).sum()))
.vstack(df.slice(1))
)
return PolarsDataFrame.from_native(df, context=self)
def to_polars(self) -> pl.Series:
return self.native
@property
def dt(self) -> PolarsSeriesDateTimeNamespace:
return PolarsSeriesDateTimeNamespace(self)
@property
def str(self) -> PolarsSeriesStringNamespace:
return PolarsSeriesStringNamespace(self)
@property
def cat(self) -> PolarsSeriesCatNamespace:
return PolarsSeriesCatNamespace(self)
@property
def struct(self) -> PolarsSeriesStructNamespace:
return PolarsSeriesStructNamespace(self)
__add__: Method[Self]
__and__: Method[Self]
__floordiv__: Method[Self]
__invert__: Method[Self]
__iter__: Method[Iterator[Any]]
__mod__: Method[Self]
__mul__: Method[Self]
__or__: Method[Self]
__pow__: Method[Self]
__radd__: Method[Self]
__rand__: Method[Self]
__rfloordiv__: Method[Self]
__rmod__: Method[Self]
__rmul__: Method[Self]
__ror__: Method[Self]
__rsub__: Method[Self]
__rtruediv__: Method[Self]
__sub__: Method[Self]
__truediv__: Method[Self]
abs: Method[Self]
all: Method[bool]
any: Method[bool]
arg_max: Method[int]
arg_min: Method[int]
arg_true: Method[Self]
clip: Method[Self]
count: Method[int]
cum_max: Method[Self]
cum_min: Method[Self]
cum_prod: Method[Self]
cum_sum: Method[Self]
diff: Method[Self]
drop_nulls: Method[Self]
exp: Method[Self]
fill_null: Method[Self]
fill_nan: Method[Self]
filter: Method[Self]
gather_every: Method[Self]
head: Method[Self]
is_between: Method[Self]
is_duplicated: Method[Self]
is_empty: Method[bool]
is_finite: Method[Self]
is_first_distinct: Method[Self]
is_in: Method[Self]
is_last_distinct: Method[Self]
is_null: Method[Self]
is_sorted: Method[bool]
is_unique: Method[Self]
item: Method[Any]
kurtosis: Method[float | None]
len: Method[int]
log: Method[Self]
max: Method[Any]
mean: Method[float]
min: Method[Any]
n_unique: Method[int]
null_count: Method[int]
quantile: Method[float]
rank: Method[Self]
round: Method[Self]
sample: Method[Self]
shift: Method[Self]
skew: Method[float | None]
sqrt: Method[Self]
std: Method[float]
sum: Method[float]
tail: Method[Self]
to_arrow: Method[pa.Array[Any]]
to_frame: Method[PolarsDataFrame]
to_list: Method[list[Any]]
to_pandas: Method[pd.Series[Any]]
unique: Method[Self]
var: Method[float]
zip_with: Method[Self]
@property
def list(self) -> PolarsSeriesListNamespace:
return PolarsSeriesListNamespace(self)
class PolarsSeriesNamespace(PolarsAnyNamespace[PolarsSeries, pl.Series]):
def __init__(self, series: PolarsSeries) -> None:
self._series = series
@property
def compliant(self) -> PolarsSeries:
return self._series
@property
def native(self) -> pl.Series:
return self._series.native
@property
def name(self) -> str:
return self.compliant.name
def __narwhals_namespace__(self) -> PolarsNamespace:
return self.compliant.__narwhals_namespace__()
def to_frame(self) -> PolarsDataFrame:
return self.compliant.to_frame()
class PolarsSeriesDateTimeNamespace(
PolarsSeriesNamespace, PolarsDateTimeNamespace[PolarsSeries, pl.Series]
): ...
class PolarsSeriesStringNamespace(
PolarsSeriesNamespace, PolarsStringNamespace[PolarsSeries, pl.Series]
):
def zfill(self, width: int) -> PolarsSeries:
name = self.name
ns = self.__narwhals_namespace__()
return self.to_frame().select(ns.col(name).str.zfill(width)).get_column(name)
class PolarsSeriesCatNamespace(
PolarsSeriesNamespace, PolarsCatNamespace[PolarsSeries, pl.Series]
): ...
class PolarsSeriesListNamespace(
PolarsSeriesNamespace, PolarsListNamespace[PolarsSeries, pl.Series]
):
def len(self) -> PolarsSeries:
name = self.name
ns = self.__narwhals_namespace__()
return self.to_frame().select(ns.col(name).list.len()).get_column(name)
def contains(self, item: NonNestedLiteral) -> PolarsSeries:
name = self.name
ns = self.__narwhals_namespace__()
return self.to_frame().select(ns.col(name).list.contains(item)).get_column(name)
class PolarsSeriesStructNamespace(
PolarsSeriesNamespace, PolarsStructNamespace[PolarsSeries, pl.Series]
): ...

View File

@ -0,0 +1,25 @@
from __future__ import annotations # pragma: no cover
from typing import (
TYPE_CHECKING, # pragma: no cover
Union, # pragma: no cover
)
if TYPE_CHECKING:
import sys
from typing import Literal, TypeVar
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
IntoPolarsExpr: TypeAlias = Union[PolarsExpr, PolarsSeries]
FrameT = TypeVar("FrameT", PolarsDataFrame, PolarsLazyFrame)
NativeAccessor: TypeAlias = Literal[
"arr", "cat", "dt", "list", "meta", "name", "str", "bin", "struct"
]

View File

@ -0,0 +1,351 @@
from __future__ import annotations
import abc
from functools import lru_cache
from typing import TYPE_CHECKING, Any, ClassVar, Final, Protocol, TypeVar, overload
import polars as pl
from narwhals._duration import Interval
from narwhals._utils import (
Implementation,
Version,
_DeferredIterable,
_StoresCompliant,
_StoresNative,
deep_getattr,
isinstance_or_issubclass,
)
from narwhals.exceptions import (
ColumnNotFoundError,
ComputeError,
DuplicateError,
InvalidOperationError,
NarwhalsError,
ShapeError,
)
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator, Mapping
from typing_extensions import TypeIs
from narwhals._polars.dataframe import Method
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
from narwhals._polars.typing import NativeAccessor
from narwhals.dtypes import DType
from narwhals.typing import IntoDType
T = TypeVar("T")
NativeT = TypeVar(
"NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr"
)
NativeT_co = TypeVar("NativeT_co", "pl.Series", "pl.Expr", covariant=True)
CompliantT_co = TypeVar("CompliantT_co", "PolarsSeries", "PolarsExpr", covariant=True)
CompliantT = TypeVar("CompliantT", "PolarsSeries", "PolarsExpr")
BACKEND_VERSION = Implementation.POLARS._backend_version()
"""Static backend version for `polars`."""
SERIES_RESPECTS_DTYPE: Final[bool] = BACKEND_VERSION >= (0, 20, 26)
"""`pl.Series(dtype=...)` fixed in https://github.com/pola-rs/polars/pull/15962
Includes `SERIES_ACCEPTS_PD_INDEX`.
"""
SERIES_ACCEPTS_PD_INDEX: Final[bool] = BACKEND_VERSION >= (0, 20, 7)
"""`pl.Series(values: pd.Index)` fixed in https://github.com/pola-rs/polars/pull/14087"""
@overload
def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ...
@overload
def extract_native(obj: T) -> T: ...
def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T:
return obj.native if _is_compliant_polars(obj) else obj
def _is_compliant_polars(
obj: _StoresNative[NativeT] | Any,
) -> TypeIs[_StoresNative[NativeT]]:
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr))
def extract_args_kwargs(
args: Iterable[Any], kwds: Mapping[str, Any], /
) -> tuple[Iterator[Any], dict[str, Any]]:
it_args = (extract_native(arg) for arg in args)
return it_args, {k: extract_native(v) for k, v in kwds.items()}
@lru_cache(maxsize=16)
def native_to_narwhals_dtype( # noqa: C901, PLR0912
dtype: pl.DataType, version: Version
) -> DType:
dtypes = version.dtypes
if dtype == pl.Float64:
return dtypes.Float64()
if dtype == pl.Float32:
return dtypes.Float32()
if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.Int128()
if dtype == pl.Int64:
return dtypes.Int64()
if dtype == pl.Int32:
return dtypes.Int32()
if dtype == pl.Int16:
return dtypes.Int16()
if dtype == pl.Int8:
return dtypes.Int8()
if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.UInt128()
if dtype == pl.UInt64:
return dtypes.UInt64()
if dtype == pl.UInt32:
return dtypes.UInt32()
if dtype == pl.UInt16:
return dtypes.UInt16()
if dtype == pl.UInt8:
return dtypes.UInt8()
if dtype == pl.String:
return dtypes.String()
if dtype == pl.Boolean:
return dtypes.Boolean()
if dtype == pl.Object:
return dtypes.Object()
if dtype == pl.Categorical:
return dtypes.Categorical()
if isinstance_or_issubclass(dtype, pl.Enum):
if version is Version.V1:
return dtypes.Enum() # type: ignore[call-arg]
categories = _DeferredIterable(dtype.categories.to_list)
return dtypes.Enum(categories)
if dtype == pl.Date:
return dtypes.Date()
if isinstance_or_issubclass(dtype, pl.Datetime):
return (
dtypes.Datetime()
if dtype is pl.Datetime
else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
)
if isinstance_or_issubclass(dtype, pl.Duration):
return (
dtypes.Duration()
if dtype is pl.Duration
else dtypes.Duration(dtype.time_unit)
)
if isinstance_or_issubclass(dtype, pl.Struct):
fields = [
dtypes.Field(name, native_to_narwhals_dtype(tp, version))
for name, tp in dtype
]
return dtypes.Struct(fields)
if isinstance_or_issubclass(dtype, pl.List):
return dtypes.List(native_to_narwhals_dtype(dtype.inner, version))
if isinstance_or_issubclass(dtype, pl.Array):
outer_shape = dtype.width if BACKEND_VERSION < (0, 20, 30) else dtype.size
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, version), outer_shape)
if dtype == pl.Decimal:
return dtypes.Decimal()
if dtype == pl.Time:
return dtypes.Time()
if dtype == pl.Binary:
return dtypes.Binary()
return dtypes.Unknown()
dtypes = Version.MAIN.dtypes
NW_TO_PL_DTYPES: Mapping[type[DType], pl.DataType] = {
dtypes.Float64: pl.Float64(),
dtypes.Float32: pl.Float32(),
dtypes.Binary: pl.Binary(),
dtypes.String: pl.String(),
dtypes.Boolean: pl.Boolean(),
dtypes.Categorical: pl.Categorical(),
dtypes.Date: pl.Date(),
dtypes.Time: pl.Time(),
dtypes.Int8: pl.Int8(),
dtypes.Int16: pl.Int16(),
dtypes.Int32: pl.Int32(),
dtypes.Int64: pl.Int64(),
dtypes.UInt8: pl.UInt8(),
dtypes.UInt16: pl.UInt16(),
dtypes.UInt32: pl.UInt32(),
dtypes.UInt64: pl.UInt64(),
dtypes.Object: pl.Object(),
dtypes.Unknown: pl.Unknown(),
}
UNSUPPORTED_DTYPES = (dtypes.Decimal,)
def narwhals_to_native_dtype( # noqa: C901
dtype: IntoDType, version: Version
) -> pl.DataType:
dtypes = version.dtypes
base_type = dtype.base_type()
if pl_type := NW_TO_PL_DTYPES.get(base_type):
return pl_type
if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
# Not available for Polars pre 1.8.0
return pl.Int128()
if isinstance_or_issubclass(dtype, dtypes.Enum):
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
raise NotImplementedError(msg)
if isinstance(dtype, dtypes.Enum):
return pl.Enum(dtype.categories)
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
if isinstance_or_issubclass(dtype, dtypes.Duration):
return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
if isinstance_or_issubclass(dtype, dtypes.List):
return pl.List(narwhals_to_native_dtype(dtype.inner, version))
if isinstance_or_issubclass(dtype, dtypes.Struct):
fields = [
pl.Field(field.name, narwhals_to_native_dtype(field.dtype, version))
for field in dtype.fields
]
return pl.Struct(fields)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
size = dtype.size
kwargs = {"width": size} if BACKEND_VERSION < (0, 20, 30) else {"shape": size}
return pl.Array(narwhals_to_native_dtype(dtype.inner, version), **kwargs)
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for Polars."
raise NotImplementedError(msg)
return pl.Unknown() # pragma: no cover
def _is_polars_exception(exception: Exception) -> bool:
if BACKEND_VERSION >= (1,):
# Old versions of Polars didn't have PolarsError.
return isinstance(exception, pl.exceptions.PolarsError)
# Last attempt, for old Polars versions.
return "polars.exceptions" in str(type(exception)) # pragma: no cover
def _is_cudf_exception(exception: Exception) -> bool:
# These exceptions are raised when running polars on GPUs via cuDF
return str(exception).startswith("CUDF failure")
def catch_polars_exception(exception: Exception) -> NarwhalsError | Exception:
if isinstance(exception, pl.exceptions.ColumnNotFoundError):
return ColumnNotFoundError(str(exception))
if isinstance(exception, pl.exceptions.ShapeError):
return ShapeError(str(exception))
if isinstance(exception, pl.exceptions.InvalidOperationError):
return InvalidOperationError(str(exception))
if isinstance(exception, pl.exceptions.DuplicateError):
return DuplicateError(str(exception))
if isinstance(exception, pl.exceptions.ComputeError):
return ComputeError(str(exception))
if _is_polars_exception(exception) or _is_cudf_exception(exception):
return NarwhalsError(str(exception)) # pragma: no cover
# Just return exception as-is.
return exception
class PolarsAnyNamespace(
_StoresCompliant[CompliantT_co],
_StoresNative[NativeT_co],
Protocol[CompliantT_co, NativeT_co],
):
_accessor: ClassVar[NativeAccessor]
def __getattr__(self, attr: str) -> Callable[..., CompliantT_co]:
def func(*args: Any, **kwargs: Any) -> CompliantT_co:
pos, kwds = extract_args_kwargs(args, kwargs)
method = deep_getattr(self.native, self._accessor, attr)
return self.compliant._with_native(method(*pos, **kwds))
return func
class PolarsDateTimeNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
_accessor: ClassVar[NativeAccessor] = "dt"
def truncate(self, every: str) -> CompliantT:
# Ensure consistent error message is raised.
Interval.parse(every)
return self.__getattr__("truncate")(every)
def offset_by(self, by: str) -> CompliantT:
# Ensure consistent error message is raised.
Interval.parse_no_constraints(by)
return self.__getattr__("offset_by")(by)
to_string: Method[CompliantT]
replace_time_zone: Method[CompliantT]
convert_time_zone: Method[CompliantT]
timestamp: Method[CompliantT]
date: Method[CompliantT]
year: Method[CompliantT]
month: Method[CompliantT]
day: Method[CompliantT]
hour: Method[CompliantT]
minute: Method[CompliantT]
second: Method[CompliantT]
millisecond: Method[CompliantT]
microsecond: Method[CompliantT]
nanosecond: Method[CompliantT]
ordinal_day: Method[CompliantT]
weekday: Method[CompliantT]
total_minutes: Method[CompliantT]
total_seconds: Method[CompliantT]
total_milliseconds: Method[CompliantT]
total_microseconds: Method[CompliantT]
total_nanoseconds: Method[CompliantT]
class PolarsStringNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
_accessor: ClassVar[NativeAccessor] = "str"
# NOTE: Use `abstractmethod` if we have defs to implement, but also `Method` usage
@abc.abstractmethod
def zfill(self, width: int) -> CompliantT: ...
len_chars: Method[CompliantT]
replace: Method[CompliantT]
replace_all: Method[CompliantT]
strip_chars: Method[CompliantT]
starts_with: Method[CompliantT]
ends_with: Method[CompliantT]
contains: Method[CompliantT]
slice: Method[CompliantT]
split: Method[CompliantT]
to_date: Method[CompliantT]
to_datetime: Method[CompliantT]
to_lowercase: Method[CompliantT]
to_uppercase: Method[CompliantT]
class PolarsCatNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
_accessor: ClassVar[NativeAccessor] = "cat"
get_categories: Method[CompliantT]
class PolarsListNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
_accessor: ClassVar[NativeAccessor] = "list"
@abc.abstractmethod
def len(self) -> CompliantT: ...
get: Method[CompliantT]
unique: Method[CompliantT]
class PolarsStructNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
_accessor: ClassVar[NativeAccessor] = "struct"
field: Method[CompliantT]