done
This commit is contained in:
@ -0,0 +1 @@
|
||||
# ! Any change to this module will trigger the pyspark and pyspark-connect tests in CI
|
601
lib/python3.11/site-packages/narwhals/_spark_like/dataframe.py
Normal file
601
lib/python3.11/site-packages/narwhals/_spark_like/dataframe.py
Normal file
@ -0,0 +1,601 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import reduce
|
||||
from operator import and_
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from narwhals._exceptions import issue_warning
|
||||
from narwhals._namespace import is_native_spark_like
|
||||
from narwhals._spark_like.utils import (
|
||||
catch_pyspark_connect_exception,
|
||||
catch_pyspark_sql_exception,
|
||||
evaluate_exprs,
|
||||
import_functions,
|
||||
import_native_dtypes,
|
||||
import_window,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._sql.dataframe import SQLLazyFrame
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
ValidateBackendVersion,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
zip_strict,
|
||||
)
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pyarrow as pa
|
||||
from sqlframe.base.column import Column
|
||||
from sqlframe.base.dataframe import BaseDataFrame
|
||||
from sqlframe.base.window import Window
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
|
||||
from narwhals._spark_like.namespace import SparkLikeNamespace
|
||||
from narwhals._spark_like.utils import SparkSession
|
||||
from narwhals._typing import _EagerAllowedImpl
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dataframe import LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import JoinStrategy, LazyUniqueKeepStrategy
|
||||
|
||||
SQLFrameDataFrame = BaseDataFrame[Any, Any, Any, Any, Any]
|
||||
|
||||
Incomplete: TypeAlias = Any # pragma: no cover
|
||||
"""Marker for working code that fails type checking."""
|
||||
|
||||
|
||||
class SparkLikeLazyFrame(
|
||||
SQLLazyFrame["SparkLikeExpr", "SQLFrameDataFrame", "LazyFrame[SQLFrameDataFrame]"],
|
||||
ValidateBackendVersion,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
native_dataframe: SQLFrameDataFrame,
|
||||
*,
|
||||
version: Version,
|
||||
implementation: Implementation,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
self._native_frame: SQLFrameDataFrame = native_dataframe
|
||||
self._implementation = implementation
|
||||
self._version = version
|
||||
self._cached_schema: dict[str, DType] | None = None
|
||||
self._cached_columns: list[str] | None = None
|
||||
if validate_backend_version: # pragma: no cover
|
||||
self._validate_backend_version()
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]: # pragma: no cover
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base import functions
|
||||
|
||||
return functions
|
||||
return import_functions(self._implementation)
|
||||
|
||||
@property
|
||||
def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base import types
|
||||
|
||||
return types
|
||||
return import_native_dtypes(self._implementation)
|
||||
|
||||
@property
|
||||
def _Window(self) -> type[Window]:
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base.window import Window
|
||||
|
||||
return Window
|
||||
return import_window(self._implementation)
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: SQLFrameDataFrame | Any) -> TypeIs[SQLFrameDataFrame]:
|
||||
return is_native_spark_like(obj)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: SQLFrameDataFrame, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version, implementation=context._implementation)
|
||||
|
||||
def to_narwhals(self) -> LazyFrame[SQLFrameDataFrame]:
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __native_namespace__(self) -> ModuleType: # pragma: no cover
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
def __narwhals_namespace__(self) -> SparkLikeNamespace:
|
||||
from narwhals._spark_like.namespace import SparkLikeNamespace
|
||||
|
||||
return SparkLikeNamespace(
|
||||
version=self._version, implementation=self._implementation
|
||||
)
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(
|
||||
self.native, version=version, implementation=self._implementation
|
||||
)
|
||||
|
||||
def _with_native(self, df: SQLFrameDataFrame) -> Self:
|
||||
return self.__class__(
|
||||
df, version=self._version, implementation=self._implementation
|
||||
)
|
||||
|
||||
def _to_arrow_schema(self) -> pa.Schema: # pragma: no cover
|
||||
import pyarrow as pa # ignore-banned-import
|
||||
|
||||
from narwhals._arrow.utils import narwhals_to_native_dtype
|
||||
|
||||
schema: list[tuple[str, pa.DataType]] = []
|
||||
nw_schema = self.collect_schema()
|
||||
native_schema = self.native.schema
|
||||
for key, value in nw_schema.items():
|
||||
try:
|
||||
native_dtype = narwhals_to_native_dtype(value, self._version)
|
||||
except Exception as exc: # noqa: BLE001,PERF203
|
||||
native_spark_dtype = native_schema[key].dataType # type: ignore[index]
|
||||
# If we can't convert the type, just set it to `pa.null`, and warn.
|
||||
# Avoid the warning if we're starting from PySpark's void type.
|
||||
# We can avoid the check when we introduce `nw.Null` dtype.
|
||||
null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue]
|
||||
if not isinstance(native_spark_dtype, null_type):
|
||||
issue_warning(
|
||||
f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}",
|
||||
UserWarning,
|
||||
)
|
||||
schema.append((key, pa.null()))
|
||||
else:
|
||||
schema.append((key, native_dtype))
|
||||
return pa.schema(schema)
|
||||
|
||||
def _collect_to_arrow(self) -> pa.Table:
|
||||
if self._implementation.is_pyspark() and self._backend_version < (4,):
|
||||
import pyarrow as pa # ignore-banned-import
|
||||
|
||||
try:
|
||||
return pa.Table.from_batches(self.native._collect_as_arrow())
|
||||
except ValueError as exc:
|
||||
if "at least one RecordBatch" in str(exc):
|
||||
# Empty dataframe
|
||||
|
||||
data: dict[str, list[Any]] = {k: [] for k in self.columns}
|
||||
pa_schema = self._to_arrow_schema()
|
||||
return pa.Table.from_pydict(data, schema=pa_schema)
|
||||
raise # pragma: no cover
|
||||
elif self._implementation.is_pyspark_connect() and self._backend_version < (4,):
|
||||
import pyarrow as pa # ignore-banned-import
|
||||
|
||||
pa_schema = self._to_arrow_schema()
|
||||
return pa.Table.from_pandas(self.native.toPandas(), schema=pa_schema)
|
||||
else:
|
||||
return self.native.toArrow()
|
||||
|
||||
def _iter_columns(self) -> Iterator[Column]:
|
||||
for col in self.columns:
|
||||
yield self._F.col(col)
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
if self._cached_columns is None:
|
||||
self._cached_columns = (
|
||||
list(self.schema)
|
||||
if self._cached_schema is not None
|
||||
else self.native.columns
|
||||
)
|
||||
return self._cached_columns
|
||||
|
||||
def _collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
if backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
self.native.toPandas(),
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is None or backend is Implementation.PYARROW:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
self._collect_to_arrow(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
pl.from_arrow(self._collect_to_arrow()), # type: ignore[arg-type]
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
if self._implementation.is_pyspark_connect():
|
||||
try:
|
||||
return self._collect(backend, **kwargs)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_pyspark_connect_exception(e) from None
|
||||
return self._collect(backend, **kwargs)
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
return self._with_native(self.native.select(*column_names))
|
||||
|
||||
def aggregate(self, *exprs: SparkLikeExpr) -> Self:
|
||||
new_columns = evaluate_exprs(self, *exprs)
|
||||
|
||||
new_columns_list = [col.alias(col_name) for col_name, col in new_columns]
|
||||
if self._implementation.is_pyspark():
|
||||
try:
|
||||
return self._with_native(self.native.agg(*new_columns_list))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_pyspark_sql_exception(e, self) from None
|
||||
return self._with_native(self.native.agg(*new_columns_list))
|
||||
|
||||
def select(self, *exprs: SparkLikeExpr) -> Self:
|
||||
new_columns = evaluate_exprs(self, *exprs)
|
||||
new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns]
|
||||
if self._implementation.is_pyspark(): # pragma: no cover
|
||||
try:
|
||||
return self._with_native(self.native.select(*new_columns_list))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_pyspark_sql_exception(e, self) from None
|
||||
return self._with_native(self.native.select(*new_columns_list))
|
||||
|
||||
def with_columns(self, *exprs: SparkLikeExpr) -> Self:
|
||||
new_columns = evaluate_exprs(self, *exprs)
|
||||
if self._implementation.is_pyspark(): # pragma: no cover
|
||||
try:
|
||||
return self._with_native(self.native.withColumns(dict(new_columns)))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_pyspark_sql_exception(e, self) from None
|
||||
|
||||
return self._with_native(self.native.withColumns(dict(new_columns)))
|
||||
|
||||
def filter(self, predicate: SparkLikeExpr) -> Self:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
condition = predicate._call(self)[0]
|
||||
if self._implementation.is_pyspark():
|
||||
try:
|
||||
return self._with_native(self.native.where(condition))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_pyspark_sql_exception(e, self) from None
|
||||
return self._with_native(self.native.where(condition))
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
if self._cached_schema is None:
|
||||
self._cached_schema = {
|
||||
field.name: native_to_narwhals_dtype(
|
||||
field.dataType,
|
||||
self._version,
|
||||
self._native_dtypes,
|
||||
self.native.sparkSession,
|
||||
)
|
||||
for field in self.native.schema
|
||||
}
|
||||
return self._cached_schema
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return self.schema
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
return self._with_native(self.native.drop(*columns_to_drop))
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
return self._with_native(self.native.limit(n))
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[SparkLikeExpr], *, drop_null_keys: bool
|
||||
) -> SparkLikeLazyGroupBy:
|
||||
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
|
||||
|
||||
return SparkLikeLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
descending = [descending] * len(by)
|
||||
|
||||
if nulls_last:
|
||||
sort_funcs = (
|
||||
self._F.desc_nulls_last if d else self._F.asc_nulls_last
|
||||
for d in descending
|
||||
)
|
||||
else:
|
||||
sort_funcs = (
|
||||
self._F.desc_nulls_first if d else self._F.asc_nulls_first
|
||||
for d in descending
|
||||
)
|
||||
|
||||
sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)]
|
||||
return self._with_native(self.native.sort(*sort_cols))
|
||||
|
||||
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
|
||||
by = list(by)
|
||||
if isinstance(reverse, bool):
|
||||
reverse = [reverse] * len(by)
|
||||
sort_funcs = (
|
||||
self._F.desc_nulls_last if not d else self._F.asc_nulls_last for d in reverse
|
||||
)
|
||||
sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)]
|
||||
return self._with_native(self.native.sort(*sort_cols).limit(k))
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
subset = list(subset) if subset else None
|
||||
return self._with_native(self.native.dropna(subset=subset))
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
rename_mapping = {
|
||||
colname: mapping.get(colname, colname) for colname in self.columns
|
||||
}
|
||||
return self._with_native(
|
||||
self.native.select(
|
||||
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
|
||||
)
|
||||
)
|
||||
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self:
|
||||
if subset and (error := self._check_columns_exist(subset)):
|
||||
raise error
|
||||
subset = list(subset) if subset else None
|
||||
if keep == "none":
|
||||
tmp = generate_temporary_column_name(8, self.columns)
|
||||
window = self._Window.partitionBy(subset or self.columns)
|
||||
df = (
|
||||
self.native.withColumn(tmp, self._F.count("*").over(window))
|
||||
.filter(self._F.col(tmp) == self._F.lit(1))
|
||||
.drop(self._F.col(tmp))
|
||||
)
|
||||
return self._with_native(df)
|
||||
return self._with_native(self.native.dropDuplicates(subset=subset))
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
left_columns = self.columns
|
||||
right_columns = other.columns
|
||||
|
||||
right_on_: list[str] = list(right_on) if right_on is not None else []
|
||||
left_on_: list[str] = list(left_on) if left_on is not None else []
|
||||
|
||||
# create a mapping for columns on other
|
||||
# `right_on` columns will be renamed as `left_on`
|
||||
# the remaining columns will be either added the suffix or left unchanged.
|
||||
right_cols_to_rename = (
|
||||
[c for c in right_columns if c not in right_on_]
|
||||
if how != "full"
|
||||
else right_columns
|
||||
)
|
||||
|
||||
rename_mapping = {
|
||||
**dict(zip(right_on_, left_on_)),
|
||||
**{
|
||||
colname: f"{colname}{suffix}" if colname in left_columns else colname
|
||||
for colname in right_cols_to_rename
|
||||
},
|
||||
}
|
||||
other_native = other.native.select(
|
||||
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
|
||||
)
|
||||
|
||||
# If how in {"semi", "anti"}, then resulting columns are same as left columns
|
||||
# Otherwise, we add the right columns with the new mapping, while keeping the
|
||||
# original order of right_columns.
|
||||
col_order = left_columns.copy()
|
||||
|
||||
if how in {"inner", "left", "cross"}:
|
||||
col_order.extend(
|
||||
rename_mapping[colname]
|
||||
for colname in right_columns
|
||||
if colname not in right_on_
|
||||
)
|
||||
elif how == "full":
|
||||
col_order.extend(rename_mapping.values())
|
||||
|
||||
right_on_remapped = [rename_mapping[c] for c in right_on_]
|
||||
on_ = (
|
||||
reduce(
|
||||
and_,
|
||||
(
|
||||
getattr(self.native, left_key) == getattr(other_native, right_key)
|
||||
for left_key, right_key in zip_strict(left_on_, right_on_remapped)
|
||||
),
|
||||
)
|
||||
if how == "full"
|
||||
else None
|
||||
if how == "cross"
|
||||
else left_on_
|
||||
)
|
||||
how_native = "full_outer" if how == "full" else how
|
||||
return self._with_native(
|
||||
self.native.join(other_native, on=on_, how=how_native).select(col_order)
|
||||
)
|
||||
|
||||
def explode(self, columns: Sequence[str]) -> Self:
|
||||
dtypes = self._version.dtypes
|
||||
|
||||
schema = self.collect_schema()
|
||||
for col_to_explode in columns:
|
||||
dtype = schema[col_to_explode]
|
||||
|
||||
if dtype != dtypes.List:
|
||||
msg = (
|
||||
f"`explode` operation not supported for dtype `{dtype}`, "
|
||||
"expected List type"
|
||||
)
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
column_names = self.columns
|
||||
|
||||
if len(columns) != 1:
|
||||
msg = (
|
||||
"Exploding on multiple columns is not supported with SparkLike backend since "
|
||||
"we cannot guarantee that the exploded columns have matching element counts."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if self._implementation.is_pyspark() or self._implementation.is_pyspark_connect():
|
||||
return self._with_native(
|
||||
self.native.select(
|
||||
*[
|
||||
self._F.col(col_name).alias(col_name)
|
||||
if col_name != columns[0]
|
||||
else self._F.explode_outer(col_name).alias(col_name)
|
||||
for col_name in column_names
|
||||
]
|
||||
)
|
||||
)
|
||||
if self._implementation.is_sqlframe():
|
||||
# Not every sqlframe dialect supports `explode_outer` function
|
||||
# (see https://github.com/eakmanrq/sqlframe/blob/3cb899c515b101ff4c197d84b34fae490d0ed257/sqlframe/base/functions.py#L2288-L2289)
|
||||
# therefore we simply explode the array column which will ignore nulls and
|
||||
# zero sized arrays, and append these specific condition with nulls (to
|
||||
# match polars behavior).
|
||||
|
||||
def null_condition(col_name: str) -> Column:
|
||||
return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0)
|
||||
|
||||
return self._with_native(
|
||||
self.native.select(
|
||||
*[
|
||||
self._F.col(col_name).alias(col_name)
|
||||
if col_name != columns[0]
|
||||
else self._F.explode(col_name).alias(col_name)
|
||||
for col_name in column_names
|
||||
]
|
||||
).union(
|
||||
self.native.filter(null_condition(columns[0])).select(
|
||||
*[
|
||||
self._F.col(col_name).alias(col_name)
|
||||
if col_name != columns[0]
|
||||
else self._F.lit(None).alias(col_name)
|
||||
for col_name in column_names
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
msg = "Unreachable code, please report an issue at https://github.com/narwhals-dev/narwhals/issues" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
if self._implementation.is_sqlframe():
|
||||
if variable_name == "":
|
||||
msg = "`variable_name` cannot be empty string for sqlframe backend."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if value_name == "":
|
||||
msg = "`value_name` cannot be empty string for sqlframe backend."
|
||||
raise NotImplementedError(msg)
|
||||
else: # pragma: no cover
|
||||
pass
|
||||
|
||||
ids = tuple(index) if index else ()
|
||||
values = (
|
||||
tuple(set(self.columns).difference(set(ids))) if on is None else tuple(on)
|
||||
)
|
||||
unpivoted_native_frame = self.native.unpivot(
|
||||
ids=ids,
|
||||
values=values,
|
||||
variableColumnName=variable_name,
|
||||
valueColumnName=value_name,
|
||||
)
|
||||
if index is None:
|
||||
unpivoted_native_frame = unpivoted_native_frame.drop(*ids)
|
||||
return self._with_native(unpivoted_native_frame)
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
|
||||
if order_by is None:
|
||||
msg = "Cannot pass `order_by` to `with_row_index` for PySpark-like"
|
||||
raise TypeError(msg)
|
||||
row_index_expr = (
|
||||
self._F.row_number().over(
|
||||
self._Window.partitionBy(self._F.lit(1)).orderBy(*order_by)
|
||||
)
|
||||
- 1
|
||||
).alias(name)
|
||||
return self._with_native(self.native.select(row_index_expr, *self.columns))
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
self.native.write.parquet(file)
|
||||
|
||||
@classmethod
|
||||
def _from_compliant_dataframe(
|
||||
cls,
|
||||
frame: CompliantDataFrameAny,
|
||||
/,
|
||||
*,
|
||||
session: SparkSession,
|
||||
implementation: Implementation,
|
||||
version: Version,
|
||||
) -> SparkLikeLazyFrame:
|
||||
from importlib.util import find_spec
|
||||
|
||||
impl = implementation
|
||||
is_spark_v4 = (not impl.is_sqlframe()) and impl._backend_version() >= (4, 0, 0)
|
||||
if is_spark_v4: # pragma: no cover
|
||||
# pyspark.sql requires pyarrow to be installed from v4.0.0
|
||||
# and since v4.0.0 the input to `createDataFrame` can be a PyArrow Table.
|
||||
data: Any = frame.to_arrow()
|
||||
elif find_spec("pandas"):
|
||||
data = frame.to_pandas()
|
||||
else: # pragma: no cover
|
||||
data = tuple(frame.iter_rows(named=True, buffer_size=512))
|
||||
|
||||
return cls(
|
||||
session.createDataFrame(data),
|
||||
version=version,
|
||||
implementation=implementation,
|
||||
validate_backend_version=True,
|
||||
)
|
||||
|
||||
gather_every = not_implemented.deprecated(
|
||||
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
|
||||
)
|
||||
join_asof = not_implemented()
|
||||
tail = not_implemented.deprecated(
|
||||
"`LazyFrame.tail` is deprecated and will be removed in a future version."
|
||||
)
|
391
lib/python3.11/site-packages/narwhals/_spark_like/expr.py
Normal file
391
lib/python3.11/site-packages/narwhals/_spark_like/expr.py
Normal file
@ -0,0 +1,391 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast
|
||||
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
|
||||
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
|
||||
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
|
||||
from narwhals._spark_like.expr_struct import SparkLikeExprStructNamespace
|
||||
from narwhals._spark_like.utils import (
|
||||
import_functions,
|
||||
import_native_dtypes,
|
||||
import_window,
|
||||
narwhals_to_native_dtype,
|
||||
true_divide,
|
||||
)
|
||||
from narwhals._sql.expr import SQLExpr
|
||||
from narwhals._utils import Implementation, Version, not_implemented, zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
|
||||
from sqlframe.base.column import Column
|
||||
from sqlframe.base.window import Window, WindowSpec
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
from narwhals._compliant import WindowInputs
|
||||
from narwhals._compliant.typing import (
|
||||
AliasNames,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
WindowFunction,
|
||||
)
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
from narwhals._spark_like.namespace import SparkLikeNamespace
|
||||
from narwhals._utils import _LimitedContext
|
||||
from narwhals.typing import FillNullStrategy, IntoDType, NonNestedLiteral, RankMethod
|
||||
|
||||
NativeRankMethod: TypeAlias = Literal["rank", "dense_rank", "row_number"]
|
||||
SparkWindowFunction = WindowFunction[SparkLikeLazyFrame, Column]
|
||||
SparkWindowInputs = WindowInputs[Column]
|
||||
|
||||
|
||||
class SparkLikeExpr(SQLExpr["SparkLikeLazyFrame", "Column"]):
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[SparkLikeLazyFrame, Column],
|
||||
window_function: SparkWindowFunction | None = None,
|
||||
*,
|
||||
evaluate_output_names: EvalNames[SparkLikeLazyFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
implementation: Implementation,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._implementation = implementation
|
||||
self._metadata: ExprMetadata | None = None
|
||||
self._window_function: SparkWindowFunction | None = window_function
|
||||
|
||||
_REMAP_RANK_METHOD: ClassVar[Mapping[RankMethod, NativeRankMethod]] = {
|
||||
"min": "rank",
|
||||
"max": "rank",
|
||||
"average": "rank",
|
||||
"dense": "dense_rank",
|
||||
"ordinal": "row_number",
|
||||
}
|
||||
|
||||
def _count_star(self) -> Column:
|
||||
return self._F.count("*")
|
||||
|
||||
def _window_expression(
|
||||
self,
|
||||
expr: Column,
|
||||
partition_by: Sequence[str | Column] = (),
|
||||
order_by: Sequence[str | Column] = (),
|
||||
rows_start: int | None = None,
|
||||
rows_end: int | None = None,
|
||||
*,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
) -> Column:
|
||||
window = self.partition_by(*partition_by)
|
||||
if order_by:
|
||||
window = window.orderBy(
|
||||
*self._sort(*order_by, descending=descending, nulls_last=nulls_last)
|
||||
)
|
||||
if rows_start is not None and rows_end is not None:
|
||||
window = window.rowsBetween(rows_start, rows_end)
|
||||
elif rows_end is not None:
|
||||
window = window.rowsBetween(self._Window.unboundedPreceding, rows_end)
|
||||
elif rows_start is not None: # pragma: no cover
|
||||
window = window.rowsBetween(rows_start, self._Window.unboundedFollowing)
|
||||
return expr.over(window)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
if kind is ExprKind.LITERAL:
|
||||
return self
|
||||
return self.over([self._F.lit(1)], [])
|
||||
|
||||
@property
|
||||
def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base import functions
|
||||
|
||||
return functions
|
||||
return import_functions(self._implementation)
|
||||
|
||||
@property
|
||||
def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base import types
|
||||
|
||||
return types
|
||||
return import_native_dtypes(self._implementation)
|
||||
|
||||
@property
|
||||
def _Window(self) -> type[Window]:
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base.window import Window
|
||||
|
||||
return Window
|
||||
return import_window(self._implementation)
|
||||
|
||||
def _sort(
|
||||
self,
|
||||
*cols: Column | str,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
) -> Iterator[Column]:
|
||||
F = self._F
|
||||
descending = descending or [False] * len(cols)
|
||||
nulls_last = nulls_last or [False] * len(cols)
|
||||
mapping = {
|
||||
(False, False): F.asc_nulls_first,
|
||||
(False, True): F.asc_nulls_last,
|
||||
(True, False): F.desc_nulls_first,
|
||||
(True, True): F.desc_nulls_last,
|
||||
}
|
||||
yield from (
|
||||
mapping[(_desc, _nulls_last)](col)
|
||||
for col, _desc, _nulls_last in zip_strict(cols, descending, nulls_last)
|
||||
)
|
||||
|
||||
def partition_by(self, *cols: Column | str) -> WindowSpec:
|
||||
"""Wraps `Window().partitionBy`, with default and `WindowInputs` handling."""
|
||||
return self._Window.partitionBy(*cols or [self._F.lit(1)])
|
||||
|
||||
def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover
|
||||
from narwhals._spark_like.namespace import SparkLikeNamespace
|
||||
|
||||
return SparkLikeNamespace(
|
||||
version=self._version, implementation=self._implementation
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _alias_native(cls, expr: Column, name: str) -> Column:
|
||||
return expr.alias(name)
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls: type[Self],
|
||||
evaluate_column_names: EvalNames[SparkLikeLazyFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
) -> Self:
|
||||
def func(df: SparkLikeLazyFrame) -> list[Column]:
|
||||
return [df._F.col(col_name) for col_name in evaluate_column_names(df)]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
implementation=context._implementation,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: SparkLikeLazyFrame) -> list[Column]:
|
||||
columns = df.columns
|
||||
return [df._F.col(columns[i]) for i in column_indices]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
implementation=context._implementation,
|
||||
)
|
||||
|
||||
def __truediv__(self, other: SparkLikeExpr) -> Self:
|
||||
def _truediv(expr: Column, other: Column) -> Column:
|
||||
return true_divide(self._F, expr, other)
|
||||
|
||||
return self._with_binary(_truediv, other)
|
||||
|
||||
def __rtruediv__(self, other: SparkLikeExpr) -> Self:
|
||||
def _rtruediv(expr: Column, other: Column) -> Column:
|
||||
return true_divide(self._F, other, expr)
|
||||
|
||||
return self._with_binary(_rtruediv, other).alias("literal")
|
||||
|
||||
def __floordiv__(self, other: SparkLikeExpr) -> Self:
|
||||
def _floordiv(expr: Column, other: Column) -> Column:
|
||||
return self._F.floor(true_divide(self._F, expr, other))
|
||||
|
||||
return self._with_binary(_floordiv, other)
|
||||
|
||||
def __rfloordiv__(self, other: SparkLikeExpr) -> Self:
|
||||
def _rfloordiv(expr: Column, other: Column) -> Column:
|
||||
return self._F.floor(true_divide(self._F, other, expr))
|
||||
|
||||
return self._with_binary(_rfloordiv, other).alias("literal")
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
invert = cast("Callable[..., Column]", operator.invert)
|
||||
return self._with_elementwise(invert)
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
|
||||
spark_dtype = narwhals_to_native_dtype(
|
||||
dtype, self._version, self._native_dtypes, df.native.sparkSession
|
||||
)
|
||||
return [expr.cast(spark_dtype) for expr in self(df)]
|
||||
|
||||
def window_f(
|
||||
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
|
||||
) -> Sequence[Column]:
|
||||
spark_dtype = narwhals_to_native_dtype(
|
||||
dtype, self._version, self._native_dtypes, df.native.sparkSession
|
||||
)
|
||||
return [expr.cast(spark_dtype) for expr in self.window_function(df, inputs)]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
window_f,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
def median(self) -> Self:
|
||||
def _median(expr: Column) -> Column:
|
||||
if self._implementation in {
|
||||
Implementation.PYSPARK,
|
||||
Implementation.PYSPARK_CONNECT,
|
||||
} and Implementation.PYSPARK._backend_version() < (3, 4): # pragma: no cover
|
||||
# Use percentile_approx with default accuracy parameter (10000)
|
||||
return self._F.percentile_approx(expr.cast("double"), 0.5)
|
||||
|
||||
return self._F.median(expr)
|
||||
|
||||
return self._with_callable(_median)
|
||||
|
||||
def null_count(self) -> Self:
|
||||
def _null_count(expr: Column) -> Column:
|
||||
return self._F.count_if(self._F.isnull(expr))
|
||||
|
||||
return self._with_callable(_null_count)
|
||||
|
||||
def std(self, ddof: int) -> Self:
|
||||
F = self._F
|
||||
if ddof == 0:
|
||||
return self._with_callable(F.stddev_pop)
|
||||
if ddof == 1:
|
||||
return self._with_callable(F.stddev_samp)
|
||||
|
||||
def func(expr: Column) -> Column:
|
||||
n_rows = F.count(expr)
|
||||
return F.stddev_samp(expr) * F.sqrt((n_rows - 1) / (n_rows - ddof))
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def var(self, ddof: int) -> Self:
|
||||
F = self._F
|
||||
if ddof == 0:
|
||||
return self._with_callable(F.var_pop)
|
||||
if ddof == 1:
|
||||
return self._with_callable(F.var_samp)
|
||||
|
||||
def func(expr: Column) -> Column:
|
||||
n_rows = F.count(expr)
|
||||
return F.var_samp(expr) * (n_rows - 1) / (n_rows - ddof)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def is_finite(self) -> Self:
|
||||
def _is_finite(expr: Column) -> Column:
|
||||
# A value is finite if it's not NaN, and not infinite, while NULLs should be
|
||||
# preserved
|
||||
is_finite_condition = (
|
||||
~self._F.isnan(expr)
|
||||
& (expr != self._F.lit(float("inf")))
|
||||
& (expr != self._F.lit(float("-inf")))
|
||||
)
|
||||
return self._F.when(~self._F.isnull(expr), is_finite_condition).otherwise(
|
||||
None
|
||||
)
|
||||
|
||||
return self._with_elementwise(_is_finite)
|
||||
|
||||
def is_in(self, values: Sequence[Any]) -> Self:
|
||||
def _is_in(expr: Column) -> Column:
|
||||
return expr.isin(values) if values else self._F.lit(False)
|
||||
|
||||
return self._with_elementwise(_is_in)
|
||||
|
||||
def len(self) -> Self:
|
||||
def _len(_expr: Column) -> Column:
|
||||
# Use count(*) to count all rows including nulls
|
||||
return self._F.count("*")
|
||||
|
||||
return self._with_callable(_len)
|
||||
|
||||
def skew(self) -> Self:
|
||||
return self._with_callable(self._F.skewness)
|
||||
|
||||
def kurtosis(self) -> Self:
|
||||
return self._with_callable(self._F.kurtosis)
|
||||
|
||||
def n_unique(self) -> Self:
|
||||
def _n_unique(expr: Column) -> Column:
|
||||
return self._F.count_distinct(expr) + self._F.max(
|
||||
self._F.isnull(expr).cast(self._native_dtypes.IntegerType())
|
||||
)
|
||||
|
||||
return self._with_callable(_n_unique)
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
def _is_nan(expr: Column) -> Column:
|
||||
return self._F.when(self._F.isnull(expr), None).otherwise(self._F.isnan(expr))
|
||||
|
||||
return self._with_elementwise(_is_nan)
|
||||
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self:
|
||||
if strategy is not None:
|
||||
|
||||
def _fill_with_strategy(
|
||||
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
|
||||
) -> Sequence[Column]:
|
||||
fn = self._F.last_value if strategy == "forward" else self._F.first_value
|
||||
if strategy == "forward":
|
||||
start = self._Window.unboundedPreceding if limit is None else -limit
|
||||
end = self._Window.currentRow
|
||||
else:
|
||||
start = self._Window.currentRow
|
||||
end = self._Window.unboundedFollowing if limit is None else limit
|
||||
return [
|
||||
fn(expr, ignoreNulls=True).over(
|
||||
self.partition_by(*inputs.partition_by)
|
||||
.orderBy(*self._sort(*inputs.order_by))
|
||||
.rowsBetween(start, end)
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_window_function(_fill_with_strategy)
|
||||
|
||||
def _fill_constant(expr: Column, value: Column) -> Column:
|
||||
return self._F.ifnull(expr, value)
|
||||
|
||||
return self._with_elementwise(_fill_constant, value=value)
|
||||
|
||||
@property
|
||||
def str(self) -> SparkLikeExprStringNamespace:
|
||||
return SparkLikeExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def dt(self) -> SparkLikeExprDateTimeNamespace:
|
||||
return SparkLikeExprDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def list(self) -> SparkLikeExprListNamespace:
|
||||
return SparkLikeExprListNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> SparkLikeExprStructNamespace:
|
||||
return SparkLikeExprStructNamespace(self)
|
||||
|
||||
quantile = not_implemented()
|
192
lib/python3.11/site-packages/narwhals/_spark_like/expr_dt.py
Normal file
192
lib/python3.11/site-packages/narwhals/_spark_like/expr_dt.py
Normal file
@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._constants import US_PER_SECOND
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._spark_like.utils import (
|
||||
UNITS_DICT,
|
||||
fetch_session_time_zone,
|
||||
strptime_to_pyspark_format,
|
||||
)
|
||||
from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlframe.base.column import Column
|
||||
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
|
||||
|
||||
class SparkLikeExprDateTimeNamespace(SQLExprDateTimeNamesSpace["SparkLikeExpr"]):
|
||||
def _weekday(self, expr: Column) -> Column:
|
||||
# PySpark's dayofweek returns 1-7 for Sunday-Saturday
|
||||
return (self.compliant._F.dayofweek(expr) + 6) % 7
|
||||
|
||||
def to_string(self, format: str) -> SparkLikeExpr:
|
||||
F = self.compliant._F
|
||||
|
||||
def _to_string(expr: Column) -> Column:
|
||||
# Handle special formats
|
||||
if format == "%G-W%V":
|
||||
return self._format_iso_week(expr)
|
||||
if format == "%G-W%V-%u":
|
||||
return self._format_iso_week_with_day(expr)
|
||||
|
||||
format_, suffix = self._format_microseconds(expr, format)
|
||||
|
||||
# Convert Python format to PySpark format
|
||||
pyspark_fmt = strptime_to_pyspark_format(format_)
|
||||
|
||||
result = F.date_format(expr, pyspark_fmt)
|
||||
if "T" in format_:
|
||||
# `strptime_to_pyspark_format` replaces "T" with " " since pyspark
|
||||
# does not support the literal "T" in `date_format`.
|
||||
# If no other spaces are in the given format, then we can revert this
|
||||
# operation, otherwise we raise an exception.
|
||||
if " " not in format_:
|
||||
result = F.replace(result, F.lit(" "), F.lit("T"))
|
||||
else: # pragma: no cover
|
||||
msg = (
|
||||
"`dt.to_string` with a format that contains both spaces and "
|
||||
" the literal 'T' is not supported for spark-like backends."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return F.concat(result, *suffix)
|
||||
|
||||
return self.compliant._with_elementwise(_to_string)
|
||||
|
||||
def millisecond(self) -> SparkLikeExpr:
|
||||
def _millisecond(expr: Column) -> Column:
|
||||
return self.compliant._F.floor(
|
||||
(self.compliant._F.unix_micros(expr) % US_PER_SECOND) / 1000
|
||||
)
|
||||
|
||||
return self.compliant._with_elementwise(_millisecond)
|
||||
|
||||
def microsecond(self) -> SparkLikeExpr:
|
||||
def _microsecond(expr: Column) -> Column:
|
||||
return self.compliant._F.unix_micros(expr) % US_PER_SECOND
|
||||
|
||||
return self.compliant._with_elementwise(_microsecond)
|
||||
|
||||
def nanosecond(self) -> SparkLikeExpr:
|
||||
def _nanosecond(expr: Column) -> Column:
|
||||
return (self.compliant._F.unix_micros(expr) % US_PER_SECOND) * 1000
|
||||
|
||||
return self.compliant._with_elementwise(_nanosecond)
|
||||
|
||||
def weekday(self) -> SparkLikeExpr:
|
||||
return self.compliant._with_elementwise(self._weekday)
|
||||
|
||||
def truncate(self, every: str) -> SparkLikeExpr:
|
||||
interval = Interval.parse(every)
|
||||
multiple, unit = interval.multiple, interval.unit
|
||||
if multiple != 1:
|
||||
msg = f"Only multiple 1 is currently supported for Spark-like.\nGot {multiple!s}."
|
||||
raise ValueError(msg)
|
||||
if unit == "ns":
|
||||
msg = "Truncating to nanoseconds is not yet supported for Spark-like."
|
||||
raise NotImplementedError(msg)
|
||||
format = UNITS_DICT[unit]
|
||||
|
||||
def _truncate(expr: Column) -> Column:
|
||||
return self.compliant._F.date_trunc(format, expr)
|
||||
|
||||
return self.compliant._with_elementwise(_truncate)
|
||||
|
||||
def offset_by(self, by: str) -> SparkLikeExpr:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
multiple, unit = interval.multiple, interval.unit
|
||||
if unit == "ns": # pragma: no cover
|
||||
msg = "Offsetting by nanoseconds is not yet supported for Spark-like."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
F = self.compliant._F
|
||||
|
||||
def _offset_by(expr: Column) -> Column:
|
||||
# https://github.com/eakmanrq/sqlframe/issues/441
|
||||
return F.timestamp_add( # pyright: ignore[reportAttributeAccessIssue]
|
||||
UNITS_DICT[unit], F.lit(multiple), expr
|
||||
)
|
||||
|
||||
return self.compliant._with_callable(_offset_by)
|
||||
|
||||
def _no_op_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover
|
||||
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
|
||||
native_series_list = self.compliant(df)
|
||||
conn_time_zone = fetch_session_time_zone(df.native.sparkSession)
|
||||
if conn_time_zone != time_zone:
|
||||
msg = (
|
||||
"PySpark stores the time zone in the session, rather than in the "
|
||||
f"data type, so changing the timezone to anything other than {conn_time_zone} "
|
||||
" (the current session time zone) is not supported."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return native_series_list
|
||||
|
||||
return self.compliant.__class__(
|
||||
func,
|
||||
evaluate_output_names=self.compliant._evaluate_output_names,
|
||||
alias_output_names=self.compliant._alias_output_names,
|
||||
version=self.compliant._version,
|
||||
implementation=self.compliant._implementation,
|
||||
)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover
|
||||
return self._no_op_time_zone(time_zone)
|
||||
|
||||
def replace_time_zone(
|
||||
self, time_zone: str | None
|
||||
) -> SparkLikeExpr: # pragma: no cover
|
||||
if time_zone is None:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: expr.cast("timestamp_ntz")
|
||||
)
|
||||
return self._no_op_time_zone(time_zone)
|
||||
|
||||
def _format_iso_week_with_day(self, expr: Column) -> Column:
|
||||
"""Format datetime as ISO week string with day."""
|
||||
F = self.compliant._F
|
||||
|
||||
year = F.date_format(expr, "yyyy")
|
||||
week = F.lpad(F.weekofyear(expr).cast("string"), 2, "0")
|
||||
day = self._weekday(expr)
|
||||
return F.concat(year, F.lit("-W"), week, F.lit("-"), day.cast("string"))
|
||||
|
||||
def _format_iso_week(self, expr: Column) -> Column:
|
||||
"""Format datetime as ISO week string."""
|
||||
F = self.compliant._F
|
||||
|
||||
year = F.date_format(expr, "yyyy")
|
||||
week = F.lpad(F.weekofyear(expr).cast("string"), 2, "0")
|
||||
return F.concat(year, F.lit("-W"), week)
|
||||
|
||||
def _format_microseconds(
|
||||
self, expr: Column, format: str
|
||||
) -> tuple[str, tuple[Column, ...]]:
|
||||
"""Format microseconds if present in format, else it's a no-op."""
|
||||
F = self.compliant._F
|
||||
|
||||
suffix: tuple[Column, ...]
|
||||
if format.endswith((".%f", "%.f")):
|
||||
import re
|
||||
|
||||
micros = F.unix_micros(expr) % US_PER_SECOND
|
||||
micros_str = F.lpad(micros.cast("string"), 6, "0")
|
||||
suffix = (F.lit("."), micros_str)
|
||||
format_ = re.sub(r"(.%|%.)f$", "", format)
|
||||
return format_, suffix
|
||||
|
||||
return format, ()
|
||||
|
||||
timestamp = not_implemented()
|
||||
total_seconds = not_implemented()
|
||||
total_minutes = not_implemented()
|
||||
total_milliseconds = not_implemented()
|
||||
total_microseconds = not_implemented()
|
||||
total_nanoseconds = not_implemented()
|
@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import ListNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base.column import Column
|
||||
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
from narwhals.typing import NonNestedLiteral
|
||||
|
||||
|
||||
class SparkLikeExprListNamespace(
|
||||
LazyExprNamespace["SparkLikeExpr"], ListNamespace["SparkLikeExpr"]
|
||||
):
|
||||
def len(self) -> SparkLikeExpr:
|
||||
return self.compliant._with_elementwise(self.compliant._F.array_size)
|
||||
|
||||
def unique(self) -> SparkLikeExpr:
|
||||
return self.compliant._with_elementwise(self.compliant._F.array_distinct)
|
||||
|
||||
def contains(self, item: NonNestedLiteral) -> SparkLikeExpr:
|
||||
def func(expr: Column) -> Column:
|
||||
F = self.compliant._F
|
||||
return F.array_contains(expr, F.lit(item))
|
||||
|
||||
return self.compliant._with_elementwise(func)
|
||||
|
||||
def get(self, index: int) -> SparkLikeExpr:
|
||||
def _get(expr: Column) -> Column:
|
||||
return expr.getItem(index)
|
||||
|
||||
return self.compliant._with_elementwise(_get)
|
@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._spark_like.utils import strptime_to_pyspark_format
|
||||
from narwhals._sql.expr_str import SQLExprStringNamespace
|
||||
from narwhals._utils import _is_naive_format, not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
|
||||
|
||||
class SparkLikeExprStringNamespace(SQLExprStringNamespace["SparkLikeExpr"]):
|
||||
def to_datetime(self, format: str | None) -> SparkLikeExpr:
|
||||
F = self.compliant._F
|
||||
if not format:
|
||||
function = F.to_timestamp
|
||||
elif _is_naive_format(format):
|
||||
function = partial(
|
||||
F.to_timestamp_ntz, format=F.lit(strptime_to_pyspark_format(format))
|
||||
)
|
||||
else:
|
||||
format = strptime_to_pyspark_format(format)
|
||||
function = partial(F.to_timestamp, format=format)
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: function(F.replace(expr, F.lit("T"), F.lit(" ")))
|
||||
)
|
||||
|
||||
def to_date(self, format: str | None) -> SparkLikeExpr:
|
||||
F = self.compliant._F
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F.to_date(expr, format=strptime_to_pyspark_format(format))
|
||||
)
|
||||
|
||||
replace = not_implemented()
|
@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import StructNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base.column import Column
|
||||
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
|
||||
|
||||
class SparkLikeExprStructNamespace(
|
||||
LazyExprNamespace["SparkLikeExpr"], StructNamespace["SparkLikeExpr"]
|
||||
):
|
||||
def field(self, name: str) -> SparkLikeExpr:
|
||||
def func(expr: Column) -> Column:
|
||||
return expr.getField(name)
|
||||
|
||||
return self.compliant._with_elementwise(func).alias(name)
|
@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._sql.group_by import SQLGroupBy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlframe.base.column import Column # noqa: F401
|
||||
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
|
||||
|
||||
class SparkLikeLazyGroupBy(SQLGroupBy["SparkLikeLazyFrame", "SparkLikeExpr", "Column"]):
|
||||
def __init__(
|
||||
self,
|
||||
df: SparkLikeLazyFrame,
|
||||
keys: Sequence[SparkLikeExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
|
||||
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
|
||||
|
||||
def agg(self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
|
||||
result = (
|
||||
self.compliant.native.groupBy(*self._keys).agg(*agg_columns)
|
||||
if (agg_columns := list(self._evaluate_exprs(exprs)))
|
||||
else self.compliant.native.select(*self._keys).dropDuplicates()
|
||||
)
|
||||
|
||||
return self.compliant._with_native(result).rename(
|
||||
dict(zip(self._keys, self._output_key_names))
|
||||
)
|
230
lib/python3.11/site-packages/narwhals/_spark_like/namespace.py
Normal file
230
lib/python3.11/site-packages/narwhals/_spark_like/namespace.py
Normal file
@ -0,0 +1,230 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from narwhals._expression_parsing import (
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
from narwhals._spark_like.selectors import SparkLikeSelectorNamespace
|
||||
from narwhals._spark_like.utils import (
|
||||
import_functions,
|
||||
import_native_dtypes,
|
||||
narwhals_to_native_dtype,
|
||||
true_divide,
|
||||
)
|
||||
from narwhals._sql.namespace import SQLNamespace
|
||||
from narwhals._sql.when_then import SQLThen, SQLWhen
|
||||
from narwhals._utils import zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from sqlframe.base.column import Column
|
||||
|
||||
from narwhals._spark_like.dataframe import SQLFrameDataFrame # noqa: F401
|
||||
from narwhals._utils import Implementation, Version
|
||||
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral, PythonLiteral
|
||||
|
||||
# Adjust slight SQL vs PySpark differences
|
||||
FUNCTION_REMAPPINGS = {
|
||||
"starts_with": "startswith",
|
||||
"ends_with": "endswith",
|
||||
"trim": "btrim",
|
||||
"str_split": "split",
|
||||
"regexp_matches": "regexp",
|
||||
}
|
||||
|
||||
|
||||
class SparkLikeNamespace(
|
||||
SQLNamespace[SparkLikeLazyFrame, SparkLikeExpr, "SQLFrameDataFrame", "Column"]
|
||||
):
|
||||
def __init__(self, *, version: Version, implementation: Implementation) -> None:
|
||||
self._version = version
|
||||
self._implementation = implementation
|
||||
|
||||
@property
|
||||
def selectors(self) -> SparkLikeSelectorNamespace:
|
||||
return SparkLikeSelectorNamespace.from_namespace(self)
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[SparkLikeExpr]:
|
||||
return SparkLikeExpr
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[SparkLikeLazyFrame]:
|
||||
return SparkLikeLazyFrame
|
||||
|
||||
@property
|
||||
def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base import functions
|
||||
|
||||
return functions
|
||||
return import_functions(self._implementation)
|
||||
|
||||
@property
|
||||
def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base import types
|
||||
|
||||
return types
|
||||
return import_native_dtypes(self._implementation)
|
||||
|
||||
def _function(self, name: str, *args: Column | PythonLiteral) -> Column:
|
||||
return getattr(self._F, FUNCTION_REMAPPINGS.get(name, name))(*args)
|
||||
|
||||
def _lit(self, value: Any) -> Column:
|
||||
return self._F.lit(value)
|
||||
|
||||
def _when(
|
||||
self, condition: Column, value: Column, otherwise: Column | None = None
|
||||
) -> Column:
|
||||
if otherwise is None:
|
||||
return self._F.when(condition, value)
|
||||
return self._F.when(condition, value).otherwise(otherwise)
|
||||
|
||||
def _coalesce(self, *exprs: Column) -> Column:
|
||||
return self._F.coalesce(*exprs)
|
||||
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> SparkLikeExpr:
|
||||
def _lit(df: SparkLikeLazyFrame) -> list[Column]:
|
||||
column = df._F.lit(value)
|
||||
if dtype:
|
||||
native_dtype = narwhals_to_native_dtype(
|
||||
dtype, self._version, df._native_dtypes, df.native.sparkSession
|
||||
)
|
||||
column = column.cast(native_dtype)
|
||||
|
||||
return [column]
|
||||
|
||||
return self._expr(
|
||||
call=_lit,
|
||||
evaluate_output_names=lambda _df: ["literal"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
def len(self) -> SparkLikeExpr:
|
||||
def func(df: SparkLikeLazyFrame) -> list[Column]:
|
||||
return [df._F.count("*")]
|
||||
|
||||
return self._expr(
|
||||
func,
|
||||
evaluate_output_names=lambda _df: ["len"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
def mean_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr:
|
||||
def func(cols: Iterable[Column]) -> Column:
|
||||
cols = list(cols)
|
||||
F = exprs[0]._F
|
||||
numerator = reduce(
|
||||
operator.add, (self._F.coalesce(col, self._F.lit(0)) for col in cols)
|
||||
)
|
||||
denominator = reduce(
|
||||
operator.add,
|
||||
(col.isNotNull().cast(self._native_dtypes.IntegerType()) for col in cols),
|
||||
)
|
||||
return true_divide(F, numerator, denominator)
|
||||
|
||||
return self._expr._from_elementwise_horizontal_op(func, *exprs)
|
||||
|
||||
def concat(
|
||||
self, items: Iterable[SparkLikeLazyFrame], *, how: ConcatMethod
|
||||
) -> SparkLikeLazyFrame:
|
||||
dfs = [item._native_frame for item in items]
|
||||
if how == "vertical":
|
||||
cols_0 = dfs[0].columns
|
||||
for i, df in enumerate(dfs[1:], start=1):
|
||||
cols_current = df.columns
|
||||
if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)):
|
||||
msg = (
|
||||
"unable to vstack, column names don't match:\n"
|
||||
f" - dataframe 0: {cols_0}\n"
|
||||
f" - dataframe {i}: {cols_current}\n"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
return SparkLikeLazyFrame(
|
||||
native_dataframe=reduce(lambda x, y: x.union(y), dfs),
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
if how == "diagonal":
|
||||
return SparkLikeLazyFrame(
|
||||
native_dataframe=reduce(
|
||||
lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs
|
||||
),
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: SparkLikeExpr, separator: str, ignore_nulls: bool
|
||||
) -> SparkLikeExpr:
|
||||
def func(df: SparkLikeLazyFrame) -> list[Column]:
|
||||
cols = [s for _expr in exprs for s in _expr(df)]
|
||||
cols_casted = [s.cast(df._native_dtypes.StringType()) for s in cols]
|
||||
null_mask = [df._F.isnull(s) for s in cols]
|
||||
|
||||
if not ignore_nulls:
|
||||
null_mask_result = reduce(operator.or_, null_mask)
|
||||
result = df._F.when(
|
||||
~null_mask_result,
|
||||
reduce(
|
||||
lambda x, y: df._F.format_string(f"%s{separator}%s", x, y),
|
||||
cols_casted,
|
||||
),
|
||||
).otherwise(df._F.lit(None))
|
||||
else:
|
||||
init_value, *values = [
|
||||
df._F.when(~nm, col).otherwise(df._F.lit(""))
|
||||
for col, nm in zip_strict(cols_casted, null_mask)
|
||||
]
|
||||
|
||||
separators = (
|
||||
df._F.when(nm, df._F.lit("")).otherwise(df._F.lit(separator))
|
||||
for nm in null_mask[:-1]
|
||||
)
|
||||
result = reduce(
|
||||
lambda x, y: df._F.format_string("%s%s", x, y),
|
||||
(
|
||||
df._F.format_string("%s%s", s, v)
|
||||
for s, v in zip_strict(separators, values)
|
||||
),
|
||||
init_value,
|
||||
)
|
||||
|
||||
return [result]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
def when(self, predicate: SparkLikeExpr) -> SparkLikeWhen:
|
||||
return SparkLikeWhen.from_expr(predicate, context=self)
|
||||
|
||||
|
||||
class SparkLikeWhen(SQLWhen[SparkLikeLazyFrame, "Column", SparkLikeExpr]):
|
||||
@property
|
||||
def _then(self) -> type[SparkLikeThen]:
|
||||
return SparkLikeThen
|
||||
|
||||
|
||||
class SparkLikeThen(
|
||||
SQLThen[SparkLikeLazyFrame, "Column", SparkLikeExpr], SparkLikeExpr
|
||||
): ...
|
@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base.column import Column # noqa: F401
|
||||
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame # noqa: F401
|
||||
from narwhals._spark_like.expr import SparkWindowFunction
|
||||
|
||||
|
||||
class SparkLikeSelectorNamespace(LazySelectorNamespace["SparkLikeLazyFrame", "Column"]):
|
||||
@property
|
||||
def _selector(self) -> type[SparkLikeSelector]:
|
||||
return SparkLikeSelector
|
||||
|
||||
|
||||
class SparkLikeSelector(CompliantSelector["SparkLikeLazyFrame", "Column"], SparkLikeExpr): # type: ignore[misc]
|
||||
_window_function: SparkWindowFunction | None = None
|
||||
|
||||
def _to_expr(self) -> SparkLikeExpr:
|
||||
return SparkLikeExpr(
|
||||
self._call,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
319
lib/python3.11/site-packages/narwhals/_spark_like/utils.py
Normal file
319
lib/python3.11/site-packages/narwhals/_spark_like/utils.py
Normal file
@ -0,0 +1,319 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from importlib import import_module
|
||||
from operator import attrgetter
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, overload
|
||||
|
||||
from narwhals._utils import Implementation, Version, isinstance_or_issubclass
|
||||
from narwhals.exceptions import ColumnNotFoundError, UnsupportedDTypeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
import sqlframe.base.types as sqlframe_types
|
||||
from sqlframe.base.column import Column
|
||||
from sqlframe.base.session import _BaseSession as Session
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.typing import CompliantLazyFrameAny
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType
|
||||
|
||||
_NativeDType: TypeAlias = sqlframe_types.DataType
|
||||
SparkSession = Session[Any, Any, Any, Any, Any, Any, Any]
|
||||
|
||||
UNITS_DICT = {
|
||||
"y": "year",
|
||||
"q": "quarter",
|
||||
"mo": "month",
|
||||
"d": "day",
|
||||
"h": "hour",
|
||||
"m": "minute",
|
||||
"s": "second",
|
||||
"ms": "millisecond",
|
||||
"us": "microsecond",
|
||||
"ns": "nanosecond",
|
||||
}
|
||||
|
||||
# see https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html
|
||||
# and https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior
|
||||
DATETIME_PATTERNS_MAPPING = {
|
||||
"%Y": "yyyy", # Year with century (4 digits)
|
||||
"%y": "yy", # Year without century (2 digits)
|
||||
"%m": "MM", # Month (01-12)
|
||||
"%d": "dd", # Day of the month (01-31)
|
||||
"%H": "HH", # Hour (24-hour clock) (00-23)
|
||||
"%I": "hh", # Hour (12-hour clock) (01-12)
|
||||
"%M": "mm", # Minute (00-59)
|
||||
"%S": "ss", # Second (00-59)
|
||||
"%f": "S", # Microseconds -> Milliseconds
|
||||
"%p": "a", # AM/PM
|
||||
"%a": "E", # Abbreviated weekday name
|
||||
"%A": "E", # Full weekday name
|
||||
"%j": "D", # Day of the year
|
||||
"%z": "Z", # Timezone offset
|
||||
"%s": "X", # Unix timestamp
|
||||
}
|
||||
|
||||
|
||||
# NOTE: don't lru_cache this as `ModuleType` isn't hashable
|
||||
def native_to_narwhals_dtype( # noqa: C901, PLR0912
|
||||
dtype: _NativeDType, version: Version, spark_types: ModuleType, session: SparkSession
|
||||
) -> DType:
|
||||
dtypes = version.dtypes
|
||||
if TYPE_CHECKING:
|
||||
native = sqlframe_types
|
||||
else:
|
||||
native = spark_types
|
||||
|
||||
if isinstance(dtype, native.DoubleType):
|
||||
return dtypes.Float64()
|
||||
if isinstance(dtype, native.FloatType):
|
||||
return dtypes.Float32()
|
||||
if isinstance(dtype, native.LongType):
|
||||
return dtypes.Int64()
|
||||
if isinstance(dtype, native.IntegerType):
|
||||
return dtypes.Int32()
|
||||
if isinstance(dtype, native.ShortType):
|
||||
return dtypes.Int16()
|
||||
if isinstance(dtype, native.ByteType):
|
||||
return dtypes.Int8()
|
||||
if isinstance(dtype, (native.StringType, native.VarcharType, native.CharType)):
|
||||
return dtypes.String()
|
||||
if isinstance(dtype, native.BooleanType):
|
||||
return dtypes.Boolean()
|
||||
if isinstance(dtype, native.DateType):
|
||||
return dtypes.Date()
|
||||
if isinstance(dtype, native.TimestampNTZType):
|
||||
# TODO(marco): cover this
|
||||
return dtypes.Datetime() # pragma: no cover
|
||||
if isinstance(dtype, native.TimestampType):
|
||||
return dtypes.Datetime(time_zone=fetch_session_time_zone(session))
|
||||
if isinstance(dtype, native.DecimalType):
|
||||
# TODO(marco): cover this
|
||||
return dtypes.Decimal() # pragma: no cover
|
||||
if isinstance(dtype, native.ArrayType):
|
||||
return dtypes.List(
|
||||
inner=native_to_narwhals_dtype(
|
||||
dtype.elementType, version, spark_types, session
|
||||
)
|
||||
)
|
||||
if isinstance(dtype, native.StructType):
|
||||
return dtypes.Struct(
|
||||
fields=[
|
||||
dtypes.Field(
|
||||
name=field.name,
|
||||
dtype=native_to_narwhals_dtype(
|
||||
field.dataType, version, spark_types, session
|
||||
),
|
||||
)
|
||||
for field in dtype
|
||||
]
|
||||
)
|
||||
if isinstance(dtype, native.BinaryType):
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_session_time_zone(session: SparkSession) -> str:
|
||||
# Timezone can't be changed in PySpark session, so this can be cached.
|
||||
try:
|
||||
return session.conf.get("spark.sql.session.timeZone") # type: ignore[attr-defined]
|
||||
except Exception: # noqa: BLE001
|
||||
# https://github.com/eakmanrq/sqlframe/issues/406
|
||||
return "<unknown>"
|
||||
|
||||
|
||||
IntoSparkDType: TypeAlias = Callable[[ModuleType], Callable[[], "_NativeDType"]]
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_SPARK_DTYPES: Mapping[type[DType], IntoSparkDType] = {
|
||||
dtypes.Float64: attrgetter("DoubleType"),
|
||||
dtypes.Float32: attrgetter("FloatType"),
|
||||
dtypes.Binary: attrgetter("BinaryType"),
|
||||
dtypes.String: attrgetter("StringType"),
|
||||
dtypes.Boolean: attrgetter("BooleanType"),
|
||||
dtypes.Date: attrgetter("DateType"),
|
||||
dtypes.Int8: attrgetter("ByteType"),
|
||||
dtypes.Int16: attrgetter("ShortType"),
|
||||
dtypes.Int32: attrgetter("IntegerType"),
|
||||
dtypes.Int64: attrgetter("LongType"),
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (
|
||||
dtypes.UInt64,
|
||||
dtypes.UInt32,
|
||||
dtypes.UInt16,
|
||||
dtypes.UInt8,
|
||||
dtypes.Enum,
|
||||
dtypes.Categorical,
|
||||
dtypes.Time,
|
||||
)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype(
|
||||
dtype: IntoDType, version: Version, spark_types: ModuleType, session: SparkSession
|
||||
) -> _NativeDType:
|
||||
dtypes = version.dtypes
|
||||
if TYPE_CHECKING:
|
||||
native = sqlframe_types
|
||||
else:
|
||||
native = spark_types
|
||||
base_type = dtype.base_type()
|
||||
if into_spark_type := NW_TO_SPARK_DTYPES.get(base_type):
|
||||
return into_spark_type(native)()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
if (tu := dtype.time_unit) != "us": # pragma: no cover
|
||||
msg = f"Only microsecond precision is supported for PySpark, got: {tu}."
|
||||
raise ValueError(msg)
|
||||
dt_time_zone = dtype.time_zone
|
||||
if dt_time_zone is None:
|
||||
return native.TimestampNTZType()
|
||||
if dt_time_zone != (tz := fetch_session_time_zone(session)): # pragma: no cover
|
||||
msg = f"Only {tz} time zone is supported, as that's the connection time zone, got: {dt_time_zone}"
|
||||
raise ValueError(msg)
|
||||
# TODO(unassigned): cover once https://github.com/narwhals-dev/narwhals/issues/2742 addressed
|
||||
return native.TimestampType() # pragma: no cover
|
||||
if isinstance_or_issubclass(dtype, (dtypes.List, dtypes.Array)):
|
||||
return native.ArrayType(
|
||||
elementType=narwhals_to_native_dtype(dtype.inner, version, native, session)
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
|
||||
return native.StructType(
|
||||
fields=[
|
||||
native.StructField(
|
||||
name=field.name,
|
||||
dataType=narwhals_to_native_dtype(
|
||||
field.dtype, version, native, session
|
||||
),
|
||||
)
|
||||
for field in dtype.fields
|
||||
]
|
||||
)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES): # pragma: no cover
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for Spark-Like backend."
|
||||
raise UnsupportedDTypeError(msg)
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def evaluate_exprs(
|
||||
df: SparkLikeLazyFrame, /, *exprs: SparkLikeExpr
|
||||
) -> list[tuple[str, Column]]:
|
||||
native_results: list[tuple[str, Column]] = []
|
||||
|
||||
for expr in exprs:
|
||||
native_series_list = expr._call(df)
|
||||
output_names = expr._evaluate_output_names(df)
|
||||
if expr._alias_output_names is not None:
|
||||
output_names = expr._alias_output_names(output_names)
|
||||
if len(output_names) != len(native_series_list): # pragma: no cover
|
||||
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
|
||||
raise AssertionError(msg)
|
||||
native_results.extend(zip(output_names, native_series_list))
|
||||
|
||||
return native_results
|
||||
|
||||
|
||||
def import_functions(implementation: Implementation, /) -> ModuleType:
|
||||
if implementation is Implementation.PYSPARK:
|
||||
from pyspark.sql import functions
|
||||
|
||||
return functions
|
||||
if implementation is Implementation.PYSPARK_CONNECT:
|
||||
from pyspark.sql.connect import functions
|
||||
|
||||
return functions
|
||||
from sqlframe.base.session import _BaseSession
|
||||
|
||||
return import_module(f"sqlframe.{_BaseSession().execution_dialect_name}.functions")
|
||||
|
||||
|
||||
def import_native_dtypes(implementation: Implementation, /) -> ModuleType:
|
||||
if implementation is Implementation.PYSPARK:
|
||||
from pyspark.sql import types
|
||||
|
||||
return types
|
||||
if implementation is Implementation.PYSPARK_CONNECT:
|
||||
from pyspark.sql.connect import types
|
||||
|
||||
return types
|
||||
from sqlframe.base.session import _BaseSession
|
||||
|
||||
return import_module(f"sqlframe.{_BaseSession().execution_dialect_name}.types")
|
||||
|
||||
|
||||
def import_window(implementation: Implementation, /) -> type[Any]:
|
||||
if implementation is Implementation.PYSPARK:
|
||||
from pyspark.sql import Window
|
||||
|
||||
return Window
|
||||
|
||||
if implementation is Implementation.PYSPARK_CONNECT:
|
||||
from pyspark.sql.connect.window import Window
|
||||
|
||||
return Window
|
||||
from sqlframe.base.session import _BaseSession
|
||||
|
||||
return import_module(
|
||||
f"sqlframe.{_BaseSession().execution_dialect_name}.window"
|
||||
).Window
|
||||
|
||||
|
||||
@overload
|
||||
def strptime_to_pyspark_format(format: None) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def strptime_to_pyspark_format(format: str) -> str: ...
|
||||
|
||||
|
||||
def strptime_to_pyspark_format(format: str | None) -> str | None:
|
||||
"""Converts a Python strptime datetime format string to a PySpark datetime format string."""
|
||||
if format is None: # pragma: no cover
|
||||
return None
|
||||
|
||||
# Replace Python format specifiers with PySpark specifiers
|
||||
pyspark_format = format
|
||||
for py_format, spark_format in DATETIME_PATTERNS_MAPPING.items():
|
||||
pyspark_format = pyspark_format.replace(py_format, spark_format)
|
||||
return pyspark_format.replace("T", " ")
|
||||
|
||||
|
||||
def true_divide(F: Any, left: Column, right: Column) -> Column:
|
||||
# PySpark before 3.5 doesn't have `try_divide`, SQLFrame doesn't have it.
|
||||
divide = getattr(F, "try_divide", operator.truediv)
|
||||
return divide(left, right)
|
||||
|
||||
|
||||
def catch_pyspark_sql_exception(
|
||||
exception: Exception, frame: CompliantLazyFrameAny, /
|
||||
) -> ColumnNotFoundError | Exception: # pragma: no cover
|
||||
from pyspark.errors import AnalysisException
|
||||
|
||||
if isinstance(exception, AnalysisException) and str(exception).startswith(
|
||||
"[UNRESOLVED_COLUMN.WITH_SUGGESTION]"
|
||||
):
|
||||
return ColumnNotFoundError.from_available_column_names(
|
||||
available_columns=frame.columns
|
||||
)
|
||||
# Just return exception as-is.
|
||||
return exception
|
||||
|
||||
|
||||
def catch_pyspark_connect_exception(
|
||||
exception: Exception, /
|
||||
) -> ColumnNotFoundError | Exception: # pragma: no cover
|
||||
from pyspark.errors.exceptions.connect import AnalysisException
|
||||
|
||||
if isinstance(exception, AnalysisException) and str(exception).startswith(
|
||||
"[UNRESOLVED_COLUMN.WITH_SUGGESTION]"
|
||||
):
|
||||
return ColumnNotFoundError(str(exception))
|
||||
# Just return exception as-is.
|
||||
return exception
|
Reference in New Issue
Block a user