This commit is contained in:
2025-09-07 22:09:54 +02:00
parent e1b817252c
commit 2fc0d000b6
7796 changed files with 2159515 additions and 933 deletions

View File

@ -0,0 +1,22 @@
"""Common test support for all numpy test scripts.
This single module should provide all the common functionality for numpy tests
in a single location, so that test scripts can just import it and work right
away.
"""
from unittest import TestCase
from . import _private, overrides
from ._private import extbuild
from ._private.utils import *
from ._private.utils import _assert_valid_refcount, _gen_alignment_data
__all__ = (
_private.utils.__all__ + ['TestCase', 'overrides']
)
from numpy._pytesttester import PytestTester
test = PytestTester(__name__)
del PytestTester

View File

@ -0,0 +1,102 @@
from unittest import TestCase
from . import overrides
from ._private.utils import (
HAS_LAPACK64,
HAS_REFCOUNT,
IS_EDITABLE,
IS_INSTALLED,
IS_MUSL,
IS_PYPY,
IS_PYSTON,
IS_WASM,
NOGIL_BUILD,
NUMPY_ROOT,
IgnoreException,
KnownFailureException,
SkipTest,
assert_,
assert_allclose,
assert_almost_equal,
assert_approx_equal,
assert_array_almost_equal,
assert_array_almost_equal_nulp,
assert_array_compare,
assert_array_equal,
assert_array_less,
assert_array_max_ulp,
assert_equal,
assert_no_gc_cycles,
assert_no_warnings,
assert_raises,
assert_raises_regex,
assert_string_equal,
assert_warns,
break_cycles,
build_err_msg,
check_support_sve,
clear_and_catch_warnings,
decorate_methods,
jiffies,
measure,
memusage,
print_assert_equal,
run_threaded,
rundocs,
runstring,
suppress_warnings,
tempdir,
temppath,
verbose,
)
__all__ = [
"HAS_LAPACK64",
"HAS_REFCOUNT",
"IS_EDITABLE",
"IS_INSTALLED",
"IS_MUSL",
"IS_PYPY",
"IS_PYSTON",
"IS_WASM",
"NOGIL_BUILD",
"NUMPY_ROOT",
"IgnoreException",
"KnownFailureException",
"SkipTest",
"TestCase",
"assert_",
"assert_allclose",
"assert_almost_equal",
"assert_approx_equal",
"assert_array_almost_equal",
"assert_array_almost_equal_nulp",
"assert_array_compare",
"assert_array_equal",
"assert_array_less",
"assert_array_max_ulp",
"assert_equal",
"assert_no_gc_cycles",
"assert_no_warnings",
"assert_raises",
"assert_raises_regex",
"assert_string_equal",
"assert_warns",
"break_cycles",
"build_err_msg",
"check_support_sve",
"clear_and_catch_warnings",
"decorate_methods",
"jiffies",
"measure",
"memusage",
"overrides",
"print_assert_equal",
"run_threaded",
"rundocs",
"runstring",
"suppress_warnings",
"tempdir",
"temppath",
"verbose",
]

View File

