Files
2025-09-07 22:09:54 +02:00

139 lines
5.1 KiB
Python

from __future__ import annotations
from typing import Any, Generic
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import StringNamespace
from narwhals._sql.typing import SQLExprT
class SQLExprStringNamespace(
LazyExprNamespace[SQLExprT], StringNamespace[SQLExprT], Generic[SQLExprT]
):
def _lit(self, value: Any) -> SQLExprT:
return self.compliant._lit(value) # type: ignore[no-any-return]
def _function(self, name: str, *args: Any) -> SQLExprT:
return self.compliant._function(name, *args) # type: ignore[no-any-return]
def _when(self, condition: Any, value: Any, otherwise: Any | None = None) -> SQLExprT:
return self.compliant._when(condition, value, otherwise) # type: ignore[no-any-return]
def contains(self, pattern: str, *, literal: bool) -> SQLExprT:
def func(expr: Any) -> Any:
if literal:
return self._function("contains", expr, self._lit(pattern))
return self._function("regexp_matches", expr, self._lit(pattern))
return self.compliant._with_elementwise(func)
def ends_with(self, suffix: str) -> SQLExprT:
return self.compliant._with_elementwise(
lambda expr: self._function("ends_with", expr, self._lit(suffix))
)
def len_chars(self) -> SQLExprT:
return self.compliant._with_elementwise(
lambda expr: self._function("length", expr)
)
def replace_all(
self, pattern: str, value: str | SQLExprT, *, literal: bool
) -> SQLExprT:
fname: str = "replace" if literal else "regexp_replace"
options: list[Any] = []
if not literal and self.compliant._implementation.is_duckdb():
options = [self._lit("g")]
if isinstance(value, str):
return self.compliant._with_elementwise(
lambda expr: self._function(
fname, expr, self._lit(pattern), self._lit(value), *options
)
)
return self.compliant._with_elementwise(
lambda expr, value: self._function(
fname, expr, self._lit(pattern), value, *options
),
value=value,
)
def slice(self, offset: int, length: int | None) -> SQLExprT:
def func(expr: SQLExprT) -> SQLExprT:
col_length = self._function("length", expr)
_offset = (
col_length + self._lit(offset + 1)
if offset < 0
else self._lit(offset + 1)
)
_length = self._lit(length) if length is not None else col_length
return self._function("substr", expr, _offset, _length)
return self.compliant._with_elementwise(func)
def split(self, by: str) -> SQLExprT:
return self.compliant._with_elementwise(
lambda expr: self._function("str_split", expr, self._lit(by))
)
def starts_with(self, prefix: str) -> SQLExprT:
return self.compliant._with_elementwise(
lambda expr: self._function("starts_with", expr, self._lit(prefix))
)
def strip_chars(self, characters: str | None) -> SQLExprT:
import string
return self.compliant._with_elementwise(
lambda expr: self._function(
"trim",
expr,
self._lit(string.whitespace if characters is None else characters),
)
)
def to_lowercase(self) -> SQLExprT:
return self.compliant._with_elementwise(
lambda expr: self._function("lower", expr)
)
def to_uppercase(self) -> SQLExprT:
return self.compliant._with_elementwise(
lambda expr: self._function("upper", expr)
)
def zfill(self, width: int) -> SQLExprT:
# There is no built-in zfill function, so we need to implement it manually
# using string manipulation functions.
def func(expr: Any) -> Any:
less_than_width = self._function("length", expr) < self._lit(width)
zero, hyphen, plus = self._lit("0"), self._lit("-"), self._lit("+")
starts_with_minus = self._function("starts_with", expr, hyphen)
starts_with_plus = self._function("starts_with", expr, plus)
substring = self._function("substr", expr, self._lit(2))
padded_substring = self._function(
"lpad", substring, self._lit(width - 1), zero
)
return self._when(
starts_with_minus & less_than_width,
self._function("concat", hyphen, padded_substring),
self._when(
starts_with_plus & less_than_width,
self._function("concat", plus, padded_substring),
self._when(
less_than_width,
self._function("lpad", expr, self._lit(width), zero),
expr,
),
),
)
# can't use `_with_elementwise` due to `when` operator.
# TODO(unassigned): implement `window_func` like we do in `Expr.cast`
return self.compliant._with_callable(func)