done
This commit is contained in:
1
lib/python3.11/site-packages/narwhals/_sql/__init__.py
Normal file
1
lib/python3.11/site-packages/narwhals/_sql/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# ! Any change to this module will trigger the pyspark and pyspark-connect tests in CI
|
46
lib/python3.11/site-packages/narwhals/_sql/dataframe.py
Normal file
46
lib/python3.11/site-packages/narwhals/_sql/dataframe.py
Normal file
@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from narwhals._compliant.dataframe import CompliantLazyFrame
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantExprT_contra,
|
||||
NativeExprT,
|
||||
NativeLazyFrameT,
|
||||
)
|
||||
from narwhals._translate import ToNarwhalsT_co
|
||||
from narwhals._utils import check_columns_exist
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
from narwhals._compliant.window import WindowInputs
|
||||
from narwhals._sql.expr import SQLExpr
|
||||
from narwhals.exceptions import ColumnNotFoundError
|
||||
|
||||
Incomplete: TypeAlias = Any
|
||||
|
||||
|
||||
class SQLLazyFrame(
|
||||
CompliantLazyFrame[CompliantExprT_contra, NativeLazyFrameT, ToNarwhalsT_co],
|
||||
Protocol[CompliantExprT_contra, NativeLazyFrameT, ToNarwhalsT_co],
|
||||
):
|
||||
def _evaluate_window_expr(
|
||||
self,
|
||||
expr: SQLExpr[Self, NativeExprT],
|
||||
/,
|
||||
window_inputs: WindowInputs[NativeExprT],
|
||||
) -> NativeExprT:
|
||||
result = expr.window_function(self, window_inputs)
|
||||
assert len(result) == 1 # debug assertion # noqa: S101
|
||||
return result[0]
|
||||
|
||||
def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:
|
||||
result = expr(self)
|
||||
assert len(result) == 1 # debug assertion # noqa: S101
|
||||
return result[0]
|
||||
|
||||
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
|
||||
return check_columns_exist(subset, available=self.columns)
|
801
lib/python3.11/site-packages/narwhals/_sql/expr.py
Normal file
801
lib/python3.11/site-packages/narwhals/_sql/expr.py
Normal file
@ -0,0 +1,801 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator as op
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol
|
||||
|
||||
from narwhals._compliant.expr import LazyExpr
|
||||
from narwhals._compliant.typing import (
|
||||
AliasNames,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
NativeExprT,
|
||||
WindowFunction,
|
||||
)
|
||||
from narwhals._compliant.window import WindowInputs
|
||||
from narwhals._expression_parsing import (
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._sql.typing import SQLLazyFrameT
|
||||
from narwhals._utils import Implementation, Version, not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import AliasNames, WindowFunction
|
||||
from narwhals._expression_parsing import ExprMetadata
|
||||
from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace
|
||||
from narwhals._sql.expr_str import SQLExprStringNamespace
|
||||
from narwhals._sql.namespace import SQLNamespace
|
||||
from narwhals.typing import (
|
||||
ModeKeepStrategy,
|
||||
NumericLiteral,
|
||||
PythonLiteral,
|
||||
RankMethod,
|
||||
TemporalLiteral,
|
||||
)
|
||||
|
||||
|
||||
class SQLExpr(LazyExpr[SQLLazyFrameT, NativeExprT], Protocol[SQLLazyFrameT, NativeExprT]):
|
||||
_call: EvalSeries[SQLLazyFrameT, NativeExprT]
|
||||
_evaluate_output_names: EvalNames[SQLLazyFrameT]
|
||||
_alias_output_names: AliasNames | None
|
||||
_version: Version
|
||||
_implementation: Implementation
|
||||
_metadata: ExprMetadata | None
|
||||
_window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[SQLLazyFrameT, NativeExprT],
|
||||
window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None = None,
|
||||
*,
|
||||
evaluate_output_names: EvalNames[SQLLazyFrameT],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
implementation: Implementation = Implementation.DUCKDB,
|
||||
) -> None: ...
|
||||
|
||||
def __call__(self, df: SQLLazyFrameT) -> Sequence[NativeExprT]:
|
||||
return self._call(df)
|
||||
|
||||
def __narwhals_namespace__(
|
||||
self,
|
||||
) -> SQLNamespace[SQLLazyFrameT, Self, Any, NativeExprT]: ...
|
||||
|
||||
def _callable_to_eval_series(
|
||||
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
|
||||
) -> EvalSeries[SQLLazyFrameT, NativeExprT]:
|
||||
def func(df: SQLLazyFrameT) -> list[NativeExprT]:
|
||||
native_series_list = self(df)
|
||||
other_native_series = {
|
||||
key: df._evaluate_expr(value)
|
||||
if self._is_expr(value)
|
||||
else self._lit(value)
|
||||
for key, value in expressifiable_args.items()
|
||||
}
|
||||
return [
|
||||
call(native_series, **other_native_series)
|
||||
for native_series in native_series_list
|
||||
]
|
||||
|
||||
return func
|
||||
|
||||
def _push_down_window_function(
|
||||
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
|
||||
) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
|
||||
def window_f(
|
||||
df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
# If a function `f` is elementwise, and `g` is another function, then
|
||||
# - `f(g) over (window)`
|
||||
# - `f(g over (window))
|
||||
# are equivalent.
|
||||
# Make sure to only use with if `call` is elementwise!
|
||||
native_series_list = self.window_function(df, window_inputs)
|
||||
other_native_series = {
|
||||
key: df._evaluate_window_expr(value, window_inputs)
|
||||
if self._is_expr(value)
|
||||
else self._lit(value)
|
||||
for key, value in expressifiable_args.items()
|
||||
}
|
||||
return [
|
||||
call(native_series, **other_native_series)
|
||||
for native_series in native_series_list
|
||||
]
|
||||
|
||||
return window_f
|
||||
|
||||
def _with_window_function(
|
||||
self, window_function: WindowFunction[SQLLazyFrameT, NativeExprT]
|
||||
) -> Self:
|
||||
return self.__class__(
|
||||
self._call,
|
||||
window_function,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
def _with_callable(
|
||||
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
|
||||
) -> Self:
|
||||
return self.__class__(
|
||||
self._callable_to_eval_series(call, **expressifiable_args),
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
def _with_elementwise(
|
||||
self, call: Callable[..., NativeExprT], /, **expressifiable_args: Self | Any
|
||||
) -> Self:
|
||||
return self.__class__(
|
||||
self._callable_to_eval_series(call, **expressifiable_args),
|
||||
self._push_down_window_function(call, **expressifiable_args),
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
def _with_binary(self, op: Callable[..., NativeExprT], other: Self | Any) -> Self:
|
||||
return self.__class__(
|
||||
self._callable_to_eval_series(op, other=other),
|
||||
self._push_down_window_function(op, other=other),
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
|
||||
current_alias_output_names = self._alias_output_names
|
||||
alias_output_names = (
|
||||
None
|
||||
if func is None
|
||||
else func
|
||||
if current_alias_output_names is None
|
||||
else lambda output_names: func(current_alias_output_names(output_names))
|
||||
)
|
||||
return type(self)(
|
||||
self._call,
|
||||
self._window_function,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=alias_output_names,
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
@property
|
||||
def window_function(self) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
|
||||
def default_window_func(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
assert not inputs.order_by # noqa: S101
|
||||
return [
|
||||
self._window_expression(expr, inputs.partition_by, inputs.order_by)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._window_function or default_window_func
|
||||
|
||||
def _function(self, name: str, *args: NativeExprT | PythonLiteral) -> NativeExprT:
|
||||
return self.__narwhals_namespace__()._function(name, *args)
|
||||
|
||||
def _lit(self, value: Any) -> NativeExprT:
|
||||
return self.__narwhals_namespace__()._lit(value)
|
||||
|
||||
def _coalesce(self, *expr: NativeExprT) -> NativeExprT:
|
||||
return self.__narwhals_namespace__()._coalesce(*expr)
|
||||
|
||||
def _count_star(self) -> NativeExprT: ...
|
||||
|
||||
def _when(
|
||||
self,
|
||||
condition: NativeExprT,
|
||||
value: NativeExprT,
|
||||
otherwise: NativeExprT | None = None,
|
||||
) -> NativeExprT:
|
||||
return self.__narwhals_namespace__()._when(condition, value, otherwise)
|
||||
|
||||
def _window_expression(
|
||||
self,
|
||||
expr: NativeExprT,
|
||||
partition_by: Sequence[str | NativeExprT] = (),
|
||||
order_by: Sequence[str | NativeExprT] = (),
|
||||
rows_start: int | None = None,
|
||||
rows_end: int | None = None,
|
||||
*,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
) -> NativeExprT: ...
|
||||
|
||||
def _cum_window_func(
|
||||
self,
|
||||
func_name: Literal["sum", "max", "min", "count", "product"],
|
||||
*,
|
||||
reverse: bool,
|
||||
) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
|
||||
def func(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
return [
|
||||
self._window_expression(
|
||||
self._function(func_name, expr),
|
||||
inputs.partition_by,
|
||||
inputs.order_by,
|
||||
descending=[reverse] * len(inputs.order_by),
|
||||
nulls_last=[reverse] * len(inputs.order_by),
|
||||
rows_end=0,
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return func
|
||||
|
||||
def _rolling_window_func(
|
||||
self,
|
||||
func_name: Literal["sum", "mean", "std", "var"],
|
||||
window_size: int,
|
||||
min_samples: int,
|
||||
ddof: int | None = None,
|
||||
*,
|
||||
center: bool,
|
||||
) -> WindowFunction[SQLLazyFrameT, NativeExprT]:
|
||||
supported_funcs = ["sum", "mean", "std", "var"]
|
||||
if center:
|
||||
half = (window_size - 1) // 2
|
||||
remainder = (window_size - 1) % 2
|
||||
start = -(half + remainder)
|
||||
end = half
|
||||
else:
|
||||
start = -(window_size - 1)
|
||||
end = 0
|
||||
|
||||
def func(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
if func_name in {"sum", "mean"}:
|
||||
func_: str = func_name
|
||||
elif func_name == "var" and ddof == 0:
|
||||
func_ = "var_pop"
|
||||
elif func_name in "var" and ddof == 1:
|
||||
func_ = "var_samp"
|
||||
elif func_name == "std" and ddof == 0:
|
||||
func_ = "stddev_pop"
|
||||
elif func_name == "std" and ddof == 1:
|
||||
func_ = "stddev_samp"
|
||||
elif func_name in {"var", "std"}: # pragma: no cover
|
||||
msg = f"Only ddof=0 and ddof=1 are currently supported for rolling_{func_name}."
|
||||
raise ValueError(msg)
|
||||
else: # pragma: no cover
|
||||
msg = f"Only the following functions are supported: {supported_funcs}.\nGot: {func_name}."
|
||||
raise ValueError(msg)
|
||||
window_kwargs: Any = {
|
||||
"partition_by": inputs.partition_by,
|
||||
"order_by": inputs.order_by,
|
||||
"rows_start": start,
|
||||
"rows_end": end,
|
||||
}
|
||||
return [
|
||||
self._when(
|
||||
self._window_expression(
|
||||
self._function("count", expr), **window_kwargs
|
||||
)
|
||||
>= self._lit(min_samples),
|
||||
self._window_expression(self._function(func_, expr), **window_kwargs),
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return func
|
||||
|
||||
@classmethod
|
||||
def _is_expr(cls, obj: Self | Any) -> TypeIs[Self]:
|
||||
return hasattr(obj, "__narwhals_expr__")
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@classmethod
|
||||
def _alias_native(cls, expr: NativeExprT, name: str, /) -> NativeExprT: ...
|
||||
|
||||
@classmethod
|
||||
def _from_elementwise_horizontal_op(
|
||||
cls, func: Callable[[Iterable[NativeExprT]], NativeExprT], *exprs: Self
|
||||
) -> Self:
|
||||
def call(df: SQLLazyFrameT) -> Sequence[NativeExprT]:
|
||||
cols = (col for _expr in exprs for col in _expr(df))
|
||||
return [func(cols)]
|
||||
|
||||
def window_function(
|
||||
df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
cols = (
|
||||
col for _expr in exprs for col in _expr.window_function(df, window_inputs)
|
||||
)
|
||||
return [func(cols)]
|
||||
|
||||
context = exprs[0]
|
||||
return cls(
|
||||
call,
|
||||
window_function=window_function,
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=context._version,
|
||||
implementation=context._implementation,
|
||||
)
|
||||
|
||||
def _is_multi_output_unnamed(self) -> bool:
|
||||
"""Return `True` for multi-output aggregations without names.
|
||||
|
||||
For example, column `'a'` only appears in the output as a grouping key:
|
||||
|
||||
df.group_by('a').agg(nw.all().sum())
|
||||
|
||||
It does not get included in:
|
||||
|
||||
nw.all().sum().
|
||||
"""
|
||||
assert self._metadata is not None # noqa: S101
|
||||
return self._metadata.expansion_kind.is_multi_unnamed()
|
||||
|
||||
# Binary
|
||||
def __eq__(self, other: Self) -> Self: # type: ignore[override]
|
||||
return self._with_binary(lambda expr, other: expr.__eq__(other), other)
|
||||
|
||||
def __ne__(self, other: Self) -> Self: # type: ignore[override]
|
||||
return self._with_binary(lambda expr, other: expr.__ne__(other), other)
|
||||
|
||||
def __add__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__add__(other), other)
|
||||
|
||||
def __sub__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__sub__(other), other)
|
||||
|
||||
def __rsub__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: other - expr, other).alias("literal")
|
||||
|
||||
def __mul__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__mul__(other), other)
|
||||
|
||||
def __truediv__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__truediv__(other), other)
|
||||
|
||||
def __rtruediv__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: other / expr, other).alias("literal")
|
||||
|
||||
def __floordiv__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__floordiv__(other), other)
|
||||
|
||||
def __rfloordiv__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: other // expr, other).alias(
|
||||
"literal"
|
||||
)
|
||||
|
||||
def __pow__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__pow__(other), other)
|
||||
|
||||
def __rpow__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: other**expr, other).alias("literal")
|
||||
|
||||
def __mod__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__mod__(other), other)
|
||||
|
||||
def __rmod__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: other % expr, other).alias("literal")
|
||||
|
||||
def __ge__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__ge__(other), other)
|
||||
|
||||
def __gt__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__gt__(other), other)
|
||||
|
||||
def __le__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__le__(other), other)
|
||||
|
||||
def __lt__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__lt__(other), other)
|
||||
|
||||
def __and__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__and__(other), other)
|
||||
|
||||
def __or__(self, other: Self) -> Self:
|
||||
return self._with_binary(lambda expr, other: expr.__or__(other), other)
|
||||
|
||||
# Aggregations
|
||||
def all(self) -> Self:
|
||||
def f(expr: NativeExprT) -> NativeExprT:
|
||||
return self._coalesce(self._function("bool_and", expr), self._lit(True))
|
||||
|
||||
def window_f(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
return [
|
||||
self._coalesce(
|
||||
self._window_expression(
|
||||
self._function("bool_and", expr), inputs.partition_by
|
||||
),
|
||||
self._lit(True),
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_callable(f)._with_window_function(window_f)
|
||||
|
||||
def any(self) -> Self:
|
||||
def f(expr: NativeExprT) -> NativeExprT:
|
||||
return self._coalesce(self._function("bool_or", expr), self._lit(False))
|
||||
|
||||
def window_f(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
return [
|
||||
self._coalesce(
|
||||
self._window_expression(
|
||||
self._function("bool_or", expr), inputs.partition_by
|
||||
),
|
||||
self._lit(False),
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_callable(f)._with_window_function(window_f)
|
||||
|
||||
def max(self) -> Self:
|
||||
return self._with_callable(lambda expr: self._function("max", expr))
|
||||
|
||||
def mean(self) -> Self:
|
||||
return self._with_callable(lambda expr: self._function("mean", expr))
|
||||
|
||||
def median(self) -> Self:
|
||||
return self._with_callable(lambda expr: self._function("median", expr))
|
||||
|
||||
def fill_nan(self, value: float | None) -> Self:
|
||||
def _fill_nan(expr: NativeExprT) -> NativeExprT:
|
||||
return self._when(self._function("isnan", expr), self._lit(value), expr)
|
||||
|
||||
return self._with_elementwise(_fill_nan)
|
||||
|
||||
def min(self) -> Self:
|
||||
return self._with_callable(lambda expr: self._function("min", expr))
|
||||
|
||||
def count(self) -> Self:
|
||||
return self._with_callable(lambda expr: self._function("count", expr))
|
||||
|
||||
def sum(self) -> Self:
|
||||
def f(expr: NativeExprT) -> NativeExprT:
|
||||
return self._coalesce(self._function("sum", expr), self._lit(0))
|
||||
|
||||
def window_f(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
return [
|
||||
self._coalesce(
|
||||
self._window_expression(
|
||||
self._function("sum", expr), inputs.partition_by
|
||||
),
|
||||
self._lit(0),
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_callable(f)._with_window_function(window_f)
|
||||
|
||||
# Elementwise
|
||||
def abs(self) -> Self:
|
||||
return self._with_elementwise(lambda expr: self._function("abs", expr))
|
||||
|
||||
def clip(
|
||||
self,
|
||||
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
) -> Self:
|
||||
def _clip_lower(expr: NativeExprT, lower_bound: Any) -> NativeExprT:
|
||||
return self._function("greatest", expr, lower_bound)
|
||||
|
||||
def _clip_upper(expr: NativeExprT, upper_bound: Any) -> NativeExprT:
|
||||
return self._function("least", expr, upper_bound)
|
||||
|
||||
def _clip_both(
|
||||
expr: NativeExprT, lower_bound: Any, upper_bound: Any
|
||||
) -> NativeExprT:
|
||||
return self._function(
|
||||
"greatest", self._function("least", expr, upper_bound), lower_bound
|
||||
)
|
||||
|
||||
if lower_bound is None:
|
||||
return self._with_elementwise(_clip_upper, upper_bound=upper_bound)
|
||||
if upper_bound is None:
|
||||
return self._with_elementwise(_clip_lower, lower_bound=lower_bound)
|
||||
return self._with_elementwise(
|
||||
_clip_both, lower_bound=lower_bound, upper_bound=upper_bound
|
||||
)
|
||||
|
||||
def is_null(self) -> Self:
|
||||
return self._with_elementwise(lambda expr: self._function("isnull", expr))
|
||||
|
||||
def round(self, decimals: int) -> Self:
|
||||
return self._with_elementwise(
|
||||
lambda expr: self._function("round", expr, self._lit(decimals))
|
||||
)
|
||||
|
||||
def sqrt(self) -> Self:
|
||||
def _sqrt(expr: NativeExprT) -> NativeExprT:
|
||||
return self._when(
|
||||
expr < self._lit(0), self._lit(float("nan")), self._function("sqrt", expr)
|
||||
)
|
||||
|
||||
return self._with_elementwise(_sqrt)
|
||||
|
||||
def exp(self) -> Self:
|
||||
return self._with_elementwise(lambda expr: self._function("exp", expr))
|
||||
|
||||
def log(self, base: float) -> Self:
|
||||
def _log(expr: NativeExprT) -> NativeExprT:
|
||||
F = self._function
|
||||
return self._when(
|
||||
expr < self._lit(0),
|
||||
self._lit(float("nan")),
|
||||
self._when(
|
||||
expr == self._lit(0),
|
||||
self._lit(float("-inf")),
|
||||
op.truediv(F("log", expr), F("log", self._lit(base))),
|
||||
),
|
||||
)
|
||||
|
||||
return self._with_elementwise(_log)
|
||||
|
||||
# Cumulative
|
||||
def cum_sum(self, *, reverse: bool) -> Self:
|
||||
return self._with_window_function(self._cum_window_func("sum", reverse=reverse))
|
||||
|
||||
def cum_max(self, *, reverse: bool) -> Self:
|
||||
return self._with_window_function(self._cum_window_func("max", reverse=reverse))
|
||||
|
||||
def cum_min(self, *, reverse: bool) -> Self:
|
||||
return self._with_window_function(self._cum_window_func("min", reverse=reverse))
|
||||
|
||||
def cum_count(self, *, reverse: bool) -> Self:
|
||||
return self._with_window_function(self._cum_window_func("count", reverse=reverse))
|
||||
|
||||
def cum_prod(self, *, reverse: bool) -> Self:
|
||||
return self._with_window_function(
|
||||
self._cum_window_func("product", reverse=reverse)
|
||||
)
|
||||
|
||||
# Rolling
|
||||
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
return self._with_window_function(
|
||||
self._rolling_window_func("sum", window_size, min_samples, center=center)
|
||||
)
|
||||
|
||||
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
return self._with_window_function(
|
||||
self._rolling_window_func("mean", window_size, min_samples, center=center)
|
||||
)
|
||||
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
return self._with_window_function(
|
||||
self._rolling_window_func(
|
||||
"var", window_size, min_samples, ddof=ddof, center=center
|
||||
)
|
||||
)
|
||||
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
return self._with_window_function(
|
||||
self._rolling_window_func(
|
||||
"std", window_size, min_samples, ddof=ddof, center=center
|
||||
)
|
||||
)
|
||||
|
||||
# Other window functions
|
||||
def diff(self) -> Self:
|
||||
def func(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
F = self._function
|
||||
window = self._window_expression
|
||||
return [
|
||||
op.sub(expr, window(F("lag", expr), inputs.partition_by, inputs.order_by))
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_window_function(func)
|
||||
|
||||
def shift(self, n: int) -> Self:
|
||||
def func(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
return [
|
||||
self._window_expression(
|
||||
self._function("lag", expr, n), inputs.partition_by, inputs.order_by
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_window_function(func)
|
||||
|
||||
def is_first_distinct(self) -> Self:
|
||||
def func(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
# pyright checkers think the return type is `list[bool]` because of `==`
|
||||
return [
|
||||
self._window_expression(
|
||||
self._function("row_number"),
|
||||
(*inputs.partition_by, expr),
|
||||
inputs.order_by,
|
||||
)
|
||||
== self._lit(1)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_window_function(func)
|
||||
|
||||
def is_last_distinct(self) -> Self:
|
||||
def func(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
return [
|
||||
self._window_expression(
|
||||
self._function("row_number"),
|
||||
(*inputs.partition_by, expr),
|
||||
inputs.order_by,
|
||||
descending=[True] * len(inputs.order_by),
|
||||
nulls_last=[True] * len(inputs.order_by),
|
||||
)
|
||||
== self._lit(1)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_window_function(func)
|
||||
|
||||
def rank(self, method: RankMethod, *, descending: bool) -> Self:
|
||||
if method in {"min", "max", "average"}:
|
||||
func = self._function("rank")
|
||||
elif method == "dense":
|
||||
func = self._function("dense_rank")
|
||||
else: # method == "ordinal"
|
||||
func = self._function("row_number")
|
||||
|
||||
def _rank(
|
||||
expr: NativeExprT,
|
||||
partition_by: Sequence[str | NativeExprT] = (),
|
||||
order_by: Sequence[str | NativeExprT] = (),
|
||||
*,
|
||||
descending: Sequence[bool],
|
||||
nulls_last: Sequence[bool],
|
||||
) -> NativeExprT:
|
||||
count_expr = self._count_star()
|
||||
window_kwargs: dict[str, Any] = {
|
||||
"partition_by": partition_by,
|
||||
"order_by": (expr, *order_by),
|
||||
"descending": descending,
|
||||
"nulls_last": nulls_last,
|
||||
}
|
||||
count_window_kwargs: dict[str, Any] = {"partition_by": (*partition_by, expr)}
|
||||
window = self._window_expression
|
||||
F = self._function
|
||||
if method == "max":
|
||||
rank_expr = op.sub(
|
||||
op.add(
|
||||
window(func, **window_kwargs),
|
||||
window(count_expr, **count_window_kwargs),
|
||||
),
|
||||
self._lit(1),
|
||||
)
|
||||
elif method == "average":
|
||||
rank_expr = op.add(
|
||||
window(func, **window_kwargs),
|
||||
op.truediv(
|
||||
op.sub(window(count_expr, **count_window_kwargs), self._lit(1)),
|
||||
self._lit(2.0),
|
||||
),
|
||||
)
|
||||
else:
|
||||
rank_expr = window(func, **window_kwargs)
|
||||
return self._when(~F("isnull", expr), rank_expr) # type: ignore[operator]
|
||||
|
||||
def _unpartitioned_rank(expr: NativeExprT) -> NativeExprT:
|
||||
return _rank(expr, descending=[descending], nulls_last=[True])
|
||||
|
||||
def _partitioned_rank(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
# node: when `descending` / `nulls_last` are supported in `.over`, they should be respected here
|
||||
# https://github.com/narwhals-dev/narwhals/issues/2790
|
||||
return [
|
||||
_rank(
|
||||
expr,
|
||||
inputs.partition_by,
|
||||
inputs.order_by,
|
||||
descending=[descending] + [False] * len(inputs.order_by),
|
||||
nulls_last=[True] + [False] * len(inputs.order_by),
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_callable(_unpartitioned_rank)._with_window_function(
|
||||
_partitioned_rank
|
||||
)
|
||||
|
||||
def is_unique(self) -> Self:
|
||||
def _is_unique(
|
||||
expr: NativeExprT, *partition_by: str | NativeExprT
|
||||
) -> NativeExprT:
|
||||
return self._window_expression(
|
||||
self._count_star(), (expr, *partition_by)
|
||||
) == self._lit(1)
|
||||
|
||||
def _unpartitioned_is_unique(expr: NativeExprT) -> NativeExprT:
|
||||
return _is_unique(expr)
|
||||
|
||||
def _partitioned_is_unique(
|
||||
df: SQLLazyFrameT, inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
assert not inputs.order_by # noqa: S101
|
||||
return [_is_unique(expr, *inputs.partition_by) for expr in self(df)]
|
||||
|
||||
return self._with_callable(_unpartitioned_is_unique)._with_window_function(
|
||||
_partitioned_is_unique
|
||||
)
|
||||
|
||||
# Other
|
||||
def over(
|
||||
self, partition_by: Sequence[str | NativeExprT], order_by: Sequence[str]
|
||||
) -> Self:
|
||||
def func(df: SQLLazyFrameT) -> Sequence[NativeExprT]:
|
||||
return self.window_function(df, WindowInputs(partition_by, order_by))
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
def mode(self, *, keep: ModeKeepStrategy) -> Self:
|
||||
if keep != "any":
|
||||
msg = (
|
||||
f"`Expr.mode(keep='{keep}')` is not implemented for backend {self._implementation}\n\n"
|
||||
"Hint: Use `nw.col(...).mode(keep='any')` instead."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: self._function("mode", expr))
|
||||
|
||||
# Namespaces
|
||||
@property
|
||||
def str(self) -> SQLExprStringNamespace[Self]: ...
|
||||
|
||||
@property
|
||||
def dt(self) -> SQLExprDateTimeNamesSpace[Self]: ...
|
||||
|
||||
# Not implemented
|
||||
|
||||
arg_max: not_implemented = not_implemented()
|
||||
arg_min: not_implemented = not_implemented()
|
||||
arg_true: not_implemented = not_implemented()
|
||||
cat: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
drop_nulls: not_implemented = not_implemented()
|
||||
ewm_mean: not_implemented = not_implemented()
|
||||
gather_every: not_implemented = not_implemented()
|
||||
head: not_implemented = not_implemented()
|
||||
map_batches: not_implemented = not_implemented()
|
||||
replace_strict: not_implemented = not_implemented()
|
||||
sort: not_implemented = not_implemented()
|
||||
tail: not_implemented = not_implemented()
|
||||
sample: not_implemented = not_implemented()
|
||||
unique: not_implemented = not_implemented()
|
48
lib/python3.11/site-packages/narwhals/_sql/expr_dt.py
Normal file
48
lib/python3.11/site-packages/narwhals/_sql/expr_dt.py
Normal file
@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Generic
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import DateTimeNamespace
|
||||
from narwhals._sql.typing import SQLExprT
|
||||
|
||||
|
||||
class SQLExprDateTimeNamesSpace(
|
||||
LazyExprNamespace[SQLExprT], DateTimeNamespace[SQLExprT], Generic[SQLExprT]
|
||||
):
|
||||
def _function(self, name: str, *args: Any) -> SQLExprT:
|
||||
return self.compliant._function(name, *args) # type: ignore[no-any-return]
|
||||
|
||||
def year(self) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(lambda expr: self._function("year", expr))
|
||||
|
||||
def month(self) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function("month", expr)
|
||||
)
|
||||
|
||||
def day(self) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(lambda expr: self._function("day", expr))
|
||||
|
||||
def hour(self) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(lambda expr: self._function("hour", expr))
|
||||
|
||||
def minute(self) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function("minute", expr)
|
||||
)
|
||||
|
||||
def second(self) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function("second", expr)
|
||||
)
|
||||
|
||||
def ordinal_day(self) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function("dayofyear", expr)
|
||||
)
|
||||
|
||||
def date(self) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function("to_date", expr)
|
||||
)
|
138
lib/python3.11/site-packages/narwhals/_sql/expr_str.py
Normal file
138
lib/python3.11/site-packages/narwhals/_sql/expr_str.py
Normal file
@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Generic
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import StringNamespace
|
||||
from narwhals._sql.typing import SQLExprT
|
||||
|
||||
|
||||
class SQLExprStringNamespace(
|
||||
LazyExprNamespace[SQLExprT], StringNamespace[SQLExprT], Generic[SQLExprT]
|
||||
):
|
||||
def _lit(self, value: Any) -> SQLExprT:
|
||||
return self.compliant._lit(value) # type: ignore[no-any-return]
|
||||
|
||||
def _function(self, name: str, *args: Any) -> SQLExprT:
|
||||
return self.compliant._function(name, *args) # type: ignore[no-any-return]
|
||||
|
||||
def _when(self, condition: Any, value: Any, otherwise: Any | None = None) -> SQLExprT:
|
||||
return self.compliant._when(condition, value, otherwise) # type: ignore[no-any-return]
|
||||
|
||||
def contains(self, pattern: str, *, literal: bool) -> SQLExprT:
|
||||
def func(expr: Any) -> Any:
|
||||
if literal:
|
||||
return self._function("contains", expr, self._lit(pattern))
|
||||
return self._function("regexp_matches", expr, self._lit(pattern))
|
||||
|
||||
return self.compliant._with_elementwise(func)
|
||||
|
||||
def ends_with(self, suffix: str) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function("ends_with", expr, self._lit(suffix))
|
||||
)
|
||||
|
||||
def len_chars(self) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function("length", expr)
|
||||
)
|
||||
|
||||
def replace_all(
|
||||
self, pattern: str, value: str | SQLExprT, *, literal: bool
|
||||
) -> SQLExprT:
|
||||
fname: str = "replace" if literal else "regexp_replace"
|
||||
|
||||
options: list[Any] = []
|
||||
if not literal and self.compliant._implementation.is_duckdb():
|
||||
options = [self._lit("g")]
|
||||
|
||||
if isinstance(value, str):
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function(
|
||||
fname, expr, self._lit(pattern), self._lit(value), *options
|
||||
)
|
||||
)
|
||||
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr, value: self._function(
|
||||
fname, expr, self._lit(pattern), value, *options
|
||||
),
|
||||
value=value,
|
||||
)
|
||||
|
||||
def slice(self, offset: int, length: int | None) -> SQLExprT:
|
||||
def func(expr: SQLExprT) -> SQLExprT:
|
||||
col_length = self._function("length", expr)
|
||||
|
||||
_offset = (
|
||||
col_length + self._lit(offset + 1)
|
||||
if offset < 0
|
||||
else self._lit(offset + 1)
|
||||
)
|
||||
_length = self._lit(length) if length is not None else col_length
|
||||
return self._function("substr", expr, _offset, _length)
|
||||
|
||||
return self.compliant._with_elementwise(func)
|
||||
|
||||
def split(self, by: str) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function("str_split", expr, self._lit(by))
|
||||
)
|
||||
|
||||
def starts_with(self, prefix: str) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function("starts_with", expr, self._lit(prefix))
|
||||
)
|
||||
|
||||
def strip_chars(self, characters: str | None) -> SQLExprT:
|
||||
import string
|
||||
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function(
|
||||
"trim",
|
||||
expr,
|
||||
self._lit(string.whitespace if characters is None else characters),
|
||||
)
|
||||
)
|
||||
|
||||
def to_lowercase(self) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function("lower", expr)
|
||||
)
|
||||
|
||||
def to_uppercase(self) -> SQLExprT:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: self._function("upper", expr)
|
||||
)
|
||||
|
||||
def zfill(self, width: int) -> SQLExprT:
|
||||
# There is no built-in zfill function, so we need to implement it manually
|
||||
# using string manipulation functions.
|
||||
|
||||
def func(expr: Any) -> Any:
|
||||
less_than_width = self._function("length", expr) < self._lit(width)
|
||||
zero, hyphen, plus = self._lit("0"), self._lit("-"), self._lit("+")
|
||||
|
||||
starts_with_minus = self._function("starts_with", expr, hyphen)
|
||||
starts_with_plus = self._function("starts_with", expr, plus)
|
||||
substring = self._function("substr", expr, self._lit(2))
|
||||
padded_substring = self._function(
|
||||
"lpad", substring, self._lit(width - 1), zero
|
||||
)
|
||||
return self._when(
|
||||
starts_with_minus & less_than_width,
|
||||
self._function("concat", hyphen, padded_substring),
|
||||
self._when(
|
||||
starts_with_plus & less_than_width,
|
||||
self._function("concat", plus, padded_substring),
|
||||
self._when(
|
||||
less_than_width,
|
||||
self._function("lpad", expr, self._lit(width), zero),
|
||||
expr,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# can't use `_with_elementwise` due to `when` operator.
|
||||
# TODO(unassigned): implement `window_func` like we do in `Expr.cast`
|
||||
return self.compliant._with_callable(func)
|
45
lib/python3.11/site-packages/narwhals/_sql/group_by.py
Normal file
45
lib/python3.11/site-packages/narwhals/_sql/group_by.py
Normal file
@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from narwhals._compliant.group_by import CompliantGroupBy, ParseKeysGroupBy
|
||||
from narwhals._compliant.typing import CompliantLazyFrameT, NativeExprT_co
|
||||
from narwhals._sql.typing import SQLExprT_contra
|
||||
from narwhals._utils import zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator
|
||||
|
||||
|
||||
class SQLGroupBy(
|
||||
ParseKeysGroupBy[CompliantLazyFrameT, SQLExprT_contra],
|
||||
CompliantGroupBy[CompliantLazyFrameT, SQLExprT_contra],
|
||||
Protocol[CompliantLazyFrameT, SQLExprT_contra, NativeExprT_co],
|
||||
):
|
||||
_keys: list[str]
|
||||
_output_key_names: list[str]
|
||||
|
||||
def _evaluate_expr(self, expr: SQLExprT_contra, /) -> Iterator[NativeExprT_co]:
|
||||
output_names = expr._evaluate_output_names(self.compliant)
|
||||
aliases = (
|
||||
expr._alias_output_names(output_names)
|
||||
if expr._alias_output_names
|
||||
else output_names
|
||||
)
|
||||
native_exprs = expr(self.compliant)
|
||||
if expr._is_multi_output_unnamed():
|
||||
exclude = {*self._keys, *self._output_key_names}
|
||||
for native_expr, name, alias in zip_strict(
|
||||
native_exprs, output_names, aliases
|
||||
):
|
||||
if name not in exclude:
|
||||
yield expr._alias_native(native_expr, alias)
|
||||
else:
|
||||
for native_expr, alias in zip_strict(native_exprs, aliases):
|
||||
yield expr._alias_native(native_expr, alias)
|
||||
|
||||
def _evaluate_exprs(
|
||||
self, exprs: Iterable[SQLExprT_contra], /
|
||||
) -> Iterator[NativeExprT_co]:
|
||||
for expr in exprs:
|
||||
yield from self._evaluate_expr(expr)
|
73
lib/python3.11/site-packages/narwhals/_sql/namespace.py
Normal file
73
lib/python3.11/site-packages/narwhals/_sql/namespace.py
Normal file
@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from narwhals._compliant import LazyNamespace
|
||||
from narwhals._compliant.typing import NativeExprT, NativeFrameT_co
|
||||
from narwhals._sql.typing import SQLExprT, SQLLazyFrameT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from narwhals.typing import PythonLiteral
|
||||
|
||||
|
||||
class SQLNamespace(
|
||||
LazyNamespace[SQLLazyFrameT, SQLExprT, NativeFrameT_co],
|
||||
Protocol[SQLLazyFrameT, SQLExprT, NativeFrameT_co, NativeExprT],
|
||||
):
|
||||
def _function(self, name: str, *args: NativeExprT | PythonLiteral) -> NativeExprT: ...
|
||||
def _lit(self, value: Any) -> NativeExprT: ...
|
||||
def _when(
|
||||
self,
|
||||
condition: NativeExprT,
|
||||
value: NativeExprT,
|
||||
otherwise: NativeExprT | None = None,
|
||||
) -> NativeExprT: ...
|
||||
def _coalesce(self, *exprs: NativeExprT) -> NativeExprT: ...
|
||||
|
||||
# Horizontal functions
|
||||
def any_horizontal(self, *exprs: SQLExprT, ignore_nulls: bool) -> SQLExprT:
|
||||
def func(cols: Iterable[NativeExprT]) -> NativeExprT:
|
||||
if ignore_nulls:
|
||||
cols = (self._coalesce(col, self._lit(False)) for col in cols)
|
||||
return reduce(operator.or_, cols)
|
||||
|
||||
return self._expr._from_elementwise_horizontal_op(func, *exprs)
|
||||
|
||||
def all_horizontal(self, *exprs: SQLExprT, ignore_nulls: bool) -> SQLExprT:
|
||||
def func(cols: Iterable[NativeExprT]) -> NativeExprT:
|
||||
if ignore_nulls:
|
||||
cols = (self._coalesce(col, self._lit(True)) for col in cols)
|
||||
return reduce(operator.and_, cols)
|
||||
|
||||
return self._expr._from_elementwise_horizontal_op(func, *exprs)
|
||||
|
||||
def max_horizontal(self, *exprs: SQLExprT) -> SQLExprT:
|
||||
def func(cols: Iterable[NativeExprT]) -> NativeExprT:
|
||||
return self._function("greatest", *cols)
|
||||
|
||||
return self._expr._from_elementwise_horizontal_op(func, *exprs)
|
||||
|
||||
def min_horizontal(self, *exprs: SQLExprT) -> SQLExprT:
|
||||
def func(cols: Iterable[NativeExprT]) -> NativeExprT:
|
||||
return self._function("least", *cols)
|
||||
|
||||
return self._expr._from_elementwise_horizontal_op(func, *exprs)
|
||||
|
||||
def sum_horizontal(self, *exprs: SQLExprT) -> SQLExprT:
|
||||
def func(cols: Iterable[NativeExprT]) -> NativeExprT:
|
||||
return reduce(
|
||||
operator.add, (self._coalesce(col, self._lit(0)) for col in cols)
|
||||
)
|
||||
|
||||
return self._expr._from_elementwise_horizontal_op(func, *exprs)
|
||||
|
||||
# Other
|
||||
def coalesce(self, *exprs: SQLExprT) -> SQLExprT:
|
||||
def func(cols: Iterable[NativeExprT]) -> NativeExprT:
|
||||
return self._coalesce(*cols)
|
||||
|
||||
return self._expr._from_elementwise_horizontal_op(func, *exprs)
|
14
lib/python3.11/site-packages/narwhals/_sql/typing.py
Normal file
14
lib/python3.11/site-packages/narwhals/_sql/typing.py
Normal file
@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._sql.dataframe import SQLLazyFrame
|
||||
from narwhals._sql.expr import SQLExpr
|
||||
|
||||
SQLExprAny = SQLExpr[Any, Any]
|
||||
SQLLazyFrameAny = SQLLazyFrame[Any, Any, Any]
|
||||
|
||||
SQLExprT = TypeVar("SQLExprT", bound="SQLExprAny")
|
||||
SQLExprT_contra = TypeVar("SQLExprT_contra", bound="SQLExprAny", contravariant=True)
|
||||
SQLLazyFrameT = TypeVar("SQLLazyFrameT", bound="SQLLazyFrameAny")
|
106
lib/python3.11/site-packages/narwhals/_sql/when_then.py
Normal file
106
lib/python3.11/site-packages/narwhals/_sql/when_then.py
Normal file
@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from narwhals._compliant.typing import NativeExprT
|
||||
from narwhals._compliant.when_then import CompliantThen, CompliantWhen
|
||||
from narwhals._sql.typing import SQLExprT, SQLLazyFrameT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant.typing import WindowFunction
|
||||
from narwhals._compliant.when_then import IntoExpr
|
||||
from narwhals._compliant.window import WindowInputs
|
||||
from narwhals._utils import _LimitedContext
|
||||
|
||||
|
||||
class SQLWhen(
|
||||
CompliantWhen[SQLLazyFrameT, NativeExprT, SQLExprT],
|
||||
Protocol[SQLLazyFrameT, NativeExprT, SQLExprT],
|
||||
):
|
||||
@property
|
||||
def _then(self) -> type[SQLThen[SQLLazyFrameT, NativeExprT, SQLExprT]]: ...
|
||||
|
||||
def __call__(self, df: SQLLazyFrameT) -> Sequence[NativeExprT]:
|
||||
is_expr = self._condition._is_expr
|
||||
when = df.__narwhals_namespace__()._when
|
||||
lit = df.__narwhals_namespace__()._lit
|
||||
condition = df._evaluate_expr(self._condition)
|
||||
then_ = self._then_value
|
||||
then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_)
|
||||
other_ = self._otherwise_value
|
||||
if other_ is None:
|
||||
result = when(condition, then)
|
||||
else:
|
||||
otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_)
|
||||
result = when(condition, then).otherwise(otherwise)
|
||||
return [result]
|
||||
|
||||
@classmethod
|
||||
def from_expr(cls, condition: SQLExprT, /, *, context: _LimitedContext) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._condition = condition
|
||||
obj._then_value = None
|
||||
obj._otherwise_value = None
|
||||
obj._implementation = context._implementation
|
||||
obj._version = context._version
|
||||
return obj
|
||||
|
||||
def _window_function(
|
||||
self, df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
when = df.__narwhals_namespace__()._when
|
||||
lit = df.__narwhals_namespace__()._lit
|
||||
is_expr = self._condition._is_expr
|
||||
condition = self._condition.window_function(df, window_inputs)[0]
|
||||
then_ = self._then_value
|
||||
then = (
|
||||
then_.window_function(df, window_inputs)[0] if is_expr(then_) else lit(then_)
|
||||
)
|
||||
|
||||
other_ = self._otherwise_value
|
||||
if other_ is None:
|
||||
result = when(condition, then)
|
||||
else:
|
||||
other = (
|
||||
other_.window_function(df, window_inputs)[0]
|
||||
if is_expr(other_)
|
||||
else lit(other_)
|
||||
)
|
||||
result = when(condition, then).otherwise(other)
|
||||
return [result]
|
||||
|
||||
|
||||
class SQLThen(
|
||||
CompliantThen[
|
||||
SQLLazyFrameT,
|
||||
NativeExprT,
|
||||
SQLExprT,
|
||||
SQLWhen[SQLLazyFrameT, NativeExprT, SQLExprT],
|
||||
],
|
||||
Protocol[SQLLazyFrameT, NativeExprT, SQLExprT],
|
||||
):
|
||||
_window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None
|
||||
|
||||
@classmethod
|
||||
def from_when(
|
||||
cls,
|
||||
when: SQLWhen[SQLLazyFrameT, NativeExprT, SQLExprT],
|
||||
then: IntoExpr[NativeExprT, SQLExprT],
|
||||
/,
|
||||
) -> Self:
|
||||
when._then_value = then
|
||||
obj = cls.__new__(cls)
|
||||
obj._call = when
|
||||
obj._window_function = when._window_function
|
||||
obj._when_value = when
|
||||
obj._evaluate_output_names = getattr(
|
||||
then, "_evaluate_output_names", lambda _df: ["literal"]
|
||||
)
|
||||
obj._alias_output_names = getattr(then, "_alias_output_names", None)
|
||||
obj._implementation = when._implementation
|
||||
obj._version = when._version
|
||||
return obj
|
Reference in New Issue
Block a user