@ -0,0 +1,250 @@
"""
Build a c-extension module on-the-fly in tests.
See build_and_import_extensions for usage hints
"""
import os
import pathlib
import subprocess
import sys
import sysconfig
import textwrap
__all__ = ['build_and_import_extension', 'compile_extension_module']
def build_and_import_extension(
modname, functions, *, prologue="", build_dir=None,
include_dirs=None, more_init=""):
"""
Build and imports a c-extension module `modname` from a list of function
fragments `functions`.
Parameters
----------
functions : list of fragments
Each fragment is a sequence of func_name, calling convention, snippet.
prologue : string
Code to precede the rest, usually extra ``#include`` or ``#define``
macros.
build_dir : pathlib.Path
Where to build the module, usually a temporary directory
include_dirs : list
Extra directories to find include files when compiling
more_init : string
Code to appear in the module PyMODINIT_FUNC
Returns
-------
out: module
The module will have been loaded and is ready for use
Examples
--------
>>> functions = [("test_bytes", "METH_O", \"\"\"
if ( !PyBytesCheck(args)) {
Py_RETURN_FALSE;
}
Py_RETURN_TRUE;
\"\"\")]
>>> mod = build_and_import_extension("testme", functions)
>>> assert not mod.test_bytes('abc')
>>> assert mod.test_bytes(b'abc')
"""
if include_dirs is None:
include_dirs = []
body = prologue + _make_methods(functions, modname)
init = """
PyObject *mod = PyModule_Create(&moduledef);
#ifdef Py_GIL_DISABLED
PyUnstable_Module_SetGIL(mod, Py_MOD_GIL_NOT_USED);
#endif
"""
if not build_dir:
build_dir = pathlib.Path('.')
if more_init:
init += """#define INITERROR return NULL
"""
init += more_init
init += "\nreturn mod;"
source_string = _make_source(modname, init, body)
mod_so = compile_extension_module(
modname, build_dir, include_dirs, source_string)
import importlib.util
spec = importlib.util.spec_from_file_location(modname, mod_so)
foo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(foo)
return foo
def compile_extension_module(
name, builddir, include_dirs,
source_string, libraries=None, library_dirs=None):
"""
Build an extension module and return the filename of the resulting
native code file.
Parameters
----------
name : string
name of the module, possibly including dots if it is a module inside a
package.
builddir : pathlib.Path
Where to build the module, usually a temporary directory
include_dirs : list
Extra directories to find include files when compiling
libraries : list
Libraries to link into the extension module
library_dirs: list
Where to find the libraries, ``-L`` passed to the linker
"""
modname = name.split('.')[-1]
dirname = builddir / name
dirname.mkdir(exist_ok=True)
cfile = _convert_str_to_file(source_string, dirname)
include_dirs = include_dirs or []
libraries = libraries or []
library_dirs = library_dirs or []
return _c_compile(
cfile, outputfilename=dirname / modname,
include_dirs=include_dirs, libraries=libraries,
library_dirs=library_dirs,
)
def _convert_str_to_file(source, dirname):
"""Helper function to create a file ``source.c`` in `dirname` that contains
the string in `source`. Returns the file name
"""
filename = dirname / 'source.c'
with filename.open('w') as f:
f.write(str(source))
return filename
def _make_methods(functions, modname):
""" Turns the name, signature, code in functions into complete functions
and lists them in a methods_table. Then turns the methods_table into a
``PyMethodDef`` structure and returns the resulting code fragment ready
for compilation
"""
methods_table = []
codes = []
for funcname, flags, code in functions:
cfuncname = f"{modname}_{funcname}"
if 'METH_KEYWORDS' in flags:
signature = '(PyObject *self, PyObject *args, PyObject *kwargs)'
else:
signature = '(PyObject *self, PyObject *args)'
methods_table.append(
"{\"%s\", (PyCFunction)%s, %s}," % (funcname, cfuncname, flags))
func_code = f"""
static PyObject* {cfuncname}{signature}
{{
{code}
}}
"""
codes.append(func_code)
body = "\n".join(codes) + """
static PyMethodDef methods[] = {
%(methods)s
{ NULL }
};
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"%(modname)s", /* m_name */
NULL, /* m_doc */
-1, /* m_size */
methods, /* m_methods */
};
""" % {'methods': '\n'.join(methods_table), 'modname': modname}
return body
def _make_source(name, init, body):
""" Combines the code fragments into source code ready to be compiled
"""
code = """
#include <Python.h>
%(body)s
PyMODINIT_FUNC
PyInit_%(name)s(void) {
%(init)s
}
""" % {
'name': name, 'init': init, 'body': body,
}
return code
def _c_compile(cfile, outputfilename, include_dirs, libraries,
library_dirs):
link_extra = []
if sys.platform == 'win32':
compile_extra = ["/we4013"]
link_extra.append('/DEBUG') # generate .pdb file
elif sys.platform.startswith('linux'):
compile_extra = [
"-O0", "-g", "-Werror=implicit-function-declaration", "-fPIC"]
else:
compile_extra = []
return build(
cfile, outputfilename,
compile_extra, link_extra,
include_dirs, libraries, library_dirs)
def build(cfile, outputfilename, compile_extra, link_extra,
include_dirs, libraries, library_dirs):
"use meson to build"
build_dir = cfile.parent / "build"
os.makedirs(build_dir, exist_ok=True)
with open(cfile.parent / "meson.build", "wt") as fid:
link_dirs = ['-L' + d for d in library_dirs]
fid.write(textwrap.dedent(f"""\
project('foo', 'c')
py = import('python').find_installation(pure: false)
py.extension_module(
'{outputfilename.parts[-1]}',
'{cfile.parts[-1]}',
c_args: {compile_extra},
link_args: {link_dirs},
include_directories: {include_dirs},
)
"""))
native_file_name = cfile.parent / ".mesonpy-native-file.ini"
with open(native_file_name, "wt") as fid:
fid.write(textwrap.dedent(f"""\
[binaries]
python = '{sys.executable}'
"""))
if sys.platform == "win32":
subprocess.check_call(["meson", "setup",
"--buildtype=release",
"--vsenv", ".."],
cwd=build_dir,
)
else:
subprocess.check_call(["meson", "setup", "--vsenv",
"..", f'--native-file={os.fspath(native_file_name)}'],
cwd=build_dir
)
so_name = outputfilename.parts[-1] + get_so_suffix()
subprocess.check_call(["meson", "compile"], cwd=build_dir)
os.rename(str(build_dir / so_name), cfile.parent / so_name)
return cfile.parent / so_name
def get_so_suffix():
ret = sysconfig.get_config_var('EXT_SUFFIX')
assert ret
return ret

