done
This commit is contained in:
668
lib/python3.11/site-packages/narwhals/_pandas_like/utils.py
Normal file
668
lib/python3.11/site-packages/narwhals/_pandas_like/utils.py
Normal file
@ -0,0 +1,668 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import operator
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from narwhals._compliant import EagerSeriesNamespace
|
||||
from narwhals._constants import (
|
||||
MS_PER_SECOND,
|
||||
NS_PER_MICROSECOND,
|
||||
NS_PER_MILLISECOND,
|
||||
NS_PER_SECOND,
|
||||
SECONDS_PER_DAY,
|
||||
US_PER_SECOND,
|
||||
)
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
Version,
|
||||
_DeferredIterable,
|
||||
check_columns_exist,
|
||||
isinstance_or_issubclass,
|
||||
)
|
||||
from narwhals.exceptions import ShapeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping
|
||||
from types import ModuleType
|
||||
|
||||
from pandas._typing import Dtype as PandasDtype
|
||||
from pandas.core.dtypes.dtypes import BaseMaskedDtype
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
|
||||
from narwhals._duration import IntervalUnit
|
||||
from narwhals._pandas_like.expr import PandasLikeExpr
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
from narwhals._pandas_like.typing import (
|
||||
NativeDataFrameT,
|
||||
NativeNDFrameT,
|
||||
NativeSeriesT,
|
||||
)
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import DTypeBackend, IntoDType, TimeUnit, _1DArray
|
||||
|
||||
ExprT = TypeVar("ExprT", bound=PandasLikeExpr)
|
||||
UnitCurrent: TypeAlias = TimeUnit
|
||||
UnitTarget: TypeAlias = TimeUnit
|
||||
BinOpBroadcast: TypeAlias = Callable[[Any, int], Any]
|
||||
IntoRhs: TypeAlias = int
|
||||
|
||||
|
||||
PANDAS_LIKE_IMPLEMENTATION = {
|
||||
Implementation.PANDAS,
|
||||
Implementation.CUDF,
|
||||
Implementation.MODIN,
|
||||
}
|
||||
PD_DATETIME_RGX = r"""^
|
||||
datetime64\[
|
||||
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
|
||||
(?:, # Begin non-capturing group for optional timezone
|
||||
\s* # Optional whitespace after comma
|
||||
(?P<time_zone> # Start named group for timezone
|
||||
[a-zA-Z\/]+ # Match timezone name, e.g., UTC, America/New_York
|
||||
(?:[+-]\d{2}:\d{2})? # Optional offset in format +HH:MM or -HH:MM
|
||||
| # OR
|
||||
pytz\.FixedOffset\(\d+\) # Match pytz.FixedOffset with integer offset in parentheses
|
||||
) # End time_zone group
|
||||
)? # End optional timezone group
|
||||
\] # Closing bracket for datetime64
|
||||
$"""
|
||||
PATTERN_PD_DATETIME = re.compile(PD_DATETIME_RGX, re.VERBOSE)
|
||||
PA_DATETIME_RGX = r"""^
|
||||
timestamp\[
|
||||
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
|
||||
(?:, # Begin non-capturing group for optional timezone
|
||||
\s?tz= # Match "tz=" prefix
|
||||
(?P<time_zone> # Start named group for timezone
|
||||
[a-zA-Z\/]* # Match timezone name (e.g., UTC, America/New_York)
|
||||
(?: # Begin optional non-capturing group for offset
|
||||
[+-]\d{2}:\d{2} # Match offset in format +HH:MM or -HH:MM
|
||||
)? # End optional offset group
|
||||
) # End time_zone group
|
||||
)? # End optional timezone group
|
||||
\] # Closing bracket for timestamp
|
||||
\[pyarrow\] # Literal string "[pyarrow]"
|
||||
$"""
|
||||
PATTERN_PA_DATETIME = re.compile(PA_DATETIME_RGX, re.VERBOSE)
|
||||
PD_DURATION_RGX = r"""^
|
||||
timedelta64\[
|
||||
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
|
||||
\] # Closing bracket for timedelta64
|
||||
$"""
|
||||
|
||||
PATTERN_PD_DURATION = re.compile(PD_DURATION_RGX, re.VERBOSE)
|
||||
PA_DURATION_RGX = r"""^
|
||||
duration\[
|
||||
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
|
||||
\] # Closing bracket for duration
|
||||
\[pyarrow\] # Literal string "[pyarrow]"
|
||||
$"""
|
||||
PATTERN_PA_DURATION = re.compile(PA_DURATION_RGX, re.VERBOSE)
|
||||
|
||||
NativeIntervalUnit: TypeAlias = Literal[
|
||||
"year",
|
||||
"quarter",
|
||||
"month",
|
||||
"week",
|
||||
"day",
|
||||
"hour",
|
||||
"minute",
|
||||
"second",
|
||||
"millisecond",
|
||||
"microsecond",
|
||||
"nanosecond",
|
||||
]
|
||||
ALIAS_DICT = {"d": "D", "m": "min"}
|
||||
UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = {
|
||||
"y": "year",
|
||||
"q": "quarter",
|
||||
"mo": "month",
|
||||
"d": "day",
|
||||
"h": "hour",
|
||||
"m": "minute",
|
||||
"s": "second",
|
||||
"ms": "millisecond",
|
||||
"us": "microsecond",
|
||||
"ns": "nanosecond",
|
||||
}
|
||||
|
||||
PANDAS_VERSION = Implementation.PANDAS._backend_version()
|
||||
"""Static backend version for `pandas`.
|
||||
|
||||
Always available if we reached here, due to a module-level import.
|
||||
"""
|
||||
|
||||
|
||||
def is_pandas_or_modin(implementation: Implementation) -> bool:
|
||||
return implementation in {Implementation.PANDAS, Implementation.MODIN}
|
||||
|
||||
|
||||
def align_and_extract_native(
|
||||
lhs: PandasLikeSeries, rhs: PandasLikeSeries | object
|
||||
) -> tuple[pd.Series[Any] | object, pd.Series[Any] | object]:
|
||||
"""Validate RHS of binary operation.
|
||||
|
||||
If the comparison isn't supported, return `NotImplemented` so that the
|
||||
"right-hand-side" operation (e.g. `__radd__`) can be tried.
|
||||
"""
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
|
||||
lhs_index = lhs.native.index
|
||||
|
||||
if lhs._broadcast and isinstance(rhs, PandasLikeSeries) and not rhs._broadcast:
|
||||
return lhs.native.iloc[0], rhs.native
|
||||
|
||||
if isinstance(rhs, PandasLikeSeries):
|
||||
if rhs._broadcast:
|
||||
return (lhs.native, rhs.native.iloc[0])
|
||||
if rhs.native.index is not lhs_index:
|
||||
return (
|
||||
lhs.native,
|
||||
set_index(rhs.native, lhs_index, implementation=rhs._implementation),
|
||||
)
|
||||
return (lhs.native, rhs.native)
|
||||
|
||||
if isinstance(rhs, list):
|
||||
msg = "Expected Series or scalar, got list."
|
||||
raise TypeError(msg)
|
||||
# `rhs` must be scalar, so just leave it as-is
|
||||
return lhs.native, rhs
|
||||
|
||||
|
||||
def set_index(
|
||||
obj: NativeNDFrameT, index: Any, *, implementation: Implementation
|
||||
) -> NativeNDFrameT:
|
||||
"""Wrapper around pandas' set_axis to set object index.
|
||||
|
||||
We can set `copy` / `inplace` based on implementation/version.
|
||||
"""
|
||||
if isinstance(index, implementation.to_native_namespace().Index) and (
|
||||
expected_len := len(index)
|
||||
) != (actual_len := len(obj)):
|
||||
msg = f"Expected object of length {expected_len}, got length: {actual_len}"
|
||||
raise ShapeError(msg)
|
||||
if implementation is Implementation.CUDF:
|
||||
obj = obj.copy(deep=False)
|
||||
obj.index = index
|
||||
return obj
|
||||
if implementation is Implementation.PANDAS and (
|
||||
(1, 5) <= implementation._backend_version() < (3,)
|
||||
): # pragma: no cover
|
||||
return obj.set_axis(index, axis=0, copy=False)
|
||||
return obj.set_axis(index, axis=0) # pragma: no cover
|
||||
|
||||
|
||||
def rename(
|
||||
obj: NativeNDFrameT, *args: Any, implementation: Implementation, **kwargs: Any
|
||||
) -> NativeNDFrameT:
|
||||
"""Wrapper around pandas' rename so that we can set `copy` based on implementation/version."""
|
||||
if implementation is Implementation.PANDAS and (
|
||||
implementation._backend_version() >= (3,)
|
||||
): # pragma: no cover
|
||||
return obj.rename(*args, **kwargs, inplace=False)
|
||||
return obj.rename(*args, **kwargs, copy=False, inplace=False)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=16)
|
||||
def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) -> DType: # noqa: C901, PLR0912
|
||||
dtype = str(native_dtype)
|
||||
|
||||
dtypes = version.dtypes
|
||||
if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}:
|
||||
return dtypes.Int64()
|
||||
if dtype in {"int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"}:
|
||||
return dtypes.Int32()
|
||||
if dtype in {"int16", "Int16", "Int16[pyarrow]", "int16[pyarrow]"}:
|
||||
return dtypes.Int16()
|
||||
if dtype in {"int8", "Int8", "Int8[pyarrow]", "int8[pyarrow]"}:
|
||||
return dtypes.Int8()
|
||||
if dtype in {"uint64", "UInt64", "UInt64[pyarrow]", "uint64[pyarrow]"}:
|
||||
return dtypes.UInt64()
|
||||
if dtype in {"uint32", "UInt32", "UInt32[pyarrow]", "uint32[pyarrow]"}:
|
||||
return dtypes.UInt32()
|
||||
if dtype in {"uint16", "UInt16", "UInt16[pyarrow]", "uint16[pyarrow]"}:
|
||||
return dtypes.UInt16()
|
||||
if dtype in {"uint8", "UInt8", "UInt8[pyarrow]", "uint8[pyarrow]"}:
|
||||
return dtypes.UInt8()
|
||||
if dtype in {
|
||||
"float64",
|
||||
"Float64",
|
||||
"Float64[pyarrow]",
|
||||
"float64[pyarrow]",
|
||||
"double[pyarrow]",
|
||||
}:
|
||||
return dtypes.Float64()
|
||||
if dtype in {
|
||||
"float32",
|
||||
"Float32",
|
||||
"Float32[pyarrow]",
|
||||
"float32[pyarrow]",
|
||||
"float[pyarrow]",
|
||||
}:
|
||||
return dtypes.Float32()
|
||||
if dtype in {
|
||||
# "there is no problem which can't be solved by adding an extra string type" pandas
|
||||
"string",
|
||||
"string[python]",
|
||||
"string[pyarrow]",
|
||||
"string[pyarrow_numpy]",
|
||||
"large_string[pyarrow]",
|
||||
"str",
|
||||
}:
|
||||
return dtypes.String()
|
||||
if dtype in {"bool", "boolean", "boolean[pyarrow]", "bool[pyarrow]"}:
|
||||
return dtypes.Boolean()
|
||||
if dtype.startswith("dictionary<"):
|
||||
return dtypes.Categorical()
|
||||
if dtype == "category":
|
||||
return native_categorical_to_narwhals_dtype(native_dtype, version)
|
||||
if (match_ := PATTERN_PD_DATETIME.match(dtype)) or (
|
||||
match_ := PATTERN_PA_DATETIME.match(dtype)
|
||||
):
|
||||
dt_time_unit: TimeUnit = match_.group("time_unit") # type: ignore[assignment]
|
||||
dt_time_zone: str | None = match_.group("time_zone")
|
||||
return dtypes.Datetime(dt_time_unit, dt_time_zone)
|
||||
if (match_ := PATTERN_PD_DURATION.match(dtype)) or (
|
||||
match_ := PATTERN_PA_DURATION.match(dtype)
|
||||
):
|
||||
du_time_unit: TimeUnit = match_.group("time_unit") # type: ignore[assignment]
|
||||
return dtypes.Duration(du_time_unit)
|
||||
if dtype == "date32[day][pyarrow]":
|
||||
return dtypes.Date()
|
||||
if dtype.startswith("decimal") and dtype.endswith("[pyarrow]"):
|
||||
return dtypes.Decimal()
|
||||
if dtype.startswith("time") and dtype.endswith("[pyarrow]"):
|
||||
return dtypes.Time()
|
||||
if dtype.startswith("binary") and dtype.endswith("[pyarrow]"):
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
def object_native_to_narwhals_dtype(
|
||||
series: PandasLikeSeries | None, version: Version, implementation: Implementation
|
||||
) -> DType:
|
||||
dtypes = version.dtypes
|
||||
if implementation is Implementation.CUDF:
|
||||
# Per conversations with their maintainers, they don't support arbitrary
|
||||
# objects, so we can just return String.
|
||||
return dtypes.String()
|
||||
|
||||
infer = pd.api.types.infer_dtype
|
||||
# Arbitrary limit of 100 elements to use to sniff dtype.
|
||||
inferred_dtype = "empty" if series is None else infer(series.head(100), skipna=True)
|
||||
if inferred_dtype == "string":
|
||||
return dtypes.String()
|
||||
if inferred_dtype == "empty" and version is not Version.V1:
|
||||
# Default to String for empty Series.
|
||||
return dtypes.String()
|
||||
if inferred_dtype == "empty":
|
||||
# But preserve returning Object in V1.
|
||||
return dtypes.Object()
|
||||
return dtypes.Object()
|
||||
|
||||
|
||||
def native_categorical_to_narwhals_dtype(
|
||||
native_dtype: pd.CategoricalDtype,
|
||||
version: Version,
|
||||
implementation: Literal[Implementation.CUDF] | None = None,
|
||||
) -> DType:
|
||||
dtypes = version.dtypes
|
||||
if version is Version.V1:
|
||||
return dtypes.Categorical()
|
||||
if native_dtype.ordered:
|
||||
into_iter = (
|
||||
_cudf_categorical_to_list(native_dtype)
|
||||
if implementation is Implementation.CUDF
|
||||
else native_dtype.categories.to_list
|
||||
)
|
||||
return dtypes.Enum(_DeferredIterable(into_iter))
|
||||
return dtypes.Categorical()
|
||||
|
||||
|
||||
def _cudf_categorical_to_list(
|
||||
native_dtype: Any,
|
||||
) -> Callable[[], list[Any]]: # pragma: no cover
|
||||
# NOTE: https://docs.rapids.ai/api/cudf/stable/user_guide/api_docs/api/cudf.core.dtypes.categoricaldtype/#cudf.core.dtypes.CategoricalDtype
|
||||
def fn() -> list[Any]:
|
||||
return native_dtype.categories.to_arrow().to_pylist()
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def native_to_narwhals_dtype(
|
||||
native_dtype: Any,
|
||||
version: Version,
|
||||
implementation: Implementation,
|
||||
*,
|
||||
allow_object: bool = False,
|
||||
) -> DType:
|
||||
str_dtype = str(native_dtype)
|
||||
|
||||
if str_dtype.startswith(("large_list", "list", "struct", "fixed_size_list")):
|
||||
from narwhals._arrow.utils import (
|
||||
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
|
||||
)
|
||||
|
||||
if hasattr(native_dtype, "to_arrow"): # pragma: no cover
|
||||
# cudf, cudf.pandas
|
||||
return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version)
|
||||
return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version)
|
||||
if str_dtype == "category" and implementation.is_cudf():
|
||||
# https://github.com/rapidsai/cudf/issues/18536
|
||||
# https://github.com/rapidsai/cudf/issues/14027
|
||||
return native_categorical_to_narwhals_dtype(
|
||||
native_dtype, version, Implementation.CUDF
|
||||
)
|
||||
if str_dtype != "object":
|
||||
return non_object_native_to_narwhals_dtype(native_dtype, version)
|
||||
if implementation is Implementation.DASK:
|
||||
# Per conversations with their maintainers, they don't support arbitrary
|
||||
# objects, so we can just return String.
|
||||
return version.dtypes.String()
|
||||
if allow_object:
|
||||
return object_native_to_narwhals_dtype(None, version, implementation)
|
||||
msg = (
|
||||
"Unreachable code, object dtype should be handled separately" # pragma: no cover
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
if Implementation.PANDAS._backend_version() >= (1, 2):
|
||||
|
||||
def is_dtype_numpy_nullable(dtype: Any) -> TypeIs[BaseMaskedDtype]:
|
||||
"""Return `True` if `dtype` is `"numpy_nullable"`."""
|
||||
# NOTE: We need a sentinel as the positive case is `BaseMaskedDtype.base = None`
|
||||
# See https://github.com/narwhals-dev/narwhals/pull/2740#discussion_r2171667055
|
||||
sentinel = object()
|
||||
return (
|
||||
isinstance(dtype, pd.api.extensions.ExtensionDtype)
|
||||
and getattr(dtype, "base", sentinel) is None
|
||||
)
|
||||
else: # pragma: no cover
|
||||
|
||||
def is_dtype_numpy_nullable(dtype: Any) -> TypeIs[BaseMaskedDtype]:
|
||||
# NOTE: `base` attribute was added between 1.1-1.2
|
||||
# Checking by isinstance requires using an import path that is no longer valid
|
||||
# `1.1`: https://github.com/pandas-dev/pandas/blob/b5958ee1999e9aead1938c0bba2b674378807b3d/pandas/core/arrays/masked.py#L37
|
||||
# `1.2`: https://github.com/pandas-dev/pandas/blob/7c48ff4409c622c582c56a5702373f726de08e96/pandas/core/arrays/masked.py#L41
|
||||
# `1.5`: https://github.com/pandas-dev/pandas/blob/35b0d1dcadf9d60722c055ee37442dc76a29e64c/pandas/core/dtypes/dtypes.py#L1609
|
||||
if isinstance(dtype, pd.api.extensions.ExtensionDtype):
|
||||
from pandas.core.arrays.masked import ( # type: ignore[attr-defined]
|
||||
BaseMaskedDtype as OldBaseMaskedDtype, # pyright: ignore[reportAttributeAccessIssue]
|
||||
)
|
||||
|
||||
return isinstance(dtype, OldBaseMaskedDtype)
|
||||
return False
|
||||
|
||||
|
||||
def get_dtype_backend(dtype: Any, implementation: Implementation) -> DTypeBackend:
|
||||
"""Get dtype backend for pandas type.
|
||||
|
||||
Matches pandas' `dtype_backend` argument in `convert_dtypes`.
|
||||
"""
|
||||
if implementation is Implementation.CUDF:
|
||||
return None
|
||||
if is_dtype_pyarrow(dtype):
|
||||
return "pyarrow"
|
||||
return "numpy_nullable" if is_dtype_numpy_nullable(dtype) else None
|
||||
|
||||
|
||||
# NOTE: Use this to avoid annotating inline
|
||||
def iter_dtype_backends(
|
||||
dtypes: Iterable[Any], implementation: Implementation
|
||||
) -> Iterator[DTypeBackend]:
|
||||
"""Yield a `DTypeBackend` per-dtype.
|
||||
|
||||
Matches pandas' `dtype_backend` argument in `convert_dtypes`.
|
||||
"""
|
||||
return (get_dtype_backend(dtype, implementation) for dtype in dtypes)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=16)
|
||||
def is_dtype_pyarrow(dtype: Any) -> TypeIs[pd.ArrowDtype]:
|
||||
return hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype)
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_PD_DTYPES_INVARIANT: Mapping[type[DType], str] = {
|
||||
# TODO(Unassigned): is there no pyarrow-backed categorical?
|
||||
# or at least, convert_dtypes(dtype_backend='pyarrow') doesn't
|
||||
# convert to it?
|
||||
dtypes.Categorical: "category",
|
||||
dtypes.Object: "object",
|
||||
}
|
||||
NW_TO_PD_DTYPES_BACKEND: Mapping[type[DType], Mapping[DTypeBackend, str | type[Any]]] = {
|
||||
dtypes.Float64: {
|
||||
"pyarrow": "Float64[pyarrow]",
|
||||
"numpy_nullable": "Float64",
|
||||
None: "float64",
|
||||
},
|
||||
dtypes.Float32: {
|
||||
"pyarrow": "Float32[pyarrow]",
|
||||
"numpy_nullable": "Float32",
|
||||
None: "float32",
|
||||
},
|
||||
dtypes.Int64: {"pyarrow": "Int64[pyarrow]", "numpy_nullable": "Int64", None: "int64"},
|
||||
dtypes.Int32: {"pyarrow": "Int32[pyarrow]", "numpy_nullable": "Int32", None: "int32"},
|
||||
dtypes.Int16: {"pyarrow": "Int16[pyarrow]", "numpy_nullable": "Int16", None: "int16"},
|
||||
dtypes.Int8: {"pyarrow": "Int8[pyarrow]", "numpy_nullable": "Int8", None: "int8"},
|
||||
dtypes.UInt64: {
|
||||
"pyarrow": "UInt64[pyarrow]",
|
||||
"numpy_nullable": "UInt64",
|
||||
None: "uint64",
|
||||
},
|
||||
dtypes.UInt32: {
|
||||
"pyarrow": "UInt32[pyarrow]",
|
||||
"numpy_nullable": "UInt32",
|
||||
None: "uint32",
|
||||
},
|
||||
dtypes.UInt16: {
|
||||
"pyarrow": "UInt16[pyarrow]",
|
||||
"numpy_nullable": "UInt16",
|
||||
None: "uint16",
|
||||
},
|
||||
dtypes.UInt8: {"pyarrow": "UInt8[pyarrow]", "numpy_nullable": "UInt8", None: "uint8"},
|
||||
dtypes.String: {"pyarrow": "string[pyarrow]", "numpy_nullable": "string", None: str},
|
||||
dtypes.Boolean: {
|
||||
"pyarrow": "boolean[pyarrow]",
|
||||
"numpy_nullable": "boolean",
|
||||
None: "bool",
|
||||
},
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (dtypes.Decimal,)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype( # noqa: C901, PLR0912
|
||||
dtype: IntoDType,
|
||||
dtype_backend: DTypeBackend,
|
||||
implementation: Implementation,
|
||||
version: Version,
|
||||
) -> str | PandasDtype:
|
||||
if dtype_backend not in {None, "pyarrow", "numpy_nullable"}:
|
||||
msg = f"Expected one of {{None, 'pyarrow', 'numpy_nullable'}}, got: '{dtype_backend}'"
|
||||
raise ValueError(msg)
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if pd_type := NW_TO_PD_DTYPES_INVARIANT.get(base_type):
|
||||
return pd_type
|
||||
if into_pd_type := NW_TO_PD_DTYPES_BACKEND.get(base_type):
|
||||
return into_pd_type[dtype_backend]
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
# Pandas does not support "ms" or "us" time units before version 2.0
|
||||
if is_pandas_or_modin(implementation) and PANDAS_VERSION < (
|
||||
2,
|
||||
): # pragma: no cover
|
||||
dt_time_unit = "ns"
|
||||
else:
|
||||
dt_time_unit = dtype.time_unit
|
||||
|
||||
if dtype_backend == "pyarrow":
|
||||
tz_part = f", tz={tz}" if (tz := dtype.time_zone) else ""
|
||||
return f"timestamp[{dt_time_unit}{tz_part}][pyarrow]"
|
||||
tz_part = f", {tz}" if (tz := dtype.time_zone) else ""
|
||||
return f"datetime64[{dt_time_unit}{tz_part}]"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
if is_pandas_or_modin(implementation) and PANDAS_VERSION < (
|
||||
2,
|
||||
): # pragma: no cover
|
||||
du_time_unit = "ns"
|
||||
else:
|
||||
du_time_unit = dtype.time_unit
|
||||
return (
|
||||
f"duration[{du_time_unit}][pyarrow]"
|
||||
if dtype_backend == "pyarrow"
|
||||
else f"timedelta64[{du_time_unit}]"
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Date):
|
||||
try:
|
||||
import pyarrow as pa # ignore-banned-import # noqa: F401
|
||||
except ModuleNotFoundError as exc: # pragma: no cover
|
||||
# BUG: Never re-raised?
|
||||
msg = "'pyarrow>=13.0.0' is required for `Date` dtype."
|
||||
raise ModuleNotFoundError(msg) from exc
|
||||
return "date32[pyarrow]"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Enum):
|
||||
if version is Version.V1:
|
||||
msg = "Converting to Enum is not supported in narwhals.stable.v1"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance(dtype, dtypes.Enum):
|
||||
ns = implementation.to_native_namespace()
|
||||
return ns.CategoricalDtype(dtype.categories, ordered=True)
|
||||
msg = "Can not cast / initialize Enum without categories present"
|
||||
raise ValueError(msg)
|
||||
if issubclass(
|
||||
base_type, (dtypes.Struct, dtypes.Array, dtypes.List, dtypes.Time, dtypes.Binary)
|
||||
):
|
||||
return narwhals_to_native_arrow_dtype(dtype, implementation, version)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES):
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for {implementation}."
|
||||
raise NotImplementedError(msg)
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def narwhals_to_native_arrow_dtype(
|
||||
dtype: IntoDType, implementation: Implementation, version: Version
|
||||
) -> pd.ArrowDtype:
|
||||
if is_pandas_or_modin(implementation) and PANDAS_VERSION >= (2, 2):
|
||||
try:
|
||||
import pyarrow as pa # ignore-banned-import # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover
|
||||
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}"
|
||||
raise ImportError(msg) from exc
|
||||
from narwhals._arrow.utils import narwhals_to_native_dtype as _to_arrow_dtype
|
||||
|
||||
return pd.ArrowDtype(_to_arrow_dtype(dtype, version))
|
||||
msg = ( # pragma: no cover
|
||||
f"Converting to {dtype} dtype is not supported for implementation "
|
||||
f"{implementation} and version {version}."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def int_dtype_mapper(dtype: Any) -> str:
|
||||
if "pyarrow" in str(dtype):
|
||||
return "Int64[pyarrow]"
|
||||
if str(dtype).lower() != str(dtype): # pragma: no cover
|
||||
return "Int64"
|
||||
return "int64"
|
||||
|
||||
|
||||
_TIMESTAMP_DATETIME_OP_FACTOR: Mapping[
|
||||
tuple[UnitCurrent, UnitTarget], tuple[BinOpBroadcast, IntoRhs]
|
||||
] = {
|
||||
("ns", "us"): (operator.floordiv, 1_000),
|
||||
("ns", "ms"): (operator.floordiv, 1_000_000),
|
||||
("us", "ns"): (operator.mul, NS_PER_MICROSECOND),
|
||||
("us", "ms"): (operator.floordiv, 1_000),
|
||||
("ms", "ns"): (operator.mul, NS_PER_MILLISECOND),
|
||||
("ms", "us"): (operator.mul, 1_000),
|
||||
("s", "ns"): (operator.mul, NS_PER_SECOND),
|
||||
("s", "us"): (operator.mul, US_PER_SECOND),
|
||||
("s", "ms"): (operator.mul, MS_PER_SECOND),
|
||||
}
|
||||
|
||||
|
||||
def calculate_timestamp_datetime(
|
||||
s: NativeSeriesT, current: TimeUnit, time_unit: TimeUnit
|
||||
) -> NativeSeriesT:
|
||||
if current == time_unit:
|
||||
return s
|
||||
if item := _TIMESTAMP_DATETIME_OP_FACTOR.get((current, time_unit)):
|
||||
fn, factor = item
|
||||
return fn(s, factor)
|
||||
msg = ( # pragma: no cover
|
||||
f"unexpected time unit {current}, please report an issue at "
|
||||
"https://github.com/narwhals-dev/narwhals"
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
_TIMESTAMP_DATE_FACTOR: Mapping[TimeUnit, int] = {
|
||||
"ns": NS_PER_SECOND,
|
||||
"us": US_PER_SECOND,
|
||||
"ms": MS_PER_SECOND,
|
||||
"s": 1,
|
||||
}
|
||||
|
||||
|
||||
def calculate_timestamp_date(s: NativeSeriesT, time_unit: TimeUnit) -> NativeSeriesT:
|
||||
return s * SECONDS_PER_DAY * _TIMESTAMP_DATE_FACTOR[time_unit]
|
||||
|
||||
|
||||
def select_columns_by_name(
|
||||
df: NativeDataFrameT,
|
||||
column_names: list[str] | _1DArray, # NOTE: Cannot be a tuple!
|
||||
implementation: Implementation,
|
||||
) -> NativeDataFrameT | Any:
|
||||
"""Select columns by name.
|
||||
|
||||
Prefer this over `df.loc[:, column_names]` as it's
|
||||
generally more performant.
|
||||
"""
|
||||
if len(column_names) == df.shape[1] and (df.columns == column_names).all():
|
||||
return df
|
||||
if (df.columns.dtype.kind == "b") or (
|
||||
implementation is Implementation.PANDAS
|
||||
and implementation._backend_version() < (1, 5)
|
||||
):
|
||||
# See https://github.com/narwhals-dev/narwhals/issues/1349#issuecomment-2470118122
|
||||
# for why we need this
|
||||
if error := check_columns_exist(column_names, available=df.columns.tolist()):
|
||||
raise error
|
||||
return df.loc[:, column_names]
|
||||
try:
|
||||
return df[column_names]
|
||||
except KeyError as e:
|
||||
if error := check_columns_exist(column_names, available=df.columns.tolist()):
|
||||
raise error from e
|
||||
raise
|
||||
|
||||
|
||||
def is_non_nullable_boolean(s: PandasLikeSeries) -> bool:
|
||||
# cuDF booleans are nullable but the native dtype is still 'bool'.
|
||||
return (
|
||||
s._implementation
|
||||
in {Implementation.PANDAS, Implementation.MODIN, Implementation.DASK}
|
||||
and s.native.dtype == "bool"
|
||||
)
|
||||
|
||||
|
||||
def import_array_module(implementation: Implementation, /) -> ModuleType:
|
||||
"""Returns numpy or cupy module depending on the given implementation."""
|
||||
if implementation in {Implementation.PANDAS, Implementation.MODIN}:
|
||||
import numpy as np
|
||||
|
||||
return np
|
||||
if implementation is Implementation.CUDF:
|
||||
import cupy as cp # ignore-banned-import # cuDF dependency.
|
||||
|
||||
return cp
|
||||
msg = f"Expected pandas/modin/cudf, got: {implementation}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
class PandasLikeSeriesNamespace(EagerSeriesNamespace["PandasLikeSeries", Any]): ...
|
Reference in New Issue
Block a user