This commit is contained in:
2025-09-07 22:09:54 +02:00
parent e1b817252c
commit 2fc0d000b6
7796 changed files with 2159515 additions and 933 deletions

View File

@ -0,0 +1 @@
# ! Any change to this module will trigger the pyspark and pyspark-connect tests in CI

View File

@ -0,0 +1,601 @@
from __future__ import annotations
from functools import reduce
from operator import and_
from typing import TYPE_CHECKING, Any
from narwhals._exceptions import issue_warning
from narwhals._namespace import is_native_spark_like
from narwhals._spark_like.utils import (
catch_pyspark_connect_exception,
catch_pyspark_sql_exception,
evaluate_exprs,
import_functions,
import_native_dtypes,
import_window,
native_to_narwhals_dtype,
)
from narwhals._sql.dataframe import SQLLazyFrame
from narwhals._utils import (
Implementation,
ValidateBackendVersion,
generate_temporary_column_name,
not_implemented,
parse_columns_to_drop,
zip_strict,
)
from narwhals.exceptions import InvalidOperationError
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from io import BytesIO
from pathlib import Path
from types import ModuleType
import pyarrow as pa
from sqlframe.base.column import Column
from sqlframe.base.dataframe import BaseDataFrame
from sqlframe.base.window import Window
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.typing import CompliantDataFrameAny
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals._spark_like.utils import SparkSession
from narwhals._typing import _EagerAllowedImpl
from narwhals._utils import Version, _LimitedContext
from narwhals.dataframe import LazyFrame
from narwhals.dtypes import DType
from narwhals.typing import JoinStrategy, LazyUniqueKeepStrategy
SQLFrameDataFrame = BaseDataFrame[Any, Any, Any, Any, Any]
Incomplete: TypeAlias = Any # pragma: no cover
"""Marker for working code that fails type checking."""
class SparkLikeLazyFrame(
SQLLazyFrame["SparkLikeExpr", "SQLFrameDataFrame", "LazyFrame[SQLFrameDataFrame]"],
ValidateBackendVersion,
):
def __init__(
self,
native_dataframe: SQLFrameDataFrame,
*,
version: Version,
implementation: Implementation,
validate_backend_version: bool = False,
) -> None:
self._native_frame: SQLFrameDataFrame = native_dataframe
self._implementation = implementation
self._version = version
self._cached_schema: dict[str, DType] | None = None
self._cached_columns: list[str] | None = None
if validate_backend_version: # pragma: no cover
self._validate_backend_version()
@property
def _backend_version(self) -> tuple[int, ...]: # pragma: no cover
return self._implementation._backend_version()
@property
def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from sqlframe.base import functions
return functions
return import_functions(self._implementation)
@property
def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from sqlframe.base import types
return types
return import_native_dtypes(self._implementation)
@property
def _Window(self) -> type[Window]:
if TYPE_CHECKING:
from sqlframe.base.window import Window
return Window
return import_window(self._implementation)
@staticmethod
def _is_native(obj: SQLFrameDataFrame | Any) -> TypeIs[SQLFrameDataFrame]:
return is_native_spark_like(obj)
@classmethod
def from_native(cls, data: SQLFrameDataFrame, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version, implementation=context._implementation)
def to_narwhals(self) -> LazyFrame[SQLFrameDataFrame]:
return self._version.lazyframe(self, level="lazy")
def __native_namespace__(self) -> ModuleType: # pragma: no cover
return self._implementation.to_native_namespace()
def __narwhals_namespace__(self) -> SparkLikeNamespace:
from narwhals._spark_like.namespace import SparkLikeNamespace
return SparkLikeNamespace(
version=self._version, implementation=self._implementation
)
def __narwhals_lazyframe__(self) -> Self:
return self
def _with_version(self, version: Version) -> Self:
return self.__class__(
self.native, version=version, implementation=self._implementation
)
def _with_native(self, df: SQLFrameDataFrame) -> Self:
return self.__class__(
df, version=self._version, implementation=self._implementation
)
def _to_arrow_schema(self) -> pa.Schema: # pragma: no cover
import pyarrow as pa # ignore-banned-import
from narwhals._arrow.utils import narwhals_to_native_dtype
schema: list[tuple[str, pa.DataType]] = []
nw_schema = self.collect_schema()
native_schema = self.native.schema
for key, value in nw_schema.items():
try:
native_dtype = narwhals_to_native_dtype(value, self._version)
except Exception as exc: # noqa: BLE001,PERF203
native_spark_dtype = native_schema[key].dataType # type: ignore[index]
# If we can't convert the type, just set it to `pa.null`, and warn.
# Avoid the warning if we're starting from PySpark's void type.
# We can avoid the check when we introduce `nw.Null` dtype.
null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue]
if not isinstance(native_spark_dtype, null_type):
issue_warning(
f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}",
UserWarning,
)
schema.append((key, pa.null()))
else:
schema.append((key, native_dtype))
return pa.schema(schema)
def _collect_to_arrow(self) -> pa.Table:
if self._implementation.is_pyspark() and self._backend_version < (4,):
import pyarrow as pa # ignore-banned-import
try:
return pa.Table.from_batches(self.native._collect_as_arrow())
except ValueError as exc:
if "at least one RecordBatch" in str(exc):
# Empty dataframe
data: dict[str, list[Any]] = {k: [] for k in self.columns}
pa_schema = self._to_arrow_schema()
return pa.Table.from_pydict(data, schema=pa_schema)
raise # pragma: no cover
elif self._implementation.is_pyspark_connect() and self._backend_version < (4,):
import pyarrow as pa # ignore-banned-import
pa_schema = self._to_arrow_schema()
return pa.Table.from_pandas(self.native.toPandas(), schema=pa_schema)
else:
return self.native.toArrow()
def _iter_columns(self) -> Iterator[Column]:
for col in self.columns:
yield self._F.col(col)
@property
def columns(self) -> list[str]:
if self._cached_columns is None:
self._cached_columns = (
list(self.schema)
if self._cached_schema is not None
else self.native.columns
)
return self._cached_columns
def _collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny:
if backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
self.native.toPandas(),
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
if backend is None or backend is Implementation.PYARROW:
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
self._collect_to_arrow(),
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame(
pl.from_arrow(self._collect_to_arrow()), # type: ignore[arg-type]
validate_backend_version=True,
version=self._version,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny:
if self._implementation.is_pyspark_connect():
try:
return self._collect(backend, **kwargs)
except Exception as e: # noqa: BLE001
raise catch_pyspark_connect_exception(e) from None
return self._collect(backend, **kwargs)
def simple_select(self, *column_names: str) -> Self:
return self._with_native(self.native.select(*column_names))
def aggregate(self, *exprs: SparkLikeExpr) -> Self:
new_columns = evaluate_exprs(self, *exprs)
new_columns_list = [col.alias(col_name) for col_name, col in new_columns]
if self._implementation.is_pyspark():
try:
return self._with_native(self.native.agg(*new_columns_list))
except Exception as e: # noqa: BLE001
raise catch_pyspark_sql_exception(e, self) from None
return self._with_native(self.native.agg(*new_columns_list))
def select(self, *exprs: SparkLikeExpr) -> Self:
new_columns = evaluate_exprs(self, *exprs)
new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns]
if self._implementation.is_pyspark(): # pragma: no cover
try:
return self._with_native(self.native.select(*new_columns_list))
except Exception as e: # noqa: BLE001
raise catch_pyspark_sql_exception(e, self) from None
return self._with_native(self.native.select(*new_columns_list))
def with_columns(self, *exprs: SparkLikeExpr) -> Self:
new_columns = evaluate_exprs(self, *exprs)
if self._implementation.is_pyspark(): # pragma: no cover
try:
return self._with_native(self.native.withColumns(dict(new_columns)))
except Exception as e: # noqa: BLE001
raise catch_pyspark_sql_exception(e, self) from None
return self._with_native(self.native.withColumns(dict(new_columns)))
def filter(self, predicate: SparkLikeExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
condition = predicate._call(self)[0]
if self._implementation.is_pyspark():
try:
return self._with_native(self.native.where(condition))
except Exception as e: # noqa: BLE001
raise catch_pyspark_sql_exception(e, self) from None
return self._with_native(self.native.where(condition))
@property
def schema(self) -> dict[str, DType]:
if self._cached_schema is None:
self._cached_schema = {
field.name: native_to_narwhals_dtype(
field.dataType,
self._version,
self._native_dtypes,
self.native.sparkSession,
)
for field in self.native.schema
}
return self._cached_schema
def collect_schema(self) -> dict[str, DType]:
return self.schema
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
return self._with_native(self.native.drop(*columns_to_drop))
def head(self, n: int) -> Self:
return self._with_native(self.native.limit(n))
def group_by(
self, keys: Sequence[str] | Sequence[SparkLikeExpr], *, drop_null_keys: bool
) -> SparkLikeLazyGroupBy:
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
return SparkLikeLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
if isinstance(descending, bool):
descending = [descending] * len(by)
if nulls_last:
sort_funcs = (
self._F.desc_nulls_last if d else self._F.asc_nulls_last
for d in descending
)
else:
sort_funcs = (
self._F.desc_nulls_first if d else self._F.asc_nulls_first
for d in descending
)
sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)]
return self._with_native(self.native.sort(*sort_cols))
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
by = list(by)
if isinstance(reverse, bool):
reverse = [reverse] * len(by)
sort_funcs = (
self._F.desc_nulls_last if not d else self._F.asc_nulls_last for d in reverse
)
sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)]
return self._with_native(self.native.sort(*sort_cols).limit(k))
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
subset = list(subset) if subset else None
return self._with_native(self.native.dropna(subset=subset))
def rename(self, mapping: Mapping[str, str]) -> Self:
rename_mapping = {
colname: mapping.get(colname, colname) for colname in self.columns
}
return self._with_native(
self.native.select(
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
)
)
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self:
if subset and (error := self._check_columns_exist(subset)):
raise error
subset = list(subset) if subset else None
if keep == "none":
tmp = generate_temporary_column_name(8, self.columns)
window = self._Window.partitionBy(subset or self.columns)
df = (
self.native.withColumn(tmp, self._F.count("*").over(window))
.filter(self._F.col(tmp) == self._F.lit(1))
.drop(self._F.col(tmp))
)
return self._with_native(df)
return self._with_native(self.native.dropDuplicates(subset=subset))
def join(
self,
other: Self,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
left_columns = self.columns
right_columns = other.columns
right_on_: list[str] = list(right_on) if right_on is not None else []
left_on_: list[str] = list(left_on) if left_on is not None else []
# create a mapping for columns on other
# `right_on` columns will be renamed as `left_on`
# the remaining columns will be either added the suffix or left unchanged.
right_cols_to_rename = (
[c for c in right_columns if c not in right_on_]
if how != "full"
else right_columns
)
rename_mapping = {
**dict(zip(right_on_, left_on_)),
**{
colname: f"{colname}{suffix}" if colname in left_columns else colname
for colname in right_cols_to_rename
},
}
other_native = other.native.select(
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
)
# If how in {"semi", "anti"}, then resulting columns are same as left columns
# Otherwise, we add the right columns with the new mapping, while keeping the
# original order of right_columns.
col_order = left_columns.copy()
if how in {"inner", "left", "cross"}:
col_order.extend(
rename_mapping[colname]
for colname in right_columns
if colname not in right_on_
)
elif how == "full":
col_order.extend(rename_mapping.values())
right_on_remapped = [rename_mapping[c] for c in right_on_]
on_ = (
reduce(
and_,
(
getattr(self.native, left_key) == getattr(other_native, right_key)
for left_key, right_key in zip_strict(left_on_, right_on_remapped)
),
)
if how == "full"
else None
if how == "cross"
else left_on_
)
how_native = "full_outer" if how == "full" else how
return self._with_native(
self.native.join(other_native, on=on_, how=how_native).select(col_order)
)
def explode(self, columns: Sequence[str]) -> Self:
dtypes = self._version.dtypes
schema = self.collect_schema()
for col_to_explode in columns:
dtype = schema[col_to_explode]
if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)
raise InvalidOperationError(msg)
column_names = self.columns
if len(columns) != 1:
msg = (
"Exploding on multiple columns is not supported with SparkLike backend since "
"we cannot guarantee that the exploded columns have matching element counts."
)
raise NotImplementedError(msg)
if self._implementation.is_pyspark() or self._implementation.is_pyspark_connect():
return self._with_native(
self.native.select(
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.explode_outer(col_name).alias(col_name)
for col_name in column_names
]
)
)
if self._implementation.is_sqlframe():
# Not every sqlframe dialect supports `explode_outer` function
# (see https://github.com/eakmanrq/sqlframe/blob/3cb899c515b101ff4c197d84b34fae490d0ed257/sqlframe/base/functions.py#L2288-L2289)
# therefore we simply explode the array column which will ignore nulls and
# zero sized arrays, and append these specific condition with nulls (to
# match polars behavior).
def null_condition(col_name: str) -> Column:
return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0)
return self._with_native(
self.native.select(
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.explode(col_name).alias(col_name)
for col_name in column_names
]
).union(
self.native.filter(null_condition(columns[0])).select(
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.lit(None).alias(col_name)
for col_name in column_names
]
)
)
)
msg = "Unreachable code, please report an issue at https://github.com/narwhals-dev/narwhals/issues" # pragma: no cover
raise AssertionError(msg)
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
if self._implementation.is_sqlframe():
if variable_name == "":
msg = "`variable_name` cannot be empty string for sqlframe backend."
raise NotImplementedError(msg)
if value_name == "":
msg = "`value_name` cannot be empty string for sqlframe backend."
raise NotImplementedError(msg)
else: # pragma: no cover
pass
ids = tuple(index) if index else ()
values = (
tuple(set(self.columns).difference(set(ids))) if on is None else tuple(on)
)
unpivoted_native_frame = self.native.unpivot(
ids=ids,
values=values,
variableColumnName=variable_name,
valueColumnName=value_name,
)
if index is None:
unpivoted_native_frame = unpivoted_native_frame.drop(*ids)
return self._with_native(unpivoted_native_frame)
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
if order_by is None:
msg = "Cannot pass `order_by` to `with_row_index` for PySpark-like"
raise TypeError(msg)
row_index_expr = (
self._F.row_number().over(
self._Window.partitionBy(self._F.lit(1)).orderBy(*order_by)
)
- 1
).alias(name)
return self._with_native(self.native.select(row_index_expr, *self.columns))
def sink_parquet(self, file: str | Path | BytesIO) -> None:
self.native.write.parquet(file)
@classmethod
def _from_compliant_dataframe(
cls,
frame: CompliantDataFrameAny,
/,
*,
session: SparkSession,
implementation: Implementation,
version: Version,
) -> SparkLikeLazyFrame:
from importlib.util import find_spec
impl = implementation
is_spark_v4 = (not impl.is_sqlframe()) and impl._backend_version() >= (4, 0, 0)
if is_spark_v4: # pragma: no cover
# pyspark.sql requires pyarrow to be installed from v4.0.0
# and since v4.0.0 the input to `createDataFrame` can be a PyArrow Table.
data: Any = frame.to_arrow()
elif find_spec("pandas"):
data = frame.to_pandas()
else: # pragma: no cover
data = tuple(frame.iter_rows(named=True, buffer_size=512))
return cls(
session.createDataFrame(data),
version=version,
implementation=implementation,
validate_backend_version=True,
)
gather_every = not_implemented.deprecated(
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
)
join_asof = not_implemented()
tail = not_implemented.deprecated(
"`LazyFrame.tail` is deprecated and will be removed in a future version."
)

