增加环绕侦察场景适配
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from . import overrides
|
||||
from . import _private as _private, overrides
|
||||
from ._private import extbuild as extbuild
|
||||
from ._private.utils import (
|
||||
BLAS_SUPPORTS_FPE,
|
||||
HAS_LAPACK64,
|
||||
HAS_REFCOUNT,
|
||||
IS_64BIT,
|
||||
IS_EDITABLE,
|
||||
IS_INSTALLED,
|
||||
IS_MUSL,
|
||||
@@ -51,8 +54,10 @@ from ._private.utils import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BLAS_SUPPORTS_FPE",
|
||||
"HAS_LAPACK64",
|
||||
"HAS_REFCOUNT",
|
||||
"IS_64BIT",
|
||||
"IS_EDITABLE",
|
||||
"IS_INSTALLED",
|
||||
"IS_MUSL",
|
||||
|
||||
@@ -25,7 +25,7 @@ from warnings import WarningMessage
|
||||
|
||||
import numpy as np
|
||||
import numpy.linalg._umath_linalg
|
||||
from numpy import isfinite, isinf, isnan
|
||||
from numpy import isfinite, isnan
|
||||
from numpy._core import arange, array, array_repr, empty, float32, intp, isnat, ndarray
|
||||
|
||||
__all__ = [
|
||||
@@ -90,14 +90,7 @@ IS_WASM = platform.machine() in ["wasm32", "wasm64"]
|
||||
IS_PYPY = sys.implementation.name == 'pypy'
|
||||
IS_PYSTON = hasattr(sys, "pyston_version_info")
|
||||
HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON
|
||||
BLAS_SUPPORTS_FPE = True
|
||||
if platform.system() == 'Darwin' or platform.machine() == 'arm64':
|
||||
try:
|
||||
blas = np.__config__.CONFIG['Build Dependencies']['blas']
|
||||
if blas['name'] == 'accelerate':
|
||||
BLAS_SUPPORTS_FPE = False
|
||||
except KeyError:
|
||||
pass
|
||||
BLAS_SUPPORTS_FPE = np._core._multiarray_umath._blas_supports_fpe(None)
|
||||
|
||||
HAS_LAPACK64 = numpy.linalg._umath_linalg._ilp64
|
||||
|
||||
@@ -303,9 +296,10 @@ def assert_equal(actual, desired, err_msg='', verbose=True, *, strict=False):
|
||||
|
||||
Notes
|
||||
-----
|
||||
By default, when one of `actual` and `desired` is a scalar and the other is
|
||||
an array, the function checks that each element of the array is equal to
|
||||
the scalar. This behaviour can be disabled by setting ``strict==True``.
|
||||
When one of `actual` and `desired` is a scalar and the other is array_like, the
|
||||
function checks that each element of the array_like is equal to the scalar.
|
||||
Note that empty arrays are therefore considered equal to scalars.
|
||||
This behaviour can be disabled by setting ``strict==True``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -363,7 +357,7 @@ def assert_equal(actual, desired, err_msg='', verbose=True, *, strict=False):
|
||||
if not isinstance(actual, dict):
|
||||
raise AssertionError(repr(type(actual)))
|
||||
assert_equal(len(actual), len(desired), err_msg, verbose)
|
||||
for k, i in desired.items():
|
||||
for k in desired:
|
||||
if k not in actual:
|
||||
raise AssertionError(repr(k))
|
||||
assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}',
|
||||
@@ -573,6 +567,8 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
|
||||
Arrays are not almost equal to 9 decimals
|
||||
<BLANKLINE>
|
||||
Mismatched elements: 1 / 2 (50%)
|
||||
Mismatch at index:
|
||||
[1]: 2.3333333333333 (ACTUAL), 2.33333334 (DESIRED)
|
||||
Max absolute difference among violations: 6.66669964e-09
|
||||
Max relative difference among violations: 2.85715698e-09
|
||||
ACTUAL: array([1. , 2.333333333])
|
||||
@@ -755,6 +751,24 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
|
||||
def isvstring(x):
|
||||
return x.dtype.char == "T"
|
||||
|
||||
def robust_any_difference(x, y):
|
||||
# We include work-arounds here to handle three types of slightly
|
||||
# pathological ndarray subclasses:
|
||||
# (1) all() on fully masked arrays returns np.ma.masked, so we use != True
|
||||
# (np.ma.masked != True evaluates as np.ma.masked, which is falsy).
|
||||
# (2) __eq__ on some ndarray subclasses returns Python booleans
|
||||
# instead of element-wise comparisons, so we cast to np.bool() in
|
||||
# that case (or in case __eq__ returns some other value with no
|
||||
# all() method).
|
||||
# (3) subclasses with bare-bones __array_function__ implementations may
|
||||
# not implement np.all(), so favor using the .all() method
|
||||
# We are not committed to supporting cases (2) and (3), but it's nice to
|
||||
# support them if possible.
|
||||
result = x == y
|
||||
if not hasattr(result, "all") or not callable(result.all):
|
||||
result = np.bool(result)
|
||||
return result.all() != True
|
||||
|
||||
def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
|
||||
"""Handling nan/inf.
|
||||
|
||||
@@ -766,18 +780,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
|
||||
|
||||
x_id = func(x)
|
||||
y_id = func(y)
|
||||
# We include work-arounds here to handle three types of slightly
|
||||
# pathological ndarray subclasses:
|
||||
# (1) all() on `masked` array scalars can return masked arrays, so we
|
||||
# use != True
|
||||
# (2) __eq__ on some ndarray subclasses returns Python booleans
|
||||
# instead of element-wise comparisons, so we cast to np.bool() and
|
||||
# use isinstance(..., bool) checks
|
||||
# (3) subclasses with bare-bones __array_function__ implementations may
|
||||
# not implement np.all(), so favor using the .all() method
|
||||
# We are not committed to supporting such subclasses, but it's nice to
|
||||
# support them if possible.
|
||||
if np.bool(x_id == y_id).all() != True:
|
||||
if robust_any_difference(x_id, y_id):
|
||||
msg = build_err_msg(
|
||||
[x, y],
|
||||
err_msg + '\n%s location mismatch:'
|
||||
@@ -787,6 +790,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
|
||||
raise AssertionError(msg)
|
||||
# If there is a scalar, then here we know the array has the same
|
||||
# flag as it everywhere, so we should return the scalar flag.
|
||||
# np.ma.masked is also handled and converted to np.False_ (even if the other
|
||||
# array has nans/infs etc.; that's OK given the handling later of fully-masked
|
||||
# results).
|
||||
if isinstance(x_id, bool) or x_id.ndim == 0:
|
||||
return np.bool(x_id)
|
||||
elif isinstance(y_id, bool) or y_id.ndim == 0:
|
||||
@@ -794,6 +800,29 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
|
||||
else:
|
||||
return y_id
|
||||
|
||||
def assert_same_inf_values(x, y, infs_mask):
|
||||
"""
|
||||
Verify all inf values match in the two arrays
|
||||
"""
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
|
||||
if not infs_mask.any():
|
||||
return
|
||||
if x.ndim > 0 and y.ndim > 0:
|
||||
x = x[infs_mask]
|
||||
y = y[infs_mask]
|
||||
else:
|
||||
assert infs_mask.all()
|
||||
|
||||
if robust_any_difference(x, y):
|
||||
msg = build_err_msg(
|
||||
[x, y],
|
||||
err_msg + '\ninf values mismatch:',
|
||||
verbose=verbose, header=header,
|
||||
names=names,
|
||||
precision=precision)
|
||||
raise AssertionError(msg)
|
||||
|
||||
try:
|
||||
if strict:
|
||||
cond = x.shape == y.shape and x.dtype == y.dtype
|
||||
@@ -818,12 +847,15 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
|
||||
flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
|
||||
|
||||
if equal_inf:
|
||||
flagged |= func_assert_same_pos(x, y,
|
||||
func=lambda xy: xy == +inf,
|
||||
hasval='+inf')
|
||||
flagged |= func_assert_same_pos(x, y,
|
||||
func=lambda xy: xy == -inf,
|
||||
hasval='-inf')
|
||||
# If equal_nan=True, skip comparing nans below for equality if they are
|
||||
# also infs (e.g. inf+nanj) since that would always fail.
|
||||
isinf_func = lambda xy: np.logical_and(np.isinf(xy), np.invert(flagged))
|
||||
infs_mask = func_assert_same_pos(
|
||||
x, y,
|
||||
func=isinf_func,
|
||||
hasval='inf')
|
||||
assert_same_inf_values(x, y, infs_mask)
|
||||
flagged |= infs_mask
|
||||
|
||||
elif istime(x) and istime(y):
|
||||
# If one is datetime64 and the other timedelta64 there is no point
|
||||
@@ -874,6 +906,31 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
|
||||
percent_mismatch = 100 * n_mismatch / n_elements
|
||||
remarks = [f'Mismatched elements: {n_mismatch} / {n_elements} '
|
||||
f'({percent_mismatch:.3g}%)']
|
||||
if invalids.ndim != 0:
|
||||
if flagged.ndim > 0:
|
||||
positions = np.argwhere(np.asarray(~flagged))[invalids]
|
||||
else:
|
||||
positions = np.argwhere(np.asarray(invalids))
|
||||
s = "\n".join(
|
||||
[
|
||||
f" {p.tolist()}: {ox if ox.ndim == 0 else ox[tuple(p)]} "
|
||||
f"({names[0]}), {oy if oy.ndim == 0 else oy[tuple(p)]} "
|
||||
f"({names[1]})"
|
||||
for p in positions[:5]
|
||||
]
|
||||
)
|
||||
if len(positions) == 1:
|
||||
remarks.append(
|
||||
f"Mismatch at index:\n{s}"
|
||||
)
|
||||
elif len(positions) <= 5:
|
||||
remarks.append(
|
||||
f"Mismatch at indices:\n{s}"
|
||||
)
|
||||
else:
|
||||
remarks.append(
|
||||
f"First 5 mismatches are at indices:\n{s}"
|
||||
)
|
||||
|
||||
with errstate(all='ignore'):
|
||||
# ignore errors for non-numeric types
|
||||
@@ -990,9 +1047,10 @@ def assert_array_equal(actual, desired, err_msg='', verbose=True, *,
|
||||
|
||||
Notes
|
||||
-----
|
||||
When one of `actual` and `desired` is a scalar and the other is array_like,
|
||||
the function checks that each element of the array_like object is equal to
|
||||
the scalar. This behaviour can be disabled with the `strict` parameter.
|
||||
When one of `actual` and `desired` is a scalar and the other is array_like, the
|
||||
function checks that each element of the array_like is equal to the scalar.
|
||||
Note that empty arrays are therefore considered equal to scalars.
|
||||
This behaviour can be disabled by setting ``strict==True``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -1011,6 +1069,8 @@ def assert_array_equal(actual, desired, err_msg='', verbose=True, *,
|
||||
Arrays are not equal
|
||||
<BLANKLINE>
|
||||
Mismatched elements: 1 / 3 (33.3%)
|
||||
Mismatch at index:
|
||||
[1]: 3.141592653589793 (ACTUAL), 3.1415926535897927 (DESIRED)
|
||||
Max absolute difference among violations: 4.4408921e-16
|
||||
Max relative difference among violations: 1.41357986e-16
|
||||
ACTUAL: array([1. , 3.141593, nan])
|
||||
@@ -1124,6 +1184,8 @@ def assert_array_almost_equal(actual, desired, decimal=6, err_msg='',
|
||||
Arrays are not almost equal to 5 decimals
|
||||
<BLANKLINE>
|
||||
Mismatched elements: 1 / 3 (33.3%)
|
||||
Mismatch at index:
|
||||
[1]: 2.33333 (ACTUAL), 2.33339 (DESIRED)
|
||||
Max absolute difference among violations: 6.e-05
|
||||
Max relative difference among violations: 2.57136612e-05
|
||||
ACTUAL: array([1. , 2.33333, nan])
|
||||
@@ -1143,24 +1205,9 @@ def assert_array_almost_equal(actual, desired, decimal=6, err_msg='',
|
||||
"""
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
from numpy._core import number, result_type
|
||||
from numpy._core.fromnumeric import any as npany
|
||||
from numpy._core.numerictypes import issubdtype
|
||||
|
||||
def compare(x, y):
|
||||
try:
|
||||
if npany(isinf(x)) or npany(isinf(y)):
|
||||
xinfid = isinf(x)
|
||||
yinfid = isinf(y)
|
||||
if not (xinfid == yinfid).all():
|
||||
return False
|
||||
# if one item, x and y is +- inf
|
||||
if x.size == y.size == 1:
|
||||
return x == y
|
||||
x = x[~xinfid]
|
||||
y = y[~yinfid]
|
||||
except (TypeError, NotImplementedError):
|
||||
pass
|
||||
|
||||
# make sure y is an inexact type to avoid abs(MIN_INT); will cause
|
||||
# casting of x later.
|
||||
dtype = result_type(y, 1.)
|
||||
@@ -1245,6 +1292,8 @@ def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False):
|
||||
Arrays are not strictly ordered `x < y`
|
||||
<BLANKLINE>
|
||||
Mismatched elements: 1 / 3 (33.3%)
|
||||
Mismatch at index:
|
||||
[0]: 1.0 (x), 1.0 (y)
|
||||
Max absolute difference among violations: 0.
|
||||
Max relative difference among violations: 0.
|
||||
x: array([ 1., 1., nan])
|
||||
@@ -1623,9 +1672,10 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
|
||||
contrast to the standard usage in numpy, NaNs are compared like numbers,
|
||||
no assertion is raised if both objects have NaNs in the same positions.
|
||||
|
||||
The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note
|
||||
that ``allclose`` has different default values). It compares the difference
|
||||
between `actual` and `desired` to ``atol + rtol * abs(desired)``.
|
||||
The test is equivalent to ``allclose(actual, desired, rtol, atol)``,
|
||||
except that it is stricter: it doesn't broadcast its operands, and has
|
||||
tighter default tolerance values. It compares the difference between
|
||||
`actual` and `desired` to ``atol + rtol * abs(desired)``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -1661,10 +1711,10 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
|
||||
|
||||
Notes
|
||||
-----
|
||||
When one of `actual` and `desired` is a scalar and the other is
|
||||
array_like, the function performs the comparison as if the scalar were
|
||||
broadcasted to the shape of the array.
|
||||
This behaviour can be disabled with the `strict` parameter.
|
||||
When one of `actual` and `desired` is a scalar and the other is array_like, the
|
||||
function performs the comparison as if the scalar were broadcasted to the shape
|
||||
of the array. Note that empty arrays are therefore considered equal to scalars.
|
||||
This behaviour can be disabled by setting ``strict==True``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -1928,7 +1978,7 @@ def integer_repr(x):
|
||||
@contextlib.contextmanager
|
||||
def _assert_warns_context(warning_class, name=None):
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
with suppress_warnings() as sup:
|
||||
with suppress_warnings(_warn=False) as sup:
|
||||
l = sup.record(warning_class)
|
||||
yield
|
||||
if not len(l) > 0:
|
||||
@@ -1952,6 +2002,11 @@ def assert_warns(warning_class, *args, **kwargs):
|
||||
|
||||
The ability to be used as a context manager is new in NumPy v1.11.0.
|
||||
|
||||
.. deprecated:: 2.4
|
||||
|
||||
This is deprecated. Use `warnings.catch_warnings` or
|
||||
``pytest.warns`` instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
warning_class : class
|
||||
@@ -1979,6 +2034,11 @@ def assert_warns(warning_class, *args, **kwargs):
|
||||
>>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4)
|
||||
>>> assert ret == 16
|
||||
"""
|
||||
warnings.warn(
|
||||
"NumPy warning suppression and assertion utilities are deprecated. "
|
||||
"Use warnings.catch_warnings, warnings.filterwarnings, pytest.warns, "
|
||||
"or pytest.filterwarnings instead. (Deprecated NumPy 2.4)",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
if not args and not kwargs:
|
||||
return _assert_warns_context(warning_class)
|
||||
elif len(args) < 1:
|
||||
@@ -2231,6 +2291,11 @@ class suppress_warnings:
|
||||
tests might need to see the warning. Additionally it allows easier
|
||||
specificity for testing warnings and can be nested.
|
||||
|
||||
.. deprecated:: 2.4
|
||||
|
||||
This is deprecated. Use `warnings.filterwarnings` or
|
||||
``pytest.filterwarnings`` instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
forwarding_rule : str, optional
|
||||
@@ -2291,7 +2356,13 @@ class suppress_warnings:
|
||||
# do something which causes a warning in np.ma.core
|
||||
pass
|
||||
"""
|
||||
def __init__(self, forwarding_rule="always"):
|
||||
def __init__(self, forwarding_rule="always", _warn=True):
|
||||
if _warn:
|
||||
warnings.warn(
|
||||
"NumPy warning suppression and assertion utilities are deprecated. "
|
||||
"Use warnings.catch_warnings, warnings.filterwarnings, pytest.warns, "
|
||||
"or pytest.filterwarnings instead. (Deprecated NumPy 2.4)",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self._entered = False
|
||||
|
||||
# Suppressions are either instance or defined inside one with block:
|
||||
|
||||
@@ -3,6 +3,7 @@ import sys
|
||||
import types
|
||||
import unittest
|
||||
import warnings
|
||||
from _typeshed import ConvertibleToFloat, GenericPath, StrOrBytesPath, StrPath
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from contextlib import _GeneratorContextManager
|
||||
from pathlib import Path
|
||||
@@ -13,6 +14,7 @@ from typing import (
|
||||
ClassVar,
|
||||
Final,
|
||||
Generic,
|
||||
Literal as L,
|
||||
NoReturn,
|
||||
ParamSpec,
|
||||
Self,
|
||||
@@ -22,12 +24,9 @@ from typing import (
|
||||
overload,
|
||||
type_check_only,
|
||||
)
|
||||
from typing import Literal as L
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
from unittest.case import SkipTest
|
||||
|
||||
from _typeshed import ConvertibleToFloat, GenericPath, StrOrBytesPath, StrPath
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy._typing import (
|
||||
ArrayLike,
|
||||
@@ -45,9 +44,13 @@ __all__ = [ # noqa: RUF022
|
||||
"IS_PYPY",
|
||||
"IS_PYSTON",
|
||||
"IS_WASM",
|
||||
"IS_INSTALLED",
|
||||
"IS_64BIT",
|
||||
"HAS_LAPACK64",
|
||||
"HAS_REFCOUNT",
|
||||
"BLAS_SUPPORTS_FPE",
|
||||
"NOGIL_BUILD",
|
||||
"NUMPY_ROOT",
|
||||
"assert_",
|
||||
"assert_array_almost_equal_nulp",
|
||||
"assert_raises_regex",
|
||||
@@ -94,7 +97,6 @@ _Tss = ParamSpec("_Tss")
|
||||
_ET = TypeVar("_ET", bound=BaseException, default=BaseException)
|
||||
_FT = TypeVar("_FT", bound=Callable[..., Any])
|
||||
_W_co = TypeVar("_W_co", bound=_WarnLog | None, default=_WarnLog | None, covariant=True)
|
||||
_T_or_bool = TypeVar("_T_or_bool", default=bool)
|
||||
|
||||
_StrLike: TypeAlias = str | bytes
|
||||
_RegexLike: TypeAlias = _StrLike | Pattern[Any]
|
||||
@@ -131,15 +133,16 @@ IS_MUSL: Final[bool] = ...
|
||||
IS_PYPY: Final[bool] = ...
|
||||
IS_PYSTON: Final[bool] = ...
|
||||
IS_WASM: Final[bool] = ...
|
||||
IS_64BIT: Final[bool] = ...
|
||||
HAS_REFCOUNT: Final[bool] = ...
|
||||
HAS_LAPACK64: Final[bool] = ...
|
||||
BLAS_SUPPORTS_FPE: Final[bool] = ...
|
||||
NOGIL_BUILD: Final[bool] = ...
|
||||
|
||||
class KnownFailureException(Exception): ...
|
||||
class IgnoreException(Exception): ...
|
||||
|
||||
# NOTE: `warnings.catch_warnings` is incorrectly defined as invariant in typeshed
|
||||
class clear_and_catch_warnings(warnings.catch_warnings[_W_co], Generic[_W_co]): # type: ignore[type-var] # pyright: ignore[reportInvalidTypeArguments]
|
||||
class clear_and_catch_warnings(warnings.catch_warnings[_W_co], Generic[_W_co]):
|
||||
class_modules: ClassVar[tuple[types.ModuleType, ...]] = ()
|
||||
modules: Final[set[types.ModuleType]]
|
||||
@overload # record: True
|
||||
@@ -149,6 +152,7 @@ class clear_and_catch_warnings(warnings.catch_warnings[_W_co], Generic[_W_co]):
|
||||
@overload # record; bool
|
||||
def __init__(self, /, record: bool, modules: _ToModules = ()) -> None: ...
|
||||
|
||||
@deprecated("Please use warnings.filterwarnings or pytest.mark.filterwarnings instead")
|
||||
class suppress_warnings:
|
||||
log: Final[_WarnLog]
|
||||
def __init__(self, /, forwarding_rule: L["always", "module", "once", "location"] = "always") -> None: ...
|
||||
@@ -163,9 +167,9 @@ class suppress_warnings:
|
||||
# Contrary to runtime we can't do `os.name` checks while type checking,
|
||||
# only `sys.platform` checks
|
||||
if sys.platform == "win32" or sys.platform == "cygwin":
|
||||
def memusage(processName: str = ..., instance: int = ...) -> int: ...
|
||||
def memusage(processName: str = "python", instance: int = 0) -> int: ...
|
||||
elif sys.platform == "linux":
|
||||
def memusage(_proc_pid_stat: StrOrBytesPath = ...) -> int | None: ...
|
||||
def memusage(_proc_pid_stat: StrOrBytesPath | None = None) -> int | None: ...
|
||||
else:
|
||||
def memusage() -> NoReturn: ...
|
||||
|
||||
@@ -178,10 +182,10 @@ else:
|
||||
def build_err_msg(
|
||||
arrays: Iterable[object],
|
||||
err_msg: object,
|
||||
header: str = ...,
|
||||
verbose: bool = ...,
|
||||
names: Sequence[str] = ...,
|
||||
precision: SupportsIndex | None = ...,
|
||||
header: str = "Items are not equal:",
|
||||
verbose: bool = True,
|
||||
names: Sequence[str] = ("ACTUAL", "DESIRED"), # = ('ACTUAL', 'DESIRED')
|
||||
precision: SupportsIndex | None = 8,
|
||||
) -> str: ...
|
||||
|
||||
#
|
||||
@@ -360,8 +364,10 @@ def assert_array_max_ulp(
|
||||
|
||||
#
|
||||
@overload
|
||||
@deprecated("Please use warnings.catch_warnings or pytest.warns instead")
|
||||
def assert_warns(warning_class: _WarningSpec) -> _GeneratorContextManager[None]: ...
|
||||
@overload
|
||||
@deprecated("Please use warnings.catch_warnings or pytest.warns instead")
|
||||
def assert_warns(warning_class: _WarningSpec, func: Callable[_Tss, _T], *args: _Tss.args, **kwargs: _Tss.kwargs) -> _T: ...
|
||||
|
||||
#
|
||||
@@ -453,7 +459,7 @@ def temppath(
|
||||
) -> _GeneratorContextManager[AnyStr]: ...
|
||||
|
||||
#
|
||||
def check_support_sve(__cache: list[_T_or_bool] = []) -> _T_or_bool: ... # noqa: PYI063
|
||||
def check_support_sve(__cache: list[bool] = ..., /) -> bool: ... # stubdefaulter: ignore[missing-default]
|
||||
|
||||
#
|
||||
def decorate_methods(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Callable, Hashable
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import ClassVar, Generic, Self
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -36,6 +36,9 @@ from numpy.testing import (
|
||||
|
||||
class _GenericTest:
|
||||
|
||||
def _assert_func(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def _test_equal(self, a, b):
|
||||
self._assert_func(a, b)
|
||||
|
||||
@@ -82,8 +85,8 @@ class _GenericTest:
|
||||
|
||||
class TestArrayEqual(_GenericTest):
|
||||
|
||||
def setup_method(self):
|
||||
self._assert_func = assert_array_equal
|
||||
def _assert_func(self, *args, **kwargs):
|
||||
assert_array_equal(*args, **kwargs)
|
||||
|
||||
def test_generic_rank1(self):
|
||||
"""Test rank 1 array for all dtypes."""
|
||||
@@ -197,6 +200,40 @@ class TestArrayEqual(_GenericTest):
|
||||
self._test_equal(a, b)
|
||||
self._test_equal(b, a)
|
||||
|
||||
# Also provides test cases for gh-11121
|
||||
def test_masked_scalar(self):
|
||||
# Test masked scalar vs. plain/masked scalar
|
||||
for a_val, b_val, b_masked in itertools.product(
|
||||
[3., np.nan, np.inf],
|
||||
[3., 4., np.nan, np.inf, -np.inf],
|
||||
[False, True],
|
||||
):
|
||||
a = np.ma.MaskedArray(a_val, mask=True)
|
||||
b = np.ma.MaskedArray(b_val, mask=True) if b_masked else np.array(b_val)
|
||||
self._test_equal(a, b)
|
||||
self._test_equal(b, a)
|
||||
|
||||
# Test masked scalar vs. plain array
|
||||
for a_val, b_val in itertools.product(
|
||||
[3., np.nan, -np.inf],
|
||||
itertools.product([3., 4., np.nan, np.inf, -np.inf], repeat=2),
|
||||
):
|
||||
a = np.ma.MaskedArray(a_val, mask=True)
|
||||
b = np.array(b_val)
|
||||
self._test_equal(a, b)
|
||||
self._test_equal(b, a)
|
||||
|
||||
# Test masked scalar vs. masked array
|
||||
for a_val, b_val, b_mask in itertools.product(
|
||||
[3., np.nan, np.inf],
|
||||
itertools.product([3., 4., np.nan, np.inf, -np.inf], repeat=2),
|
||||
itertools.product([False, True], repeat=2),
|
||||
):
|
||||
a = np.ma.MaskedArray(a_val, mask=True)
|
||||
b = np.ma.MaskedArray(b_val, mask=b_mask)
|
||||
self._test_equal(a, b)
|
||||
self._test_equal(b, a)
|
||||
|
||||
def test_subclass_that_overrides_eq(self):
|
||||
# While we cannot guarantee testing functions will always work for
|
||||
# subclasses, the tests should ideally rely only on subclasses having
|
||||
@@ -265,6 +302,8 @@ class TestArrayEqual(_GenericTest):
|
||||
b = np.array([34986, 545676, 439655, 0])
|
||||
|
||||
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [3]: 563766 (ACTUAL), 0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 563766\n'
|
||||
'Max relative difference among violations: inf')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -272,6 +311,9 @@ class TestArrayEqual(_GenericTest):
|
||||
|
||||
a = np.array([34986, 545676, 439655.2, 563766])
|
||||
expected_msg = ('Mismatched elements: 2 / 4 (50%)\n'
|
||||
'Mismatch at indices:\n'
|
||||
' [2]: 439655.2 (ACTUAL), 439655 (DESIRED)\n'
|
||||
' [3]: 563766.0 (ACTUAL), 0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: '
|
||||
'563766.\n'
|
||||
'Max relative difference among violations: '
|
||||
@@ -350,8 +392,8 @@ class TestBuildErrorMessage:
|
||||
|
||||
class TestEqual(TestArrayEqual):
|
||||
|
||||
def setup_method(self):
|
||||
self._assert_func = assert_equal
|
||||
def _assert_func(self, *args, **kwargs):
|
||||
assert_equal(*args, **kwargs)
|
||||
|
||||
def test_nan_items(self):
|
||||
self._assert_func(np.nan, np.nan)
|
||||
@@ -445,8 +487,8 @@ class TestEqual(TestArrayEqual):
|
||||
|
||||
class TestArrayAlmostEqual(_GenericTest):
|
||||
|
||||
def setup_method(self):
|
||||
self._assert_func = assert_array_almost_equal
|
||||
def _assert_func(self, *args, **kwargs):
|
||||
assert_array_almost_equal(*args, **kwargs)
|
||||
|
||||
def test_closeness(self):
|
||||
# Note that in the course of time we ended up with
|
||||
@@ -466,6 +508,8 @@ class TestArrayAlmostEqual(_GenericTest):
|
||||
self._assert_func([1.499999], [0.0], decimal=0)
|
||||
|
||||
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [0]: 1.5 (ACTUAL), 0.0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 1.5\n'
|
||||
'Max relative difference among violations: inf')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -474,12 +518,16 @@ class TestArrayAlmostEqual(_GenericTest):
|
||||
a = [1.4999999, 0.00003]
|
||||
b = [1.49999991, 0]
|
||||
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [1]: 3e-05 (ACTUAL), 0.0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 3.e-05\n'
|
||||
'Max relative difference among violations: inf')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
self._assert_func(a, b, decimal=7)
|
||||
|
||||
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [1]: 0.0 (ACTUAL), 3e-05 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 3.e-05\n'
|
||||
'Max relative difference among violations: 1.')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -493,6 +541,8 @@ class TestArrayAlmostEqual(_GenericTest):
|
||||
self._assert_func(x, y, decimal=4)
|
||||
|
||||
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [0]: 1234.2222 (ACTUAL), 1234.2223 (DESIRED)\n'
|
||||
'Max absolute difference among violations: '
|
||||
'1.e-04\n'
|
||||
'Max relative difference among violations: '
|
||||
@@ -504,6 +554,9 @@ class TestArrayAlmostEqual(_GenericTest):
|
||||
a = [5498.42354, 849.54345, 0.00]
|
||||
b = 5498.42354
|
||||
expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
|
||||
'Mismatch at indices:\n'
|
||||
' [1]: 849.54345 (ACTUAL), 5498.42354 (DESIRED)\n'
|
||||
' [2]: 0.0 (ACTUAL), 5498.42354 (DESIRED)\n'
|
||||
'Max absolute difference among violations: '
|
||||
'5498.42354\n'
|
||||
'Max relative difference among violations: 1.')
|
||||
@@ -511,6 +564,9 @@ class TestArrayAlmostEqual(_GenericTest):
|
||||
self._assert_func(a, b, decimal=9)
|
||||
|
||||
expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
|
||||
'Mismatch at indices:\n'
|
||||
' [1]: 5498.42354 (ACTUAL), 849.54345 (DESIRED)\n'
|
||||
' [2]: 5498.42354 (ACTUAL), 0.0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: '
|
||||
'5498.42354\n'
|
||||
'Max relative difference among violations: 5.4722099')
|
||||
@@ -519,6 +575,8 @@ class TestArrayAlmostEqual(_GenericTest):
|
||||
|
||||
a = [5498.42354, 0.00]
|
||||
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [1]: 5498.42354 (ACTUAL), 0.0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: '
|
||||
'5498.42354\n'
|
||||
'Max relative difference among violations: inf')
|
||||
@@ -527,6 +585,8 @@ class TestArrayAlmostEqual(_GenericTest):
|
||||
|
||||
b = 0
|
||||
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [0]: 5498.42354 (ACTUAL), 0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: '
|
||||
'5498.42354\n'
|
||||
'Max relative difference among violations: inf')
|
||||
@@ -555,6 +615,18 @@ class TestArrayAlmostEqual(_GenericTest):
|
||||
assert_raises(AssertionError,
|
||||
lambda: self._assert_func(a, b))
|
||||
|
||||
def test_complex_inf(self):
|
||||
a = np.array([np.inf + 1.j, 2. + 1.j, 3. + 1.j])
|
||||
b = a.copy()
|
||||
self._assert_func(a, b)
|
||||
b[1] = 3. + 1.j
|
||||
expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [1]: (2+1j) (ACTUAL), (3+1j) (DESIRED)\n'
|
||||
'Max absolute difference among violations: 1.\n')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
self._assert_func(a, b)
|
||||
|
||||
def test_subclass(self):
|
||||
a = np.array([[1., 2.], [3., 4.]])
|
||||
b = np.ma.masked_array([[1., 2.], [0., 4.]],
|
||||
@@ -603,6 +675,8 @@ class TestArrayAlmostEqual(_GenericTest):
|
||||
all(z)
|
||||
b = np.array([1., 202]).view(MyArray)
|
||||
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [1]: 2.0 (ACTUAL), 202.0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 200.\n'
|
||||
'Max relative difference among violations: 0.99009')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -629,8 +703,8 @@ class TestArrayAlmostEqual(_GenericTest):
|
||||
|
||||
class TestAlmostEqual(_GenericTest):
|
||||
|
||||
def setup_method(self):
|
||||
self._assert_func = assert_almost_equal
|
||||
def _assert_func(self, *args, **kwargs):
|
||||
assert_almost_equal(*args, **kwargs)
|
||||
|
||||
def test_closeness(self):
|
||||
# Note that in the course of time we ended up with
|
||||
@@ -693,6 +767,10 @@ class TestAlmostEqual(_GenericTest):
|
||||
|
||||
# Test with a different amount of decimal digits
|
||||
expected_msg = ('Mismatched elements: 3 / 3 (100%)\n'
|
||||
'Mismatch at indices:\n'
|
||||
' [0]: 1.00000000001 (ACTUAL), 1.00000000002 (DESIRED)\n'
|
||||
' [1]: 2.00000000002 (ACTUAL), 2.00000000003 (DESIRED)\n'
|
||||
' [2]: 3.00003 (ACTUAL), 3.00004 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 1.e-05\n'
|
||||
'Max relative difference among violations: '
|
||||
'3.33328889e-06\n'
|
||||
@@ -708,6 +786,8 @@ class TestAlmostEqual(_GenericTest):
|
||||
# differs. Note that we only check for the formatting of the arrays
|
||||
# themselves.
|
||||
expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [2]: 3.00003 (ACTUAL), 3.00004 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 1.e-05\n'
|
||||
'Max relative difference among violations: '
|
||||
'3.33328889e-06\n'
|
||||
@@ -720,6 +800,8 @@ class TestAlmostEqual(_GenericTest):
|
||||
x = np.array([np.inf, 0])
|
||||
y = np.array([np.inf, 1])
|
||||
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [1]: 0.0 (ACTUAL), 1.0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 1.\n'
|
||||
'Max relative difference among violations: 1.\n'
|
||||
' ACTUAL: array([inf, 0.])\n'
|
||||
@@ -731,6 +813,9 @@ class TestAlmostEqual(_GenericTest):
|
||||
x = np.array([1, 2])
|
||||
y = np.array([0, 0])
|
||||
expected_msg = ('Mismatched elements: 2 / 2 (100%)\n'
|
||||
'Mismatch at indices:\n'
|
||||
' [0]: 1 (ACTUAL), 0 (DESIRED)\n'
|
||||
' [1]: 2 (ACTUAL), 0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 2\n'
|
||||
'Max relative difference among violations: inf')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -742,6 +827,12 @@ class TestAlmostEqual(_GenericTest):
|
||||
x = 2
|
||||
y = np.ones(20)
|
||||
expected_msg = ('Mismatched elements: 20 / 20 (100%)\n'
|
||||
'First 5 mismatches are at indices:\n'
|
||||
' [0]: 2 (ACTUAL), 1.0 (DESIRED)\n'
|
||||
' [1]: 2 (ACTUAL), 1.0 (DESIRED)\n'
|
||||
' [2]: 2 (ACTUAL), 1.0 (DESIRED)\n'
|
||||
' [3]: 2 (ACTUAL), 1.0 (DESIRED)\n'
|
||||
' [4]: 2 (ACTUAL), 1.0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 1.\n'
|
||||
'Max relative difference among violations: 1.')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -750,6 +841,12 @@ class TestAlmostEqual(_GenericTest):
|
||||
y = 2
|
||||
x = np.ones(20)
|
||||
expected_msg = ('Mismatched elements: 20 / 20 (100%)\n'
|
||||
'First 5 mismatches are at indices:\n'
|
||||
' [0]: 1.0 (ACTUAL), 2 (DESIRED)\n'
|
||||
' [1]: 1.0 (ACTUAL), 2 (DESIRED)\n'
|
||||
' [2]: 1.0 (ACTUAL), 2 (DESIRED)\n'
|
||||
' [3]: 1.0 (ACTUAL), 2 (DESIRED)\n'
|
||||
' [4]: 1.0 (ACTUAL), 2 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 1.\n'
|
||||
'Max relative difference among violations: 0.5')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -776,8 +873,8 @@ class TestAlmostEqual(_GenericTest):
|
||||
|
||||
class TestApproxEqual:
|
||||
|
||||
def setup_method(self):
|
||||
self._assert_func = assert_approx_equal
|
||||
def _assert_func(self, *args, **kwargs):
|
||||
assert_approx_equal(*args, **kwargs)
|
||||
|
||||
def test_simple_0d_arrays(self):
|
||||
x = np.array(1234.22)
|
||||
@@ -819,8 +916,8 @@ class TestApproxEqual:
|
||||
|
||||
class TestArrayAssertLess:
|
||||
|
||||
def setup_method(self):
|
||||
self._assert_func = assert_array_less
|
||||
def _assert_func(self, *args, **kwargs):
|
||||
assert_array_less(*args, **kwargs)
|
||||
|
||||
def test_simple_arrays(self):
|
||||
x = np.array([1.1, 2.2])
|
||||
@@ -838,6 +935,9 @@ class TestArrayAssertLess:
|
||||
b = np.array([2, 4, 6, 8])
|
||||
|
||||
expected_msg = ('Mismatched elements: 2 / 4 (50%)\n'
|
||||
'Mismatch at indices:\n'
|
||||
' [2]: 6 (x), 6 (y)\n'
|
||||
' [3]: 20 (x), 8 (y)\n'
|
||||
'Max absolute difference among violations: 12\n'
|
||||
'Max relative difference among violations: 1.5')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -849,6 +949,11 @@ class TestArrayAssertLess:
|
||||
|
||||
self._assert_func(x, y)
|
||||
expected_msg = ('Mismatched elements: 4 / 4 (100%)\n'
|
||||
'Mismatch at indices:\n'
|
||||
' [0, 0]: 1.2 (x), 1.1 (y)\n'
|
||||
' [0, 1]: 2.3 (x), 2.2 (y)\n'
|
||||
' [1, 0]: 3.4 (x), 3.3 (y)\n'
|
||||
' [1, 1]: 4.5 (x), 4.4 (y)\n'
|
||||
'Max absolute difference among violations: 0.1\n'
|
||||
'Max relative difference among violations: 0.09090909')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -867,6 +972,8 @@ class TestArrayAssertLess:
|
||||
|
||||
y[0, 0, 0] = 0
|
||||
expected_msg = ('Mismatched elements: 1 / 8 (12.5%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [0, 0, 0]: 1.0 (x), 0.0 (y)\n'
|
||||
'Max absolute difference among violations: 1.\n'
|
||||
'Max relative difference among violations: inf')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -910,12 +1017,20 @@ class TestArrayAssertLess:
|
||||
y = 999090.54
|
||||
|
||||
expected_msg = ('Mismatched elements: 1 / 12 (8.33%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [1, 1]: 999090.54 (x), 999090.54 (y)\n'
|
||||
'Max absolute difference among violations: 0.\n'
|
||||
'Max relative difference among violations: 0.')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
self._assert_func(x, y)
|
||||
|
||||
expected_msg = ('Mismatched elements: 12 / 12 (100%)\n'
|
||||
'First 5 mismatches are at indices:\n'
|
||||
' [0, 0]: 999090.54 (x), 3.4536 (y)\n'
|
||||
' [0, 1]: 999090.54 (x), 2390.5436 (y)\n'
|
||||
' [0, 2]: 999090.54 (x), 435.54657 (y)\n'
|
||||
' [0, 3]: 999090.54 (x), 324525.4535 (y)\n'
|
||||
' [1, 0]: 999090.54 (x), 5449.54 (y)\n'
|
||||
'Max absolute difference among violations: '
|
||||
'999087.0864\n'
|
||||
'Max relative difference among violations: '
|
||||
@@ -928,12 +1043,17 @@ class TestArrayAssertLess:
|
||||
y = np.array(87654.)
|
||||
|
||||
expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [0]: 546456.0 (x), 87654.0 (y)\n'
|
||||
'Max absolute difference among violations: 458802.\n'
|
||||
'Max relative difference among violations: 5.23423917')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
self._assert_func(x, y)
|
||||
|
||||
expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
|
||||
'Mismatch at indices:\n'
|
||||
' [1]: 87654.0 (x), 0.0 (y)\n'
|
||||
' [2]: 87654.0 (x), 15.455 (y)\n'
|
||||
'Max absolute difference among violations: 87654.\n'
|
||||
'Max relative difference among violations: '
|
||||
'5670.5626011')
|
||||
@@ -943,12 +1063,18 @@ class TestArrayAssertLess:
|
||||
y = 0
|
||||
|
||||
expected_msg = ('Mismatched elements: 3 / 3 (100%)\n'
|
||||
'Mismatch at indices:\n'
|
||||
' [0]: 546456.0 (x), 0 (y)\n'
|
||||
' [1]: 0.0 (x), 0 (y)\n'
|
||||
' [2]: 15.455 (x), 0 (y)\n'
|
||||
'Max absolute difference among violations: 546456.\n'
|
||||
'Max relative difference among violations: inf')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
self._assert_func(x, y)
|
||||
|
||||
expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [1]: 0 (x), 0.0 (y)\n'
|
||||
'Max absolute difference among violations: 0.\n'
|
||||
'Max relative difference among violations: inf')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -1017,7 +1143,10 @@ class TestArrayAssertLess:
|
||||
with pytest.raises(AssertionError):
|
||||
self._assert_func(x, y.astype(np.float32), strict=True)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
|
||||
".*:DeprecationWarning")
|
||||
@pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
|
||||
class TestWarns:
|
||||
|
||||
def test_warn(self):
|
||||
@@ -1134,12 +1263,16 @@ class TestAssertAllclose:
|
||||
b = np.array([x, y, x, x])
|
||||
c = np.array([x, y, x, z])
|
||||
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [3]: 0.001 (ACTUAL), 0.0 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 0.001\n'
|
||||
'Max relative difference among violations: inf')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
assert_allclose(b, c)
|
||||
|
||||
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [3]: 0.0 (ACTUAL), 0.001 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 0.001\n'
|
||||
'Max relative difference among violations: 1.')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -1155,6 +1288,8 @@ class TestAssertAllclose:
|
||||
b = np.array([1, 1, 1, 2])
|
||||
|
||||
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
|
||||
'Mismatch at index:\n'
|
||||
' [3]: 1 (ACTUAL), 2 (DESIRED)\n'
|
||||
'Max absolute difference among violations: 1\n'
|
||||
'Max relative difference among violations: 0.5')
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
@@ -1166,11 +1301,21 @@ class TestAssertAllclose:
|
||||
# Should not raise:
|
||||
assert_allclose(a, b, equal_nan=True)
|
||||
|
||||
a = np.array([complex(np.nan, np.inf)])
|
||||
b = np.array([complex(np.nan, np.inf)])
|
||||
assert_allclose(a, b, equal_nan=True)
|
||||
b = np.array([complex(np.nan, -np.inf)])
|
||||
assert_allclose(a, b, equal_nan=True)
|
||||
|
||||
def test_not_equal_nan(self):
|
||||
a = np.array([np.nan])
|
||||
b = np.array([np.nan])
|
||||
assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
|
||||
|
||||
a = np.array([complex(np.nan, np.inf)])
|
||||
b = np.array([complex(np.nan, np.inf)])
|
||||
assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
|
||||
|
||||
def test_equal_nan_default(self):
|
||||
# Make sure equal_nan default behavior remains unchanged. (All
|
||||
# of these functions use assert_array_compare under the hood.)
|
||||
@@ -1219,6 +1364,33 @@ class TestAssertAllclose:
|
||||
with pytest.raises(AssertionError):
|
||||
assert_allclose(x, x.astype(np.float32), strict=True)
|
||||
|
||||
def test_infs(self):
|
||||
a = np.array([np.inf])
|
||||
b = np.array([np.inf])
|
||||
assert_allclose(a, b)
|
||||
|
||||
b = np.array([3.])
|
||||
expected_msg = 'inf location mismatch:'
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
assert_allclose(a, b)
|
||||
|
||||
b = np.array([-np.inf])
|
||||
expected_msg = 'inf values mismatch:'
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
assert_allclose(a, b)
|
||||
b = np.array([complex(np.inf, 1.)])
|
||||
expected_msg = 'inf values mismatch:'
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
assert_allclose(a, b)
|
||||
|
||||
a = np.array([complex(np.inf, 1.)])
|
||||
b = np.array([complex(np.inf, 1.)])
|
||||
assert_allclose(a, b)
|
||||
|
||||
b = np.array([complex(np.inf, 2.)])
|
||||
expected_msg = 'inf values mismatch:'
|
||||
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
|
||||
assert_allclose(a, b)
|
||||
|
||||
class TestArrayAlmostEqualNulp:
|
||||
|
||||
@@ -1589,6 +1761,7 @@ def _get_fresh_mod():
|
||||
return my_mod
|
||||
|
||||
|
||||
@pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
|
||||
def test_clear_and_catch_warnings():
|
||||
# Initial state of module, no warnings
|
||||
my_mod = _get_fresh_mod()
|
||||
@@ -1621,6 +1794,10 @@ def test_clear_and_catch_warnings():
|
||||
assert_warn_len_equal(my_mod, 0)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
|
||||
".*:DeprecationWarning")
|
||||
@pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
|
||||
def test_suppress_warnings_module():
|
||||
# Initial state of module, no warnings
|
||||
my_mod = _get_fresh_mod()
|
||||
@@ -1667,6 +1844,10 @@ def test_suppress_warnings_module():
|
||||
assert_warn_len_equal(my_mod, 0)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
|
||||
".*:DeprecationWarning")
|
||||
@pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
|
||||
def test_suppress_warnings_type():
|
||||
# Initial state of module, no warnings
|
||||
my_mod = _get_fresh_mod()
|
||||
@@ -1695,6 +1876,12 @@ def test_suppress_warnings_type():
|
||||
assert_warn_len_equal(my_mod, 0)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
|
||||
".*:DeprecationWarning")
|
||||
@pytest.mark.thread_unsafe(
|
||||
reason="uses deprecated thread-unsafe warnings control utilities"
|
||||
)
|
||||
def test_suppress_warnings_decorate_no_record():
|
||||
sup = suppress_warnings()
|
||||
sup.filter(UserWarning)
|
||||
@@ -1710,6 +1897,12 @@ def test_suppress_warnings_decorate_no_record():
|
||||
assert_equal(len(w), 1)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
|
||||
".*:DeprecationWarning")
|
||||
@pytest.mark.thread_unsafe(
|
||||
reason="uses deprecated thread-unsafe warnings control utilities"
|
||||
)
|
||||
def test_suppress_warnings_record():
|
||||
sup = suppress_warnings()
|
||||
log1 = sup.record()
|
||||
@@ -1747,9 +1940,16 @@ def test_suppress_warnings_record():
|
||||
warnings.warn('Some warning')
|
||||
warnings.warn('Some other warning')
|
||||
assert_equal(len(sup2.log), 1)
|
||||
assert_equal(len(sup.log), 1)
|
||||
# includes a DeprecationWarning for suppress_warnings
|
||||
assert_equal(len(sup.log), 2)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
|
||||
".*:DeprecationWarning")
|
||||
@pytest.mark.thread_unsafe(
|
||||
reason="uses deprecated thread-unsafe warnings control utilities"
|
||||
)
|
||||
def test_suppress_warnings_forwarding():
|
||||
def warn_other_module():
|
||||
# Apply along axis is implemented in python; stacklevel=2 means
|
||||
@@ -1765,7 +1965,8 @@ def test_suppress_warnings_forwarding():
|
||||
for i in range(2):
|
||||
warnings.warn("Some warning")
|
||||
|
||||
assert_equal(len(sup.log), 2)
|
||||
# includes a DeprecationWarning for suppress_warnings
|
||||
assert_equal(len(sup.log), 3)
|
||||
|
||||
with suppress_warnings() as sup:
|
||||
sup.record()
|
||||
@@ -1774,7 +1975,8 @@ def test_suppress_warnings_forwarding():
|
||||
warnings.warn("Some warning")
|
||||
warnings.warn("Some warning")
|
||||
|
||||
assert_equal(len(sup.log), 2)
|
||||
# includes a DeprecationWarning for suppress_warnings
|
||||
assert_equal(len(sup.log), 3)
|
||||
|
||||
with suppress_warnings() as sup:
|
||||
sup.record()
|
||||
@@ -1784,7 +1986,8 @@ def test_suppress_warnings_forwarding():
|
||||
warnings.warn("Some warning")
|
||||
warn_other_module()
|
||||
|
||||
assert_equal(len(sup.log), 2)
|
||||
# includes a DeprecationWarning for suppress_warnings
|
||||
assert_equal(len(sup.log), 3)
|
||||
|
||||
with suppress_warnings() as sup:
|
||||
sup.record()
|
||||
@@ -1794,7 +1997,8 @@ def test_suppress_warnings_forwarding():
|
||||
warnings.warn("Some other warning")
|
||||
warn_other_module()
|
||||
|
||||
assert_equal(len(sup.log), 2)
|
||||
# includes a DeprecationWarning for suppress_warnings
|
||||
assert_equal(len(sup.log), 3)
|
||||
|
||||
|
||||
def test_tempdir():
|
||||
@@ -1835,6 +2039,7 @@ class my_cacw(clear_and_catch_warnings):
|
||||
class_modules = (sys.modules[__name__],)
|
||||
|
||||
|
||||
@pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
|
||||
def test_clear_and_catch_warnings_inherit():
|
||||
# Test can subclass and add default modules
|
||||
my_mod = _get_fresh_mod()
|
||||
@@ -1845,6 +2050,7 @@ def test_clear_and_catch_warnings_inherit():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
|
||||
@pytest.mark.thread_unsafe(reason="garbage collector is global state")
|
||||
class TestAssertNoGcCycles:
|
||||
""" Test assert_no_gc_cycles """
|
||||
|
||||
|
||||
Reference in New Issue
Block a user