View File

@ -0,0 +1,25 @@
import pathlib
import types
from collections.abc import Sequence
__all__ = ["build_and_import_extension", "compile_extension_module"]
def build_and_import_extension(
modname: str,
functions: Sequence[tuple[str, str, str]],
*,
prologue: str = "",
build_dir: pathlib.Path | None = None,
include_dirs: Sequence[str] = [],
more_init: str = "",
) -> types.ModuleType: ...
#
def compile_extension_module(
name: str,
builddir: pathlib.Path,
include_dirs: Sequence[str],
source_string: str,
libraries: Sequence[str] = [],
library_dirs: Sequence[str] = [],
) -> pathlib.Path: ...

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,499 @@
import ast
import sys
import types
import unittest
import warnings
from collections.abc import Callable, Iterable, Sequence
from contextlib import _GeneratorContextManager
from pathlib import Path
from re import Pattern
from typing import (
Any,
AnyStr,
ClassVar,
Final,
Generic,
NoReturn,
ParamSpec,
Self,
SupportsIndex,
TypeAlias,
TypeVarTuple,
overload,
type_check_only,
)
from typing import Literal as L
from unittest.case import SkipTest
from _typeshed import ConvertibleToFloat, GenericPath, StrOrBytesPath, StrPath
from typing_extensions import TypeVar
import numpy as np
from numpy._typing import (
ArrayLike,
DTypeLike,
NDArray,
_ArrayLikeDT64_co,
_ArrayLikeNumber_co,
_ArrayLikeObject_co,
_ArrayLikeTD64_co,
)
__all__ = [ # noqa: RUF022
"IS_EDITABLE",
"IS_MUSL",
"IS_PYPY",
"IS_PYSTON",
"IS_WASM",
"HAS_LAPACK64",
"HAS_REFCOUNT",
"NOGIL_BUILD",
"assert_",
"assert_array_almost_equal_nulp",
"assert_raises_regex",
"assert_array_max_ulp",
"assert_warns",
"assert_no_warnings",
"assert_allclose",
"assert_equal",
"assert_almost_equal",
"assert_approx_equal",
"assert_array_equal",
"assert_array_less",
"assert_string_equal",
"assert_array_almost_equal",
"assert_raises",
"build_err_msg",
"decorate_methods",
"jiffies",
"memusage",
"print_assert_equal",
"rundocs",
"runstring",
"verbose",
"measure",
"IgnoreException",
"clear_and_catch_warnings",
"SkipTest",
"KnownFailureException",
"temppath",
"tempdir",
"suppress_warnings",
"assert_array_compare",
"assert_no_gc_cycles",
"break_cycles",
"check_support_sve",
"run_threaded",
]
###
_T = TypeVar("_T")
_Ts = TypeVarTuple("_Ts")
_Tss = ParamSpec("_Tss")
_ET = TypeVar("_ET", bound=BaseException, default=BaseException)
_FT = TypeVar("_FT", bound=Callable[..., Any])
_W_co = TypeVar("_W_co", bound=_WarnLog | None, default=_WarnLog | None, covariant=True)
_T_or_bool = TypeVar("_T_or_bool", default=bool)
_StrLike: TypeAlias = str | bytes
_RegexLike: TypeAlias = _StrLike | Pattern[Any]
_NumericArrayLike: TypeAlias = _ArrayLikeNumber_co | _ArrayLikeObject_co
_ExceptionSpec: TypeAlias = type[_ET] | tuple[type[_ET], ...]
_WarningSpec: TypeAlias = type[Warning]
_WarnLog: TypeAlias = list[warnings.WarningMessage]
_ToModules: TypeAlias = Iterable[types.ModuleType]
# Must return a bool or an ndarray/generic type that is supported by `np.logical_and.reduce`
_ComparisonFunc: TypeAlias = Callable[
[NDArray[Any], NDArray[Any]],
bool | np.bool | np.number | NDArray[np.bool | np.number | np.object_],
]
# Type-check only `clear_and_catch_warnings` subclasses for both values of the
# `record` parameter. Copied from the stdlib `warnings` stubs.
@type_check_only
class _clear_and_catch_warnings_with_records(clear_and_catch_warnings):
def __enter__(self) -> list[warnings.WarningMessage]: ...
@type_check_only
class _clear_and_catch_warnings_without_records(clear_and_catch_warnings):
def __enter__(self) -> None: ...
###
verbose: int = 0
NUMPY_ROOT: Final[Path] = ...
IS_INSTALLED: Final[bool] = ...
IS_EDITABLE: Final[bool] = ...
IS_MUSL: Final[bool] = ...
IS_PYPY: Final[bool] = ...
IS_PYSTON: Final[bool] = ...
IS_WASM: Final[bool] = ...
HAS_REFCOUNT: Final[bool] = ...
HAS_LAPACK64: Final[bool] = ...
NOGIL_BUILD: Final[bool] = ...
class KnownFailureException(Exception): ...
class IgnoreException(Exception): ...
# NOTE: `warnings.catch_warnings` is incorrectly defined as invariant in typeshed
class clear_and_catch_warnings(warnings.catch_warnings[_W_co], Generic[_W_co]): # type: ignore[type-var] # pyright: ignore[reportInvalidTypeArguments]
class_modules: ClassVar[tuple[types.ModuleType, ...]] = ()
modules: Final[set[types.ModuleType]]
@overload # record: True
def __init__(self: clear_and_catch_warnings[_WarnLog], /, record: L[True], modules: _ToModules = ()) -> None: ...
@overload # record: False (default)
def __init__(self: clear_and_catch_warnings[None], /, record: L[False] = False, modules: _ToModules = ()) -> None: ...
@overload # record; bool
def __init__(self, /, record: bool, modules: _ToModules = ()) -> None: ...
class suppress_warnings:
log: Final[_WarnLog]
def __init__(self, /, forwarding_rule: L["always", "module", "once", "location"] = "always") -> None: ...
def __enter__(self) -> Self: ...
def __exit__(self, cls: type[BaseException] | None, exc: BaseException | None, tb: types.TracebackType | None, /) -> None: ...
def __call__(self, /, func: _FT) -> _FT: ...
#
def filter(self, /, category: type[Warning] = ..., message: str = "", module: types.ModuleType | None = None) -> None: ...
def record(self, /, category: type[Warning] = ..., message: str = "", module: types.ModuleType | None = None) -> _WarnLog: ...
# Contrary to runtime we can't do `os.name` checks while type checking,
# only `sys.platform` checks
if sys.platform == "win32" or sys.platform == "cygwin":
def memusage(processName: str = ..., instance: int = ...) -> int: ...
elif sys.platform == "linux":
def memusage(_proc_pid_stat: StrOrBytesPath = ...) -> int | None: ...
else:
def memusage() -> NoReturn: ...
if sys.platform == "linux":
def jiffies(_proc_pid_stat: StrOrBytesPath = ..., _load_time: list[float] = []) -> int: ...
else:
def jiffies(_load_time: list[float] = []) -> int: ...
#
def build_err_msg(
arrays: Iterable[object],
err_msg: object,
header: str = ...,
verbose: bool = ...,
names: Sequence[str] = ...,
precision: SupportsIndex | None = ...,
) -> str: ...
#
def print_assert_equal(test_string: str, actual: object, desired: object) -> None: ...
#
def assert_(val: object, msg: str | Callable[[], str] = "") -> None: ...
#
def assert_equal(
actual: object,
desired: object,
err_msg: object = "",
verbose: bool = True,
*,
strict: bool = False,
) -> None: ...
def assert_almost_equal(
actual: _NumericArrayLike,
desired: _NumericArrayLike,
decimal: int = 7,
err_msg: object = "",
verbose: bool = True,
) -> None: ...
#
def assert_approx_equal(
actual: ConvertibleToFloat,
desired: ConvertibleToFloat,
significant: int = 7,
err_msg: object = "",
verbose: bool = True,
) -> None: ...
#
def assert_array_compare(
comparison: _ComparisonFunc,
x: ArrayLike,
y: ArrayLike,
err_msg: object = "",
verbose: bool = True,
header: str = "",
precision: SupportsIndex = 6,
equal_nan: bool = True,
equal_inf: bool = True,
*,
strict: bool = False,
names: tuple[str, str] = ("ACTUAL", "DESIRED"),
) -> None: ...
#
def assert_array_equal(
actual: object,
desired: object,
err_msg: object = "",
verbose: bool = True,
*,
strict: bool = False,
) -> None: ...
#
def assert_array_almost_equal(
actual: _NumericArrayLike,
desired: _NumericArrayLike,
decimal: float = 6,
err_msg: object = "",
verbose: bool = True,
) -> None: ...
@overload
def assert_array_less(
x: _ArrayLikeDT64_co,
y: _ArrayLikeDT64_co,
err_msg: object = "",
verbose: bool = True,
*,
strict: bool = False,
) -> None: ...
@overload
def assert_array_less(
x: _ArrayLikeTD64_co,
y: _ArrayLikeTD64_co,
err_msg: object = "",
verbose: bool = True,
*,
strict: bool = False,
) -> None: ...
@overload
def assert_array_less(
x: _NumericArrayLike,
y: _NumericArrayLike,
err_msg: object = "",
verbose: bool = True,
*,
strict: bool = False,
) -> None: ...
#
def assert_string_equal(actual: str, desired: str) -> None: ...
#
@overload
def assert_raises(
exception_class: _ExceptionSpec[_ET],
/,
*,
msg: str | None = None,
) -> unittest.case._AssertRaisesContext[_ET]: ...
@overload
def assert_raises(
exception_class: _ExceptionSpec,
callable: Callable[_Tss, Any],
/,
*args: _Tss.args,
**kwargs: _Tss.kwargs,
) -> None: ...
#
@overload
def assert_raises_regex(
exception_class: _ExceptionSpec[_ET],
expected_regexp: _RegexLike,
*,
msg: str | None = None,
) -> unittest.case._AssertRaisesContext[_ET]: ...
@overload
def assert_raises_regex(
exception_class: _ExceptionSpec,
expected_regexp: _RegexLike,
callable: Callable[_Tss, Any],
*args: _Tss.args,
**kwargs: _Tss.kwargs,
) -> None: ...
#
@overload
def assert_allclose(
actual: _ArrayLikeTD64_co,
desired: _ArrayLikeTD64_co,
rtol: float = 1e-7,
atol: float = 0,
equal_nan: bool = True,
err_msg: object = "",
verbose: bool = True,
*,
strict: bool = False,
) -> None: ...
@overload
def assert_allclose(
actual: _NumericArrayLike,
desired: _NumericArrayLike,
rtol: float = 1e-7,
atol: float = 0,
equal_nan: bool = True,
err_msg: object = "",
verbose: bool = True,
*,
strict: bool = False,
) -> None: ...
#
def assert_array_almost_equal_nulp(
x: _ArrayLikeNumber_co,
y: _ArrayLikeNumber_co,
nulp: float = 1,
) -> None: ...
#
def assert_array_max_ulp(
a: _ArrayLikeNumber_co,
b: _ArrayLikeNumber_co,
maxulp: float = 1,
dtype: DTypeLike | None = None,
) -> NDArray[Any]: ...
#
@overload
def assert_warns(warning_class: _WarningSpec) -> _GeneratorContextManager[None]: ...
@overload
def assert_warns(warning_class: _WarningSpec, func: Callable[_Tss, _T], *args: _Tss.args, **kwargs: _Tss.kwargs) -> _T: ...
#
@overload
def assert_no_warnings() -> _GeneratorContextManager[None]: ...
@overload
def assert_no_warnings(func: Callable[_Tss, _T], /, *args: _Tss.args, **kwargs: _Tss.kwargs) -> _T: ...
#
@overload
def assert_no_gc_cycles() -> _GeneratorContextManager[None]: ...
@overload
def assert_no_gc_cycles(func: Callable[_Tss, Any], /, *args: _Tss.args, **kwargs: _Tss.kwargs) -> None: ...
###
#
@overload
def tempdir(
suffix: None = None,
prefix: None = None,
dir: None = None,
) -> _GeneratorContextManager[str]: ...
@overload
def tempdir(
suffix: AnyStr | None = None,
prefix: AnyStr | None = None,
*,
dir: GenericPath[AnyStr],
) -> _GeneratorContextManager[AnyStr]: ...
@overload
def tempdir(
suffix: AnyStr | None = None,
*,
prefix: AnyStr,
dir: GenericPath[AnyStr] | None = None,
) -> _GeneratorContextManager[AnyStr]: ...
@overload
def tempdir(
suffix: AnyStr,
prefix: AnyStr | None = None,
dir: GenericPath[AnyStr] | None = None,
) -> _GeneratorContextManager[AnyStr]: ...
#
@overload
def temppath(
suffix: None = None,
prefix: None = None,
dir: None = None,
text: bool = False,
) -> _GeneratorContextManager[str]: ...
@overload
def temppath(
suffix: AnyStr | None,
prefix: AnyStr | None,
dir: GenericPath[AnyStr],
text: bool = False,
) -> _GeneratorContextManager[AnyStr]: ...
@overload
def temppath(
suffix: AnyStr | None = None,
prefix: AnyStr | None = None,
*,
dir: GenericPath[AnyStr],
text: bool = False,
) -> _GeneratorContextManager[AnyStr]: ...
@overload
def temppath(
suffix: AnyStr | None,
prefix: AnyStr,
dir: GenericPath[AnyStr] | None = None,
text: bool = False,
) -> _GeneratorContextManager[AnyStr]: ...
@overload
def temppath(
suffix: AnyStr | None = None,
*,
prefix: AnyStr,
dir: GenericPath[AnyStr] | None = None,
text: bool = False,
) -> _GeneratorContextManager[AnyStr]: ...
@overload
def temppath(
suffix: AnyStr,
prefix: AnyStr | None = None,
dir: GenericPath[AnyStr] | None = None,
text: bool = False,
) -> _GeneratorContextManager[AnyStr]: ...
#
def check_support_sve(__cache: list[_T_or_bool] = []) -> _T_or_bool: ... # noqa: PYI063
#
def decorate_methods(
cls: type,
decorator: Callable[[Callable[..., Any]], Any],
testmatch: _RegexLike | None = None,
) -> None: ...
#
@overload
def run_threaded(
func: Callable[[], None],
max_workers: int = 8,
pass_count: bool = False,
pass_barrier: bool = False,
outer_iterations: int = 1,
prepare_args: None = None,
) -> None: ...
@overload
def run_threaded(
func: Callable[[*_Ts], None],
max_workers: int,
pass_count: bool,
pass_barrier: bool,
outer_iterations: int,
prepare_args: tuple[*_Ts],
) -> None: ...
@overload
def run_threaded(
func: Callable[[*_Ts], None],
max_workers: int = 8,
pass_count: bool = False,
pass_barrier: bool = False,
outer_iterations: int = 1,
*,
prepare_args: tuple[*_Ts],
) -> None: ...
#
def runstring(astr: _StrLike | types.CodeType, dict: dict[str, Any] | None) -> Any: ... # noqa: ANN401
def rundocs(filename: StrPath | None = None, raise_on_error: bool = True) -> None: ...
def measure(code_str: _StrLike | ast.AST, times: int = 1, label: str | None = None) -> float: ...
def break_cycles() -> None: ...