View File

@ -0,0 +1,391 @@
from __future__ import annotations
import operator
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast
from narwhals._expression_parsing import ExprKind, ExprMetadata
from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
from narwhals._spark_like.expr_struct import SparkLikeExprStructNamespace
from narwhals._spark_like.utils import (
import_functions,
import_native_dtypes,
import_window,
narwhals_to_native_dtype,
true_divide,
)
from narwhals._sql.expr import SQLExpr
from narwhals._utils import Implementation, Version, not_implemented, zip_strict
if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
from sqlframe.base.column import Column
from sqlframe.base.window import Window, WindowSpec
from typing_extensions import Self, TypeAlias
from narwhals._compliant import WindowInputs
from narwhals._compliant.typing import (
AliasNames,
EvalNames,
EvalSeries,
WindowFunction,
)
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals._utils import _LimitedContext
from narwhals.typing import FillNullStrategy, IntoDType, NonNestedLiteral, RankMethod
NativeRankMethod: TypeAlias = Literal["rank", "dense_rank", "row_number"]
SparkWindowFunction = WindowFunction[SparkLikeLazyFrame, Column]
SparkWindowInputs = WindowInputs[Column]
class SparkLikeExpr(SQLExpr["SparkLikeLazyFrame", "Column"]):
def __init__(
self,
call: EvalSeries[SparkLikeLazyFrame, Column],
window_function: SparkWindowFunction | None = None,
*,
evaluate_output_names: EvalNames[SparkLikeLazyFrame],
alias_output_names: AliasNames | None,
version: Version,
implementation: Implementation,
) -> None:
self._call = call
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._version = version
self._implementation = implementation
self._metadata: ExprMetadata | None = None
self._window_function: SparkWindowFunction | None = window_function
_REMAP_RANK_METHOD: ClassVar[Mapping[RankMethod, NativeRankMethod]] = {
"min": "rank",
"max": "rank",
"average": "rank",
"dense": "dense_rank",
"ordinal": "row_number",
}
def _count_star(self) -> Column:
return self._F.count("*")
def _window_expression(
self,
expr: Column,
partition_by: Sequence[str | Column] = (),
order_by: Sequence[str | Column] = (),
rows_start: int | None = None,
rows_end: int | None = None,
*,
descending: Sequence[bool] | None = None,
nulls_last: Sequence[bool] | None = None,
) -> Column:
window = self.partition_by(*partition_by)
if order_by:
window = window.orderBy(
*self._sort(*order_by, descending=descending, nulls_last=nulls_last)
)
if rows_start is not None and rows_end is not None:
window = window.rowsBetween(rows_start, rows_end)
elif rows_end is not None:
window = window.rowsBetween(self._Window.unboundedPreceding, rows_end)
elif rows_start is not None: # pragma: no cover
window = window.rowsBetween(rows_start, self._Window.unboundedFollowing)
return expr.over(window)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
if kind is ExprKind.LITERAL:
return self
return self.over([self._F.lit(1)], [])
@property
def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from sqlframe.base import functions
return functions
return import_functions(self._implementation)
@property
def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from sqlframe.base import types
return types
return import_native_dtypes(self._implementation)
@property
def _Window(self) -> type[Window]:
if TYPE_CHECKING:
from sqlframe.base.window import Window
return Window
return import_window(self._implementation)
def _sort(
self,
*cols: Column | str,
descending: Sequence[bool] | None = None,
nulls_last: Sequence[bool] | None = None,
) -> Iterator[Column]:
F = self._F
descending = descending or [False] * len(cols)
nulls_last = nulls_last or [False] * len(cols)
mapping = {
(False, False): F.asc_nulls_first,
(False, True): F.asc_nulls_last,
(True, False): F.desc_nulls_first,
(True, True): F.desc_nulls_last,
}
yield from (
mapping[(_desc, _nulls_last)](col)
for col, _desc, _nulls_last in zip_strict(cols, descending, nulls_last)
)
def partition_by(self, *cols: Column | str) -> WindowSpec:
"""Wraps `Window().partitionBy`, with default and `WindowInputs` handling."""
return self._Window.partitionBy(*cols or [self._F.lit(1)])
def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover
from narwhals._spark_like.namespace import SparkLikeNamespace
return SparkLikeNamespace(
version=self._version, implementation=self._implementation
)
@classmethod
def _alias_native(cls, expr: Column, name: str) -> Column:
return expr.alias(name)
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[SparkLikeLazyFrame],
/,
*,
context: _LimitedContext,
) -> Self:
def func(df: SparkLikeLazyFrame) -> list[Column]:
return [df._F.col(col_name) for col_name in evaluate_column_names(df)]
return cls(
func,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
version=context._version,
implementation=context._implementation,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: SparkLikeLazyFrame) -> list[Column]:
columns = df.columns
return [df._F.col(columns[i]) for i in column_indices]
return cls(
func,
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
version=context._version,
implementation=context._implementation,
)
def __truediv__(self, other: SparkLikeExpr) -> Self:
def _truediv(expr: Column, other: Column) -> Column:
return true_divide(self._F, expr, other)
return self._with_binary(_truediv, other)
def __rtruediv__(self, other: SparkLikeExpr) -> Self:
def _rtruediv(expr: Column, other: Column) -> Column:
return true_divide(self._F, other, expr)
return self._with_binary(_rtruediv, other).alias("literal")
def __floordiv__(self, other: SparkLikeExpr) -> Self:
def _floordiv(expr: Column, other: Column) -> Column:
return self._F.floor(true_divide(self._F, expr, other))
return self._with_binary(_floordiv, other)
def __rfloordiv__(self, other: SparkLikeExpr) -> Self:
def _rfloordiv(expr: Column, other: Column) -> Column:
return self._F.floor(true_divide(self._F, other, expr))
return self._with_binary(_rfloordiv, other).alias("literal")
def __invert__(self) -> Self:
invert = cast("Callable[..., Column]", operator.invert)
return self._with_elementwise(invert)
def cast(self, dtype: IntoDType) -> Self:
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
spark_dtype = narwhals_to_native_dtype(
dtype, self._version, self._native_dtypes, df.native.sparkSession
)
return [expr.cast(spark_dtype) for expr in self(df)]
def window_f(
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
) -> Sequence[Column]:
spark_dtype = narwhals_to_native_dtype(
dtype, self._version, self._native_dtypes, df.native.sparkSession
)
return [expr.cast(spark_dtype) for expr in self.window_function(df, inputs)]
return self.__class__(
func,
window_f,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)
def median(self) -> Self:
def _median(expr: Column) -> Column:
if self._implementation in {
Implementation.PYSPARK,
Implementation.PYSPARK_CONNECT,
} and Implementation.PYSPARK._backend_version() < (3, 4): # pragma: no cover
# Use percentile_approx with default accuracy parameter (10000)
return self._F.percentile_approx(expr.cast("double"), 0.5)
return self._F.median(expr)
return self._with_callable(_median)
def null_count(self) -> Self:
def _null_count(expr: Column) -> Column:
return self._F.count_if(self._F.isnull(expr))
return self._with_callable(_null_count)
def std(self, ddof: int) -> Self:
F = self._F
if ddof == 0:
return self._with_callable(F.stddev_pop)
if ddof == 1:
return self._with_callable(F.stddev_samp)
def func(expr: Column) -> Column:
n_rows = F.count(expr)
return F.stddev_samp(expr) * F.sqrt((n_rows - 1) / (n_rows - ddof))
return self._with_callable(func)
def var(self, ddof: int) -> Self:
F = self._F
if ddof == 0:
return self._with_callable(F.var_pop)
if ddof == 1:
return self._with_callable(F.var_samp)
def func(expr: Column) -> Column:
n_rows = F.count(expr)
return F.var_samp(expr) * (n_rows - 1) / (n_rows - ddof)
return self._with_callable(func)
def is_finite(self) -> Self:
def _is_finite(expr: Column) -> Column:
# A value is finite if it's not NaN, and not infinite, while NULLs should be
# preserved
is_finite_condition = (
~self._F.isnan(expr)
& (expr != self._F.lit(float("inf")))
& (expr != self._F.lit(float("-inf")))
)
return self._F.when(~self._F.isnull(expr), is_finite_condition).otherwise(
None
)
return self._with_elementwise(_is_finite)
def is_in(self, values: Sequence[Any]) -> Self:
def _is_in(expr: Column) -> Column:
return expr.isin(values) if values else self._F.lit(False)
return self._with_elementwise(_is_in)
def len(self) -> Self:
def _len(_expr: Column) -> Column:
# Use count(*) to count all rows including nulls
return self._F.count("*")
return self._with_callable(_len)
def skew(self) -> Self:
return self._with_callable(self._F.skewness)
def kurtosis(self) -> Self:
return self._with_callable(self._F.kurtosis)
def n_unique(self) -> Self:
def _n_unique(expr: Column) -> Column:
return self._F.count_distinct(expr) + self._F.max(
self._F.isnull(expr).cast(self._native_dtypes.IntegerType())
)
return self._with_callable(_n_unique)
def is_nan(self) -> Self:
def _is_nan(expr: Column) -> Column:
return self._F.when(self._F.isnull(expr), None).otherwise(self._F.isnan(expr))
return self._with_elementwise(_is_nan)
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self:
if strategy is not None:
def _fill_with_strategy(
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
) -> Sequence[Column]:
fn = self._F.last_value if strategy == "forward" else self._F.first_value
if strategy == "forward":
start = self._Window.unboundedPreceding if limit is None else -limit
end = self._Window.currentRow
else:
start = self._Window.currentRow
end = self._Window.unboundedFollowing if limit is None else limit
return [
fn(expr, ignoreNulls=True).over(
self.partition_by(*inputs.partition_by)
.orderBy(*self._sort(*inputs.order_by))
.rowsBetween(start, end)
)
for expr in self(df)
]
return self._with_window_function(_fill_with_strategy)
def _fill_constant(expr: Column, value: Column) -> Column:
return self._F.ifnull(expr, value)
return self._with_elementwise(_fill_constant, value=value)
@property
def str(self) -> SparkLikeExprStringNamespace:
return SparkLikeExprStringNamespace(self)
@property
def dt(self) -> SparkLikeExprDateTimeNamespace:
return SparkLikeExprDateTimeNamespace(self)
@property
def list(self) -> SparkLikeExprListNamespace:
return SparkLikeExprListNamespace(self)
@property
def struct(self) -> SparkLikeExprStructNamespace:
return SparkLikeExprStructNamespace(self)
quantile = not_implemented()

