done
This commit is contained in:
351
lib/python3.11/site-packages/narwhals/_polars/utils.py
Normal file
351
lib/python3.11/site-packages/narwhals/_polars/utils.py
Normal file
@ -0,0 +1,351 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Final, Protocol, TypeVar, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
Version,
|
||||
_DeferredIterable,
|
||||
_StoresCompliant,
|
||||
_StoresNative,
|
||||
deep_getattr,
|
||||
isinstance_or_issubclass,
|
||||
)
|
||||
from narwhals.exceptions import (
|
||||
ColumnNotFoundError,
|
||||
ComputeError,
|
||||
DuplicateError,
|
||||
InvalidOperationError,
|
||||
NarwhalsError,
|
||||
ShapeError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterable, Iterator, Mapping
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from narwhals._polars.dataframe import Method
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
from narwhals._polars.typing import NativeAccessor
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType
|
||||
|
||||
T = TypeVar("T")
|
||||
NativeT = TypeVar(
|
||||
"NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr"
|
||||
)
|
||||
|
||||
NativeT_co = TypeVar("NativeT_co", "pl.Series", "pl.Expr", covariant=True)
|
||||
CompliantT_co = TypeVar("CompliantT_co", "PolarsSeries", "PolarsExpr", covariant=True)
|
||||
CompliantT = TypeVar("CompliantT", "PolarsSeries", "PolarsExpr")
|
||||
|
||||
BACKEND_VERSION = Implementation.POLARS._backend_version()
|
||||
"""Static backend version for `polars`."""
|
||||
|
||||
SERIES_RESPECTS_DTYPE: Final[bool] = BACKEND_VERSION >= (0, 20, 26)
|
||||
"""`pl.Series(dtype=...)` fixed in https://github.com/pola-rs/polars/pull/15962
|
||||
|
||||
Includes `SERIES_ACCEPTS_PD_INDEX`.
|
||||
"""
|
||||
|
||||
SERIES_ACCEPTS_PD_INDEX: Final[bool] = BACKEND_VERSION >= (0, 20, 7)
|
||||
"""`pl.Series(values: pd.Index)` fixed in https://github.com/pola-rs/polars/pull/14087"""
|
||||
|
||||
|
||||
@overload
|
||||
def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ...
|
||||
@overload
|
||||
def extract_native(obj: T) -> T: ...
|
||||
def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T:
|
||||
return obj.native if _is_compliant_polars(obj) else obj
|
||||
|
||||
|
||||
def _is_compliant_polars(
|
||||
obj: _StoresNative[NativeT] | Any,
|
||||
) -> TypeIs[_StoresNative[NativeT]]:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
|
||||
return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr))
|
||||
|
||||
|
||||
def extract_args_kwargs(
|
||||
args: Iterable[Any], kwds: Mapping[str, Any], /
|
||||
) -> tuple[Iterator[Any], dict[str, Any]]:
|
||||
it_args = (extract_native(arg) for arg in args)
|
||||
return it_args, {k: extract_native(v) for k, v in kwds.items()}
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def native_to_narwhals_dtype( # noqa: C901, PLR0912
|
||||
dtype: pl.DataType, version: Version
|
||||
) -> DType:
|
||||
dtypes = version.dtypes
|
||||
if dtype == pl.Float64:
|
||||
return dtypes.Float64()
|
||||
if dtype == pl.Float32:
|
||||
return dtypes.Float32()
|
||||
if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover
|
||||
# Not available for Polars pre 1.8.0
|
||||
return dtypes.Int128()
|
||||
if dtype == pl.Int64:
|
||||
return dtypes.Int64()
|
||||
if dtype == pl.Int32:
|
||||
return dtypes.Int32()
|
||||
if dtype == pl.Int16:
|
||||
return dtypes.Int16()
|
||||
if dtype == pl.Int8:
|
||||
return dtypes.Int8()
|
||||
if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover
|
||||
# Not available for Polars pre 1.8.0
|
||||
return dtypes.UInt128()
|
||||
if dtype == pl.UInt64:
|
||||
return dtypes.UInt64()
|
||||
if dtype == pl.UInt32:
|
||||
return dtypes.UInt32()
|
||||
if dtype == pl.UInt16:
|
||||
return dtypes.UInt16()
|
||||
if dtype == pl.UInt8:
|
||||
return dtypes.UInt8()
|
||||
if dtype == pl.String:
|
||||
return dtypes.String()
|
||||
if dtype == pl.Boolean:
|
||||
return dtypes.Boolean()
|
||||
if dtype == pl.Object:
|
||||
return dtypes.Object()
|
||||
if dtype == pl.Categorical:
|
||||
return dtypes.Categorical()
|
||||
if isinstance_or_issubclass(dtype, pl.Enum):
|
||||
if version is Version.V1:
|
||||
return dtypes.Enum() # type: ignore[call-arg]
|
||||
categories = _DeferredIterable(dtype.categories.to_list)
|
||||
return dtypes.Enum(categories)
|
||||
if dtype == pl.Date:
|
||||
return dtypes.Date()
|
||||
if isinstance_or_issubclass(dtype, pl.Datetime):
|
||||
return (
|
||||
dtypes.Datetime()
|
||||
if dtype is pl.Datetime
|
||||
else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, pl.Duration):
|
||||
return (
|
||||
dtypes.Duration()
|
||||
if dtype is pl.Duration
|
||||
else dtypes.Duration(dtype.time_unit)
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, pl.Struct):
|
||||
fields = [
|
||||
dtypes.Field(name, native_to_narwhals_dtype(tp, version))
|
||||
for name, tp in dtype
|
||||
]
|
||||
return dtypes.Struct(fields)
|
||||
if isinstance_or_issubclass(dtype, pl.List):
|
||||
return dtypes.List(native_to_narwhals_dtype(dtype.inner, version))
|
||||
if isinstance_or_issubclass(dtype, pl.Array):
|
||||
outer_shape = dtype.width if BACKEND_VERSION < (0, 20, 30) else dtype.size
|
||||
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, version), outer_shape)
|
||||
if dtype == pl.Decimal:
|
||||
return dtypes.Decimal()
|
||||
if dtype == pl.Time:
|
||||
return dtypes.Time()
|
||||
if dtype == pl.Binary:
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown()
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_PL_DTYPES: Mapping[type[DType], pl.DataType] = {
|
||||
dtypes.Float64: pl.Float64(),
|
||||
dtypes.Float32: pl.Float32(),
|
||||
dtypes.Binary: pl.Binary(),
|
||||
dtypes.String: pl.String(),
|
||||
dtypes.Boolean: pl.Boolean(),
|
||||
dtypes.Categorical: pl.Categorical(),
|
||||
dtypes.Date: pl.Date(),
|
||||
dtypes.Time: pl.Time(),
|
||||
dtypes.Int8: pl.Int8(),
|
||||
dtypes.Int16: pl.Int16(),
|
||||
dtypes.Int32: pl.Int32(),
|
||||
dtypes.Int64: pl.Int64(),
|
||||
dtypes.UInt8: pl.UInt8(),
|
||||
dtypes.UInt16: pl.UInt16(),
|
||||
dtypes.UInt32: pl.UInt32(),
|
||||
dtypes.UInt64: pl.UInt64(),
|
||||
dtypes.Object: pl.Object(),
|
||||
dtypes.Unknown: pl.Unknown(),
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (dtypes.Decimal,)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype( # noqa: C901
|
||||
dtype: IntoDType, version: Version
|
||||
) -> pl.DataType:
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if pl_type := NW_TO_PL_DTYPES.get(base_type):
|
||||
return pl_type
|
||||
if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
|
||||
# Not available for Polars pre 1.8.0
|
||||
return pl.Int128()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Enum):
|
||||
if version is Version.V1:
|
||||
msg = "Converting to Enum is not supported in narwhals.stable.v1"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance(dtype, dtypes.Enum):
|
||||
return pl.Enum(dtype.categories)
|
||||
msg = "Can not cast / initialize Enum without categories present"
|
||||
raise ValueError(msg)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
|
||||
if isinstance_or_issubclass(dtype, dtypes.List):
|
||||
return pl.List(narwhals_to_native_dtype(dtype.inner, version))
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct):
|
||||
fields = [
|
||||
pl.Field(field.name, narwhals_to_native_dtype(field.dtype, version))
|
||||
for field in dtype.fields
|
||||
]
|
||||
return pl.Struct(fields)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
|
||||
size = dtype.size
|
||||
kwargs = {"width": size} if BACKEND_VERSION < (0, 20, 30) else {"shape": size}
|
||||
return pl.Array(narwhals_to_native_dtype(dtype.inner, version), **kwargs)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES):
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for Polars."
|
||||
raise NotImplementedError(msg)
|
||||
return pl.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
def _is_polars_exception(exception: Exception) -> bool:
|
||||
if BACKEND_VERSION >= (1,):
|
||||
# Old versions of Polars didn't have PolarsError.
|
||||
return isinstance(exception, pl.exceptions.PolarsError)
|
||||
# Last attempt, for old Polars versions.
|
||||
return "polars.exceptions" in str(type(exception)) # pragma: no cover
|
||||
|
||||
|
||||
def _is_cudf_exception(exception: Exception) -> bool:
|
||||
# These exceptions are raised when running polars on GPUs via cuDF
|
||||
return str(exception).startswith("CUDF failure")
|
||||
|
||||
|
||||
def catch_polars_exception(exception: Exception) -> NarwhalsError | Exception:
|
||||
if isinstance(exception, pl.exceptions.ColumnNotFoundError):
|
||||
return ColumnNotFoundError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.ShapeError):
|
||||
return ShapeError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.InvalidOperationError):
|
||||
return InvalidOperationError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.DuplicateError):
|
||||
return DuplicateError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.ComputeError):
|
||||
return ComputeError(str(exception))
|
||||
if _is_polars_exception(exception) or _is_cudf_exception(exception):
|
||||
return NarwhalsError(str(exception)) # pragma: no cover
|
||||
# Just return exception as-is.
|
||||
return exception
|
||||
|
||||
|
||||
class PolarsAnyNamespace(
|
||||
_StoresCompliant[CompliantT_co],
|
||||
_StoresNative[NativeT_co],
|
||||
Protocol[CompliantT_co, NativeT_co],
|
||||
):
|
||||
_accessor: ClassVar[NativeAccessor]
|
||||
|
||||
def __getattr__(self, attr: str) -> Callable[..., CompliantT_co]:
|
||||
def func(*args: Any, **kwargs: Any) -> CompliantT_co:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
method = deep_getattr(self.native, self._accessor, attr)
|
||||
return self.compliant._with_native(method(*pos, **kwds))
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class PolarsDateTimeNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "dt"
|
||||
|
||||
def truncate(self, every: str) -> CompliantT:
|
||||
# Ensure consistent error message is raised.
|
||||
Interval.parse(every)
|
||||
return self.__getattr__("truncate")(every)
|
||||
|
||||
def offset_by(self, by: str) -> CompliantT:
|
||||
# Ensure consistent error message is raised.
|
||||
Interval.parse_no_constraints(by)
|
||||
return self.__getattr__("offset_by")(by)
|
||||
|
||||
to_string: Method[CompliantT]
|
||||
replace_time_zone: Method[CompliantT]
|
||||
convert_time_zone: Method[CompliantT]
|
||||
timestamp: Method[CompliantT]
|
||||
date: Method[CompliantT]
|
||||
year: Method[CompliantT]
|
||||
month: Method[CompliantT]
|
||||
day: Method[CompliantT]
|
||||
hour: Method[CompliantT]
|
||||
minute: Method[CompliantT]
|
||||
second: Method[CompliantT]
|
||||
millisecond: Method[CompliantT]
|
||||
microsecond: Method[CompliantT]
|
||||
nanosecond: Method[CompliantT]
|
||||
ordinal_day: Method[CompliantT]
|
||||
weekday: Method[CompliantT]
|
||||
total_minutes: Method[CompliantT]
|
||||
total_seconds: Method[CompliantT]
|
||||
total_milliseconds: Method[CompliantT]
|
||||
total_microseconds: Method[CompliantT]
|
||||
total_nanoseconds: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsStringNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "str"
|
||||
|
||||
# NOTE: Use `abstractmethod` if we have defs to implement, but also `Method` usage
|
||||
@abc.abstractmethod
|
||||
def zfill(self, width: int) -> CompliantT: ...
|
||||
|
||||
len_chars: Method[CompliantT]
|
||||
replace: Method[CompliantT]
|
||||
replace_all: Method[CompliantT]
|
||||
strip_chars: Method[CompliantT]
|
||||
starts_with: Method[CompliantT]
|
||||
ends_with: Method[CompliantT]
|
||||
contains: Method[CompliantT]
|
||||
slice: Method[CompliantT]
|
||||
split: Method[CompliantT]
|
||||
to_date: Method[CompliantT]
|
||||
to_datetime: Method[CompliantT]
|
||||
to_lowercase: Method[CompliantT]
|
||||
to_uppercase: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsCatNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "cat"
|
||||
get_categories: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsListNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "list"
|
||||
|
||||
@abc.abstractmethod
|
||||
def len(self) -> CompliantT: ...
|
||||
|
||||
get: Method[CompliantT]
|
||||
|
||||
unique: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsStructNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "struct"
|
||||
field: Method[CompliantT]
|
Reference in New Issue
Block a user