View File

@ -0,0 +1,84 @@
"""Tools for testing implementations of __array_function__ and ufunc overrides
"""
import numpy._core.umath as _umath
from numpy import ufunc as _ufunc
from numpy._core.overrides import ARRAY_FUNCTIONS as _array_functions
def get_overridable_numpy_ufuncs():
"""List all numpy ufuncs overridable via `__array_ufunc__`
Parameters
----------
None
Returns
-------
set
A set containing all overridable ufuncs in the public numpy API.
"""
ufuncs = {obj for obj in _umath.__dict__.values()
if isinstance(obj, _ufunc)}
return ufuncs
def allows_array_ufunc_override(func):
"""Determine if a function can be overridden via `__array_ufunc__`
Parameters
----------
func : callable
Function that may be overridable via `__array_ufunc__`
Returns
-------
bool
`True` if `func` is overridable via `__array_ufunc__` and
`False` otherwise.
Notes
-----
This function is equivalent to ``isinstance(func, np.ufunc)`` and
will work correctly for ufuncs defined outside of Numpy.
"""
return isinstance(func, _ufunc)
def get_overridable_numpy_array_functions():
"""List all numpy functions overridable via `__array_function__`
Parameters
----------
None
Returns
-------
set
A set containing all functions in the public numpy API that are
overridable via `__array_function__`.
"""
# 'import numpy' doesn't import recfunctions, so make sure it's imported
# so ufuncs defined there show up in the ufunc listing
from numpy.lib import recfunctions # noqa: F401
return _array_functions.copy()
def allows_array_function_override(func):
"""Determine if a Numpy function can be overridden via `__array_function__`
Parameters
----------
func : callable
Function that may be overridable via `__array_function__`
Returns
-------
bool
`True` if `func` is a function in the Numpy API that is
overridable via `__array_function__` and `False` otherwise.
"""
return func in _array_functions