View File

@ -0,0 +1,192 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._constants import US_PER_SECOND
from narwhals._duration import Interval
from narwhals._spark_like.utils import (
UNITS_DICT,
fetch_session_time_zone,
strptime_to_pyspark_format,
)
from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace
from narwhals._utils import not_implemented
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlframe.base.column import Column
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
class SparkLikeExprDateTimeNamespace(SQLExprDateTimeNamesSpace["SparkLikeExpr"]):
def _weekday(self, expr: Column) -> Column:
# PySpark's dayofweek returns 1-7 for Sunday-Saturday
return (self.compliant._F.dayofweek(expr) + 6) % 7
def to_string(self, format: str) -> SparkLikeExpr:
F = self.compliant._F
def _to_string(expr: Column) -> Column:
# Handle special formats
if format == "%G-W%V":
return self._format_iso_week(expr)
if format == "%G-W%V-%u":
return self._format_iso_week_with_day(expr)
format_, suffix = self._format_microseconds(expr, format)
# Convert Python format to PySpark format
pyspark_fmt = strptime_to_pyspark_format(format_)
result = F.date_format(expr, pyspark_fmt)
if "T" in format_:
# `strptime_to_pyspark_format` replaces "T" with " " since pyspark
# does not support the literal "T" in `date_format`.
# If no other spaces are in the given format, then we can revert this
# operation, otherwise we raise an exception.
if " " not in format_:
result = F.replace(result, F.lit(" "), F.lit("T"))
else: # pragma: no cover
msg = (
"`dt.to_string` with a format that contains both spaces and "
" the literal 'T' is not supported for spark-like backends."
)
raise NotImplementedError(msg)
return F.concat(result, *suffix)
return self.compliant._with_elementwise(_to_string)
def millisecond(self) -> SparkLikeExpr:
def _millisecond(expr: Column) -> Column:
return self.compliant._F.floor(
(self.compliant._F.unix_micros(expr) % US_PER_SECOND) / 1000
)
return self.compliant._with_elementwise(_millisecond)
def microsecond(self) -> SparkLikeExpr:
def _microsecond(expr: Column) -> Column:
return self.compliant._F.unix_micros(expr) % US_PER_SECOND
return self.compliant._with_elementwise(_microsecond)
def nanosecond(self) -> SparkLikeExpr:
def _nanosecond(expr: Column) -> Column:
return (self.compliant._F.unix_micros(expr) % US_PER_SECOND) * 1000
return self.compliant._with_elementwise(_nanosecond)
def weekday(self) -> SparkLikeExpr:
return self.compliant._with_elementwise(self._weekday)
def truncate(self, every: str) -> SparkLikeExpr:
interval = Interval.parse(every)
multiple, unit = interval.multiple, interval.unit
if multiple != 1:
msg = f"Only multiple 1 is currently supported for Spark-like.\nGot {multiple!s}."
raise ValueError(msg)
if unit == "ns":
msg = "Truncating to nanoseconds is not yet supported for Spark-like."
raise NotImplementedError(msg)
format = UNITS_DICT[unit]
def _truncate(expr: Column) -> Column:
return self.compliant._F.date_trunc(format, expr)
return self.compliant._with_elementwise(_truncate)
def offset_by(self, by: str) -> SparkLikeExpr:
interval = Interval.parse_no_constraints(by)
multiple, unit = interval.multiple, interval.unit
if unit == "ns": # pragma: no cover
msg = "Offsetting by nanoseconds is not yet supported for Spark-like."
raise NotImplementedError(msg)
F = self.compliant._F
def _offset_by(expr: Column) -> Column:
# https://github.com/eakmanrq/sqlframe/issues/441
return F.timestamp_add( # pyright: ignore[reportAttributeAccessIssue]
UNITS_DICT[unit], F.lit(multiple), expr
)
return self.compliant._with_callable(_offset_by)
def _no_op_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
native_series_list = self.compliant(df)
conn_time_zone = fetch_session_time_zone(df.native.sparkSession)
if conn_time_zone != time_zone:
msg = (
"PySpark stores the time zone in the session, rather than in the "
f"data type, so changing the timezone to anything other than {conn_time_zone} "
" (the current session time zone) is not supported."
)
raise NotImplementedError(msg)
return native_series_list
return self.compliant.__class__(
func,
evaluate_output_names=self.compliant._evaluate_output_names,
alias_output_names=self.compliant._alias_output_names,
version=self.compliant._version,
implementation=self.compliant._implementation,
)
def convert_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover
return self._no_op_time_zone(time_zone)
def replace_time_zone(
self, time_zone: str | None
) -> SparkLikeExpr: # pragma: no cover
if time_zone is None:
return self.compliant._with_elementwise(
lambda expr: expr.cast("timestamp_ntz")
)
return self._no_op_time_zone(time_zone)
def _format_iso_week_with_day(self, expr: Column) -> Column:
"""Format datetime as ISO week string with day."""
F = self.compliant._F
year = F.date_format(expr, "yyyy")
week = F.lpad(F.weekofyear(expr).cast("string"), 2, "0")
day = self._weekday(expr)
return F.concat(year, F.lit("-W"), week, F.lit("-"), day.cast("string"))
def _format_iso_week(self, expr: Column) -> Column:
"""Format datetime as ISO week string."""
F = self.compliant._F
year = F.date_format(expr, "yyyy")
week = F.lpad(F.weekofyear(expr).cast("string"), 2, "0")
return F.concat(year, F.lit("-W"), week)
def _format_microseconds(
self, expr: Column, format: str
) -> tuple[str, tuple[Column, ...]]:
"""Format microseconds if present in format, else it's a no-op."""
F = self.compliant._F
suffix: tuple[Column, ...]
if format.endswith((".%f", "%.f")):
import re
micros = F.unix_micros(expr) % US_PER_SECOND
micros_str = F.lpad(micros.cast("string"), 6, "0")
suffix = (F.lit("."), micros_str)
format_ = re.sub(r"(.%|%.)f$", "", format)
return format_, suffix
return format, ()
timestamp = not_implemented()
total_seconds = not_implemented()
total_minutes = not_implemented()
total_milliseconds = not_implemented()
total_microseconds = not_implemented()
total_nanoseconds = not_implemented()

