done
This commit is contained in:
542
lib/python3.11/site-packages/narwhals/_duckdb/dataframe.py
Normal file
542
lib/python3.11/site-packages/narwhals/_duckdb/dataframe.py
Normal file
@ -0,0 +1,542 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import reduce
|
||||
from operator import and_
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import duckdb
|
||||
from duckdb import StarExpression
|
||||
|
||||
from narwhals._duckdb.utils import (
|
||||
DeferredTimeZone,
|
||||
F,
|
||||
catch_duckdb_exception,
|
||||
col,
|
||||
evaluate_exprs,
|
||||
join_column_names,
|
||||
lit,
|
||||
native_to_narwhals_dtype,
|
||||
window_expression,
|
||||
)
|
||||
from narwhals._sql.dataframe import SQLLazyFrame
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
ValidateBackendVersion,
|
||||
Version,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
requires,
|
||||
zip_strict,
|
||||
)
|
||||
from narwhals.dependencies import get_duckdb
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from duckdb import Expression
|
||||
from duckdb.typing import DuckDBPyType
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
from narwhals._duckdb.group_by import DuckDBGroupBy
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
from narwhals._duckdb.series import DuckDBInterchangeSeries
|
||||
from narwhals._typing import _EagerAllowedImpl
|
||||
from narwhals._utils import _LimitedContext
|
||||
from narwhals.dataframe import LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.stable.v1 import DataFrame as DataFrameV1
|
||||
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
|
||||
|
||||
|
||||
class DuckDBLazyFrame(
|
||||
SQLLazyFrame[
|
||||
"DuckDBExpr",
|
||||
"duckdb.DuckDBPyRelation",
|
||||
"LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]",
|
||||
],
|
||||
ValidateBackendVersion,
|
||||
):
|
||||
_implementation = Implementation.DUCKDB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: duckdb.DuckDBPyRelation,
|
||||
*,
|
||||
version: Version,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
self._native_frame: duckdb.DuckDBPyRelation = df
|
||||
self._version = version
|
||||
self._cached_native_schema: dict[str, DuckDBPyType] | None = None
|
||||
self._cached_columns: list[str] | None = None
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: duckdb.DuckDBPyRelation | Any) -> TypeIs[duckdb.DuckDBPyRelation]:
|
||||
return isinstance(obj, duckdb.DuckDBPyRelation)
|
||||
|
||||
@classmethod
|
||||
def from_native(
|
||||
cls, data: duckdb.DuckDBPyRelation, /, *, context: _LimitedContext
|
||||
) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
def to_narwhals(
|
||||
self, *args: Any, **kwds: Any
|
||||
) -> LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]:
|
||||
if self._version is Version.V1:
|
||||
from narwhals.stable.v1 import DataFrame as DataFrameV1
|
||||
|
||||
return DataFrameV1(self, level="interchange") # type: ignore[no-any-return]
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __narwhals_dataframe__(self) -> Self: # pragma: no cover
|
||||
# Keep around for backcompat.
|
||||
if self._version is not Version.V1:
|
||||
msg = "__narwhals_dataframe__ is not implemented for DuckDBLazyFrame"
|
||||
raise AttributeError(msg)
|
||||
return self
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
return get_duckdb() # type: ignore[no-any-return]
|
||||
|
||||
def __narwhals_namespace__(self) -> DuckDBNamespace:
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
|
||||
return DuckDBNamespace(version=self._version)
|
||||
|
||||
def get_column(self, name: str) -> DuckDBInterchangeSeries:
|
||||
from narwhals._duckdb.series import DuckDBInterchangeSeries
|
||||
|
||||
return DuckDBInterchangeSeries(self.native.select(name), version=self._version)
|
||||
|
||||
def _iter_columns(self) -> Iterator[Expression]:
|
||||
for name in self.columns:
|
||||
yield col(name)
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
if backend is None or backend is Implementation.PYARROW:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
self.native.arrow(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
self.native.df(),
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
self.native.pl(), validate_backend_version=True, version=self._version
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
return self._with_native(self.native.limit(n))
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
return self._with_native(self.native.select(*column_names))
|
||||
|
||||
def aggregate(self, *exprs: DuckDBExpr) -> Self:
|
||||
selection = [val.alias(name) for name, val in evaluate_exprs(self, *exprs)]
|
||||
try:
|
||||
return self._with_native(self.native.aggregate(selection)) # type: ignore[arg-type]
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_duckdb_exception(e, self) from None
|
||||
|
||||
def select(self, *exprs: DuckDBExpr) -> Self:
|
||||
selection = (val.alias(name) for name, val in evaluate_exprs(self, *exprs))
|
||||
try:
|
||||
return self._with_native(self.native.select(*selection))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_duckdb_exception(e, self) from None
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
selection = (name for name in self.columns if name not in columns_to_drop)
|
||||
return self._with_native(self.native.select(*selection))
|
||||
|
||||
def lazy(self, backend: None = None, **_: None) -> Self:
|
||||
# The `backend`` argument has no effect but we keep it here for
|
||||
# backwards compatibility because in `narwhals.stable.v1`
|
||||
# function `.from_native()` will return a DataFrame for DuckDB.
|
||||
|
||||
if backend is not None: # pragma: no cover
|
||||
msg = "`backend` argument is not supported for DuckDB"
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
def with_columns(self, *exprs: DuckDBExpr) -> Self:
|
||||
new_columns_map = dict(evaluate_exprs(self, *exprs))
|
||||
result = [
|
||||
new_columns_map.pop(name).alias(name)
|
||||
if name in new_columns_map
|
||||
else col(name)
|
||||
for name in self.columns
|
||||
]
|
||||
result.extend(value.alias(name) for name, value in new_columns_map.items())
|
||||
try:
|
||||
return self._with_native(self.native.select(*result))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_duckdb_exception(e, self) from None
|
||||
|
||||
def filter(self, predicate: DuckDBExpr) -> Self:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
mask = predicate(self)[0]
|
||||
try:
|
||||
return self._with_native(self.native.filter(mask))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_duckdb_exception(e, self) from None
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
if self._cached_native_schema is None:
|
||||
# Note: prefer `self._cached_native_schema` over `functools.cached_property`
|
||||
# due to Python3.13 failures.
|
||||
self._cached_native_schema = dict(zip(self.columns, self.native.types))
|
||||
|
||||
deferred_time_zone = DeferredTimeZone(self.native)
|
||||
return {
|
||||
column_name: native_to_narwhals_dtype(
|
||||
duckdb_dtype, self._version, deferred_time_zone
|
||||
)
|
||||
for column_name, duckdb_dtype in zip_strict(
|
||||
self.native.columns, self.native.types
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
if self._cached_columns is None:
|
||||
self._cached_columns = (
|
||||
list(self.schema)
|
||||
if self._cached_native_schema is not None
|
||||
else self.native.columns
|
||||
)
|
||||
return self._cached_columns
|
||||
|
||||
def to_pandas(self) -> pd.DataFrame:
|
||||
# only if version is v1, keep around for backcompat
|
||||
return self.native.df()
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
# only if version is v1, keep around for backcompat
|
||||
return self.native.arrow()
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
def _with_native(self, df: duckdb.DuckDBPyRelation) -> Self:
|
||||
return self.__class__(df, version=self._version)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[DuckDBExpr], *, drop_null_keys: bool
|
||||
) -> DuckDBGroupBy:
|
||||
from narwhals._duckdb.group_by import DuckDBGroupBy
|
||||
|
||||
return DuckDBGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
df = self.native
|
||||
selection = (
|
||||
col(name).alias(mapping[name]) if name in mapping else col(name)
|
||||
for name in df.columns
|
||||
)
|
||||
return self._with_native(self.native.select(*selection))
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
native_how = "outer" if how == "full" else how
|
||||
|
||||
if native_how == "cross":
|
||||
if self._backend_version < (1, 1, 4):
|
||||
msg = f"'duckdb>=1.1.4' is required for cross-join, found version: {self._backend_version}"
|
||||
raise NotImplementedError(msg)
|
||||
rel = self.native.set_alias("lhs").cross(other.native.set_alias("rhs"))
|
||||
else:
|
||||
# help mypy
|
||||
assert left_on is not None # noqa: S101
|
||||
assert right_on is not None # noqa: S101
|
||||
it = (
|
||||
col(f'lhs."{left}"') == col(f'rhs."{right}"')
|
||||
for left, right in zip_strict(left_on, right_on)
|
||||
)
|
||||
condition: Expression = reduce(and_, it)
|
||||
rel = self.native.set_alias("lhs").join(
|
||||
other.native.set_alias("rhs"),
|
||||
# NOTE: Fixed in `--pre` https://github.com/duckdb/duckdb/pull/16933
|
||||
condition=condition, # type: ignore[arg-type, unused-ignore]
|
||||
how=native_how,
|
||||
)
|
||||
|
||||
if native_how in {"inner", "left", "cross", "outer"}:
|
||||
select = [col(f'lhs."{x}"') for x in self.columns]
|
||||
for name in other.columns:
|
||||
col_in_lhs: bool = name in self.columns
|
||||
if native_how == "outer" and not col_in_lhs:
|
||||
select.append(col(f'rhs."{name}"'))
|
||||
elif (native_how == "outer") or (
|
||||
col_in_lhs and (right_on is None or name not in right_on)
|
||||
):
|
||||
select.append(col(f'rhs."{name}"').alias(f"{name}{suffix}"))
|
||||
elif right_on is None or name not in right_on:
|
||||
select.append(col(name))
|
||||
res = rel.select(*select).set_alias(self.native.alias)
|
||||
else: # semi, anti
|
||||
res = rel.select("lhs.*").set_alias(self.native.alias)
|
||||
|
||||
return self._with_native(res)
|
||||
|
||||
def join_asof(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
left_on: str,
|
||||
right_on: str,
|
||||
by_left: Sequence[str] | None,
|
||||
by_right: Sequence[str] | None,
|
||||
strategy: AsofJoinStrategy,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
lhs = self.native
|
||||
rhs = other.native
|
||||
conditions: list[Expression] = []
|
||||
if by_left is not None and by_right is not None:
|
||||
conditions.extend(
|
||||
col(f'lhs."{left}"') == col(f'rhs."{right}"')
|
||||
for left, right in zip_strict(by_left, by_right)
|
||||
)
|
||||
else:
|
||||
by_left = by_right = []
|
||||
if strategy == "backward":
|
||||
conditions.append(col(f'lhs."{left_on}"') >= col(f'rhs."{right_on}"'))
|
||||
elif strategy == "forward":
|
||||
conditions.append(col(f'lhs."{left_on}"') <= col(f'rhs."{right_on}"'))
|
||||
else:
|
||||
msg = "Only 'backward' and 'forward' strategies are currently supported for DuckDB"
|
||||
raise NotImplementedError(msg)
|
||||
condition: Expression = reduce(and_, conditions)
|
||||
select = ["lhs.*"]
|
||||
for name in rhs.columns:
|
||||
if name in lhs.columns and (
|
||||
right_on is None or name not in {right_on, *by_right}
|
||||
):
|
||||
select.append(f'rhs."{name}" as "{name}{suffix}"')
|
||||
elif right_on is None or name not in {right_on, *by_right}:
|
||||
select.append(str(col(name)))
|
||||
# Replace with Python API call once
|
||||
# https://github.com/duckdb/duckdb/discussions/16947 is addressed.
|
||||
query = f"""
|
||||
SELECT {",".join(select)}
|
||||
FROM lhs
|
||||
ASOF LEFT JOIN rhs
|
||||
ON {condition}
|
||||
""" # noqa: S608
|
||||
return self._with_native(duckdb.sql(query))
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return self.schema
|
||||
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self:
|
||||
if subset_ := subset if keep == "any" else (subset or self.columns):
|
||||
# Sanitise input
|
||||
if error := self._check_columns_exist(subset_):
|
||||
raise error
|
||||
idx_name = generate_temporary_column_name(8, self.columns)
|
||||
count_name = generate_temporary_column_name(8, [*self.columns, idx_name])
|
||||
name = count_name if keep == "none" else idx_name
|
||||
idx_expr = window_expression(F("row_number"), subset_).alias(idx_name)
|
||||
count_expr = window_expression(
|
||||
F("count", StarExpression()), subset_, ()
|
||||
).alias(count_name)
|
||||
return self._with_native(
|
||||
self.native.select(StarExpression(), idx_expr, count_expr)
|
||||
.filter(col(name) == lit(1))
|
||||
.select(StarExpression(exclude=[count_name, idx_name]))
|
||||
)
|
||||
return self._with_native(self.native.unique(join_column_names(*self.columns)))
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
descending = [descending] * len(by)
|
||||
if nulls_last:
|
||||
it = (
|
||||
col(name).nulls_last() if not desc else col(name).desc().nulls_last()
|
||||
for name, desc in zip_strict(by, descending)
|
||||
)
|
||||
else:
|
||||
it = (
|
||||
col(name).nulls_first() if not desc else col(name).desc().nulls_first()
|
||||
for name, desc in zip_strict(by, descending)
|
||||
)
|
||||
return self._with_native(self.native.sort(*it))
|
||||
|
||||
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
|
||||
_df = self.native
|
||||
by = list(by)
|
||||
if isinstance(reverse, bool):
|
||||
descending = [not reverse] * len(by)
|
||||
else:
|
||||
descending = [not rev for rev in reverse]
|
||||
expr = window_expression(
|
||||
F("row_number"),
|
||||
order_by=by,
|
||||
descending=descending,
|
||||
nulls_last=[True] * len(by),
|
||||
)
|
||||
condition = expr <= lit(k)
|
||||
query = f"""
|
||||
SELECT *
|
||||
FROM _df
|
||||
QUALIFY {condition}
|
||||
""" # noqa: S608
|
||||
return self._with_native(duckdb.sql(query))
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
subset_ = subset if subset is not None else self.columns
|
||||
keep_condition = reduce(and_, (col(name).isnotnull() for name in subset_))
|
||||
return self._with_native(self.native.filter(keep_condition))
|
||||
|
||||
def explode(self, columns: Sequence[str]) -> Self:
|
||||
dtypes = self._version.dtypes
|
||||
schema = self.collect_schema()
|
||||
for name in columns:
|
||||
dtype = schema[name]
|
||||
if dtype != dtypes.List:
|
||||
msg = (
|
||||
f"`explode` operation not supported for dtype `{dtype}`, "
|
||||
"expected List type"
|
||||
)
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
if len(columns) != 1:
|
||||
msg = (
|
||||
"Exploding on multiple columns is not supported with DuckDB backend since "
|
||||
"we cannot guarantee that the exploded columns have matching element counts."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
col_to_explode = col(columns[0])
|
||||
rel = self.native
|
||||
original_columns = self.columns
|
||||
|
||||
not_null_condition = col_to_explode.isnotnull() & F("len", col_to_explode) > lit(
|
||||
0
|
||||
)
|
||||
non_null_rel = rel.filter(not_null_condition).select(
|
||||
*(
|
||||
F("unnest", col_to_explode).alias(name) if name in columns else name
|
||||
for name in original_columns
|
||||
)
|
||||
)
|
||||
|
||||
null_rel = rel.filter(~not_null_condition).select(
|
||||
*(
|
||||
lit(None).alias(name) if name in columns else name
|
||||
for name in original_columns
|
||||
)
|
||||
)
|
||||
|
||||
return self._with_native(non_null_rel.union(null_rel))
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
index_ = [] if index is None else index
|
||||
on_ = [c for c in self.columns if c not in index_] if on is None else on
|
||||
|
||||
if variable_name == "":
|
||||
msg = "`variable_name` cannot be empty string for duckdb backend."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if value_name == "":
|
||||
msg = "`value_name` cannot be empty string for duckdb backend."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
unpivot_on = join_column_names(*on_)
|
||||
rel = self.native # noqa: F841
|
||||
# Replace with Python API once
|
||||
# https://github.com/duckdb/duckdb/discussions/16980 is addressed.
|
||||
query = f"""
|
||||
unpivot rel
|
||||
on {unpivot_on}
|
||||
into
|
||||
name {col(variable_name)}
|
||||
value {col(value_name)}
|
||||
"""
|
||||
return self._with_native(
|
||||
duckdb.sql(query).select(*[*index_, variable_name, value_name])
|
||||
)
|
||||
|
||||
@requires.backend_version((1, 3))
|
||||
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
|
||||
if order_by is None:
|
||||
msg = "Cannot pass `order_by` to `with_row_index` for DuckDB"
|
||||
raise TypeError(msg)
|
||||
expr = (window_expression(F("row_number"), order_by=order_by) - lit(1)).alias(
|
||||
name
|
||||
)
|
||||
return self._with_native(self.native.select(expr, StarExpression()))
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
df = self.native # noqa: F841
|
||||
query = f"""
|
||||
COPY (SELECT * FROM df)
|
||||
TO '{file}'
|
||||
(FORMAT parquet)
|
||||
""" # noqa: S608
|
||||
duckdb.sql(query)
|
||||
|
||||
gather_every = not_implemented.deprecated(
|
||||
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
|
||||
)
|
||||
tail = not_implemented.deprecated(
|
||||
"`LazyFrame.tail` is deprecated and will be removed in a future version."
|
||||
)
|
303
lib/python3.11/site-packages/narwhals/_duckdb/expr.py
Normal file
303
lib/python3.11/site-packages/narwhals/_duckdb/expr.py
Normal file
@ -0,0 +1,303 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
|
||||
|
||||
from duckdb import CoalesceOperator, StarExpression
|
||||
|
||||
from narwhals._duckdb.expr_dt import DuckDBExprDateTimeNamespace
|
||||
from narwhals._duckdb.expr_list import DuckDBExprListNamespace
|
||||
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
|
||||
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
|
||||
from narwhals._duckdb.utils import (
|
||||
DeferredTimeZone,
|
||||
F,
|
||||
col,
|
||||
lit,
|
||||
narwhals_to_native_dtype,
|
||||
when,
|
||||
window_expression,
|
||||
)
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._sql.expr import SQLExpr
|
||||
from narwhals._utils import Implementation, Version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from duckdb import Expression
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant import WindowInputs
|
||||
from narwhals._compliant.typing import (
|
||||
AliasNames,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
WindowFunction,
|
||||
)
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
from narwhals._utils import _LimitedContext
|
||||
from narwhals.typing import (
|
||||
FillNullStrategy,
|
||||
IntoDType,
|
||||
NonNestedLiteral,
|
||||
RollingInterpolationMethod,
|
||||
)
|
||||
|
||||
DuckDBWindowFunction = WindowFunction[DuckDBLazyFrame, Expression]
|
||||
DuckDBWindowInputs = WindowInputs[Expression]
|
||||
|
||||
|
||||
class DuckDBExpr(SQLExpr["DuckDBLazyFrame", "Expression"]):
|
||||
_implementation = Implementation.DUCKDB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[DuckDBLazyFrame, Expression],
|
||||
window_function: DuckDBWindowFunction | None = None,
|
||||
*,
|
||||
evaluate_output_names: EvalNames[DuckDBLazyFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
implementation: Implementation = Implementation.DUCKDB,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._metadata: ExprMetadata | None = None
|
||||
self._window_function: DuckDBWindowFunction | None = window_function
|
||||
|
||||
def _count_star(self) -> Expression:
|
||||
return F("count", StarExpression())
|
||||
|
||||
def _window_expression(
|
||||
self,
|
||||
expr: Expression,
|
||||
partition_by: Sequence[str | Expression] = (),
|
||||
order_by: Sequence[str | Expression] = (),
|
||||
rows_start: int | None = None,
|
||||
rows_end: int | None = None,
|
||||
*,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
) -> Expression:
|
||||
return window_expression(
|
||||
expr,
|
||||
partition_by,
|
||||
order_by,
|
||||
rows_start,
|
||||
rows_end,
|
||||
descending=descending,
|
||||
nulls_last=nulls_last,
|
||||
)
|
||||
|
||||
def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
|
||||
return DuckDBNamespace(version=self._version)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
if kind is ExprKind.LITERAL:
|
||||
return self
|
||||
if self._backend_version < (1, 3):
|
||||
msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns."
|
||||
raise NotImplementedError(msg)
|
||||
return self.over([lit(1)], [])
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls,
|
||||
evaluate_column_names: EvalNames[DuckDBLazyFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
) -> Self:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
return [col(name) for name in evaluate_column_names(df)]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
columns = df.columns
|
||||
return [col(columns[i]) for i in column_indices]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _alias_native(cls, expr: Expression, name: str) -> Expression:
|
||||
return expr.alias(name)
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
invert = cast("Callable[..., Expression]", operator.invert)
|
||||
return self._with_elementwise(invert)
|
||||
|
||||
def skew(self) -> Self:
|
||||
def func(expr: Expression) -> Expression:
|
||||
count = F("count", expr)
|
||||
# Adjust population skewness by correction factor to get sample skewness
|
||||
sample_skewness = (
|
||||
F("skewness", expr)
|
||||
* (count - lit(2))
|
||||
/ F("sqrt", count * (count - lit(1)))
|
||||
)
|
||||
return when(count == lit(0), lit(None)).otherwise(
|
||||
when(count == lit(1), lit(float("nan"))).otherwise(
|
||||
when(count == lit(2), lit(0.0)).otherwise(sample_skewness)
|
||||
)
|
||||
)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def kurtosis(self) -> Self:
|
||||
return self._with_callable(lambda expr: F("kurtosis_pop", expr))
|
||||
|
||||
def quantile(
|
||||
self, quantile: float, interpolation: RollingInterpolationMethod
|
||||
) -> Self:
|
||||
def func(expr: Expression) -> Expression:
|
||||
if interpolation == "linear":
|
||||
return F("quantile_cont", expr, lit(quantile))
|
||||
msg = "Only linear interpolation methods are supported for DuckDB quantile."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def n_unique(self) -> Self:
|
||||
def func(expr: Expression) -> Expression:
|
||||
# https://stackoverflow.com/a/79338887/4451315
|
||||
return F("array_unique", F("array_agg", expr)) + F(
|
||||
"max", when(expr.isnotnull(), lit(0)).otherwise(lit(1))
|
||||
)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def len(self) -> Self:
|
||||
return self._with_callable(lambda _expr: F("count"))
|
||||
|
||||
def std(self, ddof: int) -> Self:
|
||||
if ddof == 0:
|
||||
return self._with_callable(lambda expr: F("stddev_pop", expr))
|
||||
if ddof == 1:
|
||||
return self._with_callable(lambda expr: F("stddev_samp", expr))
|
||||
|
||||
def _std(expr: Expression) -> Expression:
|
||||
n_samples = F("count", expr)
|
||||
return (
|
||||
F("stddev_pop", expr)
|
||||
* F("sqrt", n_samples)
|
||||
/ (F("sqrt", (n_samples - lit(ddof))))
|
||||
)
|
||||
|
||||
return self._with_callable(_std)
|
||||
|
||||
def var(self, ddof: int) -> Self:
|
||||
if ddof == 0:
|
||||
return self._with_callable(lambda expr: F("var_pop", expr))
|
||||
if ddof == 1:
|
||||
return self._with_callable(lambda expr: F("var_samp", expr))
|
||||
|
||||
def _var(expr: Expression) -> Expression:
|
||||
n_samples = F("count", expr)
|
||||
return F("var_pop", expr) * n_samples / (n_samples - lit(ddof))
|
||||
|
||||
return self._with_callable(_var)
|
||||
|
||||
def null_count(self) -> Self:
|
||||
return self._with_callable(lambda expr: F("sum", expr.isnull().cast("int")))
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
return self._with_elementwise(lambda expr: F("isnan", expr))
|
||||
|
||||
def is_finite(self) -> Self:
|
||||
return self._with_elementwise(lambda expr: F("isfinite", expr))
|
||||
|
||||
def is_in(self, other: Sequence[Any]) -> Self:
|
||||
return self._with_elementwise(lambda expr: F("contains", lit(other), expr))
|
||||
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self:
|
||||
if strategy is not None:
|
||||
if self._backend_version < (1, 3): # pragma: no cover
|
||||
msg = f"`fill_null` with `strategy={strategy}` is only available in 'duckdb>=1.3.0'."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def _fill_with_strategy(
|
||||
df: DuckDBLazyFrame, inputs: DuckDBWindowInputs
|
||||
) -> Sequence[Expression]:
|
||||
fill_func = "last_value" if strategy == "forward" else "first_value"
|
||||
rows_start, rows_end = (
|
||||
(-limit if limit is not None else None, 0)
|
||||
if strategy == "forward"
|
||||
else (0, limit)
|
||||
)
|
||||
return [
|
||||
window_expression(
|
||||
F(fill_func, expr),
|
||||
inputs.partition_by,
|
||||
inputs.order_by,
|
||||
rows_start=rows_start,
|
||||
rows_end=rows_end,
|
||||
ignore_nulls=True,
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_window_function(_fill_with_strategy)
|
||||
|
||||
def _fill_constant(expr: Expression, value: Any) -> Expression:
|
||||
return CoalesceOperator(expr, value)
|
||||
|
||||
return self._with_elementwise(_fill_constant, value=value)
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
tz = DeferredTimeZone(df.native)
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
|
||||
return [expr.cast(native_dtype) for expr in self(df)]
|
||||
|
||||
def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
|
||||
tz = DeferredTimeZone(df.native)
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
|
||||
return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
window_f,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
@property
|
||||
def str(self) -> DuckDBExprStringNamespace:
|
||||
return DuckDBExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def dt(self) -> DuckDBExprDateTimeNamespace:
|
||||
return DuckDBExprDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def list(self) -> DuckDBExprListNamespace:
|
||||
return DuckDBExprListNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> DuckDBExprStructNamespace:
|
||||
return DuckDBExprStructNamespace(self)
|
132
lib/python3.11/site-packages/narwhals/_duckdb/expr_dt.py
Normal file
132
lib/python3.11/site-packages/narwhals/_duckdb/expr_dt.py
Normal file
@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._constants import (
|
||||
MS_PER_MINUTE,
|
||||
MS_PER_SECOND,
|
||||
NS_PER_SECOND,
|
||||
SECONDS_PER_MINUTE,
|
||||
US_PER_MINUTE,
|
||||
US_PER_SECOND,
|
||||
)
|
||||
from narwhals._duckdb.utils import UNITS_DICT, F, fetch_rel_time_zone, lit
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from duckdb import Expression
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
|
||||
|
||||
class DuckDBExprDateTimeNamespace(SQLExprDateTimeNamesSpace["DuckDBExpr"]):
|
||||
def millisecond(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("millisecond", expr) - F("second", expr) * lit(MS_PER_SECOND)
|
||||
)
|
||||
|
||||
def microsecond(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("microsecond", expr) - F("second", expr) * lit(US_PER_SECOND)
|
||||
)
|
||||
|
||||
def nanosecond(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("nanosecond", expr) - F("second", expr) * lit(NS_PER_SECOND)
|
||||
)
|
||||
|
||||
def to_string(self, format: str) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("strftime", expr, lit(format))
|
||||
)
|
||||
|
||||
def weekday(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(lambda expr: F("isodow", expr))
|
||||
|
||||
def date(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(lambda expr: expr.cast("date"))
|
||||
|
||||
def total_minutes(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("datepart", lit("minute"), expr)
|
||||
)
|
||||
|
||||
def total_seconds(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: lit(SECONDS_PER_MINUTE) * F("datepart", lit("minute"), expr)
|
||||
+ F("datepart", lit("second"), expr)
|
||||
)
|
||||
|
||||
def total_milliseconds(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: lit(MS_PER_MINUTE) * F("datepart", lit("minute"), expr)
|
||||
+ F("datepart", lit("millisecond"), expr)
|
||||
)
|
||||
|
||||
def total_microseconds(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: lit(US_PER_MINUTE) * F("datepart", lit("minute"), expr)
|
||||
+ F("datepart", lit("microsecond"), expr)
|
||||
)
|
||||
|
||||
def truncate(self, every: str) -> DuckDBExpr:
|
||||
interval = Interval.parse(every)
|
||||
multiple, unit = interval.multiple, interval.unit
|
||||
if multiple != 1:
|
||||
# https://github.com/duckdb/duckdb/issues/17554
|
||||
msg = f"Only multiple 1 is currently supported for DuckDB.\nGot {multiple!s}."
|
||||
raise ValueError(msg)
|
||||
if unit == "ns":
|
||||
msg = "Truncating to nanoseconds is not yet supported for DuckDB."
|
||||
raise NotImplementedError(msg)
|
||||
format = lit(UNITS_DICT[unit])
|
||||
|
||||
def _truncate(expr: Expression) -> Expression:
|
||||
return F("date_trunc", format, expr)
|
||||
|
||||
return self.compliant._with_elementwise(_truncate)
|
||||
|
||||
def offset_by(self, by: str) -> DuckDBExpr:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
format = lit(f"{interval.multiple!s} {UNITS_DICT[interval.unit]}")
|
||||
|
||||
def _offset_by(expr: Expression) -> Expression:
|
||||
return F("date_add", format, expr)
|
||||
|
||||
return self.compliant._with_callable(_offset_by)
|
||||
|
||||
def _no_op_time_zone(self, time_zone: str) -> DuckDBExpr:
|
||||
def func(df: DuckDBLazyFrame) -> Sequence[Expression]:
|
||||
native_series_list = self.compliant(df)
|
||||
conn_time_zone = fetch_rel_time_zone(df.native)
|
||||
if conn_time_zone != time_zone:
|
||||
msg = (
|
||||
"DuckDB stores the time zone in the connection, rather than in the "
|
||||
f"data type, so changing the timezone to anything other than {conn_time_zone} "
|
||||
" (the current connection time zone) is not supported."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return native_series_list
|
||||
|
||||
return self.compliant.__class__(
|
||||
func,
|
||||
evaluate_output_names=self.compliant._evaluate_output_names,
|
||||
alias_output_names=self.compliant._alias_output_names,
|
||||
version=self.compliant._version,
|
||||
)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> DuckDBExpr:
|
||||
return self._no_op_time_zone(time_zone)
|
||||
|
||||
def replace_time_zone(self, time_zone: str | None) -> DuckDBExpr:
|
||||
if time_zone is None:
|
||||
return self.compliant._with_elementwise(lambda expr: expr.cast("timestamp"))
|
||||
return self._no_op_time_zone(time_zone)
|
||||
|
||||
total_nanoseconds = not_implemented()
|
||||
timestamp = not_implemented()
|
40
lib/python3.11/site-packages/narwhals/_duckdb/expr_list.py
Normal file
40
lib/python3.11/site-packages/narwhals/_duckdb/expr_list.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import ListNamespace
|
||||
from narwhals._duckdb.utils import F, lit, when
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from duckdb import Expression
|
||||
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
from narwhals.typing import NonNestedLiteral
|
||||
|
||||
|
||||
class DuckDBExprListNamespace(
|
||||
LazyExprNamespace["DuckDBExpr"], ListNamespace["DuckDBExpr"]
|
||||
):
|
||||
def len(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(lambda expr: F("len", expr))
|
||||
|
||||
def unique(self) -> DuckDBExpr:
|
||||
def func(expr: Expression) -> Expression:
|
||||
expr_distinct = F("list_distinct", expr)
|
||||
return when(
|
||||
F("array_position", expr, lit(None)).isnotnull(),
|
||||
F("list_append", expr_distinct, lit(None)),
|
||||
).otherwise(expr_distinct)
|
||||
|
||||
return self.compliant._with_callable(func)
|
||||
|
||||
def contains(self, item: NonNestedLiteral) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("list_contains", expr, lit(item))
|
||||
)
|
||||
|
||||
def get(self, index: int) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("list_extract", expr, lit(index + 1))
|
||||
)
|
30
lib/python3.11/site-packages/narwhals/_duckdb/expr_str.py
Normal file
30
lib/python3.11/site-packages/narwhals/_duckdb/expr_str.py
Normal file
@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._duckdb.utils import F, lit
|
||||
from narwhals._sql.expr_str import SQLExprStringNamespace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
|
||||
|
||||
class DuckDBExprStringNamespace(SQLExprStringNamespace["DuckDBExpr"]):
|
||||
def to_datetime(self, format: str | None) -> DuckDBExpr:
|
||||
if format is None:
|
||||
msg = "Cannot infer format with DuckDB backend, please specify `format` explicitly."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("strptime", expr, lit(format))
|
||||
)
|
||||
|
||||
def to_date(self, format: str | None) -> DuckDBExpr:
|
||||
if format is not None:
|
||||
return self.to_datetime(format=format).dt.date()
|
||||
|
||||
compliant_expr = self.compliant
|
||||
return compliant_expr.cast(compliant_expr._version.dtypes.Date())
|
||||
|
||||
replace = not_implemented()
|
19
lib/python3.11/site-packages/narwhals/_duckdb/expr_struct.py
Normal file
19
lib/python3.11/site-packages/narwhals/_duckdb/expr_struct.py
Normal file
@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import StructNamespace
|
||||
from narwhals._duckdb.utils import F, lit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
|
||||
|
||||
class DuckDBExprStructNamespace(
|
||||
LazyExprNamespace["DuckDBExpr"], StructNamespace["DuckDBExpr"]
|
||||
):
|
||||
def field(self, name: str) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("struct_extract", expr, lit(name))
|
||||
).alias(name)
|
33
lib/python3.11/site-packages/narwhals/_duckdb/group_by.py
Normal file
33
lib/python3.11/site-packages/narwhals/_duckdb/group_by.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._sql.group_by import SQLGroupBy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from duckdb import Expression # noqa: F401
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
|
||||
|
||||
class DuckDBGroupBy(SQLGroupBy["DuckDBLazyFrame", "DuckDBExpr", "Expression"]):
|
||||
def __init__(
|
||||
self,
|
||||
df: DuckDBLazyFrame,
|
||||
keys: Sequence[DuckDBExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
|
||||
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
|
||||
|
||||
def agg(self, *exprs: DuckDBExpr) -> DuckDBLazyFrame:
|
||||
agg_columns = list(chain(self._keys, self._evaluate_exprs(exprs)))
|
||||
return self.compliant._with_native(
|
||||
self.compliant.native.aggregate(agg_columns) # type: ignore[arg-type]
|
||||
).rename(dict(zip(self._keys, self._output_key_names)))
|
164
lib/python3.11/site-packages/narwhals/_duckdb/namespace.py
Normal file
164
lib/python3.11/site-packages/narwhals/_duckdb/namespace.py
Normal file
@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from duckdb import CoalesceOperator, Expression
|
||||
from duckdb.typing import BIGINT, VARCHAR
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
from narwhals._duckdb.selectors import DuckDBSelectorNamespace
|
||||
from narwhals._duckdb.utils import (
|
||||
DeferredTimeZone,
|
||||
F,
|
||||
concat_str,
|
||||
function,
|
||||
lit,
|
||||
narwhals_to_native_dtype,
|
||||
when,
|
||||
)
|
||||
from narwhals._expression_parsing import (
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._sql.namespace import SQLNamespace
|
||||
from narwhals._sql.when_then import SQLThen, SQLWhen
|
||||
from narwhals._utils import Implementation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from duckdb import DuckDBPyRelation # noqa: F401
|
||||
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral
|
||||
|
||||
|
||||
class DuckDBNamespace(
|
||||
SQLNamespace[DuckDBLazyFrame, DuckDBExpr, "DuckDBPyRelation", Expression]
|
||||
):
|
||||
_implementation: Implementation = Implementation.DUCKDB
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
@property
|
||||
def selectors(self) -> DuckDBSelectorNamespace:
|
||||
return DuckDBSelectorNamespace.from_namespace(self)
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[DuckDBExpr]:
|
||||
return DuckDBExpr
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[DuckDBLazyFrame]:
|
||||
return DuckDBLazyFrame
|
||||
|
||||
def _function(self, name: str, *args: Expression) -> Expression: # type: ignore[override]
|
||||
return function(name, *args)
|
||||
|
||||
def _lit(self, value: Any) -> Expression:
|
||||
return lit(value)
|
||||
|
||||
def _when(
|
||||
self,
|
||||
condition: Expression,
|
||||
value: Expression,
|
||||
otherwise: Expression | None = None,
|
||||
) -> Expression:
|
||||
if otherwise is None:
|
||||
return when(condition, value)
|
||||
return when(condition, value).otherwise(otherwise)
|
||||
|
||||
def _coalesce(self, *exprs: Expression) -> Expression:
|
||||
return CoalesceOperator(*exprs)
|
||||
|
||||
def concat(
|
||||
self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod
|
||||
) -> DuckDBLazyFrame:
|
||||
native_items = [item._native_frame for item in items]
|
||||
items = list(items)
|
||||
first = items[0]
|
||||
schema = first.schema
|
||||
if how == "vertical" and not all(x.schema == schema for x in items[1:]):
|
||||
msg = "inputs should all have the same schema"
|
||||
raise TypeError(msg)
|
||||
res = reduce(lambda x, y: x.union(y), native_items)
|
||||
return first._with_native(res)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: DuckDBExpr, separator: str, ignore_nulls: bool
|
||||
) -> DuckDBExpr:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
cols = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
if not ignore_nulls:
|
||||
null_mask_result = reduce(operator.or_, (s.isnull() for s in cols))
|
||||
cols_separated = [
|
||||
y
|
||||
for x in [
|
||||
(col.cast(VARCHAR),)
|
||||
if i == len(cols) - 1
|
||||
else (col.cast(VARCHAR), lit(separator))
|
||||
for i, col in enumerate(cols)
|
||||
]
|
||||
for y in x
|
||||
]
|
||||
return [when(~null_mask_result, concat_str(*cols_separated))]
|
||||
return [concat_str(*cols, separator=separator)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def mean_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
|
||||
def func(cols: Iterable[Expression]) -> Expression:
|
||||
cols = list(cols)
|
||||
return reduce(
|
||||
operator.add, (CoalesceOperator(col, lit(0)) for col in cols)
|
||||
) / reduce(operator.add, (col.isnotnull().cast(BIGINT) for col in cols))
|
||||
|
||||
return self._expr._from_elementwise_horizontal_op(func, *exprs)
|
||||
|
||||
def when(self, predicate: DuckDBExpr) -> DuckDBWhen:
|
||||
return DuckDBWhen.from_expr(predicate, context=self)
|
||||
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DuckDBExpr:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
tz = DeferredTimeZone(df.native)
|
||||
if dtype is not None:
|
||||
target = narwhals_to_native_dtype(dtype, self._version, tz)
|
||||
return [lit(value).cast(target)]
|
||||
return [lit(value)]
|
||||
|
||||
return self._expr(
|
||||
func,
|
||||
evaluate_output_names=lambda _df: ["literal"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def len(self) -> DuckDBExpr:
|
||||
def func(_df: DuckDBLazyFrame) -> list[Expression]:
|
||||
return [F("count")]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
evaluate_output_names=lambda _df: ["len"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
|
||||
class DuckDBWhen(SQLWhen["DuckDBLazyFrame", Expression, DuckDBExpr]):
|
||||
@property
|
||||
def _then(self) -> type[DuckDBThen]:
|
||||
return DuckDBThen
|
||||
|
||||
|
||||
class DuckDBThen(SQLThen["DuckDBLazyFrame", Expression, DuckDBExpr], DuckDBExpr): ...
|
33
lib/python3.11/site-packages/narwhals/_duckdb/selectors.py
Normal file
33
lib/python3.11/site-packages/narwhals/_duckdb/selectors.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from duckdb import Expression # noqa: F401
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame # noqa: F401
|
||||
from narwhals._duckdb.expr import DuckDBWindowFunction
|
||||
|
||||
|
||||
class DuckDBSelectorNamespace(LazySelectorNamespace["DuckDBLazyFrame", "Expression"]):
|
||||
@property
|
||||
def _selector(self) -> type[DuckDBSelector]:
|
||||
return DuckDBSelector
|
||||
|
||||
|
||||
class DuckDBSelector( # type: ignore[misc]
|
||||
CompliantSelector["DuckDBLazyFrame", "Expression"], DuckDBExpr
|
||||
):
|
||||
_window_function: DuckDBWindowFunction | None = None
|
||||
|
||||
def _to_expr(self) -> DuckDBExpr:
|
||||
return DuckDBExpr(
|
||||
self._call,
|
||||
self._window_function,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
44
lib/python3.11/site-packages/narwhals/_duckdb/series.py
Normal file
44
lib/python3.11/site-packages/narwhals/_duckdb/series.py
Normal file
@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._duckdb.utils import DeferredTimeZone, native_to_narwhals_dtype
|
||||
from narwhals.dependencies import get_duckdb
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
import duckdb
|
||||
from typing_extensions import Never, Self
|
||||
|
||||
from narwhals._utils import Version
|
||||
from narwhals.dtypes import DType
|
||||
|
||||
|
||||
class DuckDBInterchangeSeries:
|
||||
def __init__(self, df: duckdb.DuckDBPyRelation, version: Version) -> None:
|
||||
self._native_series = df
|
||||
self._version = version
|
||||
|
||||
def __narwhals_series__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
return get_duckdb() # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
return native_to_narwhals_dtype(
|
||||
self._native_series.types[0],
|
||||
self._version,
|
||||
DeferredTimeZone(self._native_series),
|
||||
)
|
||||
|
||||
def __getattr__(self, attr: str) -> Never:
|
||||
msg = ( # pragma: no cover
|
||||
f"Attribute {attr} is not supported for interchange-level dataframes.\n\n"
|
||||
"If you would like to see this kind of object better supported in "
|
||||
"Narwhals, please open a feature request "
|
||||
"at https://github.com/narwhals-dev/narwhals/issues."
|
||||
)
|
||||
raise NotImplementedError(msg) # pragma: no cover
|
18
lib/python3.11/site-packages/narwhals/_duckdb/typing.py
Normal file
18
lib/python3.11/site-packages/narwhals/_duckdb/typing.py
Normal file
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from duckdb import Expression
|
||||
|
||||
|
||||
class WindowExpressionKwargs(TypedDict, total=False):
|
||||
partition_by: Sequence[str | Expression]
|
||||
order_by: Sequence[str | Expression]
|
||||
rows_start: int | None
|
||||
rows_end: int | None
|
||||
descending: Sequence[bool]
|
||||
nulls_last: Sequence[bool]
|
||||
ignore_nulls: bool
|
370
lib/python3.11/site-packages/narwhals/_duckdb/utils.py
Normal file
370
lib/python3.11/site-packages/narwhals/_duckdb/utils.py
Normal file
@ -0,0 +1,370 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import duckdb
|
||||
import duckdb.typing as duckdb_dtypes
|
||||
from duckdb.typing import DuckDBPyType
|
||||
|
||||
from narwhals._utils import Version, isinstance_or_issubclass, zip_strict
|
||||
from narwhals.exceptions import ColumnNotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from duckdb import DuckDBPyRelation, Expression
|
||||
|
||||
from narwhals._compliant.typing import CompliantLazyFrameAny
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType, TimeUnit
|
||||
|
||||
|
||||
UNITS_DICT = {
|
||||
"y": "year",
|
||||
"q": "quarter",
|
||||
"mo": "month",
|
||||
"d": "day",
|
||||
"h": "hour",
|
||||
"m": "minute",
|
||||
"s": "second",
|
||||
"ms": "millisecond",
|
||||
"us": "microsecond",
|
||||
"ns": "nanosecond",
|
||||
}
|
||||
DESCENDING_TO_ORDER = {True: "desc", False: "asc"}
|
||||
NULLS_LAST_TO_NULLS_POS = {True: "nulls last", False: "nulls first"}
|
||||
|
||||
col = duckdb.ColumnExpression
|
||||
"""Alias for `duckdb.ColumnExpression`."""
|
||||
|
||||
lit = duckdb.ConstantExpression
|
||||
"""Alias for `duckdb.ConstantExpression`."""
|
||||
|
||||
when = duckdb.CaseExpression
|
||||
"""Alias for `duckdb.CaseExpression`."""
|
||||
|
||||
F = duckdb.FunctionExpression
|
||||
"""Alias for `duckdb.FunctionExpression`."""
|
||||
|
||||
|
||||
def concat_str(*exprs: Expression, separator: str = "") -> Expression:
|
||||
"""Concatenate many strings, NULL inputs are skipped.
|
||||
|
||||
Wraps [concat] and [concat_ws] `FunctionExpression`(s).
|
||||
|
||||
Arguments:
|
||||
exprs: Native columns.
|
||||
separator: String that will be used to separate the values of each column.
|
||||
|
||||
Returns:
|
||||
A new native expression.
|
||||
|
||||
[concat]: https://duckdb.org/docs/stable/sql/functions/char.html#concatstring-
|
||||
[concat_ws]: https://duckdb.org/docs/stable/sql/functions/char.html#concat_wsseparator-string-
|
||||
"""
|
||||
return F("concat_ws", lit(separator), *exprs) if separator else F("concat", *exprs)
|
||||
|
||||
|
||||
def evaluate_exprs(
|
||||
df: DuckDBLazyFrame, /, *exprs: DuckDBExpr
|
||||
) -> list[tuple[str, Expression]]:
|
||||
native_results: list[tuple[str, Expression]] = []
|
||||
for expr in exprs:
|
||||
native_series_list = expr._call(df)
|
||||
output_names = expr._evaluate_output_names(df)
|
||||
if expr._alias_output_names is not None:
|
||||
output_names = expr._alias_output_names(output_names)
|
||||
if len(output_names) != len(native_series_list): # pragma: no cover
|
||||
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
|
||||
raise AssertionError(msg)
|
||||
native_results.extend(zip(output_names, native_series_list))
|
||||
return native_results
|
||||
|
||||
|
||||
class DeferredTimeZone:
|
||||
"""Object which gets passed between `native_to_narwhals_dtype` calls.
|
||||
|
||||
DuckDB stores the time zone in the connection, rather than in the dtypes, so
|
||||
this ensures that when calculating the schema of a dataframe with multiple
|
||||
timezone-aware columns, that the connection's time zone is only fetched once.
|
||||
|
||||
Note: we cannot make the time zone a cached `DuckDBLazyFrame` property because
|
||||
the time zone can be modified after `DuckDBLazyFrame` creation:
|
||||
|
||||
```python
|
||||
df = nw.from_native(rel)
|
||||
print(df.collect_schema())
|
||||
rel.query("set timezone = 'Asia/Kolkata'")
|
||||
print(df.collect_schema()) # should change to reflect new time zone
|
||||
```
|
||||
"""
|
||||
|
||||
_cached_time_zone: str | None = None
|
||||
|
||||
def __init__(self, rel: DuckDBPyRelation) -> None:
|
||||
self._rel = rel
|
||||
|
||||
@property
|
||||
def time_zone(self) -> str:
|
||||
"""Fetch relation time zone (if it wasn't calculated already)."""
|
||||
if self._cached_time_zone is None:
|
||||
self._cached_time_zone = fetch_rel_time_zone(self._rel)
|
||||
return self._cached_time_zone
|
||||
|
||||
|
||||
def native_to_narwhals_dtype(
|
||||
duckdb_dtype: DuckDBPyType, version: Version, deferred_time_zone: DeferredTimeZone
|
||||
) -> DType:
|
||||
duckdb_dtype_id = duckdb_dtype.id
|
||||
dtypes = version.dtypes
|
||||
|
||||
# Handle nested data types first
|
||||
if duckdb_dtype_id == "list":
|
||||
return dtypes.List(
|
||||
native_to_narwhals_dtype(duckdb_dtype.child, version, deferred_time_zone)
|
||||
)
|
||||
|
||||
if duckdb_dtype_id == "struct":
|
||||
children = duckdb_dtype.children
|
||||
return dtypes.Struct(
|
||||
[
|
||||
dtypes.Field(
|
||||
name=child[0],
|
||||
dtype=native_to_narwhals_dtype(child[1], version, deferred_time_zone),
|
||||
)
|
||||
for child in children
|
||||
]
|
||||
)
|
||||
|
||||
if duckdb_dtype_id == "array":
|
||||
child, size = duckdb_dtype.children
|
||||
shape: list[int] = [size[1]]
|
||||
|
||||
while child[1].id == "array":
|
||||
child, size = child[1].children
|
||||
shape.insert(0, size[1])
|
||||
|
||||
inner = native_to_narwhals_dtype(child[1], version, deferred_time_zone)
|
||||
return dtypes.Array(inner=inner, shape=tuple(shape))
|
||||
|
||||
if duckdb_dtype_id == "enum":
|
||||
if version is Version.V1:
|
||||
return dtypes.Enum() # type: ignore[call-arg]
|
||||
categories = duckdb_dtype.children[0][1]
|
||||
return dtypes.Enum(categories=categories)
|
||||
|
||||
if duckdb_dtype_id == "timestamp with time zone":
|
||||
return dtypes.Datetime(time_zone=deferred_time_zone.time_zone)
|
||||
|
||||
return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version)
|
||||
|
||||
|
||||
def fetch_rel_time_zone(rel: duckdb.DuckDBPyRelation) -> str:
|
||||
result = rel.query(
|
||||
"duckdb_settings()", "select value from duckdb_settings() where name = 'TimeZone'"
|
||||
).fetchone()
|
||||
assert result is not None # noqa: S101
|
||||
return result[0] # type: ignore[no-any-return]
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) -> DType:
|
||||
dtypes = version.dtypes
|
||||
return {
|
||||
"hugeint": dtypes.Int128(),
|
||||
"bigint": dtypes.Int64(),
|
||||
"integer": dtypes.Int32(),
|
||||
"smallint": dtypes.Int16(),
|
||||
"tinyint": dtypes.Int8(),
|
||||
"uhugeint": dtypes.UInt128(),
|
||||
"ubigint": dtypes.UInt64(),
|
||||
"uinteger": dtypes.UInt32(),
|
||||
"usmallint": dtypes.UInt16(),
|
||||
"utinyint": dtypes.UInt8(),
|
||||
"double": dtypes.Float64(),
|
||||
"float": dtypes.Float32(),
|
||||
"varchar": dtypes.String(),
|
||||
"date": dtypes.Date(),
|
||||
"timestamp_s": dtypes.Datetime("s"),
|
||||
"timestamp_ms": dtypes.Datetime("ms"),
|
||||
"timestamp": dtypes.Datetime(),
|
||||
"timestamp_ns": dtypes.Datetime("ns"),
|
||||
"boolean": dtypes.Boolean(),
|
||||
"interval": dtypes.Duration(),
|
||||
"decimal": dtypes.Decimal(),
|
||||
"time": dtypes.Time(),
|
||||
"blob": dtypes.Binary(),
|
||||
}.get(duckdb_dtype_id, dtypes.Unknown())
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_DUCKDB_DTYPES: Mapping[type[DType], DuckDBPyType] = {
|
||||
dtypes.Float64: duckdb_dtypes.DOUBLE,
|
||||
dtypes.Float32: duckdb_dtypes.FLOAT,
|
||||
dtypes.Binary: duckdb_dtypes.BLOB,
|
||||
dtypes.String: duckdb_dtypes.VARCHAR,
|
||||
dtypes.Boolean: duckdb_dtypes.BOOLEAN,
|
||||
dtypes.Date: duckdb_dtypes.DATE,
|
||||
dtypes.Time: duckdb_dtypes.TIME,
|
||||
dtypes.Int8: duckdb_dtypes.TINYINT,
|
||||
dtypes.Int16: duckdb_dtypes.SMALLINT,
|
||||
dtypes.Int32: duckdb_dtypes.INTEGER,
|
||||
dtypes.Int64: duckdb_dtypes.BIGINT,
|
||||
dtypes.Int128: DuckDBPyType("INT128"),
|
||||
dtypes.UInt8: duckdb_dtypes.UTINYINT,
|
||||
dtypes.UInt16: duckdb_dtypes.USMALLINT,
|
||||
dtypes.UInt32: duckdb_dtypes.UINTEGER,
|
||||
dtypes.UInt64: duckdb_dtypes.UBIGINT,
|
||||
dtypes.UInt128: DuckDBPyType("UINT128"),
|
||||
}
|
||||
TIME_UNIT_TO_TIMESTAMP: Mapping[TimeUnit, DuckDBPyType] = {
|
||||
"s": duckdb_dtypes.TIMESTAMP_S,
|
||||
"ms": duckdb_dtypes.TIMESTAMP_MS,
|
||||
"us": duckdb_dtypes.TIMESTAMP,
|
||||
"ns": duckdb_dtypes.TIMESTAMP_NS,
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (dtypes.Decimal, dtypes.Categorical)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype( # noqa: PLR0912, C901
|
||||
dtype: IntoDType, version: Version, deferred_time_zone: DeferredTimeZone
|
||||
) -> DuckDBPyType:
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if duckdb_type := NW_TO_DUCKDB_DTYPES.get(base_type):
|
||||
return duckdb_type
|
||||
if isinstance_or_issubclass(dtype, dtypes.Enum):
|
||||
if version is Version.V1:
|
||||
msg = "Converting to Enum is not supported in narwhals.stable.v1"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance(dtype, dtypes.Enum):
|
||||
return DuckDBPyType(f"ENUM{dtype.categories!r}")
|
||||
msg = "Can not cast / initialize Enum without categories present"
|
||||
raise ValueError(msg)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
tu = dtype.time_unit
|
||||
tz = dtype.time_zone
|
||||
if not tz:
|
||||
return TIME_UNIT_TO_TIMESTAMP[tu]
|
||||
if tu != "us":
|
||||
msg = f"Only microsecond precision is supported for timezone-aware `Datetime` in DuckDB, got {tu} precision"
|
||||
raise ValueError(msg)
|
||||
if tz != (rel_tz := deferred_time_zone.time_zone): # pragma: no cover
|
||||
msg = f"Only the connection time zone {rel_tz} is supported, got: {tz}."
|
||||
raise ValueError(msg)
|
||||
# TODO(unassigned): cover once https://github.com/narwhals-dev/narwhals/issues/2742 addressed
|
||||
return duckdb_dtypes.TIMESTAMP_TZ # pragma: no cover
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
if (tu := dtype.time_unit) != "us": # pragma: no cover
|
||||
msg = f"Only microsecond-precision Duration is supported, got {tu} precision"
|
||||
return duckdb_dtypes.INTERVAL
|
||||
if isinstance_or_issubclass(dtype, dtypes.List):
|
||||
inner = narwhals_to_native_dtype(dtype.inner, version, deferred_time_zone)
|
||||
return duckdb.list_type(inner)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct):
|
||||
fields = {
|
||||
field.name: narwhals_to_native_dtype(field.dtype, version, deferred_time_zone)
|
||||
for field in dtype.fields
|
||||
}
|
||||
return duckdb.struct_type(fields)
|
||||
if isinstance(dtype, dtypes.Array):
|
||||
nw_inner: IntoDType = dtype
|
||||
while isinstance(nw_inner, dtypes.Array):
|
||||
nw_inner = nw_inner.inner
|
||||
duckdb_inner = narwhals_to_native_dtype(nw_inner, version, deferred_time_zone)
|
||||
duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape)
|
||||
return DuckDBPyType(f"{duckdb_inner}{duckdb_shape_fmt}")
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES):
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for DuckDB."
|
||||
raise NotImplementedError(msg)
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def parse_into_expression(into_expression: str | Expression) -> Expression:
|
||||
return col(into_expression) if isinstance(into_expression, str) else into_expression
|
||||
|
||||
|
||||
def generate_partition_by_sql(*partition_by: str | Expression) -> str:
|
||||
if not partition_by:
|
||||
return ""
|
||||
by_sql = ", ".join([f"{parse_into_expression(x)}" for x in partition_by])
|
||||
return f"partition by {by_sql}"
|
||||
|
||||
|
||||
def join_column_names(*names: str) -> str:
|
||||
return ", ".join(str(col(name)) for name in names)
|
||||
|
||||
|
||||
def generate_order_by_sql(
|
||||
*order_by: str | Expression, descending: Sequence[bool], nulls_last: Sequence[bool]
|
||||
) -> str:
|
||||
if not order_by:
|
||||
return ""
|
||||
by_sql = ",".join(
|
||||
f"{parse_into_expression(x)} {DESCENDING_TO_ORDER[_descending]} {NULLS_LAST_TO_NULLS_POS[_nulls_last]}"
|
||||
for x, _descending, _nulls_last in zip_strict(order_by, descending, nulls_last)
|
||||
)
|
||||
return f"order by {by_sql}"
|
||||
|
||||
|
||||
def window_expression(
|
||||
expr: Expression,
|
||||
partition_by: Sequence[str | Expression] = (),
|
||||
order_by: Sequence[str | Expression] = (),
|
||||
rows_start: int | None = None,
|
||||
rows_end: int | None = None,
|
||||
*,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
ignore_nulls: bool = False,
|
||||
) -> Expression:
|
||||
# TODO(unassigned): Replace with `duckdb.WindowExpression` when they release it.
|
||||
# https://github.com/duckdb/duckdb/discussions/14725#discussioncomment-11200348
|
||||
try:
|
||||
from duckdb import SQLExpression
|
||||
except ModuleNotFoundError as exc: # pragma: no cover
|
||||
msg = f"DuckDB>=1.3.0 is required for this operation. Found: DuckDB {duckdb.__version__}"
|
||||
raise NotImplementedError(msg) from exc
|
||||
pb = generate_partition_by_sql(*partition_by)
|
||||
descending = descending or [False] * len(order_by)
|
||||
nulls_last = nulls_last or [False] * len(order_by)
|
||||
ob = generate_order_by_sql(*order_by, descending=descending, nulls_last=nulls_last)
|
||||
|
||||
if rows_start is not None and rows_end is not None:
|
||||
rows = f"rows between {-rows_start} preceding and {rows_end} following"
|
||||
elif rows_end is not None:
|
||||
rows = f"rows between unbounded preceding and {rows_end} following"
|
||||
elif rows_start is not None:
|
||||
rows = f"rows between {-rows_start} preceding and unbounded following"
|
||||
else:
|
||||
rows = ""
|
||||
|
||||
func = f"{str(expr).removesuffix(')')} ignore nulls)" if ignore_nulls else str(expr)
|
||||
return SQLExpression(f"{func} over ({pb} {ob} {rows})")
|
||||
|
||||
|
||||
def catch_duckdb_exception(
|
||||
exception: Exception, frame: CompliantLazyFrameAny, /
|
||||
) -> ColumnNotFoundError | Exception:
|
||||
if isinstance(exception, duckdb.BinderException) and any(
|
||||
msg in str(exception)
|
||||
for msg in (
|
||||
"not found in FROM clause",
|
||||
"this column cannot be referenced before it is defined",
|
||||
)
|
||||
):
|
||||
return ColumnNotFoundError.from_available_column_names(
|
||||
available_columns=frame.columns
|
||||
)
|
||||
# Just return exception as-is.
|
||||
return exception
|
||||
|
||||
|
||||
def function(name: str, *args: Expression) -> Expression:
|
||||
if name == "isnull":
|
||||
return args[0].isnull()
|
||||
return F(name, *args)
|
Reference in New Issue
Block a user