View File

@ -0,0 +1,11 @@
from collections.abc import Callable, Hashable
from typing import Any
from typing_extensions import TypeIs
import numpy as np
def get_overridable_numpy_ufuncs() -> set[np.ufunc]: ...
def get_overridable_numpy_array_functions() -> set[Callable[..., Any]]: ...
def allows_array_ufunc_override(func: object) -> TypeIs[np.ufunc]: ...
def allows_array_function_override(func: Hashable) -> bool: ...

View File

@ -0,0 +1,207 @@
#!/usr/bin/env python3
"""Prints type-coercion tables for the built-in NumPy types
"""
from collections import namedtuple
import numpy as np
from numpy._core.numerictypes import obj2sctype
# Generic object that can be added, but doesn't do anything else
class GenericObject:
def __init__(self, v):
self.v = v
def __add__(self, other):
return self
def __radd__(self, other):
return self
dtype = np.dtype('O')
def print_cancast_table(ntypes):
print('X', end=' ')
for char in ntypes:
print(char, end=' ')
print()
for row in ntypes:
print(row, end=' ')
for col in ntypes:
if np.can_cast(row, col, "equiv"):
cast = "#"
elif np.can_cast(row, col, "safe"):
cast = "="
elif np.can_cast(row, col, "same_kind"):
cast = "~"
elif np.can_cast(row, col, "unsafe"):
cast = "."
else:
cast = " "
print(cast, end=' ')
print()
def print_coercion_table(ntypes, inputfirstvalue, inputsecondvalue, firstarray,
use_promote_types=False):
print('+', end=' ')
for char in ntypes:
print(char, end=' ')
print()
for row in ntypes:
if row == 'O':
rowtype = GenericObject
else:
rowtype = obj2sctype(row)
print(row, end=' ')
for col in ntypes:
if col == 'O':
coltype = GenericObject
else:
coltype = obj2sctype(col)
try:
if firstarray:
rowvalue = np.array([rowtype(inputfirstvalue)], dtype=rowtype)
else:
rowvalue = rowtype(inputfirstvalue)
colvalue = coltype(inputsecondvalue)
if use_promote_types:
char = np.promote_types(rowvalue.dtype, colvalue.dtype).char
else:
value = np.add(rowvalue, colvalue)
if isinstance(value, np.ndarray):
char = value.dtype.char
else:
char = np.dtype(type(value)).char
except ValueError:
char = '!'
except OverflowError:
char = '@'
except TypeError:
char = '#'
print(char, end=' ')
print()
def print_new_cast_table(*, can_cast=True, legacy=False, flags=False):
"""Prints new casts, the values given are default "can-cast" values, not
actual ones.
"""
from numpy._core._multiarray_tests import get_all_cast_information
cast_table = {
-1: " ",
0: "#", # No cast (classify as equivalent here)
1: "#", # equivalent casting
2: "=", # safe casting
3: "~", # same-kind casting
4: ".", # unsafe casting
}
flags_table = {
0: "", 7: "",
1: "", 2: "", 4: "",
3: "", 5: "",
6: "",
}
cast_info = namedtuple("cast_info", ["can_cast", "legacy", "flags"])
no_cast_info = cast_info(" ", " ", " ")
casts = get_all_cast_information()
table = {}
dtypes = set()
for cast in casts:
dtypes.add(cast["from"])
dtypes.add(cast["to"])
if cast["from"] not in table:
table[cast["from"]] = {}
to_dict = table[cast["from"]]
can_cast = cast_table[cast["casting"]]
legacy = "L" if cast["legacy"] else "."
flags = 0
if cast["requires_pyapi"]:
flags |= 1
if cast["supports_unaligned"]:
flags |= 2
if cast["no_floatingpoint_errors"]:
flags |= 4
flags = flags_table[flags]
to_dict[cast["to"]] = cast_info(can_cast=can_cast, legacy=legacy, flags=flags)
# The np.dtype(x.type) is a bit strange, because dtype classes do
# not expose much yet.
types = np.typecodes["All"]
def sorter(x):
# This is a bit weird hack, to get a table as close as possible to
# the one printing all typecodes (but expecting user-dtypes).
dtype = np.dtype(x.type)
try:
indx = types.index(dtype.char)
except ValueError:
indx = np.inf
return (indx, dtype.char)
dtypes = sorted(dtypes, key=sorter)
def print_table(field="can_cast"):
print('X', end=' ')
for dt in dtypes:
print(np.dtype(dt.type).char, end=' ')
print()
for from_dt in dtypes:
print(np.dtype(from_dt.type).char, end=' ')
row = table.get(from_dt, {})
for to_dt in dtypes:
print(getattr(row.get(to_dt, no_cast_info), field), end=' ')
print()
if can_cast:
# Print the actual table:
print()
print("Casting: # is equivalent, = is safe, ~ is same-kind, and . is unsafe")
print()
print_table("can_cast")
if legacy:
print()
print("L denotes a legacy cast . a non-legacy one.")
print()
print_table("legacy")
if flags:
print()
print(f"{flags_table[0]}: no flags, "
f"{flags_table[1]}: PyAPI, "
f"{flags_table[2]}: supports unaligned, "
f"{flags_table[4]}: no-float-errors")
print()
print_table("flags")
if __name__ == '__main__':
print("can cast")
print_cancast_table(np.typecodes['All'])
print()
print("In these tables, ValueError is '!', OverflowError is '@', TypeError is '#'")
print()
print("scalar + scalar")
print_coercion_table(np.typecodes['All'], 0, 0, False)
print()
print("scalar + neg scalar")
print_coercion_table(np.typecodes['All'], 0, -1, False)
print()
print("array + scalar")
print_coercion_table(np.typecodes['All'], 0, 0, True)
print()
print("array + neg scalar")
print_coercion_table(np.typecodes['All'], 0, -1, True)
print()
print("promote_types")
print_coercion_table(np.typecodes['All'], 0, 0, False, True)
print("New casting type promotion:")
print_new_cast_table(can_cast=True, legacy=True, flags=True)

View File

@ -0,0 +1,27 @@
from collections.abc import Iterable
from typing import ClassVar, Generic, Self
from typing_extensions import TypeVar
import numpy as np
_VT_co = TypeVar("_VT_co", default=object, covariant=True)
# undocumented
class GenericObject(Generic[_VT_co]):
dtype: ClassVar[np.dtype[np.object_]] = ...
v: _VT_co
def __init__(self, /, v: _VT_co) -> None: ...
def __add__(self, other: object, /) -> Self: ...
def __radd__(self, other: object, /) -> Self: ...
def print_cancast_table(ntypes: Iterable[str]) -> None: ...
def print_coercion_table(
ntypes: Iterable[str],
inputfirstvalue: int,
inputsecondvalue: int,
firstarray: bool,
use_promote_types: bool = False,
) -> None: ...
def print_new_cast_table(*, can_cast: bool = True, legacy: bool = False, flags: bool = False) -> None: ...

File diff suppressed because it is too large Load Diff