View File

@ -0,0 +1,35 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import ListNamespace
if TYPE_CHECKING:
from sqlframe.base.column import Column
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals.typing import NonNestedLiteral
class SparkLikeExprListNamespace(
LazyExprNamespace["SparkLikeExpr"], ListNamespace["SparkLikeExpr"]
):
def len(self) -> SparkLikeExpr:
return self.compliant._with_elementwise(self.compliant._F.array_size)
def unique(self) -> SparkLikeExpr:
return self.compliant._with_elementwise(self.compliant._F.array_distinct)
def contains(self, item: NonNestedLiteral) -> SparkLikeExpr:
def func(expr: Column) -> Column:
F = self.compliant._F
return F.array_contains(expr, F.lit(item))
return self.compliant._with_elementwise(func)
def get(self, index: int) -> SparkLikeExpr:
def _get(expr: Column) -> Column:
return expr.getItem(index)
return self.compliant._with_elementwise(_get)

View File

@ -0,0 +1,36 @@
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING
from narwhals._spark_like.utils import strptime_to_pyspark_format
from narwhals._sql.expr_str import SQLExprStringNamespace
from narwhals._utils import _is_naive_format, not_implemented
if TYPE_CHECKING:
from narwhals._spark_like.expr import SparkLikeExpr
class SparkLikeExprStringNamespace(SQLExprStringNamespace["SparkLikeExpr"]):
def to_datetime(self, format: str | None) -> SparkLikeExpr:
F = self.compliant._F
if not format:
function = F.to_timestamp
elif _is_naive_format(format):
function = partial(
F.to_timestamp_ntz, format=F.lit(strptime_to_pyspark_format(format))
)
else:
format = strptime_to_pyspark_format(format)
function = partial(F.to_timestamp, format=format)
return self.compliant._with_elementwise(
lambda expr: function(F.replace(expr, F.lit("T"), F.lit(" ")))
)
def to_date(self, format: str | None) -> SparkLikeExpr:
F = self.compliant._F
return self.compliant._with_elementwise(
lambda expr: F.to_date(expr, format=strptime_to_pyspark_format(format))
)
replace = not_implemented()

