增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -1,9 +1,12 @@
from unittest import TestCase
from . import overrides
from . import _private as _private, overrides
from ._private import extbuild as extbuild
from ._private.utils import (
BLAS_SUPPORTS_FPE,
HAS_LAPACK64,
HAS_REFCOUNT,
IS_64BIT,
IS_EDITABLE,
IS_INSTALLED,
IS_MUSL,
@@ -51,8 +54,10 @@ from ._private.utils import (
)
__all__ = [
"BLAS_SUPPORTS_FPE",
"HAS_LAPACK64",
"HAS_REFCOUNT",
"IS_64BIT",
"IS_EDITABLE",
"IS_INSTALLED",
"IS_MUSL",

View File

@@ -25,7 +25,7 @@ from warnings import WarningMessage
import numpy as np
import numpy.linalg._umath_linalg
from numpy import isfinite, isinf, isnan
from numpy import isfinite, isnan
from numpy._core import arange, array, array_repr, empty, float32, intp, isnat, ndarray
__all__ = [
@@ -90,14 +90,7 @@ IS_WASM = platform.machine() in ["wasm32", "wasm64"]
IS_PYPY = sys.implementation.name == 'pypy'
IS_PYSTON = hasattr(sys, "pyston_version_info")
HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON
BLAS_SUPPORTS_FPE = True
if platform.system() == 'Darwin' or platform.machine() == 'arm64':
try:
blas = np.__config__.CONFIG['Build Dependencies']['blas']
if blas['name'] == 'accelerate':
BLAS_SUPPORTS_FPE = False
except KeyError:
pass
BLAS_SUPPORTS_FPE = np._core._multiarray_umath._blas_supports_fpe(None)
HAS_LAPACK64 = numpy.linalg._umath_linalg._ilp64
@@ -303,9 +296,10 @@ def assert_equal(actual, desired, err_msg='', verbose=True, *, strict=False):
Notes
-----
By default, when one of `actual` and `desired` is a scalar and the other is
an array, the function checks that each element of the array is equal to
the scalar. This behaviour can be disabled by setting ``strict==True``.
When one of `actual` and `desired` is a scalar and the other is array_like, the
function checks that each element of the array_like is equal to the scalar.
Note that empty arrays are therefore considered equal to scalars.
This behaviour can be disabled by setting ``strict==True``.
Examples
--------
@@ -363,7 +357,7 @@ def assert_equal(actual, desired, err_msg='', verbose=True, *, strict=False):
if not isinstance(actual, dict):
raise AssertionError(repr(type(actual)))
assert_equal(len(actual), len(desired), err_msg, verbose)
for k, i in desired.items():
for k in desired:
if k not in actual:
raise AssertionError(repr(k))
assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}',
@@ -573,6 +567,8 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
Arrays are not almost equal to 9 decimals
<BLANKLINE>
Mismatched elements: 1 / 2 (50%)
Mismatch at index:
[1]: 2.3333333333333 (ACTUAL), 2.33333334 (DESIRED)
Max absolute difference among violations: 6.66669964e-09
Max relative difference among violations: 2.85715698e-09
ACTUAL: array([1. , 2.333333333])
@@ -755,6 +751,24 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
def isvstring(x):
return x.dtype.char == "T"
def robust_any_difference(x, y):
# We include work-arounds here to handle three types of slightly
# pathological ndarray subclasses:
# (1) all() on fully masked arrays returns np.ma.masked, so we use != True
# (np.ma.masked != True evaluates as np.ma.masked, which is falsy).
# (2) __eq__ on some ndarray subclasses returns Python booleans
# instead of element-wise comparisons, so we cast to np.bool() in
# that case (or in case __eq__ returns some other value with no
# all() method).
# (3) subclasses with bare-bones __array_function__ implementations may
# not implement np.all(), so favor using the .all() method
# We are not committed to supporting cases (2) and (3), but it's nice to
# support them if possible.
result = x == y
if not hasattr(result, "all") or not callable(result.all):
result = np.bool(result)
return result.all() != True
def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
"""Handling nan/inf.
@@ -766,18 +780,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
x_id = func(x)
y_id = func(y)
# We include work-arounds here to handle three types of slightly
# pathological ndarray subclasses:
# (1) all() on `masked` array scalars can return masked arrays, so we
# use != True
# (2) __eq__ on some ndarray subclasses returns Python booleans
# instead of element-wise comparisons, so we cast to np.bool() and
# use isinstance(..., bool) checks
# (3) subclasses with bare-bones __array_function__ implementations may
# not implement np.all(), so favor using the .all() method
# We are not committed to supporting such subclasses, but it's nice to
# support them if possible.
if np.bool(x_id == y_id).all() != True:
if robust_any_difference(x_id, y_id):
msg = build_err_msg(
[x, y],
err_msg + '\n%s location mismatch:'
@@ -787,6 +790,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
raise AssertionError(msg)
# If there is a scalar, then here we know the array has the same
# flag as it everywhere, so we should return the scalar flag.
# np.ma.masked is also handled and converted to np.False_ (even if the other
# array has nans/infs etc.; that's OK given the handling later of fully-masked
# results).
if isinstance(x_id, bool) or x_id.ndim == 0:
return np.bool(x_id)
elif isinstance(y_id, bool) or y_id.ndim == 0:
@@ -794,6 +800,29 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
else:
return y_id
def assert_same_inf_values(x, y, infs_mask):
"""
Verify all inf values match in the two arrays
"""
__tracebackhide__ = True # Hide traceback for py.test
if not infs_mask.any():
return
if x.ndim > 0 and y.ndim > 0:
x = x[infs_mask]
y = y[infs_mask]
else:
assert infs_mask.all()
if robust_any_difference(x, y):
msg = build_err_msg(
[x, y],
err_msg + '\ninf values mismatch:',
verbose=verbose, header=header,
names=names,
precision=precision)
raise AssertionError(msg)
try:
if strict:
cond = x.shape == y.shape and x.dtype == y.dtype
@@ -818,12 +847,15 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
if equal_inf:
flagged |= func_assert_same_pos(x, y,
func=lambda xy: xy == +inf,
hasval='+inf')
flagged |= func_assert_same_pos(x, y,
func=lambda xy: xy == -inf,
hasval='-inf')
# If equal_nan=True, skip comparing nans below for equality if they are
# also infs (e.g. inf+nanj) since that would always fail.
isinf_func = lambda xy: np.logical_and(np.isinf(xy), np.invert(flagged))
infs_mask = func_assert_same_pos(
x, y,
func=isinf_func,
hasval='inf')
assert_same_inf_values(x, y, infs_mask)
flagged |= infs_mask
elif istime(x) and istime(y):
# If one is datetime64 and the other timedelta64 there is no point
@@ -874,6 +906,31 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
percent_mismatch = 100 * n_mismatch / n_elements
remarks = [f'Mismatched elements: {n_mismatch} / {n_elements} '
f'({percent_mismatch:.3g}%)']
if invalids.ndim != 0:
if flagged.ndim > 0:
positions = np.argwhere(np.asarray(~flagged))[invalids]
else:
positions = np.argwhere(np.asarray(invalids))
s = "\n".join(
[
f" {p.tolist()}: {ox if ox.ndim == 0 else ox[tuple(p)]} "
f"({names[0]}), {oy if oy.ndim == 0 else oy[tuple(p)]} "
f"({names[1]})"
for p in positions[:5]
]
)
if len(positions) == 1:
remarks.append(
f"Mismatch at index:\n{s}"
)
elif len(positions) <= 5:
remarks.append(
f"Mismatch at indices:\n{s}"
)
else:
remarks.append(
f"First 5 mismatches are at indices:\n{s}"
)
with errstate(all='ignore'):
# ignore errors for non-numeric types
@@ -990,9 +1047,10 @@ def assert_array_equal(actual, desired, err_msg='', verbose=True, *,
Notes
-----
When one of `actual` and `desired` is a scalar and the other is array_like,
the function checks that each element of the array_like object is equal to
the scalar. This behaviour can be disabled with the `strict` parameter.
When one of `actual` and `desired` is a scalar and the other is array_like, the
function checks that each element of the array_like is equal to the scalar.
Note that empty arrays are therefore considered equal to scalars.
This behaviour can be disabled by setting ``strict==True``.
Examples
--------
@@ -1011,6 +1069,8 @@ def assert_array_equal(actual, desired, err_msg='', verbose=True, *,
Arrays are not equal
<BLANKLINE>
Mismatched elements: 1 / 3 (33.3%)
Mismatch at index:
[1]: 3.141592653589793 (ACTUAL), 3.1415926535897927 (DESIRED)
Max absolute difference among violations: 4.4408921e-16
Max relative difference among violations: 1.41357986e-16
ACTUAL: array([1. , 3.141593, nan])
@@ -1124,6 +1184,8 @@ def assert_array_almost_equal(actual, desired, decimal=6, err_msg='',
Arrays are not almost equal to 5 decimals
<BLANKLINE>
Mismatched elements: 1 / 3 (33.3%)
Mismatch at index:
[1]: 2.33333 (ACTUAL), 2.33339 (DESIRED)
Max absolute difference among violations: 6.e-05
Max relative difference among violations: 2.57136612e-05
ACTUAL: array([1. , 2.33333, nan])
@@ -1143,24 +1205,9 @@ def assert_array_almost_equal(actual, desired, decimal=6, err_msg='',
"""
__tracebackhide__ = True # Hide traceback for py.test
from numpy._core import number, result_type
from numpy._core.fromnumeric import any as npany
from numpy._core.numerictypes import issubdtype
def compare(x, y):
try:
if npany(isinf(x)) or npany(isinf(y)):
xinfid = isinf(x)
yinfid = isinf(y)
if not (xinfid == yinfid).all():
return False
# if one item, x and y is +- inf
if x.size == y.size == 1:
return x == y
x = x[~xinfid]
y = y[~yinfid]
except (TypeError, NotImplementedError):
pass
# make sure y is an inexact type to avoid abs(MIN_INT); will cause
# casting of x later.
dtype = result_type(y, 1.)
@@ -1245,6 +1292,8 @@ def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False):
Arrays are not strictly ordered `x < y`
<BLANKLINE>
Mismatched elements: 1 / 3 (33.3%)
Mismatch at index:
[0]: 1.0 (x), 1.0 (y)
Max absolute difference among violations: 0.
Max relative difference among violations: 0.
x: array([ 1., 1., nan])
@@ -1623,9 +1672,10 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
contrast to the standard usage in numpy, NaNs are compared like numbers,
no assertion is raised if both objects have NaNs in the same positions.
The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note
that ``allclose`` has different default values). It compares the difference
between `actual` and `desired` to ``atol + rtol * abs(desired)``.
The test is equivalent to ``allclose(actual, desired, rtol, atol)``,
except that it is stricter: it doesn't broadcast its operands, and has
tighter default tolerance values. It compares the difference between
`actual` and `desired` to ``atol + rtol * abs(desired)``.
Parameters
----------
@@ -1661,10 +1711,10 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
Notes
-----
When one of `actual` and `desired` is a scalar and the other is
array_like, the function performs the comparison as if the scalar were
broadcasted to the shape of the array.
This behaviour can be disabled with the `strict` parameter.
When one of `actual` and `desired` is a scalar and the other is array_like, the
function performs the comparison as if the scalar were broadcasted to the shape
of the array. Note that empty arrays are therefore considered equal to scalars.
This behaviour can be disabled by setting ``strict==True``.
Examples
--------
@@ -1928,7 +1978,7 @@ def integer_repr(x):
@contextlib.contextmanager
def _assert_warns_context(warning_class, name=None):
__tracebackhide__ = True # Hide traceback for py.test
with suppress_warnings() as sup:
with suppress_warnings(_warn=False) as sup:
l = sup.record(warning_class)
yield
if not len(l) > 0:
@@ -1952,6 +2002,11 @@ def assert_warns(warning_class, *args, **kwargs):
The ability to be used as a context manager is new in NumPy v1.11.0.
.. deprecated:: 2.4
This is deprecated. Use `warnings.catch_warnings` or
``pytest.warns`` instead.
Parameters
----------
warning_class : class
@@ -1979,6 +2034,11 @@ def assert_warns(warning_class, *args, **kwargs):
>>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4)
>>> assert ret == 16
"""
warnings.warn(
"NumPy warning suppression and assertion utilities are deprecated. "
"Use warnings.catch_warnings, warnings.filterwarnings, pytest.warns, "
"or pytest.filterwarnings instead. (Deprecated NumPy 2.4)",
DeprecationWarning, stacklevel=2)
if not args and not kwargs:
return _assert_warns_context(warning_class)
elif len(args) < 1:
@@ -2231,6 +2291,11 @@ class suppress_warnings:
tests might need to see the warning. Additionally it allows easier
specificity for testing warnings and can be nested.
.. deprecated:: 2.4
This is deprecated. Use `warnings.filterwarnings` or
``pytest.filterwarnings`` instead.
Parameters
----------
forwarding_rule : str, optional
@@ -2291,7 +2356,13 @@ class suppress_warnings:
# do something which causes a warning in np.ma.core
pass
"""
def __init__(self, forwarding_rule="always"):
def __init__(self, forwarding_rule="always", _warn=True):
if _warn:
warnings.warn(
"NumPy warning suppression and assertion utilities are deprecated. "
"Use warnings.catch_warnings, warnings.filterwarnings, pytest.warns, "
"or pytest.filterwarnings instead. (Deprecated NumPy 2.4)",
DeprecationWarning, stacklevel=2)
self._entered = False
# Suppressions are either instance or defined inside one with block:

