done
This commit is contained in:
106
lib/python3.11/site-packages/narwhals/_sql/when_then.py
Normal file
106
lib/python3.11/site-packages/narwhals/_sql/when_then.py
Normal file
@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from narwhals._compliant.typing import NativeExprT
|
||||
from narwhals._compliant.when_then import CompliantThen, CompliantWhen
|
||||
from narwhals._sql.typing import SQLExprT, SQLLazyFrameT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant.typing import WindowFunction
|
||||
from narwhals._compliant.when_then import IntoExpr
|
||||
from narwhals._compliant.window import WindowInputs
|
||||
from narwhals._utils import _LimitedContext
|
||||
|
||||
|
||||
class SQLWhen(
|
||||
CompliantWhen[SQLLazyFrameT, NativeExprT, SQLExprT],
|
||||
Protocol[SQLLazyFrameT, NativeExprT, SQLExprT],
|
||||
):
|
||||
@property
|
||||
def _then(self) -> type[SQLThen[SQLLazyFrameT, NativeExprT, SQLExprT]]: ...
|
||||
|
||||
def __call__(self, df: SQLLazyFrameT) -> Sequence[NativeExprT]:
|
||||
is_expr = self._condition._is_expr
|
||||
when = df.__narwhals_namespace__()._when
|
||||
lit = df.__narwhals_namespace__()._lit
|
||||
condition = df._evaluate_expr(self._condition)
|
||||
then_ = self._then_value
|
||||
then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_)
|
||||
other_ = self._otherwise_value
|
||||
if other_ is None:
|
||||
result = when(condition, then)
|
||||
else:
|
||||
otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_)
|
||||
result = when(condition, then).otherwise(otherwise)
|
||||
return [result]
|
||||
|
||||
@classmethod
|
||||
def from_expr(cls, condition: SQLExprT, /, *, context: _LimitedContext) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._condition = condition
|
||||
obj._then_value = None
|
||||
obj._otherwise_value = None
|
||||
obj._implementation = context._implementation
|
||||
obj._version = context._version
|
||||
return obj
|
||||
|
||||
def _window_function(
|
||||
self, df: SQLLazyFrameT, window_inputs: WindowInputs[NativeExprT]
|
||||
) -> Sequence[NativeExprT]:
|
||||
when = df.__narwhals_namespace__()._when
|
||||
lit = df.__narwhals_namespace__()._lit
|
||||
is_expr = self._condition._is_expr
|
||||
condition = self._condition.window_function(df, window_inputs)[0]
|
||||
then_ = self._then_value
|
||||
then = (
|
||||
then_.window_function(df, window_inputs)[0] if is_expr(then_) else lit(then_)
|
||||
)
|
||||
|
||||
other_ = self._otherwise_value
|
||||
if other_ is None:
|
||||
result = when(condition, then)
|
||||
else:
|
||||
other = (
|
||||
other_.window_function(df, window_inputs)[0]
|
||||
if is_expr(other_)
|
||||
else lit(other_)
|
||||
)
|
||||
result = when(condition, then).otherwise(other)
|
||||
return [result]
|
||||
|
||||
|
||||
class SQLThen(
|
||||
CompliantThen[
|
||||
SQLLazyFrameT,
|
||||
NativeExprT,
|
||||
SQLExprT,
|
||||
SQLWhen[SQLLazyFrameT, NativeExprT, SQLExprT],
|
||||
],
|
||||
Protocol[SQLLazyFrameT, NativeExprT, SQLExprT],
|
||||
):
|
||||
_window_function: WindowFunction[SQLLazyFrameT, NativeExprT] | None
|
||||
|
||||
@classmethod
|
||||
def from_when(
|
||||
cls,
|
||||
when: SQLWhen[SQLLazyFrameT, NativeExprT, SQLExprT],
|
||||
then: IntoExpr[NativeExprT, SQLExprT],
|
||||
/,
|
||||
) -> Self:
|
||||
when._then_value = then
|
||||
obj = cls.__new__(cls)
|
||||
obj._call = when
|
||||
obj._window_function = when._window_function
|
||||
obj._when_value = when
|
||||
obj._evaluate_output_names = getattr(
|
||||
then, "_evaluate_output_names", lambda _df: ["literal"]
|
||||
)
|
||||
obj._alias_output_names = getattr(then, "_alias_output_names", None)
|
||||
obj._implementation = when._implementation
|
||||
obj._version = when._version
|
||||
return obj
|
Reference in New Issue
Block a user