View File

@ -0,0 +1,21 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import StructNamespace
if TYPE_CHECKING:
from sqlframe.base.column import Column
from narwhals._spark_like.expr import SparkLikeExpr
class SparkLikeExprStructNamespace(
LazyExprNamespace["SparkLikeExpr"], StructNamespace["SparkLikeExpr"]
):
def field(self, name: str) -> SparkLikeExpr:
def func(expr: Column) -> Column:
return expr.getField(name)
return self.compliant._with_elementwise(func).alias(name)

View File

@ -0,0 +1,37 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._sql.group_by import SQLGroupBy
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlframe.base.column import Column # noqa: F401
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
class SparkLikeLazyGroupBy(SQLGroupBy["SparkLikeLazyFrame", "SparkLikeExpr", "Column"]):
def __init__(
self,
df: SparkLikeLazyFrame,
keys: Sequence[SparkLikeExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
def agg(self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
result = (
self.compliant.native.groupBy(*self._keys).agg(*agg_columns)
if (agg_columns := list(self._evaluate_exprs(exprs)))
else self.compliant.native.select(*self._keys).dropDuplicates()
)
return self.compliant._with_native(result).rename(
dict(zip(self._keys, self._output_key_names))
)

View File

@ -0,0 +1,230 @@
from __future__ import annotations
import operator
from functools import reduce
from typing import TYPE_CHECKING, Any
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals._spark_like.selectors import SparkLikeSelectorNamespace
from narwhals._spark_like.utils import (
import_functions,
import_native_dtypes,
narwhals_to_native_dtype,
true_divide,
)
from narwhals._sql.namespace import SQLNamespace
from narwhals._sql.when_then import SQLThen, SQLWhen
from narwhals._utils import zip_strict
if TYPE_CHECKING:
from collections.abc import Iterable
from sqlframe.base.column import Column
from narwhals._spark_like.dataframe import SQLFrameDataFrame # noqa: F401
from narwhals._utils import Implementation, Version
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral, PythonLiteral
# Adjust slight SQL vs PySpark differences
FUNCTION_REMAPPINGS = {
"starts_with": "startswith",
"ends_with": "endswith",
"trim": "btrim",
"str_split": "split",
"regexp_matches": "regexp",
}
class SparkLikeNamespace(
SQLNamespace[SparkLikeLazyFrame, SparkLikeExpr, "SQLFrameDataFrame", "Column"]
):
def __init__(self, *, version: Version, implementation: Implementation) -> None:
self._version = version
self._implementation = implementation
@property
def selectors(self) -> SparkLikeSelectorNamespace:
return SparkLikeSelectorNamespace.from_namespace(self)
@property
def _expr(self) -> type[SparkLikeExpr]:
return SparkLikeExpr
@property
def _lazyframe(self) -> type[SparkLikeLazyFrame]:
return SparkLikeLazyFrame
@property
def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from sqlframe.base import functions
return functions
return import_functions(self._implementation)
@property
def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from sqlframe.base import types
return types
return import_native_dtypes(self._implementation)
def _function(self, name: str, *args: Column | PythonLiteral) -> Column:
return getattr(self._F, FUNCTION_REMAPPINGS.get(name, name))(*args)
def _lit(self, value: Any) -> Column:
return self._F.lit(value)
def _when(
self, condition: Column, value: Column, otherwise: Column | None = None
) -> Column:
if otherwise is None:
return self._F.when(condition, value)
return self._F.when(condition, value).otherwise(otherwise)
def _coalesce(self, *exprs: Column) -> Column:
return self._F.coalesce(*exprs)
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> SparkLikeExpr:
def _lit(df: SparkLikeLazyFrame) -> list[Column]:
column = df._F.lit(value)
if dtype:
native_dtype = narwhals_to_native_dtype(
dtype, self._version, df._native_dtypes, df.native.sparkSession
)
column = column.cast(native_dtype)
return [column]
return self._expr(
call=_lit,
evaluate_output_names=lambda _df: ["literal"],
alias_output_names=None,
version=self._version,
implementation=self._implementation,
)
def len(self) -> SparkLikeExpr:
def func(df: SparkLikeLazyFrame) -> list[Column]:
return [df._F.count("*")]
return self._expr(
func,
evaluate_output_names=lambda _df: ["len"],
alias_output_names=None,
version=self._version,
implementation=self._implementation,
)
def mean_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr:
def func(cols: Iterable[Column]) -> Column:
cols = list(cols)
F = exprs[0]._F
numerator = reduce(
operator.add, (self._F.coalesce(col, self._F.lit(0)) for col in cols)
)
denominator = reduce(
operator.add,
(col.isNotNull().cast(self._native_dtypes.IntegerType()) for col in cols),
)
return true_divide(F, numerator, denominator)
return self._expr._from_elementwise_horizontal_op(func, *exprs)
def concat(
self, items: Iterable[SparkLikeLazyFrame], *, how: ConcatMethod
) -> SparkLikeLazyFrame:
dfs = [item._native_frame for item in items]
if how == "vertical":
cols_0 = dfs[0].columns
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.columns
if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)):
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0}\n"
f" - dataframe {i}: {cols_current}\n"
)
raise TypeError(msg)
return SparkLikeLazyFrame(
native_dataframe=reduce(lambda x, y: x.union(y), dfs),
version=self._version,
implementation=self._implementation,
)
if how == "diagonal":
return SparkLikeLazyFrame(
native_dataframe=reduce(
lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs
),
version=self._version,
implementation=self._implementation,
)
raise NotImplementedError
def concat_str(
self, *exprs: SparkLikeExpr, separator: str, ignore_nulls: bool
) -> SparkLikeExpr:
def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [s for _expr in exprs for s in _expr(df)]
cols_casted = [s.cast(df._native_dtypes.StringType()) for s in cols]
null_mask = [df._F.isnull(s) for s in cols]
if not ignore_nulls:
null_mask_result = reduce(operator.or_, null_mask)
result = df._F.when(
~null_mask_result,
reduce(
lambda x, y: df._F.format_string(f"%s{separator}%s", x, y),
cols_casted,
),
).otherwise(df._F.lit(None))
else:
init_value, *values = [
df._F.when(~nm, col).otherwise(df._F.lit(""))
for col, nm in zip_strict(cols_casted, null_mask)
]
separators = (
df._F.when(nm, df._F.lit("")).otherwise(df._F.lit(separator))
for nm in null_mask[:-1]
)
result = reduce(
lambda x, y: df._F.format_string("%s%s", x, y),
(
df._F.format_string("%s%s", s, v)
for s, v in zip_strict(separators, values)
),
init_value,
)
return [result]
return self._expr(
call=func,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
implementation=self._implementation,
)
def when(self, predicate: SparkLikeExpr) -> SparkLikeWhen:
return SparkLikeWhen.from_expr(predicate, context=self)
class SparkLikeWhen(SQLWhen[SparkLikeLazyFrame, "Column", SparkLikeExpr]):
@property
def _then(self) -> type[SparkLikeThen]:
return SparkLikeThen
class SparkLikeThen(
SQLThen[SparkLikeLazyFrame, "Column", SparkLikeExpr], SparkLikeExpr
): ...

