done
This commit is contained in:
319
lib/python3.11/site-packages/narwhals/_spark_like/utils.py
Normal file
319
lib/python3.11/site-packages/narwhals/_spark_like/utils.py
Normal file
@ -0,0 +1,319 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from importlib import import_module
|
||||
from operator import attrgetter
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, overload
|
||||
|
||||
from narwhals._utils import Implementation, Version, isinstance_or_issubclass
|
||||
from narwhals.exceptions import ColumnNotFoundError, UnsupportedDTypeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
import sqlframe.base.types as sqlframe_types
|
||||
from sqlframe.base.column import Column
|
||||
from sqlframe.base.session import _BaseSession as Session
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.typing import CompliantLazyFrameAny
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType
|
||||
|
||||
_NativeDType: TypeAlias = sqlframe_types.DataType
|
||||
SparkSession = Session[Any, Any, Any, Any, Any, Any, Any]
|
||||
|
||||
UNITS_DICT = {
|
||||
"y": "year",
|
||||
"q": "quarter",
|
||||
"mo": "month",
|
||||
"d": "day",
|
||||
"h": "hour",
|
||||
"m": "minute",
|
||||
"s": "second",
|
||||
"ms": "millisecond",
|
||||
"us": "microsecond",
|
||||
"ns": "nanosecond",
|
||||
}
|
||||
|
||||
# see https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html
|
||||
# and https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior
|
||||
DATETIME_PATTERNS_MAPPING = {
|
||||
"%Y": "yyyy", # Year with century (4 digits)
|
||||
"%y": "yy", # Year without century (2 digits)
|
||||
"%m": "MM", # Month (01-12)
|
||||
"%d": "dd", # Day of the month (01-31)
|
||||
"%H": "HH", # Hour (24-hour clock) (00-23)
|
||||
"%I": "hh", # Hour (12-hour clock) (01-12)
|
||||
"%M": "mm", # Minute (00-59)
|
||||
"%S": "ss", # Second (00-59)
|
||||
"%f": "S", # Microseconds -> Milliseconds
|
||||
"%p": "a", # AM/PM
|
||||
"%a": "E", # Abbreviated weekday name
|
||||
"%A": "E", # Full weekday name
|
||||
"%j": "D", # Day of the year
|
||||
"%z": "Z", # Timezone offset
|
||||
"%s": "X", # Unix timestamp
|
||||
}
|
||||
|
||||
|
||||
# NOTE: don't lru_cache this as `ModuleType` isn't hashable
|
||||
def native_to_narwhals_dtype( # noqa: C901, PLR0912
|
||||
dtype: _NativeDType, version: Version, spark_types: ModuleType, session: SparkSession
|
||||
) -> DType:
|
||||
dtypes = version.dtypes
|
||||
if TYPE_CHECKING:
|
||||
native = sqlframe_types
|
||||
else:
|
||||
native = spark_types
|
||||
|
||||
if isinstance(dtype, native.DoubleType):
|
||||
return dtypes.Float64()
|
||||
if isinstance(dtype, native.FloatType):
|
||||
return dtypes.Float32()
|
||||
if isinstance(dtype, native.LongType):
|
||||
return dtypes.Int64()
|
||||
if isinstance(dtype, native.IntegerType):
|
||||
return dtypes.Int32()
|
||||
if isinstance(dtype, native.ShortType):
|
||||
return dtypes.Int16()
|
||||
if isinstance(dtype, native.ByteType):
|
||||
return dtypes.Int8()
|
||||
if isinstance(dtype, (native.StringType, native.VarcharType, native.CharType)):
|
||||
return dtypes.String()
|
||||
if isinstance(dtype, native.BooleanType):
|
||||
return dtypes.Boolean()
|
||||
if isinstance(dtype, native.DateType):
|
||||
return dtypes.Date()
|
||||
if isinstance(dtype, native.TimestampNTZType):
|
||||
# TODO(marco): cover this
|
||||
return dtypes.Datetime() # pragma: no cover
|
||||
if isinstance(dtype, native.TimestampType):
|
||||
return dtypes.Datetime(time_zone=fetch_session_time_zone(session))
|
||||
if isinstance(dtype, native.DecimalType):
|
||||
# TODO(marco): cover this
|
||||
return dtypes.Decimal() # pragma: no cover
|
||||
if isinstance(dtype, native.ArrayType):
|
||||
return dtypes.List(
|
||||
inner=native_to_narwhals_dtype(
|
||||
dtype.elementType, version, spark_types, session
|
||||
)
|
||||
)
|
||||
if isinstance(dtype, native.StructType):
|
||||
return dtypes.Struct(
|
||||
fields=[
|
||||
dtypes.Field(
|
||||
name=field.name,
|
||||
dtype=native_to_narwhals_dtype(
|
||||
field.dataType, version, spark_types, session
|
||||
),
|
||||
)
|
||||
for field in dtype
|
||||
]
|
||||
)
|
||||
if isinstance(dtype, native.BinaryType):
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_session_time_zone(session: SparkSession) -> str:
|
||||
# Timezone can't be changed in PySpark session, so this can be cached.
|
||||
try:
|
||||
return session.conf.get("spark.sql.session.timeZone") # type: ignore[attr-defined]
|
||||
except Exception: # noqa: BLE001
|
||||
# https://github.com/eakmanrq/sqlframe/issues/406
|
||||
return "<unknown>"
|
||||
|
||||
|
||||
IntoSparkDType: TypeAlias = Callable[[ModuleType], Callable[[], "_NativeDType"]]
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_SPARK_DTYPES: Mapping[type[DType], IntoSparkDType] = {
|
||||
dtypes.Float64: attrgetter("DoubleType"),
|
||||
dtypes.Float32: attrgetter("FloatType"),
|
||||
dtypes.Binary: attrgetter("BinaryType"),
|
||||
dtypes.String: attrgetter("StringType"),
|
||||
dtypes.Boolean: attrgetter("BooleanType"),
|
||||
dtypes.Date: attrgetter("DateType"),
|
||||
dtypes.Int8: attrgetter("ByteType"),
|
||||
dtypes.Int16: attrgetter("ShortType"),
|
||||
dtypes.Int32: attrgetter("IntegerType"),
|
||||
dtypes.Int64: attrgetter("LongType"),
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (
|
||||
dtypes.UInt64,
|
||||
dtypes.UInt32,
|
||||
dtypes.UInt16,
|
||||
dtypes.UInt8,
|
||||
dtypes.Enum,
|
||||
dtypes.Categorical,
|
||||
dtypes.Time,
|
||||
)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype(
|
||||
dtype: IntoDType, version: Version, spark_types: ModuleType, session: SparkSession
|
||||
) -> _NativeDType:
|
||||
dtypes = version.dtypes
|
||||
if TYPE_CHECKING:
|
||||
native = sqlframe_types
|
||||
else:
|
||||
native = spark_types
|
||||
base_type = dtype.base_type()
|
||||
if into_spark_type := NW_TO_SPARK_DTYPES.get(base_type):
|
||||
return into_spark_type(native)()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
if (tu := dtype.time_unit) != "us": # pragma: no cover
|
||||
msg = f"Only microsecond precision is supported for PySpark, got: {tu}."
|
||||
raise ValueError(msg)
|
||||
dt_time_zone = dtype.time_zone
|
||||
if dt_time_zone is None:
|
||||
return native.TimestampNTZType()
|
||||
if dt_time_zone != (tz := fetch_session_time_zone(session)): # pragma: no cover
|
||||
msg = f"Only {tz} time zone is supported, as that's the connection time zone, got: {dt_time_zone}"
|
||||
raise ValueError(msg)
|
||||
# TODO(unassigned): cover once https://github.com/narwhals-dev/narwhals/issues/2742 addressed
|
||||
return native.TimestampType() # pragma: no cover
|
||||
if isinstance_or_issubclass(dtype, (dtypes.List, dtypes.Array)):
|
||||
return native.ArrayType(
|
||||
elementType=narwhals_to_native_dtype(dtype.inner, version, native, session)
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
|
||||
return native.StructType(
|
||||
fields=[
|
||||
native.StructField(
|
||||
name=field.name,
|
||||
dataType=narwhals_to_native_dtype(
|
||||
field.dtype, version, native, session
|
||||
),
|
||||
)
|
||||
for field in dtype.fields
|
||||
]
|
||||
)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES): # pragma: no cover
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for Spark-Like backend."
|
||||
raise UnsupportedDTypeError(msg)
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def evaluate_exprs(
|
||||
df: SparkLikeLazyFrame, /, *exprs: SparkLikeExpr
|
||||
) -> list[tuple[str, Column]]:
|
||||
native_results: list[tuple[str, Column]] = []
|
||||
|
||||
for expr in exprs:
|
||||
native_series_list = expr._call(df)
|
||||
output_names = expr._evaluate_output_names(df)
|
||||
if expr._alias_output_names is not None:
|
||||
output_names = expr._alias_output_names(output_names)
|
||||
if len(output_names) != len(native_series_list): # pragma: no cover
|
||||
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
|
||||
raise AssertionError(msg)
|
||||
native_results.extend(zip(output_names, native_series_list))
|
||||
|
||||
return native_results
|
||||
|
||||
|
||||
def import_functions(implementation: Implementation, /) -> ModuleType:
|
||||
if implementation is Implementation.PYSPARK:
|
||||
from pyspark.sql import functions
|
||||
|
||||
return functions
|
||||
if implementation is Implementation.PYSPARK_CONNECT:
|
||||
from pyspark.sql.connect import functions
|
||||
|
||||
return functions
|
||||
from sqlframe.base.session import _BaseSession
|
||||
|
||||
return import_module(f"sqlframe.{_BaseSession().execution_dialect_name}.functions")
|
||||
|
||||
|
||||
def import_native_dtypes(implementation: Implementation, /) -> ModuleType:
|
||||
if implementation is Implementation.PYSPARK:
|
||||
from pyspark.sql import types
|
||||
|
||||
return types
|
||||
if implementation is Implementation.PYSPARK_CONNECT:
|
||||
from pyspark.sql.connect import types
|
||||
|
||||
return types
|
||||
from sqlframe.base.session import _BaseSession
|
||||
|
||||
return import_module(f"sqlframe.{_BaseSession().execution_dialect_name}.types")
|
||||
|
||||
|
||||
def import_window(implementation: Implementation, /) -> type[Any]:
|
||||
if implementation is Implementation.PYSPARK:
|
||||
from pyspark.sql import Window
|
||||
|
||||
return Window
|
||||
|
||||
if implementation is Implementation.PYSPARK_CONNECT:
|
||||
from pyspark.sql.connect.window import Window
|
||||
|
||||
return Window
|
||||
from sqlframe.base.session import _BaseSession
|
||||
|
||||
return import_module(
|
||||
f"sqlframe.{_BaseSession().execution_dialect_name}.window"
|
||||
).Window
|
||||
|
||||
|
||||
@overload
|
||||
def strptime_to_pyspark_format(format: None) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def strptime_to_pyspark_format(format: str) -> str: ...
|
||||
|
||||
|
||||
def strptime_to_pyspark_format(format: str | None) -> str | None:
|
||||
"""Converts a Python strptime datetime format string to a PySpark datetime format string."""
|
||||
if format is None: # pragma: no cover
|
||||
return None
|
||||
|
||||
# Replace Python format specifiers with PySpark specifiers
|
||||
pyspark_format = format
|
||||
for py_format, spark_format in DATETIME_PATTERNS_MAPPING.items():
|
||||
pyspark_format = pyspark_format.replace(py_format, spark_format)
|
||||
return pyspark_format.replace("T", " ")
|
||||
|
||||
|
||||
def true_divide(F: Any, left: Column, right: Column) -> Column:
|
||||
# PySpark before 3.5 doesn't have `try_divide`, SQLFrame doesn't have it.
|
||||
divide = getattr(F, "try_divide", operator.truediv)
|
||||
return divide(left, right)
|
||||
|
||||
|
||||
def catch_pyspark_sql_exception(
|
||||
exception: Exception, frame: CompliantLazyFrameAny, /
|
||||
) -> ColumnNotFoundError | Exception: # pragma: no cover
|
||||
from pyspark.errors import AnalysisException
|
||||
|
||||
if isinstance(exception, AnalysisException) and str(exception).startswith(
|
||||
"[UNRESOLVED_COLUMN.WITH_SUGGESTION]"
|
||||
):
|
||||
return ColumnNotFoundError.from_available_column_names(
|
||||
available_columns=frame.columns
|
||||
)
|
||||
# Just return exception as-is.
|
||||
return exception
|
||||
|
||||
|
||||
def catch_pyspark_connect_exception(
|
||||
exception: Exception, /
|
||||
) -> ColumnNotFoundError | Exception: # pragma: no cover
|
||||
from pyspark.errors.exceptions.connect import AnalysisException
|
||||
|
||||
if isinstance(exception, AnalysisException) and str(exception).startswith(
|
||||
"[UNRESOLVED_COLUMN.WITH_SUGGESTION]"
|
||||
):
|
||||
return ColumnNotFoundError(str(exception))
|
||||
# Just return exception as-is.
|
||||
return exception
|
Reference in New Issue
Block a user