done
This commit is contained in:
159
lib/python3.11/site-packages/narwhals/_arrow/group_by.py
Normal file
159
lib/python3.11/site-packages/narwhals/_arrow/group_by.py
Normal file
@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar
|
||||
from narwhals._compliant import EagerGroupBy
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
from narwhals._utils import generate_temporary_column_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
|
||||
AggregateOptions,
|
||||
Aggregation,
|
||||
Incomplete,
|
||||
)
|
||||
from narwhals._compliant.typing import NarwhalsAggregation
|
||||
from narwhals.typing import UniqueKeepStrategy
|
||||
|
||||
|
||||
class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
|
||||
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
|
||||
"sum": "sum",
|
||||
"mean": "mean",
|
||||
"median": "approximate_median",
|
||||
"max": "max",
|
||||
"min": "min",
|
||||
"std": "stddev",
|
||||
"var": "variance",
|
||||
"len": "count",
|
||||
"n_unique": "count_distinct",
|
||||
"count": "count",
|
||||
"all": "all",
|
||||
"any": "any",
|
||||
}
|
||||
_REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
|
||||
"any": "min",
|
||||
"first": "min",
|
||||
"last": "max",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: ArrowDataFrame,
|
||||
keys: Sequence[ArrowExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._df = df
|
||||
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
|
||||
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
|
||||
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
|
||||
self._drop_null_keys = drop_null_keys
|
||||
|
||||
def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
|
||||
self._ensure_all_simple(exprs)
|
||||
aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = []
|
||||
expected_pyarrow_column_names: list[str] = self._keys.copy()
|
||||
new_column_names: list[str] = self._keys.copy()
|
||||
exclude = (*self._keys, *self._output_key_names)
|
||||
|
||||
for expr in exprs:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(
|
||||
expr, self.compliant, exclude
|
||||
)
|
||||
|
||||
if expr._depth == 0:
|
||||
# e.g. `agg(nw.len())`
|
||||
if expr._function_name != "len": # pragma: no cover
|
||||
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
||||
raise AssertionError(msg)
|
||||
|
||||
new_column_names.append(aliases[0])
|
||||
expected_pyarrow_column_names.append(f"{self._keys[0]}_count")
|
||||
aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))
|
||||
continue
|
||||
|
||||
function_name = self._leaf_name(expr)
|
||||
if function_name in {"std", "var"}:
|
||||
assert "ddof" in expr._scalar_kwargs # noqa: S101
|
||||
option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"])
|
||||
elif function_name in {"len", "n_unique"}:
|
||||
option = pc.CountOptions(mode="all")
|
||||
elif function_name == "count":
|
||||
option = pc.CountOptions(mode="only_valid")
|
||||
elif function_name in {"all", "any"}:
|
||||
option = pc.ScalarAggregateOptions(min_count=0)
|
||||
else:
|
||||
option = None
|
||||
|
||||
function_name = self._remap_expr_name(function_name)
|
||||
new_column_names.extend(aliases)
|
||||
expected_pyarrow_column_names.extend(
|
||||
[f"{output_name}_{function_name}" for output_name in output_names]
|
||||
)
|
||||
aggs.extend(
|
||||
[(output_name, function_name, option) for output_name in output_names]
|
||||
)
|
||||
|
||||
result_simple = self._grouped.aggregate(aggs)
|
||||
|
||||
# Rename columns, being very careful
|
||||
expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)
|
||||
for idx, item in enumerate(expected_pyarrow_column_names):
|
||||
expected_old_names_indices[item].append(idx)
|
||||
if not (
|
||||
set(result_simple.column_names) == set(expected_pyarrow_column_names)
|
||||
and len(result_simple.column_names) == len(expected_pyarrow_column_names)
|
||||
): # pragma: no cover
|
||||
msg = (
|
||||
f"Safety assertion failed, expected {expected_pyarrow_column_names} "
|
||||
f"got {result_simple.column_names}, "
|
||||
"please report a bug at https://github.com/narwhals-dev/narwhals/issues"
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
index_map: list[int] = [
|
||||
expected_old_names_indices[item].pop(0) for item in result_simple.column_names
|
||||
]
|
||||
new_column_names = [new_column_names[i] for i in index_map]
|
||||
result_simple = result_simple.rename_columns(new_column_names)
|
||||
return self.compliant._with_native(result_simple).rename(
|
||||
dict(zip(self._keys, self._output_key_names))
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
|
||||
col_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=self.compliant.columns
|
||||
)
|
||||
null_token: str = "__null_token_value__" # noqa: S105
|
||||
|
||||
table = self.compliant.native
|
||||
it, separator_scalar = cast_to_comparable_string_types(
|
||||
*(table[key] for key in self._keys), separator=""
|
||||
)
|
||||
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
|
||||
# Reality: `str` is fine
|
||||
concat_str: Incomplete = pc.binary_join_element_wise
|
||||
key_values = concat_str(
|
||||
*it, separator_scalar, null_handling="replace", null_replacement=null_token
|
||||
)
|
||||
table = table.add_column(i=0, field_=col_token, column=key_values)
|
||||
|
||||
for v in pc.unique(key_values):
|
||||
t = self.compliant._with_native(
|
||||
table.filter(pc.equal(table[col_token], v)).drop([col_token])
|
||||
)
|
||||
row = t.simple_select(*self._keys).row(0)
|
||||
yield (
|
||||
tuple(extract_py_scalar(el) for el in row),
|
||||
t.simple_select(*self._df.columns),
|
||||
)
|
Reference in New Issue
Block a user