View File

@ -0,0 +1,31 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
from narwhals._spark_like.expr import SparkLikeExpr
if TYPE_CHECKING:
from sqlframe.base.column import Column # noqa: F401
from narwhals._spark_like.dataframe import SparkLikeLazyFrame # noqa: F401
from narwhals._spark_like.expr import SparkWindowFunction
class SparkLikeSelectorNamespace(LazySelectorNamespace["SparkLikeLazyFrame", "Column"]):
@property
def _selector(self) -> type[SparkLikeSelector]:
return SparkLikeSelector
class SparkLikeSelector(CompliantSelector["SparkLikeLazyFrame", "Column"], SparkLikeExpr): # type: ignore[misc]
_window_function: SparkWindowFunction | None = None
def _to_expr(self) -> SparkLikeExpr:
return SparkLikeExpr(
self._call,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)

View File

@ -0,0 +1,319 @@
from __future__ import annotations
import operator
from collections.abc import Callable
from functools import lru_cache
from importlib import import_module
from operator import attrgetter
from types import ModuleType
from typing import TYPE_CHECKING, Any, overload
from narwhals._utils import Implementation, Version, isinstance_or_issubclass
from narwhals.exceptions import ColumnNotFoundError, UnsupportedDTypeError
if TYPE_CHECKING:
from collections.abc import Mapping
import sqlframe.base.types as sqlframe_types
from sqlframe.base.column import Column
from sqlframe.base.session import _BaseSession as Session
from typing_extensions import TypeAlias
from narwhals._compliant.typing import CompliantLazyFrameAny
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals.dtypes import DType
from narwhals.typing import IntoDType
_NativeDType: TypeAlias = sqlframe_types.DataType
SparkSession = Session[Any, Any, Any, Any, Any, Any, Any]
UNITS_DICT = {
"y": "year",
"q": "quarter",
"mo": "month",
"d": "day",
"h": "hour",
"m": "minute",
"s": "second",
"ms": "millisecond",
"us": "microsecond",
"ns": "nanosecond",
}
# see https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html
# and https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior
DATETIME_PATTERNS_MAPPING = {
"%Y": "yyyy", # Year with century (4 digits)
"%y": "yy", # Year without century (2 digits)
"%m": "MM", # Month (01-12)
"%d": "dd", # Day of the month (01-31)
"%H": "HH", # Hour (24-hour clock) (00-23)
"%I": "hh", # Hour (12-hour clock) (01-12)
"%M": "mm", # Minute (00-59)
"%S": "ss", # Second (00-59)
"%f": "S", # Microseconds -> Milliseconds
"%p": "a", # AM/PM
"%a": "E", # Abbreviated weekday name
"%A": "E", # Full weekday name
"%j": "D", # Day of the year
"%z": "Z", # Timezone offset
"%s": "X", # Unix timestamp
}
# NOTE: don't lru_cache this as `ModuleType` isn't hashable
def native_to_narwhals_dtype( # noqa: C901, PLR0912
dtype: _NativeDType, version: Version, spark_types: ModuleType, session: SparkSession
) -> DType:
dtypes = version.dtypes
if TYPE_CHECKING:
native = sqlframe_types
else:
native = spark_types
if isinstance(dtype, native.DoubleType):
return dtypes.Float64()
if isinstance(dtype, native.FloatType):
return dtypes.Float32()
if isinstance(dtype, native.LongType):
return dtypes.Int64()
if isinstance(dtype, native.IntegerType):
return dtypes.Int32()
if isinstance(dtype, native.ShortType):
return dtypes.Int16()
if isinstance(dtype, native.ByteType):
return dtypes.Int8()
if isinstance(dtype, (native.StringType, native.VarcharType, native.CharType)):
return dtypes.String()
if isinstance(dtype, native.BooleanType):
return dtypes.Boolean()
if isinstance(dtype, native.DateType):
return dtypes.Date()
if isinstance(dtype, native.TimestampNTZType):
# TODO(marco): cover this
return dtypes.Datetime() # pragma: no cover
if isinstance(dtype, native.TimestampType):
return dtypes.Datetime(time_zone=fetch_session_time_zone(session))
if isinstance(dtype, native.DecimalType):
# TODO(marco): cover this
return dtypes.Decimal() # pragma: no cover
if isinstance(dtype, native.ArrayType):
return dtypes.List(
inner=native_to_narwhals_dtype(
dtype.elementType, version, spark_types, session
)
)
if isinstance(dtype, native.StructType):
return dtypes.Struct(
fields=[
dtypes.Field(
name=field.name,
dtype=native_to_narwhals_dtype(
field.dataType, version, spark_types, session
),
)
for field in dtype
]
)
if isinstance(dtype, native.BinaryType):
return dtypes.Binary()
return dtypes.Unknown() # pragma: no cover
@lru_cache(maxsize=4)
def fetch_session_time_zone(session: SparkSession) -> str:
# Timezone can't be changed in PySpark session, so this can be cached.
try:
return session.conf.get("spark.sql.session.timeZone") # type: ignore[attr-defined]
except Exception: # noqa: BLE001
# https://github.com/eakmanrq/sqlframe/issues/406
return "<unknown>"
IntoSparkDType: TypeAlias = Callable[[ModuleType], Callable[[], "_NativeDType"]]
dtypes = Version.MAIN.dtypes
NW_TO_SPARK_DTYPES: Mapping[type[DType], IntoSparkDType] = {
dtypes.Float64: attrgetter("DoubleType"),
dtypes.Float32: attrgetter("FloatType"),
dtypes.Binary: attrgetter("BinaryType"),
dtypes.String: attrgetter("StringType"),
dtypes.Boolean: attrgetter("BooleanType"),
dtypes.Date: attrgetter("DateType"),
dtypes.Int8: attrgetter("ByteType"),
dtypes.Int16: attrgetter("ShortType"),
dtypes.Int32: attrgetter("IntegerType"),
dtypes.Int64: attrgetter("LongType"),
}
UNSUPPORTED_DTYPES = (
dtypes.UInt64,
dtypes.UInt32,
dtypes.UInt16,
dtypes.UInt8,
dtypes.Enum,
dtypes.Categorical,
dtypes.Time,
)
def narwhals_to_native_dtype(
dtype: IntoDType, version: Version, spark_types: ModuleType, session: SparkSession
) -> _NativeDType:
dtypes = version.dtypes
if TYPE_CHECKING:
native = sqlframe_types
else:
native = spark_types
base_type = dtype.base_type()
if into_spark_type := NW_TO_SPARK_DTYPES.get(base_type):
return into_spark_type(native)()
if isinstance_or_issubclass(dtype, dtypes.Datetime):
if (tu := dtype.time_unit) != "us": # pragma: no cover
msg = f"Only microsecond precision is supported for PySpark, got: {tu}."
raise ValueError(msg)
dt_time_zone = dtype.time_zone
if dt_time_zone is None:
return native.TimestampNTZType()
if dt_time_zone != (tz := fetch_session_time_zone(session)): # pragma: no cover
msg = f"Only {tz} time zone is supported, as that's the connection time zone, got: {dt_time_zone}"
raise ValueError(msg)
# TODO(unassigned): cover once https://github.com/narwhals-dev/narwhals/issues/2742 addressed
return native.TimestampType() # pragma: no cover
if isinstance_or_issubclass(dtype, (dtypes.List, dtypes.Array)):
return native.ArrayType(
elementType=narwhals_to_native_dtype(dtype.inner, version, native, session)
)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
return native.StructType(
fields=[
native.StructField(
name=field.name,
dataType=narwhals_to_native_dtype(
field.dtype, version, native, session
),
)
for field in dtype.fields
]
)
if issubclass(base_type, UNSUPPORTED_DTYPES): # pragma: no cover
msg = f"Converting to {base_type.__name__} dtype is not supported for Spark-Like backend."
raise UnsupportedDTypeError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
def evaluate_exprs(
df: SparkLikeLazyFrame, /, *exprs: SparkLikeExpr
) -> list[tuple[str, Column]]:
native_results: list[tuple[str, Column]] = []
for expr in exprs:
native_series_list = expr._call(df)
output_names = expr._evaluate_output_names(df)
if expr._alias_output_names is not None:
output_names = expr._alias_output_names(output_names)
if len(output_names) != len(native_series_list): # pragma: no cover
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.extend(zip(output_names, native_series_list))
return native_results
def import_functions(implementation: Implementation, /) -> ModuleType:
if implementation is Implementation.PYSPARK:
from pyspark.sql import functions
return functions
if implementation is Implementation.PYSPARK_CONNECT:
from pyspark.sql.connect import functions
return functions
from sqlframe.base.session import _BaseSession
return import_module(f"sqlframe.{_BaseSession().execution_dialect_name}.functions")
def import_native_dtypes(implementation: Implementation, /) -> ModuleType:
if implementation is Implementation.PYSPARK:
from pyspark.sql import types
return types
if implementation is Implementation.PYSPARK_CONNECT:
from pyspark.sql.connect import types
return types
from sqlframe.base.session import _BaseSession
return import_module(f"sqlframe.{_BaseSession().execution_dialect_name}.types")
def import_window(implementation: Implementation, /) -> type[Any]:
if implementation is Implementation.PYSPARK:
from pyspark.sql import Window
return Window
if implementation is Implementation.PYSPARK_CONNECT:
from pyspark.sql.connect.window import Window
return Window
from sqlframe.base.session import _BaseSession
return import_module(
f"sqlframe.{_BaseSession().execution_dialect_name}.window"
).Window
@overload
def strptime_to_pyspark_format(format: None) -> None: ...
@overload
def strptime_to_pyspark_format(format: str) -> str: ...
def strptime_to_pyspark_format(format: str | None) -> str | None:
"""Converts a Python strptime datetime format string to a PySpark datetime format string."""
if format is None: # pragma: no cover
return None
# Replace Python format specifiers with PySpark specifiers
pyspark_format = format
for py_format, spark_format in DATETIME_PATTERNS_MAPPING.items():
pyspark_format = pyspark_format.replace(py_format, spark_format)
return pyspark_format.replace("T", " ")
def true_divide(F: Any, left: Column, right: Column) -> Column:
# PySpark before 3.5 doesn't have `try_divide`, SQLFrame doesn't have it.
divide = getattr(F, "try_divide", operator.truediv)
return divide(left, right)
def catch_pyspark_sql_exception(
exception: Exception, frame: CompliantLazyFrameAny, /
) -> ColumnNotFoundError | Exception: # pragma: no cover
from pyspark.errors import AnalysisException
if isinstance(exception, AnalysisException) and str(exception).startswith(
"[UNRESOLVED_COLUMN.WITH_SUGGESTION]"
):
return ColumnNotFoundError.from_available_column_names(
available_columns=frame.columns
)
# Just return exception as-is.
return exception
def catch_pyspark_connect_exception(
exception: Exception, /
) -> ColumnNotFoundError | Exception: # pragma: no cover
from pyspark.errors.exceptions.connect import AnalysisException
if isinstance(exception, AnalysisException) and str(exception).startswith(
"[UNRESOLVED_COLUMN.WITH_SUGGESTION]"
):
return ColumnNotFoundError(str(exception))
# Just return exception as-is.
return exception