done
This commit is contained in:
792
lib/python3.11/site-packages/narwhals/_arrow/dataframe.py
Normal file
792
lib/python3.11/site-packages/narwhals/_arrow/dataframe.py
Normal file
@ -0,0 +1,792 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Collection, Iterator, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.utils import native_to_narwhals_dtype
|
||||
from narwhals._compliant import EagerDataFrame
|
||||
from narwhals._expression_parsing import ExprKind
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
Version,
|
||||
check_column_names_are_unique,
|
||||
convert_str_slice_to_int_slice,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
scale_bytes,
|
||||
supports_arrow_c_stream,
|
||||
zip_strict,
|
||||
)
|
||||
from narwhals.dependencies import is_numpy_array_1d
|
||||
from narwhals.exceptions import ShapeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.group_by import ArrowGroupBy
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
|
||||
ChunkedArrayAny,
|
||||
Mask,
|
||||
Order,
|
||||
)
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
|
||||
from narwhals._spark_like.utils import SparkSession
|
||||
from narwhals._translate import IntoArrowTable
|
||||
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import (
|
||||
IntoSchema,
|
||||
JoinStrategy,
|
||||
SizedMultiIndexSelector,
|
||||
SizedMultiNameSelector,
|
||||
SizeUnit,
|
||||
UniqueKeepStrategy,
|
||||
_1DArray,
|
||||
_2DArray,
|
||||
_SliceIndex,
|
||||
_SliceName,
|
||||
)
|
||||
|
||||
JoinType: TypeAlias = Literal[
|
||||
"left semi",
|
||||
"right semi",
|
||||
"left anti",
|
||||
"right anti",
|
||||
"inner",
|
||||
"left outer",
|
||||
"right outer",
|
||||
"full outer",
|
||||
]
|
||||
PromoteOptions: TypeAlias = Literal["none", "default", "permissive"]
|
||||
|
||||
|
||||
class ArrowDataFrame(
|
||||
EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "ChunkedArrayAny"]
|
||||
):
|
||||
_implementation = Implementation.PYARROW
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
native_dataframe: pa.Table,
|
||||
*,
|
||||
version: Version,
|
||||
validate_column_names: bool,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
if validate_column_names:
|
||||
check_column_names_are_unique(native_dataframe.column_names)
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
self._native_frame = native_dataframe
|
||||
self._version = version
|
||||
|
||||
@classmethod
|
||||
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
|
||||
backend_version = context._implementation._backend_version()
|
||||
if cls._is_native(data):
|
||||
native = data
|
||||
elif backend_version >= (14,) or isinstance(data, Collection):
|
||||
native = pa.table(data)
|
||||
elif supports_arrow_c_stream(data): # pragma: no cover
|
||||
msg = f"'pyarrow>=14.0.0' is required for `from_arrow` for object of type {type(data).__name__!r}."
|
||||
raise ModuleNotFoundError(msg)
|
||||
else: # pragma: no cover
|
||||
msg = f"`from_arrow` is not supported for object of type {type(data).__name__!r}."
|
||||
raise TypeError(msg)
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: IntoSchema | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
pa_schema = Schema(schema).to_arrow() if schema is not None else schema
|
||||
if pa_schema and not data:
|
||||
native = pa_schema.empty_table()
|
||||
else:
|
||||
native = pa.Table.from_pydict(data, schema=pa_schema)
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pa.Table | Any) -> TypeIs[pa.Table]:
|
||||
return isinstance(obj, pa.Table)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: pa.Table, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version, validate_column_names=True)
|
||||
|
||||
@classmethod
|
||||
def from_numpy(
|
||||
cls,
|
||||
data: _2DArray,
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: IntoSchema | Sequence[str] | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
arrays = [pa.array(val) for val in data.T]
|
||||
if isinstance(schema, (Mapping, Schema)):
|
||||
native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow())
|
||||
else:
|
||||
native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema))
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
def __narwhals_namespace__(self) -> ArrowNamespace:
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
|
||||
return ArrowNamespace(version=self._version)
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.PYARROW:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def __narwhals_dataframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version, validate_column_names=False)
|
||||
|
||||
def _with_native(self, df: pa.Table, *, validate_column_names: bool = True) -> Self:
|
||||
return self.__class__(
|
||||
df, version=self._version, validate_column_names=validate_column_names
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, int]:
|
||||
return self.native.shape
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
def row(self, index: int) -> tuple[Any, ...]:
|
||||
return tuple(col[index] for col in self.native.itercolumns())
|
||||
|
||||
@overload
|
||||
def rows(self, *, named: Literal[True]) -> list[dict[str, Any]]: ...
|
||||
|
||||
@overload
|
||||
def rows(self, *, named: Literal[False]) -> list[tuple[Any, ...]]: ...
|
||||
|
||||
@overload
|
||||
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ...
|
||||
|
||||
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
|
||||
if not named:
|
||||
return list(self.iter_rows(named=False, buffer_size=512)) # type: ignore[return-value]
|
||||
return self.native.to_pylist()
|
||||
|
||||
def iter_columns(self) -> Iterator[ArrowSeries]:
|
||||
for name, series in zip_strict(self.columns, self.native.itercolumns()):
|
||||
yield ArrowSeries.from_native(series, context=self, name=name)
|
||||
|
||||
_iter_columns = iter_columns
|
||||
|
||||
def iter_rows(
|
||||
self, *, named: bool, buffer_size: int
|
||||
) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]:
|
||||
df = self.native
|
||||
num_rows = df.num_rows
|
||||
|
||||
if not named:
|
||||
for i in range(0, num_rows, buffer_size):
|
||||
rows = df[i : i + buffer_size].to_pydict().values()
|
||||
yield from zip_strict(*rows)
|
||||
else:
|
||||
for i in range(0, num_rows, buffer_size):
|
||||
yield from df[i : i + buffer_size].to_pylist()
|
||||
|
||||
def get_column(self, name: str) -> ArrowSeries:
|
||||
if not isinstance(name, str):
|
||||
msg = f"Expected str, got: {type(name)}"
|
||||
raise TypeError(msg)
|
||||
return ArrowSeries.from_native(self.native[name], context=self, name=name)
|
||||
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray:
|
||||
return self.native.__array__(dtype, copy=copy)
|
||||
|
||||
def _gather(self, rows: SizedMultiIndexSelector[ChunkedArrayAny]) -> Self:
|
||||
if len(rows) == 0:
|
||||
return self._with_native(self.native.slice(0, 0))
|
||||
if self._backend_version < (18,) and isinstance(rows, tuple):
|
||||
rows = list(rows)
|
||||
return self._with_native(self.native.take(rows))
|
||||
|
||||
def _gather_slice(self, rows: _SliceIndex | range) -> Self:
|
||||
start = rows.start or 0
|
||||
stop = rows.stop if rows.stop is not None else len(self.native)
|
||||
if start < 0:
|
||||
start = len(self.native) + start
|
||||
if stop < 0:
|
||||
stop = len(self.native) + stop
|
||||
if rows.step is not None and rows.step != 1:
|
||||
msg = "Slicing with step is not supported on PyArrow tables"
|
||||
raise NotImplementedError(msg)
|
||||
return self._with_native(self.native.slice(start, stop - start))
|
||||
|
||||
def _select_slice_name(self, columns: _SliceName) -> Self:
|
||||
start, stop, step = convert_str_slice_to_int_slice(columns, self.columns)
|
||||
return self._with_native(self.native.select(self.columns[start:stop:step]))
|
||||
|
||||
def _select_slice_index(self, columns: _SliceIndex | range) -> Self:
|
||||
return self._with_native(
|
||||
self.native.select(self.columns[columns.start : columns.stop : columns.step])
|
||||
)
|
||||
|
||||
def _select_multi_index(
|
||||
self, columns: SizedMultiIndexSelector[ChunkedArrayAny]
|
||||
) -> Self:
|
||||
selector: Sequence[int]
|
||||
if isinstance(columns, pa.ChunkedArray):
|
||||
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
|
||||
selector = cast("Sequence[int]", columns.to_pylist())
|
||||
# TODO @dangotbanned: Fix upstream, it is actually much narrower
|
||||
# **Doesn't accept `ndarray`**
|
||||
elif is_numpy_array_1d(columns):
|
||||
selector = columns.tolist()
|
||||
else:
|
||||
selector = columns
|
||||
return self._with_native(self.native.select(selector))
|
||||
|
||||
def _select_multi_name(
|
||||
self, columns: SizedMultiNameSelector[ChunkedArrayAny]
|
||||
) -> Self:
|
||||
selector: Sequence[str] | _1DArray
|
||||
if isinstance(columns, pa.ChunkedArray):
|
||||
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
|
||||
selector = cast("Sequence[str]", columns.to_pylist())
|
||||
else:
|
||||
selector = columns
|
||||
# NOTE: Fixed in https://github.com/zen-xu/pyarrow-stubs/pull/221
|
||||
return self._with_native(self.native.select(selector)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
return {
|
||||
field.name: native_to_narwhals_dtype(field.type, self._version)
|
||||
for field in self.native.schema
|
||||
}
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return self.schema
|
||||
|
||||
def estimated_size(self, unit: SizeUnit) -> int | float:
|
||||
sz = self.native.nbytes
|
||||
return scale_bytes(sz, unit)
|
||||
|
||||
explode = not_implemented()
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
return self.native.column_names
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
return self._with_native(
|
||||
self.native.select(list(column_names)), validate_column_names=False
|
||||
)
|
||||
|
||||
def select(self, *exprs: ArrowExpr) -> Self:
|
||||
new_series = self._evaluate_into_exprs(*exprs)
|
||||
if not new_series:
|
||||
# return empty dataframe, like Polars does
|
||||
return self._with_native(
|
||||
self.native.__class__.from_arrays([]), validate_column_names=False
|
||||
)
|
||||
names = [s.name for s in new_series]
|
||||
align = new_series[0]._align_full_broadcast
|
||||
reshaped = align(*new_series)
|
||||
df = pa.Table.from_arrays([s.native for s in reshaped], names=names)
|
||||
return self._with_native(df, validate_column_names=True)
|
||||
|
||||
def _extract_comparand(self, other: ArrowSeries) -> ChunkedArrayAny:
|
||||
length = len(self)
|
||||
if not other._broadcast:
|
||||
if (len_other := len(other)) != length:
|
||||
msg = f"Expected object of length {length}, got: {len_other}."
|
||||
raise ShapeError(msg)
|
||||
return other.native
|
||||
|
||||
value = other.native[0]
|
||||
return pa.chunked_array([pa.repeat(value, length)])
|
||||
|
||||
def with_columns(self, *exprs: ArrowExpr) -> Self:
|
||||
# NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame)
|
||||
# All `pyarrow` data is immutable, so this is fine
|
||||
native_frame = self.native
|
||||
new_columns = self._evaluate_into_exprs(*exprs)
|
||||
columns = self.columns
|
||||
|
||||
for col_value in new_columns:
|
||||
col_name = col_value.name
|
||||
column = self._extract_comparand(col_value)
|
||||
native_frame = (
|
||||
native_frame.set_column(columns.index(col_name), col_name, column=column)
|
||||
if col_name in columns
|
||||
else native_frame.append_column(col_name, column=column)
|
||||
)
|
||||
|
||||
return self._with_native(native_frame, validate_column_names=False)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[ArrowExpr], *, drop_null_keys: bool
|
||||
) -> ArrowGroupBy:
|
||||
from narwhals._arrow.group_by import ArrowGroupBy
|
||||
|
||||
return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
how_to_join_map: dict[str, JoinType] = {
|
||||
"anti": "left anti",
|
||||
"semi": "left semi",
|
||||
"inner": "inner",
|
||||
"left": "left outer",
|
||||
"full": "full outer",
|
||||
}
|
||||
|
||||
if how == "cross":
|
||||
plx = self.__narwhals_namespace__()
|
||||
key_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=[*self.columns, *other.columns]
|
||||
)
|
||||
|
||||
return self._with_native(
|
||||
self.with_columns(
|
||||
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
|
||||
)
|
||||
.native.join(
|
||||
other.with_columns(
|
||||
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
|
||||
).native,
|
||||
keys=key_token,
|
||||
right_keys=key_token,
|
||||
join_type="inner",
|
||||
right_suffix=suffix,
|
||||
)
|
||||
.drop([key_token])
|
||||
)
|
||||
|
||||
coalesce_keys = how != "full" # polars full join does not coalesce keys
|
||||
return self._with_native(
|
||||
self.native.join(
|
||||
other.native,
|
||||
keys=left_on or [], # type: ignore[arg-type]
|
||||
right_keys=right_on, # type: ignore[arg-type]
|
||||
join_type=how_to_join_map[how],
|
||||
right_suffix=suffix,
|
||||
coalesce_keys=coalesce_keys,
|
||||
)
|
||||
)
|
||||
|
||||
join_asof = not_implemented()
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
return self._with_native(self.native.drop(to_drop), validate_column_names=False)
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
if subset is None:
|
||||
return self._with_native(self.native.drop_null(), validate_column_names=False)
|
||||
plx = self.__narwhals_namespace__()
|
||||
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
|
||||
return self.filter(mask)
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
order: Order = "descending" if descending else "ascending"
|
||||
sorting: list[tuple[str, Order]] = [(key, order) for key in by]
|
||||
else:
|
||||
sorting = [
|
||||
(key, "descending" if is_descending else "ascending")
|
||||
for key, is_descending in zip_strict(by, descending)
|
||||
]
|
||||
|
||||
null_placement = "at_end" if nulls_last else "at_start"
|
||||
|
||||
return self._with_native(
|
||||
self.native.sort_by(sorting, null_placement=null_placement),
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
|
||||
if isinstance(reverse, bool):
|
||||
order: Order = "ascending" if reverse else "descending"
|
||||
sorting: list[tuple[str, Order]] = [(key, order) for key in by]
|
||||
else:
|
||||
sorting = [
|
||||
(key, "ascending" if is_ascending else "descending")
|
||||
for key, is_ascending in zip_strict(by, reverse)
|
||||
]
|
||||
return self._with_native(
|
||||
self.native.take(pc.select_k_unstable(self.native, k, sorting)), # type: ignore[call-overload]
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
def to_pandas(self) -> pd.DataFrame:
|
||||
return self.native.to_pandas()
|
||||
|
||||
def to_polars(self) -> pl.DataFrame:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
return pl.from_arrow(self.native) # type: ignore[return-value]
|
||||
|
||||
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
arr: Any = np.column_stack([col.to_numpy() for col in self.native.columns])
|
||||
return arr
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[True]) -> dict[str, ArrowSeries]: ...
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
|
||||
|
||||
def to_dict(
|
||||
self, *, as_series: bool
|
||||
) -> dict[str, ArrowSeries] | dict[str, list[Any]]:
|
||||
it = self.iter_columns()
|
||||
if as_series:
|
||||
return {ser.name: ser for ser in it}
|
||||
return {ser.name: ser.to_list() for ser in it}
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
|
||||
plx = self.__narwhals_namespace__()
|
||||
if order_by is None:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
data = pa.array(np.arange(len(self), dtype=np.int64))
|
||||
row_index = plx._expr._from_series(
|
||||
plx._series.from_iterable(data, context=self, name=name)
|
||||
)
|
||||
else:
|
||||
rank = plx.col(order_by[0]).rank("ordinal", descending=False)
|
||||
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
|
||||
return self.select(row_index, plx.all())
|
||||
|
||||
def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self:
|
||||
if isinstance(predicate, list):
|
||||
mask_native: Mask | ChunkedArrayAny = predicate
|
||||
else:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
mask_native = self._evaluate_into_exprs(predicate)[0].native
|
||||
return self._with_native(
|
||||
self.native.filter(mask_native), validate_column_names=False
|
||||
)
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
df = self.native
|
||||
if n >= 0:
|
||||
return self._with_native(df.slice(0, n), validate_column_names=False)
|
||||
num_rows = df.num_rows
|
||||
return self._with_native(
|
||||
df.slice(0, max(0, num_rows + n)), validate_column_names=False
|
||||
)
|
||||
|
||||
def tail(self, n: int) -> Self:
|
||||
df = self.native
|
||||
if n >= 0:
|
||||
num_rows = df.num_rows
|
||||
return self._with_native(
|
||||
df.slice(max(0, num_rows - n)), validate_column_names=False
|
||||
)
|
||||
return self._with_native(df.slice(abs(n)), validate_column_names=False)
|
||||
|
||||
def lazy(
|
||||
self,
|
||||
backend: _LazyAllowedImpl | None = None,
|
||||
*,
|
||||
session: SparkSession | None = None,
|
||||
) -> CompliantLazyFrameAny:
|
||||
if backend is None:
|
||||
return self
|
||||
if backend is Implementation.DUCKDB:
|
||||
import duckdb # ignore-banned-import
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
|
||||
_df = self.native
|
||||
return DuckDBLazyFrame(
|
||||
duckdb.table("_df"), validate_backend_version=True, version=self._version
|
||||
)
|
||||
if backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsLazyFrame
|
||||
|
||||
return PolarsLazyFrame(
|
||||
cast("pl.DataFrame", pl.from_arrow(self.native)).lazy(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
if backend is Implementation.DASK:
|
||||
import dask.dataframe as dd # ignore-banned-import
|
||||
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
|
||||
return DaskLazyFrame(
|
||||
dd.from_pandas(self.native.to_pandas()),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
if backend is Implementation.IBIS:
|
||||
import ibis # ignore-banned-import
|
||||
|
||||
from narwhals._ibis.dataframe import IbisLazyFrame
|
||||
|
||||
return IbisLazyFrame(
|
||||
ibis.memtable(self.native, columns=self.columns),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
if backend.is_spark_like():
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
|
||||
if session is None:
|
||||
msg = "Spark like backends require `session` to be not None."
|
||||
raise ValueError(msg)
|
||||
|
||||
return SparkLikeLazyFrame._from_compliant_dataframe(
|
||||
self, session=session, implementation=backend, version=self._version
|
||||
)
|
||||
|
||||
raise AssertionError # pragma: no cover
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
if backend is Implementation.PYARROW or backend is None:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
self.native, version=self._version, validate_column_names=False
|
||||
)
|
||||
|
||||
if backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
self.native.to_pandas(),
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
cast("pl.DataFrame", pl.from_arrow(self.native)),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise AssertionError(msg) # pragma: no cover
|
||||
|
||||
def clone(self) -> Self:
|
||||
return self._with_native(self.native, validate_column_names=False)
|
||||
|
||||
def item(self, row: int | None, column: int | str | None) -> Any:
|
||||
from narwhals._arrow.series import maybe_extract_py_scalar
|
||||
|
||||
if row is None and column is None:
|
||||
if self.shape != (1, 1):
|
||||
msg = (
|
||||
"can only call `.item()` if the dataframe is of shape (1, 1),"
|
||||
" or if explicit row/col values are provided;"
|
||||
f" frame has shape {self.shape!r}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return maybe_extract_py_scalar(self.native[0][0], return_py_scalar=True)
|
||||
|
||||
if row is None or column is None:
|
||||
msg = "cannot call `.item()` with only one of `row` or `column`"
|
||||
raise ValueError(msg)
|
||||
|
||||
_col = self.columns.index(column) if isinstance(column, str) else column
|
||||
return maybe_extract_py_scalar(self.native[_col][row], return_py_scalar=True)
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
names: dict[str, str] | list[str]
|
||||
if self._backend_version >= (17,):
|
||||
names = cast("dict[str, str]", mapping)
|
||||
else: # pragma: no cover
|
||||
names = [mapping.get(c, c) for c in self.columns]
|
||||
return self._with_native(self.native.rename_columns(names))
|
||||
|
||||
def write_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
import pyarrow.parquet as pp
|
||||
|
||||
pp.write_table(self.native, file)
|
||||
|
||||
@overload
|
||||
def write_csv(self, file: None) -> str: ...
|
||||
|
||||
@overload
|
||||
def write_csv(self, file: str | Path | BytesIO) -> None: ...
|
||||
|
||||
def write_csv(self, file: str | Path | BytesIO | None) -> str | None:
|
||||
import pyarrow.csv as pa_csv
|
||||
|
||||
if file is None:
|
||||
csv_buffer = pa.BufferOutputStream()
|
||||
pa_csv.write_csv(self.native, csv_buffer)
|
||||
return csv_buffer.getvalue().to_pybytes().decode()
|
||||
pa_csv.write_csv(self.native, file)
|
||||
return None
|
||||
|
||||
def is_unique(self) -> ArrowSeries:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
||||
row_index = pa.array(np.arange(len(self)))
|
||||
keep_idx = (
|
||||
self.native.append_column(col_token, row_index)
|
||||
.group_by(self.columns)
|
||||
.aggregate([(col_token, "min"), (col_token, "max")])
|
||||
)
|
||||
native = pa.chunked_array(
|
||||
pc.and_(
|
||||
pc.is_in(row_index, keep_idx[f"{col_token}_min"]),
|
||||
pc.is_in(row_index, keep_idx[f"{col_token}_max"]),
|
||||
)
|
||||
)
|
||||
return ArrowSeries.from_native(native, context=self)
|
||||
|
||||
def unique(
|
||||
self,
|
||||
subset: Sequence[str] | None,
|
||||
*,
|
||||
keep: UniqueKeepStrategy,
|
||||
maintain_order: bool | None = None,
|
||||
) -> Self:
|
||||
# The param `maintain_order` is only here for compatibility with the Polars API
|
||||
# and has no effect on the output.
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
if subset and (error := self._check_columns_exist(subset)):
|
||||
raise error
|
||||
subset = list(subset or self.columns)
|
||||
|
||||
if keep in {"any", "first", "last"}:
|
||||
from narwhals._arrow.group_by import ArrowGroupBy
|
||||
|
||||
agg_func = ArrowGroupBy._REMAP_UNIQUE[keep]
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
||||
keep_idx_native = (
|
||||
self.native.append_column(col_token, pa.array(np.arange(len(self))))
|
||||
.group_by(subset)
|
||||
.aggregate([(col_token, agg_func)])
|
||||
.column(f"{col_token}_{agg_func}")
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.take(keep_idx_native), validate_column_names=False
|
||||
)
|
||||
|
||||
keep_idx = self.simple_select(*subset).is_unique()
|
||||
plx = self.__narwhals_namespace__()
|
||||
return self.filter(plx._expr._from_series(keep_idx))
|
||||
|
||||
def gather_every(self, n: int, offset: int) -> Self:
|
||||
return self._with_native(self.native[offset::n], validate_column_names=False)
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
return self.native
|
||||
|
||||
def sample(
|
||||
self,
|
||||
n: int | None,
|
||||
*,
|
||||
fraction: float | None,
|
||||
with_replacement: bool,
|
||||
seed: int | None,
|
||||
) -> Self:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
num_rows = len(self)
|
||||
if n is None and fraction is not None:
|
||||
n = int(num_rows * fraction)
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
idx = np.arange(num_rows)
|
||||
mask = rng.choice(idx, size=n, replace=with_replacement)
|
||||
return self._with_native(self.native.take(mask), validate_column_names=False)
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
n_rows = len(self)
|
||||
index_ = [] if index is None else index
|
||||
on_ = [c for c in self.columns if c not in index_] if on is None else on
|
||||
concat = (
|
||||
partial(pa.concat_tables, promote_options="permissive")
|
||||
if self._backend_version >= (14, 0, 0)
|
||||
else pa.concat_tables
|
||||
)
|
||||
names = [*index_, variable_name, value_name]
|
||||
return self._with_native(
|
||||
concat(
|
||||
[
|
||||
pa.Table.from_arrays(
|
||||
[
|
||||
*(self.native.column(idx_col) for idx_col in index_),
|
||||
cast(
|
||||
"ChunkedArrayAny",
|
||||
pa.array([on_col] * n_rows, pa.string()),
|
||||
),
|
||||
self.native.column(on_col),
|
||||
],
|
||||
names=names,
|
||||
)
|
||||
for on_col in on_
|
||||
]
|
||||
)
|
||||
)
|
||||
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
|
||||
# upcast numeric to non-numeric (e.g. string) datatypes
|
||||
|
||||
pivot = not_implemented()
|
170
lib/python3.11/site-packages/narwhals/_arrow/expr.py
Normal file
170
lib/python3.11/site-packages/narwhals/_arrow/expr.py
Normal file
@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._compliant import EagerExpr
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
|
||||
from narwhals._expression_parsing import ExprMetadata
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
|
||||
|
||||
class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]):
|
||||
_implementation: Implementation = Implementation.PYARROW
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[ArrowDataFrame, ArrowSeries],
|
||||
*,
|
||||
depth: int,
|
||||
function_name: str,
|
||||
evaluate_output_names: EvalNames[ArrowDataFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
implementation: Implementation | None = None,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._depth = depth
|
||||
self._function_name = function_name
|
||||
self._depth = depth
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._scalar_kwargs = scalar_kwargs or {}
|
||||
self._metadata: ExprMetadata | None = None
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls: type[Self],
|
||||
evaluate_column_names: EvalNames[ArrowDataFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
function_name: str = "",
|
||||
) -> Self:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
try:
|
||||
return [
|
||||
ArrowSeries(
|
||||
df.native[column_name], name=column_name, version=df._version
|
||||
)
|
||||
for column_name in evaluate_column_names(df)
|
||||
]
|
||||
except KeyError as e:
|
||||
if error := df._check_columns_exist(evaluate_column_names(df)):
|
||||
raise error from e
|
||||
raise
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name=function_name,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
tbl = df.native
|
||||
cols = df.columns
|
||||
return [
|
||||
ArrowSeries.from_native(tbl[i], name=cols[i], context=df)
|
||||
for i in column_indices
|
||||
]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="nth",
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
def __narwhals_namespace__(self) -> ArrowNamespace:
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
|
||||
return ArrowNamespace(version=self._version)
|
||||
|
||||
def _reuse_series_extra_kwargs(
|
||||
self, *, returns_scalar: bool = False
|
||||
) -> dict[str, Any]:
|
||||
return {"_return_py_scalar": False} if returns_scalar else {}
|
||||
|
||||
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
|
||||
if (
|
||||
partition_by
|
||||
and self._metadata is not None
|
||||
and not self._metadata.is_scalar_like
|
||||
):
|
||||
msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if not partition_by:
|
||||
# e.g. `nw.col('a').cum_sum().order_by(key)`
|
||||
# which we can always easily support, as it doesn't require grouping.
|
||||
assert order_by # noqa: S101
|
||||
|
||||
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
|
||||
token = generate_temporary_column_name(8, df.columns)
|
||||
df = df.with_row_index(token, order_by=None).sort(
|
||||
*order_by, descending=False, nulls_last=False
|
||||
)
|
||||
result = self(df.drop([token], strict=True))
|
||||
# TODO(marco): is there a way to do this efficiently without
|
||||
# doing 2 sorts? Here we're sorting the dataframe and then
|
||||
# again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
|
||||
sorting_indices = pc.sort_indices(df.get_column(token).native)
|
||||
return [s._with_native(s.native.take(sorting_indices)) for s in result]
|
||||
else:
|
||||
|
||||
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
|
||||
if overlap := set(output_names).intersection(partition_by):
|
||||
# E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
|
||||
# we just don't support it yet.
|
||||
msg = (
|
||||
f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
|
||||
"This is not yet supported."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
tmp = df.group_by(partition_by, drop_null_keys=False).agg(self)
|
||||
tmp = df.simple_select(*partition_by).join(
|
||||
tmp,
|
||||
how="left",
|
||||
left_on=partition_by,
|
||||
right_on=partition_by,
|
||||
suffix="_right",
|
||||
)
|
||||
return [tmp.get_column(alias) for alias in aliases]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=self._function_name + "->over",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
ewm_mean = not_implemented()
|
159
lib/python3.11/site-packages/narwhals/_arrow/group_by.py
Normal file
159
lib/python3.11/site-packages/narwhals/_arrow/group_by.py
Normal file
@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar
|
||||
from narwhals._compliant import EagerGroupBy
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
from narwhals._utils import generate_temporary_column_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
|
||||
AggregateOptions,
|
||||
Aggregation,
|
||||
Incomplete,
|
||||
)
|
||||
from narwhals._compliant.typing import NarwhalsAggregation
|
||||
from narwhals.typing import UniqueKeepStrategy
|
||||
|
||||
|
||||
class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
|
||||
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
|
||||
"sum": "sum",
|
||||
"mean": "mean",
|
||||
"median": "approximate_median",
|
||||
"max": "max",
|
||||
"min": "min",
|
||||
"std": "stddev",
|
||||
"var": "variance",
|
||||
"len": "count",
|
||||
"n_unique": "count_distinct",
|
||||
"count": "count",
|
||||
"all": "all",
|
||||
"any": "any",
|
||||
}
|
||||
_REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
|
||||
"any": "min",
|
||||
"first": "min",
|
||||
"last": "max",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: ArrowDataFrame,
|
||||
keys: Sequence[ArrowExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._df = df
|
||||
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
|
||||
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
|
||||
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
|
||||
self._drop_null_keys = drop_null_keys
|
||||
|
||||
def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
|
||||
self._ensure_all_simple(exprs)
|
||||
aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = []
|
||||
expected_pyarrow_column_names: list[str] = self._keys.copy()
|
||||
new_column_names: list[str] = self._keys.copy()
|
||||
exclude = (*self._keys, *self._output_key_names)
|
||||
|
||||
for expr in exprs:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(
|
||||
expr, self.compliant, exclude
|
||||
)
|
||||
|
||||
if expr._depth == 0:
|
||||
# e.g. `agg(nw.len())`
|
||||
if expr._function_name != "len": # pragma: no cover
|
||||
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
||||
raise AssertionError(msg)
|
||||
|
||||
new_column_names.append(aliases[0])
|
||||
expected_pyarrow_column_names.append(f"{self._keys[0]}_count")
|
||||
aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))
|
||||
continue
|
||||
|
||||
function_name = self._leaf_name(expr)
|
||||
if function_name in {"std", "var"}:
|
||||
assert "ddof" in expr._scalar_kwargs # noqa: S101
|
||||
option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"])
|
||||
elif function_name in {"len", "n_unique"}:
|
||||
option = pc.CountOptions(mode="all")
|
||||
elif function_name == "count":
|
||||
option = pc.CountOptions(mode="only_valid")
|
||||
elif function_name in {"all", "any"}:
|
||||
option = pc.ScalarAggregateOptions(min_count=0)
|
||||
else:
|
||||
option = None
|
||||
|
||||
function_name = self._remap_expr_name(function_name)
|
||||
new_column_names.extend(aliases)
|
||||
expected_pyarrow_column_names.extend(
|
||||
[f"{output_name}_{function_name}" for output_name in output_names]
|
||||
)
|
||||
aggs.extend(
|
||||
[(output_name, function_name, option) for output_name in output_names]
|
||||
)
|
||||
|
||||
result_simple = self._grouped.aggregate(aggs)
|
||||
|
||||
# Rename columns, being very careful
|
||||
expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)
|
||||
for idx, item in enumerate(expected_pyarrow_column_names):
|
||||
expected_old_names_indices[item].append(idx)
|
||||
if not (
|
||||
set(result_simple.column_names) == set(expected_pyarrow_column_names)
|
||||
and len(result_simple.column_names) == len(expected_pyarrow_column_names)
|
||||
): # pragma: no cover
|
||||
msg = (
|
||||
f"Safety assertion failed, expected {expected_pyarrow_column_names} "
|
||||
f"got {result_simple.column_names}, "
|
||||
"please report a bug at https://github.com/narwhals-dev/narwhals/issues"
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
index_map: list[int] = [
|
||||
expected_old_names_indices[item].pop(0) for item in result_simple.column_names
|
||||
]
|
||||
new_column_names = [new_column_names[i] for i in index_map]
|
||||
result_simple = result_simple.rename_columns(new_column_names)
|
||||
return self.compliant._with_native(result_simple).rename(
|
||||
dict(zip(self._keys, self._output_key_names))
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
|
||||
col_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=self.compliant.columns
|
||||
)
|
||||
null_token: str = "__null_token_value__" # noqa: S105
|
||||
|
||||
table = self.compliant.native
|
||||
it, separator_scalar = cast_to_comparable_string_types(
|
||||
*(table[key] for key in self._keys), separator=""
|
||||
)
|
||||
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
|
||||
# Reality: `str` is fine
|
||||
concat_str: Incomplete = pc.binary_join_element_wise
|
||||
key_values = concat_str(
|
||||
*it, separator_scalar, null_handling="replace", null_replacement=null_token
|
||||
)
|
||||
table = table.add_column(i=0, field_=col_token, column=key_values)
|
||||
|
||||
for v in pc.unique(key_values):
|
||||
t = self.compliant._with_native(
|
||||
table.filter(pc.equal(table[col_token], v)).drop([col_token])
|
||||
)
|
||||
row = t.simple_select(*self._keys).row(0)
|
||||
yield (
|
||||
tuple(extract_py_scalar(el) for el in row),
|
||||
t.simple_select(*self._df.columns),
|
||||
)
|
303
lib/python3.11/site-packages/narwhals/_arrow/namespace.py
Normal file
303
lib/python3.11/site-packages/narwhals/_arrow/namespace.py
Normal file
@ -0,0 +1,303 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.selectors import ArrowSelectorNamespace
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.utils import cast_to_comparable_string_types
|
||||
from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen
|
||||
from narwhals._expression_parsing import (
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._utils import Implementation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import IntoDType, NonNestedLiteral
|
||||
|
||||
|
||||
class ArrowNamespace(
|
||||
EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table, "ChunkedArrayAny"]
|
||||
):
|
||||
_implementation = Implementation.PYARROW
|
||||
|
||||
@property
|
||||
def _dataframe(self) -> type[ArrowDataFrame]:
|
||||
return ArrowDataFrame
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[ArrowExpr]:
|
||||
return ArrowExpr
|
||||
|
||||
@property
|
||||
def _series(self) -> type[ArrowSeries]:
|
||||
return ArrowSeries
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
def len(self) -> ArrowExpr:
|
||||
# coverage bug? this is definitely hit
|
||||
return self._expr( # pragma: no cover
|
||||
lambda df: [
|
||||
ArrowSeries.from_iterable([len(df.native)], name="len", context=self)
|
||||
],
|
||||
depth=0,
|
||||
function_name="len",
|
||||
evaluate_output_names=lambda _df: ["len"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> ArrowExpr:
|
||||
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
|
||||
arrow_series = ArrowSeries.from_iterable(
|
||||
data=[value], name="literal", context=self
|
||||
)
|
||||
if dtype:
|
||||
return arrow_series.cast(dtype)
|
||||
return arrow_series
|
||||
|
||||
return self._expr(
|
||||
lambda df: [_lit_arrow_series(df)],
|
||||
depth=0,
|
||||
function_name="lit",
|
||||
evaluate_output_names=lambda _df: ["literal"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def all_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
series: Iterator[ArrowSeries] = chain.from_iterable(e(df) for e in exprs)
|
||||
align = self._series._align_full_broadcast
|
||||
if ignore_nulls:
|
||||
series = (s.fill_null(True, None, None) for s in series)
|
||||
return [reduce(operator.and_, align(*series))]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="all_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def any_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
series: Iterator[ArrowSeries] = chain.from_iterable(e(df) for e in exprs)
|
||||
align = self._series._align_full_broadcast
|
||||
if ignore_nulls:
|
||||
series = (s.fill_null(False, None, None) for s in series)
|
||||
return [reduce(operator.or_, align(*series))]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="any_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
it = chain.from_iterable(expr(df) for expr in exprs)
|
||||
series = (s.fill_null(0, strategy=None, limit=None) for s in it)
|
||||
align = self._series._align_full_broadcast
|
||||
return [reduce(operator.add, align(*series))]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="sum_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
int_64 = self._version.dtypes.Int64()
|
||||
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
expr_results = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
align = self._series._align_full_broadcast
|
||||
series = align(
|
||||
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
|
||||
)
|
||||
non_na = align(*(1 - s.is_null().cast(int_64) for s in expr_results))
|
||||
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="mean_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
init_series, *series = align(init_series, *series)
|
||||
native_series = reduce(
|
||||
pc.min_element_wise, [s.native for s in series], init_series.native
|
||||
)
|
||||
return [
|
||||
ArrowSeries(native_series, name=init_series.name, version=self._version)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="min_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
init_series, *series = align(init_series, *series)
|
||||
native_series = reduce(
|
||||
pc.max_element_wise, [s.native for s in series], init_series.native
|
||||
)
|
||||
return [
|
||||
ArrowSeries(native_series, name=init_series.name, version=self._version)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="max_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
|
||||
if self._backend_version >= (14,):
|
||||
return pa.concat_tables(dfs, promote_options="default")
|
||||
return pa.concat_tables(dfs, promote=True) # pragma: no cover
|
||||
|
||||
def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
|
||||
names = list(chain.from_iterable(df.column_names for df in dfs))
|
||||
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
|
||||
return pa.Table.from_arrays(arrays, names=names)
|
||||
|
||||
def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:
|
||||
cols_0 = dfs[0].column_names
|
||||
for i, df in enumerate(dfs[1:], start=1):
|
||||
cols_current = df.column_names
|
||||
if cols_current != cols_0:
|
||||
msg = (
|
||||
"unable to vstack, column names don't match:\n"
|
||||
f" - dataframe 0: {cols_0}\n"
|
||||
f" - dataframe {i}: {cols_current}\n"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
return pa.concat_tables(dfs)
|
||||
|
||||
@property
|
||||
def selectors(self) -> ArrowSelectorNamespace:
|
||||
return ArrowSelectorNamespace.from_namespace(self)
|
||||
|
||||
def when(self, predicate: ArrowExpr) -> ArrowWhen:
|
||||
return ArrowWhen.from_expr(predicate, context=self)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool
|
||||
) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
compliant_series_list = align(
|
||||
*(chain.from_iterable(expr(df) for expr in exprs))
|
||||
)
|
||||
name = compliant_series_list[0].name
|
||||
null_handling: Literal["skip", "emit_null"] = (
|
||||
"skip" if ignore_nulls else "emit_null"
|
||||
)
|
||||
it, separator_scalar = cast_to_comparable_string_types(
|
||||
*(s.native for s in compliant_series_list), separator=separator
|
||||
)
|
||||
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
|
||||
# Reality: `str` is fine
|
||||
concat_str: Incomplete = pc.binary_join_element_wise
|
||||
compliant = self._series(
|
||||
concat_str(*it, separator_scalar, null_handling=null_handling),
|
||||
name=name,
|
||||
version=self._version,
|
||||
)
|
||||
return [compliant]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="concat_str",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def coalesce(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
init_series, *series = align(*chain.from_iterable(expr(df) for expr in exprs))
|
||||
return [
|
||||
ArrowSeries(
|
||||
pc.coalesce(init_series.native, *(s.native for s in series)),
|
||||
name=init_series.name,
|
||||
version=self._version,
|
||||
)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="coalesce",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
|
||||
class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]):
|
||||
@property
|
||||
def _then(self) -> type[ArrowThen]:
|
||||
return ArrowThen
|
||||
|
||||
def _if_then_else(
|
||||
self,
|
||||
when: ChunkedArrayAny,
|
||||
then: ChunkedArrayAny,
|
||||
otherwise: ArrayOrScalar | NonNestedLiteral,
|
||||
/,
|
||||
) -> ChunkedArrayAny:
|
||||
otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise
|
||||
return pc.if_else(when, then, otherwise)
|
||||
|
||||
|
||||
class ArrowThen(
|
||||
CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr, ArrowWhen], ArrowExpr
|
||||
):
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "whenthen"
|
33
lib/python3.11/site-packages/narwhals/_arrow/selectors.py
Normal file
33
lib/python3.11/site-packages/narwhals/_arrow/selectors.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._compliant import CompliantSelector, EagerSelectorNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame # noqa: F401
|
||||
from narwhals._arrow.series import ArrowSeries # noqa: F401
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
|
||||
|
||||
class ArrowSelectorNamespace(EagerSelectorNamespace["ArrowDataFrame", "ArrowSeries"]):
|
||||
@property
|
||||
def _selector(self) -> type[ArrowSelector]:
|
||||
return ArrowSelector
|
||||
|
||||
|
||||
class ArrowSelector(CompliantSelector["ArrowDataFrame", "ArrowSeries"], ArrowExpr): # type: ignore[misc]
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "selector"
|
||||
|
||||
def _to_expr(self) -> ArrowExpr:
|
||||
return ArrowExpr(
|
||||
self._call,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
1191
lib/python3.11/site-packages/narwhals/_arrow/series.py
Normal file
1191
lib/python3.11/site-packages/narwhals/_arrow/series.py
Normal file
File diff suppressed because it is too large
Load Diff
18
lib/python3.11/site-packages/narwhals/_arrow/series_cat.py
Normal file
18
lib/python3.11/site-packages/narwhals/_arrow/series_cat.py
Normal file
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import Incomplete
|
||||
|
||||
|
||||
class ArrowSeriesCatNamespace(ArrowSeriesNamespace):
|
||||
def get_categories(self) -> ArrowSeries:
|
||||
# NOTE: Should be `list[pa.DictionaryArray]`, but `DictionaryArray` has no attributes
|
||||
chunks: Incomplete = self.native.chunks
|
||||
return self.with_native(pa.concat_arrays(x.dictionary for x in chunks).unique())
|
226
lib/python3.11/site-packages/narwhals/_arrow/series_dt.py
Normal file
226
lib/python3.11/site-packages/narwhals/_arrow/series_dt.py
Normal file
@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import UNITS_DICT, ArrowSeriesNamespace, floordiv_compat, lit
|
||||
from narwhals._constants import (
|
||||
MS_PER_MINUTE,
|
||||
MS_PER_SECOND,
|
||||
NS_PER_MICROSECOND,
|
||||
NS_PER_MILLISECOND,
|
||||
NS_PER_MINUTE,
|
||||
NS_PER_SECOND,
|
||||
SECONDS_PER_DAY,
|
||||
SECONDS_PER_MINUTE,
|
||||
US_PER_MINUTE,
|
||||
US_PER_SECOND,
|
||||
)
|
||||
from narwhals._duration import Interval
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny
|
||||
from narwhals.dtypes import Datetime
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
UnitCurrent: TypeAlias = TimeUnit
|
||||
UnitTarget: TypeAlias = TimeUnit
|
||||
BinOpBroadcast: TypeAlias = Callable[[ChunkedArrayAny, ScalarAny], ChunkedArrayAny]
|
||||
IntoRhs: TypeAlias = int
|
||||
|
||||
|
||||
class ArrowSeriesDateTimeNamespace(ArrowSeriesNamespace):
|
||||
_TIMESTAMP_DATE_FACTOR: ClassVar[Mapping[TimeUnit, int]] = {
|
||||
"ns": NS_PER_SECOND,
|
||||
"us": US_PER_SECOND,
|
||||
"ms": MS_PER_SECOND,
|
||||
"s": 1,
|
||||
}
|
||||
_TIMESTAMP_DATETIME_OP_FACTOR: ClassVar[
|
||||
Mapping[tuple[UnitCurrent, UnitTarget], tuple[BinOpBroadcast, IntoRhs]]
|
||||
] = {
|
||||
("ns", "us"): (floordiv_compat, 1_000),
|
||||
("ns", "ms"): (floordiv_compat, 1_000_000),
|
||||
("us", "ns"): (pc.multiply, NS_PER_MICROSECOND),
|
||||
("us", "ms"): (floordiv_compat, 1_000),
|
||||
("ms", "ns"): (pc.multiply, NS_PER_MILLISECOND),
|
||||
("ms", "us"): (pc.multiply, 1_000),
|
||||
("s", "ns"): (pc.multiply, NS_PER_SECOND),
|
||||
("s", "us"): (pc.multiply, US_PER_SECOND),
|
||||
("s", "ms"): (pc.multiply, MS_PER_SECOND),
|
||||
}
|
||||
|
||||
@property
|
||||
def unit(self) -> TimeUnit: # NOTE: Unsafe (native).
|
||||
return cast("pa.TimestampType[TimeUnit, Any]", self.native.type).unit
|
||||
|
||||
@property
|
||||
def time_zone(self) -> str | None: # NOTE: Unsafe (narwhals).
|
||||
return cast("Datetime", self.compliant.dtype).time_zone
|
||||
|
||||
def to_string(self, format: str) -> ArrowSeries:
|
||||
# PyArrow differs from other libraries in that %S also prints out
|
||||
# the fractional part of the second...:'(
|
||||
# https://arrow.apache.org/docs/python/generated/pyarrow.compute.strftime.html
|
||||
format = format.replace("%S.%f", "%S").replace("%S%.f", "%S")
|
||||
return self.with_native(pc.strftime(self.native, format))
|
||||
|
||||
def replace_time_zone(self, time_zone: str | None) -> ArrowSeries:
|
||||
if time_zone is not None:
|
||||
result = pc.assume_timezone(pc.local_timestamp(self.native), time_zone)
|
||||
else:
|
||||
result = pc.local_timestamp(self.native)
|
||||
return self.with_native(result)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> ArrowSeries:
|
||||
ser = self.replace_time_zone("UTC") if self.time_zone is None else self.compliant
|
||||
return self.with_native(ser.native.cast(pa.timestamp(self.unit, time_zone)))
|
||||
|
||||
def timestamp(self, time_unit: TimeUnit) -> ArrowSeries:
|
||||
ser = self.compliant
|
||||
dtypes = ser._version.dtypes
|
||||
if isinstance(ser.dtype, dtypes.Datetime):
|
||||
current = ser.dtype.time_unit
|
||||
s_cast = self.native.cast(pa.int64())
|
||||
if current == time_unit:
|
||||
result = s_cast
|
||||
elif item := self._TIMESTAMP_DATETIME_OP_FACTOR.get((current, time_unit)):
|
||||
fn, factor = item
|
||||
result = fn(s_cast, lit(factor))
|
||||
else: # pragma: no cover
|
||||
msg = f"unexpected time unit {current}, please report an issue at https://github.com/narwhals-dev/narwhals"
|
||||
raise AssertionError(msg)
|
||||
return self.with_native(result)
|
||||
if isinstance(ser.dtype, dtypes.Date):
|
||||
time_s = pc.multiply(self.native.cast(pa.int32()), lit(SECONDS_PER_DAY))
|
||||
factor = self._TIMESTAMP_DATE_FACTOR[time_unit]
|
||||
return self.with_native(pc.multiply(time_s, lit(factor)))
|
||||
msg = "Input should be either of Date or Datetime type"
|
||||
raise TypeError(msg)
|
||||
|
||||
def date(self) -> ArrowSeries:
|
||||
return self.with_native(self.native.cast(pa.date32()))
|
||||
|
||||
def year(self) -> ArrowSeries:
|
||||
return self.with_native(pc.year(self.native))
|
||||
|
||||
def month(self) -> ArrowSeries:
|
||||
return self.with_native(pc.month(self.native))
|
||||
|
||||
def day(self) -> ArrowSeries:
|
||||
return self.with_native(pc.day(self.native))
|
||||
|
||||
def hour(self) -> ArrowSeries:
|
||||
return self.with_native(pc.hour(self.native))
|
||||
|
||||
def minute(self) -> ArrowSeries:
|
||||
return self.with_native(pc.minute(self.native))
|
||||
|
||||
def second(self) -> ArrowSeries:
|
||||
return self.with_native(pc.second(self.native))
|
||||
|
||||
def millisecond(self) -> ArrowSeries:
|
||||
return self.with_native(pc.millisecond(self.native))
|
||||
|
||||
def microsecond(self) -> ArrowSeries:
|
||||
arr = self.native
|
||||
result = pc.add(pc.multiply(pc.millisecond(arr), lit(1000)), pc.microsecond(arr))
|
||||
return self.with_native(result)
|
||||
|
||||
def nanosecond(self) -> ArrowSeries:
|
||||
result = pc.add(
|
||||
pc.multiply(self.microsecond().native, lit(1000)), pc.nanosecond(self.native)
|
||||
)
|
||||
return self.with_native(result)
|
||||
|
||||
def ordinal_day(self) -> ArrowSeries:
|
||||
return self.with_native(pc.day_of_year(self.native))
|
||||
|
||||
def weekday(self) -> ArrowSeries:
|
||||
return self.with_native(pc.day_of_week(self.native, count_from_zero=False))
|
||||
|
||||
def total_minutes(self) -> ArrowSeries:
|
||||
unit_to_minutes_factor = {
|
||||
"s": SECONDS_PER_MINUTE,
|
||||
"ms": MS_PER_MINUTE,
|
||||
"us": US_PER_MINUTE,
|
||||
"ns": NS_PER_MINUTE,
|
||||
}
|
||||
factor = lit(unit_to_minutes_factor[self.unit], type=pa.int64())
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_seconds(self) -> ArrowSeries:
|
||||
unit_to_seconds_factor = {
|
||||
"s": 1,
|
||||
"ms": MS_PER_SECOND,
|
||||
"us": US_PER_SECOND,
|
||||
"ns": NS_PER_SECOND,
|
||||
}
|
||||
factor = lit(unit_to_seconds_factor[self.unit], type=pa.int64())
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_milliseconds(self) -> ArrowSeries:
|
||||
unit_to_milli_factor = {
|
||||
"s": 1e3, # seconds
|
||||
"ms": 1, # milli
|
||||
"us": 1e3, # micro
|
||||
"ns": 1e6, # nano
|
||||
}
|
||||
factor = lit(unit_to_milli_factor[self.unit], type=pa.int64())
|
||||
if self.unit == "s":
|
||||
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_microseconds(self) -> ArrowSeries:
|
||||
unit_to_micro_factor = {
|
||||
"s": 1e6, # seconds
|
||||
"ms": 1e3, # milli
|
||||
"us": 1, # micro
|
||||
"ns": 1e3, # nano
|
||||
}
|
||||
factor = lit(unit_to_micro_factor[self.unit], type=pa.int64())
|
||||
if self.unit in {"s", "ms"}:
|
||||
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_nanoseconds(self) -> ArrowSeries:
|
||||
unit_to_nano_factor = {
|
||||
"s": NS_PER_SECOND,
|
||||
"ms": NS_PER_MILLISECOND,
|
||||
"us": NS_PER_MICROSECOND,
|
||||
"ns": 1,
|
||||
}
|
||||
factor = lit(unit_to_nano_factor[self.unit], type=pa.int64())
|
||||
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def truncate(self, every: str) -> ArrowSeries:
|
||||
interval = Interval.parse(every)
|
||||
return self.with_native(
|
||||
pc.floor_temporal(self.native, interval.multiple, UNITS_DICT[interval.unit])
|
||||
)
|
||||
|
||||
def offset_by(self, by: str) -> ArrowSeries:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
native = self.native
|
||||
if interval.unit in {"y", "q", "mo"}:
|
||||
msg = f"Offsetting by {interval.unit} is not yet supported for pyarrow."
|
||||
raise NotImplementedError(msg)
|
||||
dtype = self.compliant.dtype
|
||||
datetime_dtype = self.version.dtypes.Datetime
|
||||
if interval.unit == "d" and isinstance(dtype, datetime_dtype) and dtype.time_zone:
|
||||
offset: pa.DurationScalar[Any] = lit(interval.to_timedelta())
|
||||
native_naive = pc.local_timestamp(native)
|
||||
result = pc.assume_timezone(pc.add(native_naive, offset), dtype.time_zone)
|
||||
return self.with_native(result)
|
||||
if interval.unit == "ns": # pragma: no cover
|
||||
offset = lit(interval.multiple, pa.duration("ns")) # type: ignore[assignment]
|
||||
else:
|
||||
offset = lit(interval.to_timedelta())
|
||||
return self.with_native(pc.add(native, offset))
|
24
lib/python3.11/site-packages/narwhals/_arrow/series_list.py
Normal file
24
lib/python3.11/site-packages/narwhals/_arrow/series_list.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
|
||||
class ArrowSeriesListNamespace(ArrowSeriesNamespace):
|
||||
def len(self) -> ArrowSeries:
|
||||
return self.with_native(pc.list_value_length(self.native).cast(pa.uint32()))
|
||||
|
||||
unique = not_implemented()
|
||||
|
||||
contains = not_implemented()
|
||||
|
||||
def get(self, index: int) -> ArrowSeries:
|
||||
return self.with_native(pc.list_element(self.native, index))
|
115
lib/python3.11/site-packages/narwhals/_arrow/series_str.py
Normal file
115
lib/python3.11/site-packages/narwhals/_arrow/series_str.py
Normal file
@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import Incomplete
|
||||
|
||||
|
||||
class ArrowSeriesStringNamespace(ArrowSeriesNamespace):
|
||||
def len_chars(self) -> ArrowSeries:
|
||||
return self.with_native(pc.utf8_length(self.native))
|
||||
|
||||
def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> ArrowSeries:
|
||||
fn = pc.replace_substring if literal else pc.replace_substring_regex
|
||||
try:
|
||||
arr = fn(self.native, pattern, replacement=value, max_replacements=n)
|
||||
except TypeError as e:
|
||||
if not isinstance(value, str):
|
||||
msg = "PyArrow backed `.str.replace` only supports str replacement values"
|
||||
raise TypeError(msg) from e
|
||||
raise
|
||||
return self.with_native(arr)
|
||||
|
||||
def replace_all(self, pattern: str, value: str, *, literal: bool) -> ArrowSeries:
|
||||
try:
|
||||
return self.replace(pattern, value, literal=literal, n=-1)
|
||||
except TypeError as e:
|
||||
if not isinstance(value, str):
|
||||
msg = "PyArrow backed `.str.replace_all` only supports str replacement values."
|
||||
raise TypeError(msg) from e
|
||||
raise
|
||||
|
||||
def strip_chars(self, characters: str | None) -> ArrowSeries:
|
||||
return self.with_native(
|
||||
pc.utf8_trim(self.native, characters or string.whitespace)
|
||||
)
|
||||
|
||||
def starts_with(self, prefix: str) -> ArrowSeries:
|
||||
return self.with_native(pc.equal(self.slice(0, len(prefix)).native, lit(prefix)))
|
||||
|
||||
def ends_with(self, suffix: str) -> ArrowSeries:
|
||||
return self.with_native(
|
||||
pc.equal(self.slice(-len(suffix), None).native, lit(suffix))
|
||||
)
|
||||
|
||||
def contains(self, pattern: str, *, literal: bool) -> ArrowSeries:
|
||||
check_func = pc.match_substring if literal else pc.match_substring_regex
|
||||
return self.with_native(check_func(self.native, pattern))
|
||||
|
||||
def slice(self, offset: int, length: int | None) -> ArrowSeries:
|
||||
stop = offset + length if length is not None else None
|
||||
return self.with_native(
|
||||
pc.utf8_slice_codeunits(self.native, start=offset, stop=stop)
|
||||
)
|
||||
|
||||
def split(self, by: str) -> ArrowSeries:
|
||||
split_series = pc.split_pattern(self.native, by) # type: ignore[call-overload]
|
||||
return self.with_native(split_series)
|
||||
|
||||
def to_datetime(self, format: str | None) -> ArrowSeries:
|
||||
format = parse_datetime_format(self.native) if format is None else format
|
||||
timestamp_array = pc.strptime(self.native, format=format, unit="us")
|
||||
return self.with_native(timestamp_array)
|
||||
|
||||
def to_date(self, format: str | None) -> ArrowSeries:
|
||||
return self.to_datetime(format=format).dt.date()
|
||||
|
||||
def to_uppercase(self) -> ArrowSeries:
|
||||
return self.with_native(pc.utf8_upper(self.native))
|
||||
|
||||
def to_lowercase(self) -> ArrowSeries:
|
||||
return self.with_native(pc.utf8_lower(self.native))
|
||||
|
||||
def zfill(self, width: int) -> ArrowSeries:
|
||||
binary_join: Incomplete = pc.binary_join_element_wise
|
||||
native = self.native
|
||||
hyphen, plus = lit("-"), lit("+")
|
||||
first_char, remaining_chars = (
|
||||
self.slice(0, 1).native,
|
||||
self.slice(1, None).native,
|
||||
)
|
||||
|
||||
# Conditions
|
||||
less_than_width = pc.less(pc.utf8_length(native), lit(width))
|
||||
starts_with_hyphen = pc.equal(first_char, hyphen)
|
||||
starts_with_plus = pc.equal(first_char, plus)
|
||||
|
||||
conditions = pc.make_struct(
|
||||
pc.and_(starts_with_hyphen, less_than_width),
|
||||
pc.and_(starts_with_plus, less_than_width),
|
||||
less_than_width,
|
||||
)
|
||||
|
||||
# Cases
|
||||
padded_remaining_chars = pc.utf8_lpad(remaining_chars, width - 1, padding="0")
|
||||
|
||||
result = pc.case_when(
|
||||
conditions,
|
||||
binary_join(
|
||||
pa.repeat(hyphen, len(native)), padded_remaining_chars, ""
|
||||
), # starts with hyphen and less than width
|
||||
binary_join(
|
||||
pa.repeat(plus, len(native)), padded_remaining_chars, ""
|
||||
), # starts with plus and less than width
|
||||
pc.utf8_lpad(native, width=width, padding="0"), # less than width
|
||||
native,
|
||||
)
|
||||
return self.with_native(result)
|
@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
|
||||
class ArrowSeriesStructNamespace(ArrowSeriesNamespace):
|
||||
def field(self, name: str) -> ArrowSeries:
|
||||
return self.with_native(pc.struct_field(self.native, name)).alias(name)
|
72
lib/python3.11/site-packages/narwhals/_arrow/typing.py
Normal file
72
lib/python3.11/site-packages/narwhals/_arrow/typing.py
Normal file
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations # pragma: no cover
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING, # pragma: no cover
|
||||
Any, # pragma: no cover
|
||||
TypeVar, # pragma: no cover
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
from typing import Generic, Literal
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import pyarrow as pa
|
||||
from pyarrow.__lib_pxi.table import (
|
||||
AggregateOptions, # noqa: F401
|
||||
Aggregation, # noqa: F401
|
||||
)
|
||||
from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource]
|
||||
Indices, # noqa: F401
|
||||
Mask, # noqa: F401
|
||||
Order, # noqa: F401
|
||||
)
|
||||
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
IntoArrowExpr: TypeAlias = "ArrowExpr | ArrowSeries"
|
||||
TieBreaker: TypeAlias = Literal["min", "max", "first", "dense"]
|
||||
NullPlacement: TypeAlias = Literal["at_start", "at_end"]
|
||||
NativeIntervalUnit: TypeAlias = Literal[
|
||||
"year",
|
||||
"quarter",
|
||||
"month",
|
||||
"week",
|
||||
"day",
|
||||
"hour",
|
||||
"minute",
|
||||
"second",
|
||||
"millisecond",
|
||||
"microsecond",
|
||||
"nanosecond",
|
||||
]
|
||||
|
||||
ChunkedArrayAny: TypeAlias = pa.ChunkedArray[Any]
|
||||
ArrayAny: TypeAlias = pa.Array[Any]
|
||||
ArrayOrChunkedArray: TypeAlias = "ArrayAny | ChunkedArrayAny"
|
||||
ScalarAny: TypeAlias = pa.Scalar[Any]
|
||||
ArrayOrScalar: TypeAlias = "ArrayOrChunkedArray | ScalarAny"
|
||||
ArrayOrScalarT1 = TypeVar("ArrayOrScalarT1", ArrayAny, ChunkedArrayAny, ScalarAny)
|
||||
ArrayOrScalarT2 = TypeVar("ArrayOrScalarT2", ArrayAny, ChunkedArrayAny, ScalarAny)
|
||||
_AsPyType = TypeVar("_AsPyType")
|
||||
|
||||
class _BasicDataType(pa.DataType, Generic[_AsPyType]): ...
|
||||
|
||||
|
||||
Incomplete: TypeAlias = Any # pragma: no cover
|
||||
"""
|
||||
Marker for working code that fails on the stubs.
|
||||
|
||||
Common issues:
|
||||
- Annotated for `Array`, but not `ChunkedArray`
|
||||
- Relies on typing information that the stubs don't provide statically
|
||||
- Missing attributes
|
||||
- Incorrect return types
|
||||
- Inconsistent use of generic/concrete types
|
||||
- `_clone_signature` used on signatures that are not identical
|
||||
"""
|
438
lib/python3.11/site-packages/narwhals/_arrow/utils.py
Normal file
438
lib/python3.11/site-packages/narwhals/_arrow/utils.py
Normal file
@ -0,0 +1,438 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._compliant import EagerSeriesNamespace
|
||||
from narwhals._utils import Version, isinstance_or_issubclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping
|
||||
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import (
|
||||
ArrayAny,
|
||||
ArrayOrScalar,
|
||||
ArrayOrScalarT1,
|
||||
ArrayOrScalarT2,
|
||||
ChunkedArrayAny,
|
||||
NativeIntervalUnit,
|
||||
ScalarAny,
|
||||
)
|
||||
from narwhals._duration import IntervalUnit
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType, PythonLiteral
|
||||
|
||||
# NOTE: stubs don't allow for `ChunkedArray[StructArray]`
|
||||
# Intended to represent the `.chunks` property storing `list[pa.StructArray]`
|
||||
ChunkedArrayStructArray: TypeAlias = ChunkedArrayAny
|
||||
|
||||
def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ...
|
||||
def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ...
|
||||
def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ...
|
||||
def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ...
|
||||
def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ...
|
||||
def is_dictionary(t: Any) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ...
|
||||
def extract_regex(
|
||||
strings: ChunkedArrayAny,
|
||||
/,
|
||||
pattern: str,
|
||||
*,
|
||||
options: Any = None,
|
||||
memory_pool: Any = None,
|
||||
) -> ChunkedArrayStructArray: ...
|
||||
else:
|
||||
from pyarrow.compute import extract_regex
|
||||
from pyarrow.types import (
|
||||
is_dictionary, # noqa: F401
|
||||
is_duration,
|
||||
is_fixed_size_list,
|
||||
is_large_list,
|
||||
is_list,
|
||||
is_timestamp,
|
||||
)
|
||||
|
||||
UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = {
|
||||
"y": "year",
|
||||
"q": "quarter",
|
||||
"mo": "month",
|
||||
"d": "day",
|
||||
"h": "hour",
|
||||
"m": "minute",
|
||||
"s": "second",
|
||||
"ms": "millisecond",
|
||||
"us": "microsecond",
|
||||
"ns": "nanosecond",
|
||||
}
|
||||
|
||||
lit = pa.scalar
|
||||
"""Alias for `pyarrow.scalar`."""
|
||||
|
||||
|
||||
def extract_py_scalar(value: Any, /) -> Any:
|
||||
from narwhals._arrow.series import maybe_extract_py_scalar
|
||||
|
||||
return maybe_extract_py_scalar(value, return_py_scalar=True)
|
||||
|
||||
|
||||
def is_array_or_scalar(obj: Any) -> TypeIs[ArrayOrScalar]:
|
||||
"""Return True for any base `pyarrow` container."""
|
||||
return isinstance(obj, (pa.ChunkedArray, pa.Array, pa.Scalar))
|
||||
|
||||
|
||||
def chunked_array(
|
||||
arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, /
|
||||
) -> ChunkedArrayAny:
|
||||
if isinstance(arr, pa.ChunkedArray):
|
||||
return arr
|
||||
if isinstance(arr, list):
|
||||
return pa.chunked_array(arr, dtype)
|
||||
return pa.chunked_array([arr], dtype)
|
||||
|
||||
|
||||
def nulls_like(n: int, series: ArrowSeries) -> ArrayAny:
|
||||
"""Create a strongly-typed Array instance with all elements null.
|
||||
|
||||
Uses the type of `series`, without upseting `mypy`.
|
||||
"""
|
||||
return pa.nulls(n, series.native.type)
|
||||
|
||||
|
||||
def zeros(n: int, /) -> pa.Int64Array:
|
||||
return pa.repeat(0, n)
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: # noqa: C901, PLR0912
|
||||
dtypes = version.dtypes
|
||||
if pa.types.is_int64(dtype):
|
||||
return dtypes.Int64()
|
||||
if pa.types.is_int32(dtype):
|
||||
return dtypes.Int32()
|
||||
if pa.types.is_int16(dtype):
|
||||
return dtypes.Int16()
|
||||
if pa.types.is_int8(dtype):
|
||||
return dtypes.Int8()
|
||||
if pa.types.is_uint64(dtype):
|
||||
return dtypes.UInt64()
|
||||
if pa.types.is_uint32(dtype):
|
||||
return dtypes.UInt32()
|
||||
if pa.types.is_uint16(dtype):
|
||||
return dtypes.UInt16()
|
||||
if pa.types.is_uint8(dtype):
|
||||
return dtypes.UInt8()
|
||||
if pa.types.is_boolean(dtype):
|
||||
return dtypes.Boolean()
|
||||
if pa.types.is_float64(dtype):
|
||||
return dtypes.Float64()
|
||||
if pa.types.is_float32(dtype):
|
||||
return dtypes.Float32()
|
||||
# bug in coverage? it shows `31->exit` (where `31` is currently the line number of
|
||||
# the next line), even though both when the if condition is true and false are covered
|
||||
if ( # pragma: no cover
|
||||
pa.types.is_string(dtype)
|
||||
or pa.types.is_large_string(dtype)
|
||||
or getattr(pa.types, "is_string_view", lambda _: False)(dtype)
|
||||
):
|
||||
return dtypes.String()
|
||||
if pa.types.is_date32(dtype):
|
||||
return dtypes.Date()
|
||||
if is_timestamp(dtype):
|
||||
return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz)
|
||||
if is_duration(dtype):
|
||||
return dtypes.Duration(time_unit=dtype.unit)
|
||||
if pa.types.is_dictionary(dtype):
|
||||
return dtypes.Categorical()
|
||||
if pa.types.is_struct(dtype):
|
||||
return dtypes.Struct(
|
||||
[
|
||||
dtypes.Field(
|
||||
dtype.field(i).name,
|
||||
native_to_narwhals_dtype(dtype.field(i).type, version),
|
||||
)
|
||||
for i in range(dtype.num_fields)
|
||||
]
|
||||
)
|
||||
if is_list(dtype) or is_large_list(dtype):
|
||||
return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version))
|
||||
if is_fixed_size_list(dtype):
|
||||
return dtypes.Array(
|
||||
native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size
|
||||
)
|
||||
if pa.types.is_decimal(dtype):
|
||||
return dtypes.Decimal()
|
||||
if pa.types.is_time32(dtype) or pa.types.is_time64(dtype):
|
||||
return dtypes.Time()
|
||||
if pa.types.is_binary(dtype):
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_PA_DTYPES: Mapping[type[DType], pa.DataType] = {
|
||||
dtypes.Float64: pa.float64(),
|
||||
dtypes.Float32: pa.float32(),
|
||||
dtypes.Binary: pa.binary(),
|
||||
dtypes.String: pa.string(),
|
||||
dtypes.Boolean: pa.bool_(),
|
||||
dtypes.Categorical: pa.dictionary(pa.uint32(), pa.string()),
|
||||
dtypes.Date: pa.date32(),
|
||||
dtypes.Time: pa.time64("ns"),
|
||||
dtypes.Int8: pa.int8(),
|
||||
dtypes.Int16: pa.int16(),
|
||||
dtypes.Int32: pa.int32(),
|
||||
dtypes.Int64: pa.int64(),
|
||||
dtypes.UInt8: pa.uint8(),
|
||||
dtypes.UInt16: pa.uint16(),
|
||||
dtypes.UInt32: pa.uint32(),
|
||||
dtypes.UInt64: pa.uint64(),
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (dtypes.Decimal, dtypes.Object)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType:
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if pa_type := NW_TO_PA_DTYPES.get(base_type):
|
||||
return pa_type
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
unit = dtype.time_unit
|
||||
return pa.timestamp(unit, tz) if (tz := dtype.time_zone) else pa.timestamp(unit)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
return pa.duration(dtype.time_unit)
|
||||
if isinstance_or_issubclass(dtype, dtypes.List):
|
||||
return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version))
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct):
|
||||
return pa.struct(
|
||||
[
|
||||
(field.name, narwhals_to_native_dtype(field.dtype, version=version))
|
||||
for field in dtype.fields
|
||||
]
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
|
||||
inner = narwhals_to_native_dtype(dtype.inner, version=version)
|
||||
list_size = dtype.size
|
||||
return pa.list_(inner, list_size=list_size)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES):
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for PyArrow."
|
||||
raise NotImplementedError(msg)
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def extract_native(
|
||||
lhs: ArrowSeries, rhs: ArrowSeries | PythonLiteral | ScalarAny
|
||||
) -> tuple[ChunkedArrayAny | ScalarAny, ChunkedArrayAny | ScalarAny]:
|
||||
"""Extract native objects in binary operation.
|
||||
|
||||
If the comparison isn't supported, return `NotImplemented` so that the
|
||||
"right-hand-side" operation (e.g. `__radd__`) can be tried.
|
||||
|
||||
If one of the two sides has a `_broadcast` flag, then extract the scalar
|
||||
underneath it so that PyArrow can do its own broadcasting.
|
||||
"""
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
if rhs is None: # pragma: no cover
|
||||
return lhs.native, lit(None, type=lhs._type)
|
||||
|
||||
if isinstance(rhs, ArrowSeries):
|
||||
if lhs._broadcast and not rhs._broadcast:
|
||||
return lhs.native[0], rhs.native
|
||||
if rhs._broadcast:
|
||||
return lhs.native, rhs.native[0]
|
||||
return lhs.native, rhs.native
|
||||
|
||||
if isinstance(rhs, list):
|
||||
msg = "Expected Series or scalar, got list."
|
||||
raise TypeError(msg)
|
||||
|
||||
return lhs.native, rhs if isinstance(rhs, pa.Scalar) else lit(rhs)
|
||||
|
||||
|
||||
def floordiv_compat(left: ArrayOrScalar, right: ArrayOrScalar, /) -> Any:
|
||||
# The following lines are adapted from pandas' pyarrow implementation.
|
||||
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154
|
||||
|
||||
if pa.types.is_integer(left.type) and pa.types.is_integer(right.type):
|
||||
divided = pc.divide_checked(left, right)
|
||||
# TODO @dangotbanned: Use a `TypeVar` in guards
|
||||
# Narrowing to a `Union` isn't interacting well with the rest of the stubs
|
||||
# https://github.com/zen-xu/pyarrow-stubs/pull/215
|
||||
if pa.types.is_signed_integer(divided.type):
|
||||
div_type = cast("pa._lib.Int64Type", divided.type)
|
||||
has_remainder = pc.not_equal(pc.multiply(divided, right), left)
|
||||
has_one_negative_operand = pc.less(
|
||||
pc.bit_wise_xor(left, right), lit(0, div_type)
|
||||
)
|
||||
result = pc.if_else(
|
||||
pc.and_(has_remainder, has_one_negative_operand),
|
||||
pc.subtract(divided, lit(1, div_type)),
|
||||
divided,
|
||||
)
|
||||
else:
|
||||
result = divided # pragma: no cover
|
||||
result = result.cast(left.type)
|
||||
else:
|
||||
divided = pc.divide(left, right)
|
||||
result = pc.floor(divided)
|
||||
return result
|
||||
|
||||
|
||||
def cast_for_truediv(
|
||||
arrow_array: ArrayOrScalarT1, pa_object: ArrayOrScalarT2
|
||||
) -> tuple[ArrayOrScalarT1, ArrayOrScalarT2]:
|
||||
# Lifted from:
|
||||
# https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122
|
||||
# Ensure int / int -> float mirroring Python/Numpy behavior
|
||||
# as pc.divide_checked(int, int) -> int
|
||||
if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(pa_object.type):
|
||||
# GH: 56645. # noqa: ERA001
|
||||
# https://github.com/apache/arrow/issues/35563
|
||||
return arrow_array.cast(pa.float64(), safe=False), pa_object.cast(
|
||||
pa.float64(), safe=False
|
||||
)
|
||||
|
||||
return arrow_array, pa_object
|
||||
|
||||
|
||||
# Regex for date, time, separator and timezone components
|
||||
DATE_RE = r"(?P<date>\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})"
|
||||
SEP_RE = r"(?P<sep>\s|T)"
|
||||
TIME_RE = r"(?P<time>\d{2}:\d{2}(?::\d{2})?|\d{6}?)" # \s*(?P<period>[AP]M)?)?
|
||||
HMS_RE = r"^(?P<hms>\d{2}:\d{2}:\d{2})$"
|
||||
HM_RE = r"^(?P<hm>\d{2}:\d{2})$"
|
||||
HMS_RE_NO_SEP = r"^(?P<hms_no_sep>\d{6})$"
|
||||
TZ_RE = r"(?P<tz>Z|[+-]\d{2}:?\d{2})" # Matches 'Z', '+02:00', '+0200', '+02', etc.
|
||||
FULL_RE = rf"{DATE_RE}{SEP_RE}?{TIME_RE}?{TZ_RE}?$"
|
||||
|
||||
# Separate regexes for different date formats
|
||||
YMD_RE = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])$"
|
||||
DMY_RE = r"^(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
|
||||
MDY_RE = r"^(?P<month>0[1-9]|1[0-2])(?P<sep1>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
|
||||
YMD_RE_NO_SEP = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<month>0[1-9]|1[0-2])(?P<day>0[1-9]|[12][0-9]|3[01])$"
|
||||
|
||||
DATE_FORMATS = (
|
||||
(YMD_RE_NO_SEP, "%Y%m%d"),
|
||||
(YMD_RE, "%Y-%m-%d"),
|
||||
(DMY_RE, "%d-%m-%Y"),
|
||||
(MDY_RE, "%m-%d-%Y"),
|
||||
)
|
||||
TIME_FORMATS = ((HMS_RE, "%H:%M:%S"), (HM_RE, "%H:%M"), (HMS_RE_NO_SEP, "%H%M%S"))
|
||||
|
||||
|
||||
def _extract_regex_concat_arrays(
|
||||
strings: ChunkedArrayAny,
|
||||
/,
|
||||
pattern: str,
|
||||
*,
|
||||
options: Any = None,
|
||||
memory_pool: Any = None,
|
||||
) -> pa.StructArray:
|
||||
r = pa.concat_arrays(
|
||||
extract_regex(strings, pattern, options=options, memory_pool=memory_pool).chunks
|
||||
)
|
||||
return cast("pa.StructArray", r)
|
||||
|
||||
|
||||
def parse_datetime_format(arr: ChunkedArrayAny) -> str:
|
||||
"""Try to infer datetime format from StringArray."""
|
||||
matches = _extract_regex_concat_arrays(arr.drop_null().slice(0, 10), pattern=FULL_RE)
|
||||
if not pc.all(matches.is_valid()).as_py():
|
||||
msg = (
|
||||
"Unable to infer datetime format, provided format is not supported. "
|
||||
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
separators = matches.field("sep")
|
||||
tz = matches.field("tz")
|
||||
|
||||
# separators and time zones must be unique
|
||||
if pc.count(pc.unique(separators)).as_py() > 1:
|
||||
msg = "Found multiple separator values while inferring datetime format."
|
||||
raise ValueError(msg)
|
||||
|
||||
if pc.count(pc.unique(tz)).as_py() > 1:
|
||||
msg = "Found multiple timezone values while inferring datetime format."
|
||||
raise ValueError(msg)
|
||||
|
||||
date_value = _parse_date_format(cast("pc.StringArray", matches.field("date")))
|
||||
time_value = _parse_time_format(cast("pc.StringArray", matches.field("time")))
|
||||
|
||||
sep_value = separators[0].as_py()
|
||||
tz_value = "%z" if tz[0].as_py() else ""
|
||||
|
||||
return f"{date_value}{sep_value}{time_value}{tz_value}"
|
||||
|
||||
|
||||
def _parse_date_format(arr: pc.StringArray) -> str:
|
||||
for date_rgx, date_fmt in DATE_FORMATS:
|
||||
matches = pc.extract_regex(arr, pattern=date_rgx)
|
||||
if date_fmt == "%Y%m%d" and pc.all(matches.is_valid()).as_py():
|
||||
return date_fmt
|
||||
if (
|
||||
pc.all(matches.is_valid()).as_py()
|
||||
and pc.count(pc.unique(sep1 := matches.field("sep1"))).as_py() == 1
|
||||
and pc.count(pc.unique(sep2 := matches.field("sep2"))).as_py() == 1
|
||||
and (date_sep_value := sep1[0].as_py()) == sep2[0].as_py()
|
||||
):
|
||||
return date_fmt.replace("-", date_sep_value)
|
||||
|
||||
msg = (
|
||||
"Unable to infer datetime format. "
|
||||
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _parse_time_format(arr: pc.StringArray) -> str:
|
||||
for time_rgx, time_fmt in TIME_FORMATS:
|
||||
matches = pc.extract_regex(arr, pattern=time_rgx)
|
||||
if pc.all(matches.is_valid()).as_py():
|
||||
return time_fmt
|
||||
return ""
|
||||
|
||||
|
||||
def pad_series(
|
||||
series: ArrowSeries, *, window_size: int, center: bool
|
||||
) -> tuple[ArrowSeries, int]:
|
||||
"""Pad series with None values on the left and/or right side, depending on the specified parameters.
|
||||
|
||||
Arguments:
|
||||
series: The input ArrowSeries to be padded.
|
||||
window_size: The desired size of the window.
|
||||
center: Specifies whether to center the padding or not.
|
||||
|
||||
Returns:
|
||||
A tuple containing the padded ArrowSeries and the offset value.
|
||||
"""
|
||||
if not center:
|
||||
return series, 0
|
||||
offset_left = window_size // 2
|
||||
# subtract one if window_size is even
|
||||
offset_right = offset_left - (window_size % 2 == 0)
|
||||
pad_left = pa.array([None] * offset_left, type=series._type)
|
||||
pad_right = pa.array([None] * offset_right, type=series._type)
|
||||
concat = pa.concat_arrays([pad_left, *series.native.chunks, pad_right])
|
||||
return series._with_native(concat), offset_left + offset_right
|
||||
|
||||
|
||||
def cast_to_comparable_string_types(
|
||||
*chunked_arrays: ChunkedArrayAny, separator: str
|
||||
) -> tuple[Iterator[ChunkedArrayAny], ScalarAny]:
|
||||
# Ensure `chunked_arrays` are either all `string` or all `large_string`.
|
||||
dtype = (
|
||||
pa.string() # (PyArrow default)
|
||||
if not any(pa.types.is_large_string(ca.type) for ca in chunked_arrays)
|
||||
else pa.large_string()
|
||||
)
|
||||
return (ca.cast(dtype) for ca in chunked_arrays), lit(separator, dtype)
|
||||
|
||||
|
||||
class ArrowSeriesNamespace(EagerSeriesNamespace["ArrowSeries", "ChunkedArrayAny"]): ...
|
Reference in New Issue
Block a user