done
This commit is contained in:
303
lib/python3.11/site-packages/narwhals/_duckdb/expr.py
Normal file
303
lib/python3.11/site-packages/narwhals/_duckdb/expr.py
Normal file
@ -0,0 +1,303 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
|
||||
|
||||
from duckdb import CoalesceOperator, StarExpression
|
||||
|
||||
from narwhals._duckdb.expr_dt import DuckDBExprDateTimeNamespace
|
||||
from narwhals._duckdb.expr_list import DuckDBExprListNamespace
|
||||
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
|
||||
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
|
||||
from narwhals._duckdb.utils import (
|
||||
DeferredTimeZone,
|
||||
F,
|
||||
col,
|
||||
lit,
|
||||
narwhals_to_native_dtype,
|
||||
when,
|
||||
window_expression,
|
||||
)
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._sql.expr import SQLExpr
|
||||
from narwhals._utils import Implementation, Version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from duckdb import Expression
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant import WindowInputs
|
||||
from narwhals._compliant.typing import (
|
||||
AliasNames,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
WindowFunction,
|
||||
)
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
from narwhals._utils import _LimitedContext
|
||||
from narwhals.typing import (
|
||||
FillNullStrategy,
|
||||
IntoDType,
|
||||
NonNestedLiteral,
|
||||
RollingInterpolationMethod,
|
||||
)
|
||||
|
||||
DuckDBWindowFunction = WindowFunction[DuckDBLazyFrame, Expression]
|
||||
DuckDBWindowInputs = WindowInputs[Expression]
|
||||
|
||||
|
||||
class DuckDBExpr(SQLExpr["DuckDBLazyFrame", "Expression"]):
|
||||
_implementation = Implementation.DUCKDB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[DuckDBLazyFrame, Expression],
|
||||
window_function: DuckDBWindowFunction | None = None,
|
||||
*,
|
||||
evaluate_output_names: EvalNames[DuckDBLazyFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
implementation: Implementation = Implementation.DUCKDB,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._metadata: ExprMetadata | None = None
|
||||
self._window_function: DuckDBWindowFunction | None = window_function
|
||||
|
||||
def _count_star(self) -> Expression:
|
||||
return F("count", StarExpression())
|
||||
|
||||
def _window_expression(
|
||||
self,
|
||||
expr: Expression,
|
||||
partition_by: Sequence[str | Expression] = (),
|
||||
order_by: Sequence[str | Expression] = (),
|
||||
rows_start: int | None = None,
|
||||
rows_end: int | None = None,
|
||||
*,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
) -> Expression:
|
||||
return window_expression(
|
||||
expr,
|
||||
partition_by,
|
||||
order_by,
|
||||
rows_start,
|
||||
rows_end,
|
||||
descending=descending,
|
||||
nulls_last=nulls_last,
|
||||
)
|
||||
|
||||
def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
|
||||
return DuckDBNamespace(version=self._version)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
if kind is ExprKind.LITERAL:
|
||||
return self
|
||||
if self._backend_version < (1, 3):
|
||||
msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns."
|
||||
raise NotImplementedError(msg)
|
||||
return self.over([lit(1)], [])
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls,
|
||||
evaluate_column_names: EvalNames[DuckDBLazyFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
) -> Self:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
return [col(name) for name in evaluate_column_names(df)]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
columns = df.columns
|
||||
return [col(columns[i]) for i in column_indices]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _alias_native(cls, expr: Expression, name: str) -> Expression:
|
||||
return expr.alias(name)
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
invert = cast("Callable[..., Expression]", operator.invert)
|
||||
return self._with_elementwise(invert)
|
||||
|
||||
def skew(self) -> Self:
|
||||
def func(expr: Expression) -> Expression:
|
||||
count = F("count", expr)
|
||||
# Adjust population skewness by correction factor to get sample skewness
|
||||
sample_skewness = (
|
||||
F("skewness", expr)
|
||||
* (count - lit(2))
|
||||
/ F("sqrt", count * (count - lit(1)))
|
||||
)
|
||||
return when(count == lit(0), lit(None)).otherwise(
|
||||
when(count == lit(1), lit(float("nan"))).otherwise(
|
||||
when(count == lit(2), lit(0.0)).otherwise(sample_skewness)
|
||||
)
|
||||
)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def kurtosis(self) -> Self:
|
||||
return self._with_callable(lambda expr: F("kurtosis_pop", expr))
|
||||
|
||||
def quantile(
|
||||
self, quantile: float, interpolation: RollingInterpolationMethod
|
||||
) -> Self:
|
||||
def func(expr: Expression) -> Expression:
|
||||
if interpolation == "linear":
|
||||
return F("quantile_cont", expr, lit(quantile))
|
||||
msg = "Only linear interpolation methods are supported for DuckDB quantile."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def n_unique(self) -> Self:
|
||||
def func(expr: Expression) -> Expression:
|
||||
# https://stackoverflow.com/a/79338887/4451315
|
||||
return F("array_unique", F("array_agg", expr)) + F(
|
||||
"max", when(expr.isnotnull(), lit(0)).otherwise(lit(1))
|
||||
)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def len(self) -> Self:
|
||||
return self._with_callable(lambda _expr: F("count"))
|
||||
|
||||
def std(self, ddof: int) -> Self:
|
||||
if ddof == 0:
|
||||
return self._with_callable(lambda expr: F("stddev_pop", expr))
|
||||
if ddof == 1:
|
||||
return self._with_callable(lambda expr: F("stddev_samp", expr))
|
||||
|
||||
def _std(expr: Expression) -> Expression:
|
||||
n_samples = F("count", expr)
|
||||
return (
|
||||
F("stddev_pop", expr)
|
||||
* F("sqrt", n_samples)
|
||||
/ (F("sqrt", (n_samples - lit(ddof))))
|
||||
)
|
||||
|
||||
return self._with_callable(_std)
|
||||
|
||||
def var(self, ddof: int) -> Self:
|
||||
if ddof == 0:
|
||||
return self._with_callable(lambda expr: F("var_pop", expr))
|
||||
if ddof == 1:
|
||||
return self._with_callable(lambda expr: F("var_samp", expr))
|
||||
|
||||
def _var(expr: Expression) -> Expression:
|
||||
n_samples = F("count", expr)
|
||||
return F("var_pop", expr) * n_samples / (n_samples - lit(ddof))
|
||||
|
||||
return self._with_callable(_var)
|
||||
|
||||
def null_count(self) -> Self:
|
||||
return self._with_callable(lambda expr: F("sum", expr.isnull().cast("int")))
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
return self._with_elementwise(lambda expr: F("isnan", expr))
|
||||
|
||||
def is_finite(self) -> Self:
|
||||
return self._with_elementwise(lambda expr: F("isfinite", expr))
|
||||
|
||||
def is_in(self, other: Sequence[Any]) -> Self:
|
||||
return self._with_elementwise(lambda expr: F("contains", lit(other), expr))
|
||||
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self:
|
||||
if strategy is not None:
|
||||
if self._backend_version < (1, 3): # pragma: no cover
|
||||
msg = f"`fill_null` with `strategy={strategy}` is only available in 'duckdb>=1.3.0'."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def _fill_with_strategy(
|
||||
df: DuckDBLazyFrame, inputs: DuckDBWindowInputs
|
||||
) -> Sequence[Expression]:
|
||||
fill_func = "last_value" if strategy == "forward" else "first_value"
|
||||
rows_start, rows_end = (
|
||||
(-limit if limit is not None else None, 0)
|
||||
if strategy == "forward"
|
||||
else (0, limit)
|
||||
)
|
||||
return [
|
||||
window_expression(
|
||||
F(fill_func, expr),
|
||||
inputs.partition_by,
|
||||
inputs.order_by,
|
||||
rows_start=rows_start,
|
||||
rows_end=rows_end,
|
||||
ignore_nulls=True,
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_window_function(_fill_with_strategy)
|
||||
|
||||
def _fill_constant(expr: Expression, value: Any) -> Expression:
|
||||
return CoalesceOperator(expr, value)
|
||||
|
||||
return self._with_elementwise(_fill_constant, value=value)
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
tz = DeferredTimeZone(df.native)
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
|
||||
return [expr.cast(native_dtype) for expr in self(df)]
|
||||
|
||||
def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
|
||||
tz = DeferredTimeZone(df.native)
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
|
||||
return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
window_f,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
@property
|
||||
def str(self) -> DuckDBExprStringNamespace:
|
||||
return DuckDBExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def dt(self) -> DuckDBExprDateTimeNamespace:
|
||||
return DuckDBExprDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def list(self) -> DuckDBExprListNamespace:
|
||||
return DuckDBExprListNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> DuckDBExprStructNamespace:
|
||||
return DuckDBExprStructNamespace(self)
|
Reference in New Issue
Block a user