View File

@@ -3,6 +3,7 @@ import sys
import types
import unittest
import warnings
from _typeshed import ConvertibleToFloat, GenericPath, StrOrBytesPath, StrPath
from collections.abc import Callable, Iterable, Sequence
from contextlib import _GeneratorContextManager
from pathlib import Path
@@ -13,6 +14,7 @@ from typing import (
ClassVar,
Final,
Generic,
Literal as L,
NoReturn,
ParamSpec,
Self,
@@ -22,12 +24,9 @@ from typing import (
overload,
type_check_only,
)
from typing import Literal as L
from typing_extensions import TypeVar, deprecated
from unittest.case import SkipTest
from _typeshed import ConvertibleToFloat, GenericPath, StrOrBytesPath, StrPath
from typing_extensions import TypeVar
import numpy as np
from numpy._typing import (
ArrayLike,
@@ -45,9 +44,13 @@ __all__ = [ # noqa: RUF022
"IS_PYPY",
"IS_PYSTON",
"IS_WASM",
"IS_INSTALLED",
"IS_64BIT",
"HAS_LAPACK64",
"HAS_REFCOUNT",
"BLAS_SUPPORTS_FPE",
"NOGIL_BUILD",
"NUMPY_ROOT",
"assert_",
"assert_array_almost_equal_nulp",
"assert_raises_regex",
@@ -94,7 +97,6 @@ _Tss = ParamSpec("_Tss")
_ET = TypeVar("_ET", bound=BaseException, default=BaseException)
_FT = TypeVar("_FT", bound=Callable[..., Any])
_W_co = TypeVar("_W_co", bound=_WarnLog | None, default=_WarnLog | None, covariant=True)
_T_or_bool = TypeVar("_T_or_bool", default=bool)
_StrLike: TypeAlias = str | bytes
_RegexLike: TypeAlias = _StrLike | Pattern[Any]
@@ -131,15 +133,16 @@ IS_MUSL: Final[bool] = ...
IS_PYPY: Final[bool] = ...
IS_PYSTON: Final[bool] = ...
IS_WASM: Final[bool] = ...
IS_64BIT: Final[bool] = ...
HAS_REFCOUNT: Final[bool] = ...
HAS_LAPACK64: Final[bool] = ...
BLAS_SUPPORTS_FPE: Final[bool] = ...
NOGIL_BUILD: Final[bool] = ...
class KnownFailureException(Exception): ...
class IgnoreException(Exception): ...
# NOTE: `warnings.catch_warnings` is incorrectly defined as invariant in typeshed
class clear_and_catch_warnings(warnings.catch_warnings[_W_co], Generic[_W_co]): # type: ignore[type-var] # pyright: ignore[reportInvalidTypeArguments]
class clear_and_catch_warnings(warnings.catch_warnings[_W_co], Generic[_W_co]):
class_modules: ClassVar[tuple[types.ModuleType, ...]] = ()
modules: Final[set[types.ModuleType]]
@overload # record: True
@@ -149,6 +152,7 @@ class clear_and_catch_warnings(warnings.catch_warnings[_W_co], Generic[_W_co]):
@overload # record; bool
def __init__(self, /, record: bool, modules: _ToModules = ()) -> None: ...
@deprecated("Please use warnings.filterwarnings or pytest.mark.filterwarnings instead")
class suppress_warnings:
log: Final[_WarnLog]
def __init__(self, /, forwarding_rule: L["always", "module", "once", "location"] = "always") -> None: ...
@@ -163,9 +167,9 @@ class suppress_warnings:
# Contrary to runtime we can't do `os.name` checks while type checking,
# only `sys.platform` checks
if sys.platform == "win32" or sys.platform == "cygwin":
def memusage(processName: str = ..., instance: int = ...) -> int: ...
def memusage(processName: str = "python", instance: int = 0) -> int: ...
elif sys.platform == "linux":
def memusage(_proc_pid_stat: StrOrBytesPath = ...) -> int | None: ...
def memusage(_proc_pid_stat: StrOrBytesPath | None = None) -> int | None: ...
else:
def memusage() -> NoReturn: ...
@@ -178,10 +182,10 @@ else:
def build_err_msg(
arrays: Iterable[object],
err_msg: object,
header: str = ...,
verbose: bool = ...,
names: Sequence[str] = ...,
precision: SupportsIndex | None = ...,
header: str = "Items are not equal:",
verbose: bool = True,
names: Sequence[str] = ("ACTUAL", "DESIRED"), # = ('ACTUAL', 'DESIRED')
precision: SupportsIndex | None = 8,
) -> str: ...
#
@@ -360,8 +364,10 @@ def assert_array_max_ulp(
#
@overload
@deprecated("Please use warnings.catch_warnings or pytest.warns instead")
def assert_warns(warning_class: _WarningSpec) -> _GeneratorContextManager[None]: ...
@overload
@deprecated("Please use warnings.catch_warnings or pytest.warns instead")
def assert_warns(warning_class: _WarningSpec, func: Callable[_Tss, _T], *args: _Tss.args, **kwargs: _Tss.kwargs) -> _T: ...
#
@@ -453,7 +459,7 @@ def temppath(
) -> _GeneratorContextManager[AnyStr]: ...
#
def check_support_sve(__cache: list[_T_or_bool] = []) -> _T_or_bool: ... # noqa: PYI063
def check_support_sve(__cache: list[bool] = ..., /) -> bool: ... # stubdefaulter: ignore[missing-default]
#
def decorate_methods(

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable, Hashable
from typing import Any
from typing_extensions import TypeIs
import numpy as np

View File

@@ -1,6 +1,5 @@
from collections.abc import Iterable
from typing import ClassVar, Generic, Self
from typing_extensions import TypeVar
import numpy as np

View File

@@ -36,6 +36,9 @@ from numpy.testing import (
class _GenericTest:
def _assert_func(self, *args, **kwargs):
pass
def _test_equal(self, a, b):
self._assert_func(a, b)
@@ -82,8 +85,8 @@ class _GenericTest:
class TestArrayEqual(_GenericTest):
def setup_method(self):
self._assert_func = assert_array_equal
def _assert_func(self, *args, **kwargs):
assert_array_equal(*args, **kwargs)
def test_generic_rank1(self):
"""Test rank 1 array for all dtypes."""
@@ -197,6 +200,40 @@ class TestArrayEqual(_GenericTest):
self._test_equal(a, b)
self._test_equal(b, a)
# Also provides test cases for gh-11121
def test_masked_scalar(self):
# Test masked scalar vs. plain/masked scalar
for a_val, b_val, b_masked in itertools.product(
[3., np.nan, np.inf],
[3., 4., np.nan, np.inf, -np.inf],
[False, True],
):
a = np.ma.MaskedArray(a_val, mask=True)
b = np.ma.MaskedArray(b_val, mask=True) if b_masked else np.array(b_val)
self._test_equal(a, b)
self._test_equal(b, a)
# Test masked scalar vs. plain array
for a_val, b_val in itertools.product(
[3., np.nan, -np.inf],
itertools.product([3., 4., np.nan, np.inf, -np.inf], repeat=2),
):
a = np.ma.MaskedArray(a_val, mask=True)
b = np.array(b_val)
self._test_equal(a, b)
self._test_equal(b, a)
# Test masked scalar vs. masked array
for a_val, b_val, b_mask in itertools.product(
[3., np.nan, np.inf],
itertools.product([3., 4., np.nan, np.inf, -np.inf], repeat=2),
itertools.product([False, True], repeat=2),
):
a = np.ma.MaskedArray(a_val, mask=True)
b = np.ma.MaskedArray(b_val, mask=b_mask)
self._test_equal(a, b)
self._test_equal(b, a)
def test_subclass_that_overrides_eq(self):
# While we cannot guarantee testing functions will always work for
# subclasses, the tests should ideally rely only on subclasses having
@@ -265,6 +302,8 @@ class TestArrayEqual(_GenericTest):
b = np.array([34986, 545676, 439655, 0])
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
'Mismatch at index:\n'
' [3]: 563766 (ACTUAL), 0 (DESIRED)\n'
'Max absolute difference among violations: 563766\n'
'Max relative difference among violations: inf')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -272,6 +311,9 @@ class TestArrayEqual(_GenericTest):
a = np.array([34986, 545676, 439655.2, 563766])
expected_msg = ('Mismatched elements: 2 / 4 (50%)\n'
'Mismatch at indices:\n'
' [2]: 439655.2 (ACTUAL), 439655 (DESIRED)\n'
' [3]: 563766.0 (ACTUAL), 0 (DESIRED)\n'
'Max absolute difference among violations: '
'563766.\n'
'Max relative difference among violations: '
@@ -350,8 +392,8 @@ class TestBuildErrorMessage:
class TestEqual(TestArrayEqual):
def setup_method(self):
self._assert_func = assert_equal
def _assert_func(self, *args, **kwargs):
assert_equal(*args, **kwargs)
def test_nan_items(self):
self._assert_func(np.nan, np.nan)
@@ -445,8 +487,8 @@ class TestEqual(TestArrayEqual):
class TestArrayAlmostEqual(_GenericTest):
def setup_method(self):
self._assert_func = assert_array_almost_equal
def _assert_func(self, *args, **kwargs):
assert_array_almost_equal(*args, **kwargs)
def test_closeness(self):
# Note that in the course of time we ended up with
@@ -466,6 +508,8 @@ class TestArrayAlmostEqual(_GenericTest):
self._assert_func([1.499999], [0.0], decimal=0)
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
'Mismatch at index:\n'
' [0]: 1.5 (ACTUAL), 0.0 (DESIRED)\n'
'Max absolute difference among violations: 1.5\n'
'Max relative difference among violations: inf')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -474,12 +518,16 @@ class TestArrayAlmostEqual(_GenericTest):
a = [1.4999999, 0.00003]
b = [1.49999991, 0]
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
'Mismatch at index:\n'
' [1]: 3e-05 (ACTUAL), 0.0 (DESIRED)\n'
'Max absolute difference among violations: 3.e-05\n'
'Max relative difference among violations: inf')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
self._assert_func(a, b, decimal=7)
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
'Mismatch at index:\n'
' [1]: 0.0 (ACTUAL), 3e-05 (DESIRED)\n'
'Max absolute difference among violations: 3.e-05\n'
'Max relative difference among violations: 1.')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -493,6 +541,8 @@ class TestArrayAlmostEqual(_GenericTest):
self._assert_func(x, y, decimal=4)
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
'Mismatch at index:\n'
' [0]: 1234.2222 (ACTUAL), 1234.2223 (DESIRED)\n'
'Max absolute difference among violations: '
'1.e-04\n'
'Max relative difference among violations: '
@@ -504,6 +554,9 @@ class TestArrayAlmostEqual(_GenericTest):
a = [5498.42354, 849.54345, 0.00]
b = 5498.42354
expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
'Mismatch at indices:\n'
' [1]: 849.54345 (ACTUAL), 5498.42354 (DESIRED)\n'
' [2]: 0.0 (ACTUAL), 5498.42354 (DESIRED)\n'
'Max absolute difference among violations: '
'5498.42354\n'
'Max relative difference among violations: 1.')
@@ -511,6 +564,9 @@ class TestArrayAlmostEqual(_GenericTest):
self._assert_func(a, b, decimal=9)
expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
'Mismatch at indices:\n'
' [1]: 5498.42354 (ACTUAL), 849.54345 (DESIRED)\n'
' [2]: 5498.42354 (ACTUAL), 0.0 (DESIRED)\n'
'Max absolute difference among violations: '
'5498.42354\n'
'Max relative difference among violations: 5.4722099')
@@ -519,6 +575,8 @@ class TestArrayAlmostEqual(_GenericTest):
a = [5498.42354, 0.00]
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
'Mismatch at index:\n'
' [1]: 5498.42354 (ACTUAL), 0.0 (DESIRED)\n'
'Max absolute difference among violations: '
'5498.42354\n'
'Max relative difference among violations: inf')
@@ -527,6 +585,8 @@ class TestArrayAlmostEqual(_GenericTest):
b = 0
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
'Mismatch at index:\n'
' [0]: 5498.42354 (ACTUAL), 0 (DESIRED)\n'
'Max absolute difference among violations: '
'5498.42354\n'
'Max relative difference among violations: inf')
@@ -555,6 +615,18 @@ class TestArrayAlmostEqual(_GenericTest):
assert_raises(AssertionError,
lambda: self._assert_func(a, b))
def test_complex_inf(self):
a = np.array([np.inf + 1.j, 2. + 1.j, 3. + 1.j])
b = a.copy()
self._assert_func(a, b)
b[1] = 3. + 1.j
expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
'Mismatch at index:\n'
' [1]: (2+1j) (ACTUAL), (3+1j) (DESIRED)\n'
'Max absolute difference among violations: 1.\n')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
self._assert_func(a, b)
def test_subclass(self):
a = np.array([[1., 2.], [3., 4.]])
b = np.ma.masked_array([[1., 2.], [0., 4.]],
@@ -603,6 +675,8 @@ class TestArrayAlmostEqual(_GenericTest):
all(z)
b = np.array([1., 202]).view(MyArray)
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
'Mismatch at index:\n'
' [1]: 2.0 (ACTUAL), 202.0 (DESIRED)\n'
'Max absolute difference among violations: 200.\n'
'Max relative difference among violations: 0.99009')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -629,8 +703,8 @@ class TestArrayAlmostEqual(_GenericTest):
class TestAlmostEqual(_GenericTest):
def setup_method(self):
self._assert_func = assert_almost_equal
def _assert_func(self, *args, **kwargs):
assert_almost_equal(*args, **kwargs)
def test_closeness(self):
# Note that in the course of time we ended up with
@@ -693,6 +767,10 @@ class TestAlmostEqual(_GenericTest):
# Test with a different amount of decimal digits
expected_msg = ('Mismatched elements: 3 / 3 (100%)\n'
'Mismatch at indices:\n'
' [0]: 1.00000000001 (ACTUAL), 1.00000000002 (DESIRED)\n'
' [1]: 2.00000000002 (ACTUAL), 2.00000000003 (DESIRED)\n'
' [2]: 3.00003 (ACTUAL), 3.00004 (DESIRED)\n'
'Max absolute difference among violations: 1.e-05\n'
'Max relative difference among violations: '
'3.33328889e-06\n'
@@ -708,6 +786,8 @@ class TestAlmostEqual(_GenericTest):
# differs. Note that we only check for the formatting of the arrays
# themselves.
expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
'Mismatch at index:\n'
' [2]: 3.00003 (ACTUAL), 3.00004 (DESIRED)\n'
'Max absolute difference among violations: 1.e-05\n'
'Max relative difference among violations: '
'3.33328889e-06\n'
@@ -720,6 +800,8 @@ class TestAlmostEqual(_GenericTest):
x = np.array([np.inf, 0])
y = np.array([np.inf, 1])
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
'Mismatch at index:\n'
' [1]: 0.0 (ACTUAL), 1.0 (DESIRED)\n'
'Max absolute difference among violations: 1.\n'
'Max relative difference among violations: 1.\n'
' ACTUAL: array([inf, 0.])\n'
@@ -731,6 +813,9 @@ class TestAlmostEqual(_GenericTest):
x = np.array([1, 2])
y = np.array([0, 0])
expected_msg = ('Mismatched elements: 2 / 2 (100%)\n'
'Mismatch at indices:\n'
' [0]: 1 (ACTUAL), 0 (DESIRED)\n'
' [1]: 2 (ACTUAL), 0 (DESIRED)\n'
'Max absolute difference among violations: 2\n'
'Max relative difference among violations: inf')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -742,6 +827,12 @@ class TestAlmostEqual(_GenericTest):
x = 2
y = np.ones(20)
expected_msg = ('Mismatched elements: 20 / 20 (100%)\n'
'First 5 mismatches are at indices:\n'
' [0]: 2 (ACTUAL), 1.0 (DESIRED)\n'
' [1]: 2 (ACTUAL), 1.0 (DESIRED)\n'
' [2]: 2 (ACTUAL), 1.0 (DESIRED)\n'
' [3]: 2 (ACTUAL), 1.0 (DESIRED)\n'
' [4]: 2 (ACTUAL), 1.0 (DESIRED)\n'
'Max absolute difference among violations: 1.\n'
'Max relative difference among violations: 1.')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -750,6 +841,12 @@ class TestAlmostEqual(_GenericTest):
y = 2
x = np.ones(20)
expected_msg = ('Mismatched elements: 20 / 20 (100%)\n'
'First 5 mismatches are at indices:\n'
' [0]: 1.0 (ACTUAL), 2 (DESIRED)\n'
' [1]: 1.0 (ACTUAL), 2 (DESIRED)\n'
' [2]: 1.0 (ACTUAL), 2 (DESIRED)\n'
' [3]: 1.0 (ACTUAL), 2 (DESIRED)\n'
' [4]: 1.0 (ACTUAL), 2 (DESIRED)\n'
'Max absolute difference among violations: 1.\n'
'Max relative difference among violations: 0.5')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -776,8 +873,8 @@ class TestAlmostEqual(_GenericTest):
class TestApproxEqual:
def setup_method(self):
self._assert_func = assert_approx_equal
def _assert_func(self, *args, **kwargs):
assert_approx_equal(*args, **kwargs)
def test_simple_0d_arrays(self):
x = np.array(1234.22)
@@ -819,8 +916,8 @@ class TestApproxEqual:
class TestArrayAssertLess:
def setup_method(self):
self._assert_func = assert_array_less
def _assert_func(self, *args, **kwargs):
assert_array_less(*args, **kwargs)
def test_simple_arrays(self):
x = np.array([1.1, 2.2])
@@ -838,6 +935,9 @@ class TestArrayAssertLess:
b = np.array([2, 4, 6, 8])
expected_msg = ('Mismatched elements: 2 / 4 (50%)\n'
'Mismatch at indices:\n'
' [2]: 6 (x), 6 (y)\n'
' [3]: 20 (x), 8 (y)\n'
'Max absolute difference among violations: 12\n'
'Max relative difference among violations: 1.5')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -849,6 +949,11 @@ class TestArrayAssertLess:
self._assert_func(x, y)
expected_msg = ('Mismatched elements: 4 / 4 (100%)\n'
'Mismatch at indices:\n'
' [0, 0]: 1.2 (x), 1.1 (y)\n'
' [0, 1]: 2.3 (x), 2.2 (y)\n'
' [1, 0]: 3.4 (x), 3.3 (y)\n'
' [1, 1]: 4.5 (x), 4.4 (y)\n'
'Max absolute difference among violations: 0.1\n'
'Max relative difference among violations: 0.09090909')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -867,6 +972,8 @@ class TestArrayAssertLess:
y[0, 0, 0] = 0
expected_msg = ('Mismatched elements: 1 / 8 (12.5%)\n'
'Mismatch at index:\n'
' [0, 0, 0]: 1.0 (x), 0.0 (y)\n'
'Max absolute difference among violations: 1.\n'
'Max relative difference among violations: inf')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -910,12 +1017,20 @@ class TestArrayAssertLess:
y = 999090.54
expected_msg = ('Mismatched elements: 1 / 12 (8.33%)\n'
'Mismatch at index:\n'
' [1, 1]: 999090.54 (x), 999090.54 (y)\n'
'Max absolute difference among violations: 0.\n'
'Max relative difference among violations: 0.')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
self._assert_func(x, y)
expected_msg = ('Mismatched elements: 12 / 12 (100%)\n'
'First 5 mismatches are at indices:\n'
' [0, 0]: 999090.54 (x), 3.4536 (y)\n'
' [0, 1]: 999090.54 (x), 2390.5436 (y)\n'
' [0, 2]: 999090.54 (x), 435.54657 (y)\n'
' [0, 3]: 999090.54 (x), 324525.4535 (y)\n'
' [1, 0]: 999090.54 (x), 5449.54 (y)\n'
'Max absolute difference among violations: '
'999087.0864\n'
'Max relative difference among violations: '
@@ -928,12 +1043,17 @@ class TestArrayAssertLess:
y = np.array(87654.)
expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
'Mismatch at index:\n'
' [0]: 546456.0 (x), 87654.0 (y)\n'
'Max absolute difference among violations: 458802.\n'
'Max relative difference among violations: 5.23423917')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
self._assert_func(x, y)
expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
'Mismatch at indices:\n'
' [1]: 87654.0 (x), 0.0 (y)\n'
' [2]: 87654.0 (x), 15.455 (y)\n'
'Max absolute difference among violations: 87654.\n'
'Max relative difference among violations: '
'5670.5626011')
@@ -943,12 +1063,18 @@ class TestArrayAssertLess:
y = 0
expected_msg = ('Mismatched elements: 3 / 3 (100%)\n'
'Mismatch at indices:\n'
' [0]: 546456.0 (x), 0 (y)\n'
' [1]: 0.0 (x), 0 (y)\n'
' [2]: 15.455 (x), 0 (y)\n'
'Max absolute difference among violations: 546456.\n'
'Max relative difference among violations: inf')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
self._assert_func(x, y)
expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
'Mismatch at index:\n'
' [1]: 0 (x), 0.0 (y)\n'
'Max absolute difference among violations: 0.\n'
'Max relative difference among violations: inf')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -1017,7 +1143,10 @@ class TestArrayAssertLess:
with pytest.raises(AssertionError):
self._assert_func(x, y.astype(np.float32), strict=True)
@pytest.mark.filterwarnings(
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
".*:DeprecationWarning")
@pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
class TestWarns:
def test_warn(self):
@@ -1134,12 +1263,16 @@ class TestAssertAllclose:
b = np.array([x, y, x, x])
c = np.array([x, y, x, z])
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
'Mismatch at index:\n'
' [3]: 0.001 (ACTUAL), 0.0 (DESIRED)\n'
'Max absolute difference among violations: 0.001\n'
'Max relative difference among violations: inf')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
assert_allclose(b, c)
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
'Mismatch at index:\n'
' [3]: 0.0 (ACTUAL), 0.001 (DESIRED)\n'
'Max absolute difference among violations: 0.001\n'
'Max relative difference among violations: 1.')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -1155,6 +1288,8 @@ class TestAssertAllclose:
b = np.array([1, 1, 1, 2])
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
'Mismatch at index:\n'
' [3]: 1 (ACTUAL), 2 (DESIRED)\n'
'Max absolute difference among violations: 1\n'
'Max relative difference among violations: 0.5')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
@@ -1166,11 +1301,21 @@ class TestAssertAllclose:
# Should not raise:
assert_allclose(a, b, equal_nan=True)
a = np.array([complex(np.nan, np.inf)])
b = np.array([complex(np.nan, np.inf)])
assert_allclose(a, b, equal_nan=True)
b = np.array([complex(np.nan, -np.inf)])
assert_allclose(a, b, equal_nan=True)
def test_not_equal_nan(self):
a = np.array([np.nan])
b = np.array([np.nan])
assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
a = np.array([complex(np.nan, np.inf)])
b = np.array([complex(np.nan, np.inf)])
assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
def test_equal_nan_default(self):
# Make sure equal_nan default behavior remains unchanged. (All
# of these functions use assert_array_compare under the hood.)
@@ -1219,6 +1364,33 @@ class TestAssertAllclose:
with pytest.raises(AssertionError):
assert_allclose(x, x.astype(np.float32), strict=True)
def test_infs(self):
a = np.array([np.inf])
b = np.array([np.inf])
assert_allclose(a, b)
b = np.array([3.])
expected_msg = 'inf location mismatch:'
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
assert_allclose(a, b)
b = np.array([-np.inf])
expected_msg = 'inf values mismatch:'
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
assert_allclose(a, b)
b = np.array([complex(np.inf, 1.)])
expected_msg = 'inf values mismatch:'
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
assert_allclose(a, b)
a = np.array([complex(np.inf, 1.)])
b = np.array([complex(np.inf, 1.)])
assert_allclose(a, b)
b = np.array([complex(np.inf, 2.)])
expected_msg = 'inf values mismatch:'
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
assert_allclose(a, b)
class TestArrayAlmostEqualNulp:
@@ -1589,6 +1761,7 @@ def _get_fresh_mod():
return my_mod
@pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
def test_clear_and_catch_warnings():
# Initial state of module, no warnings
my_mod = _get_fresh_mod()
@@ -1621,6 +1794,10 @@ def test_clear_and_catch_warnings():
assert_warn_len_equal(my_mod, 0)
@pytest.mark.filterwarnings(
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
".*:DeprecationWarning")
@pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
def test_suppress_warnings_module():
# Initial state of module, no warnings
my_mod = _get_fresh_mod()
@@ -1667,6 +1844,10 @@ def test_suppress_warnings_module():
assert_warn_len_equal(my_mod, 0)
@pytest.mark.filterwarnings(
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
".*:DeprecationWarning")
@pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
def test_suppress_warnings_type():
# Initial state of module, no warnings
my_mod = _get_fresh_mod()
@@ -1695,6 +1876,12 @@ def test_suppress_warnings_type():
assert_warn_len_equal(my_mod, 0)
@pytest.mark.filterwarnings(
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
".*:DeprecationWarning")
@pytest.mark.thread_unsafe(
reason="uses deprecated thread-unsafe warnings control utilities"
)
def test_suppress_warnings_decorate_no_record():
sup = suppress_warnings()
sup.filter(UserWarning)
@@ -1710,6 +1897,12 @@ def test_suppress_warnings_decorate_no_record():
assert_equal(len(w), 1)
@pytest.mark.filterwarnings(
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
".*:DeprecationWarning")
@pytest.mark.thread_unsafe(
reason="uses deprecated thread-unsafe warnings control utilities"
)
def test_suppress_warnings_record():
sup = suppress_warnings()
log1 = sup.record()
@@ -1747,9 +1940,16 @@ def test_suppress_warnings_record():
warnings.warn('Some warning')
warnings.warn('Some other warning')
assert_equal(len(sup2.log), 1)
assert_equal(len(sup.log), 1)
# includes a DeprecationWarning for suppress_warnings
assert_equal(len(sup.log), 2)
@pytest.mark.filterwarnings(
"ignore:.*NumPy warning suppression and assertion utilities are deprecated"
".*:DeprecationWarning")
@pytest.mark.thread_unsafe(
reason="uses deprecated thread-unsafe warnings control utilities"
)
def test_suppress_warnings_forwarding():
def warn_other_module():
# Apply along axis is implemented in python; stacklevel=2 means
@@ -1765,7 +1965,8 @@ def test_suppress_warnings_forwarding():
for i in range(2):
warnings.warn("Some warning")
assert_equal(len(sup.log), 2)
# includes a DeprecationWarning for suppress_warnings
assert_equal(len(sup.log), 3)
with suppress_warnings() as sup:
sup.record()
@@ -1774,7 +1975,8 @@ def test_suppress_warnings_forwarding():
warnings.warn("Some warning")
warnings.warn("Some warning")
assert_equal(len(sup.log), 2)
# includes a DeprecationWarning for suppress_warnings
assert_equal(len(sup.log), 3)
with suppress_warnings() as sup:
sup.record()
@@ -1784,7 +1986,8 @@ def test_suppress_warnings_forwarding():
warnings.warn("Some warning")
warn_other_module()
assert_equal(len(sup.log), 2)
# includes a DeprecationWarning for suppress_warnings
assert_equal(len(sup.log), 3)
with suppress_warnings() as sup:
sup.record()
@@ -1794,7 +1997,8 @@ def test_suppress_warnings_forwarding():
warnings.warn("Some other warning")
warn_other_module()
assert_equal(len(sup.log), 2)
# includes a DeprecationWarning for suppress_warnings
assert_equal(len(sup.log), 3)
def test_tempdir():
@@ -1835,6 +2039,7 @@ class my_cacw(clear_and_catch_warnings):
class_modules = (sys.modules[__name__],)
@pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
def test_clear_and_catch_warnings_inherit():
# Test can subclass and add default modules
my_mod = _get_fresh_mod()
@@ -1845,6 +2050,7 @@ def test_clear_and_catch_warnings_inherit():
@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
@pytest.mark.thread_unsafe(reason="garbage collector is global state")
class TestAssertNoGcCycles:
""" Test assert_no_gc_cycles """