chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.codegen.abstract_nodes import List
|
||||
|
||||
|
||||
def test_List():
|
||||
l = List(2, 3, 4)
|
||||
assert l == List(2, 3, 4)
|
||||
assert str(l) == "[2, 3, 4]"
|
||||
x, y, z = symbols('x y z')
|
||||
l = List(x**2,y**3,z**4)
|
||||
# contrary to python's built-in list, we can call e.g. "replace" on List.
|
||||
m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp)
|
||||
assert m == [x**2, y-3, z-4]
|
||||
hash(m)
|
||||
@@ -0,0 +1,180 @@
|
||||
import tempfile
|
||||
from sympy import log, Min, Max, sqrt
|
||||
from sympy.core.numbers import Float
|
||||
from sympy.core.symbol import Symbol, symbols
|
||||
from sympy.functions.elementary.trigonometric import cos
|
||||
from sympy.codegen.ast import Assignment, Raise, RuntimeError_, QuotedString
|
||||
from sympy.codegen.algorithms import newtons_method, newtons_method_function
|
||||
from sympy.codegen.cfunctions import expm1
|
||||
from sympy.codegen.fnodes import bind_C
|
||||
from sympy.codegen.futils import render_as_module as f_module
|
||||
from sympy.codegen.pyutils import render_as_module as py_module
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c, has_fortran
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, raises, skip_under_pyodide
|
||||
|
||||
cython = import_module('cython')
|
||||
wurlitzer = import_module('wurlitzer')
|
||||
|
||||
def test_newtons_method():
|
||||
x, dx, atol = symbols('x dx atol')
|
||||
expr = cos(x) - x**3
|
||||
algo = newtons_method(expr, x, atol, dx)
|
||||
assert algo.has(Assignment(dx, -expr/expr.diff(x)))
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_newtons_method_function__ccode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x)
|
||||
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
|
||||
compile_kw = {"std": 'c99'}
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton.c', ('#include <math.h>\n'
|
||||
'#include <stdio.h>\n') + ccode(func)),
|
||||
('_newton.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double)\n"
|
||||
"def py_newton(x):\n"
|
||||
" return newton(x)\n"))
|
||||
], build_dir=folder, compile_kwargs=compile_kw)
|
||||
assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_newtons_method_function__fcode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x, attrs=[bind_C(name='newton')])
|
||||
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_fortran():
|
||||
skip("No Fortran compiler found.")
|
||||
|
||||
f_mod = f_module([func], 'mod_newton')
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton.f90', f_mod),
|
||||
('_newton.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double*)\n"
|
||||
"def py_newton(double x):\n"
|
||||
" return newton(&x)\n"))
|
||||
], build_dir=folder)
|
||||
assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
def test_newtons_method_function__pycode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x)
|
||||
py_mod = py_module(func)
|
||||
namespace = {}
|
||||
exec(py_mod, namespace, namespace)
|
||||
res = eval('newton(0.5)', namespace)
|
||||
assert abs(res - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
@may_xfail
|
||||
@skip_under_pyodide("Emscripten does not support process spawning")
|
||||
def test_newtons_method_function__ccode_parameters():
|
||||
args = x, A, k, p = symbols('x A k p')
|
||||
expr = A*cos(k*x) - p*x**3
|
||||
raises(ValueError, lambda: newtons_method_function(expr, x))
|
||||
use_wurlitzer = wurlitzer
|
||||
|
||||
func = newtons_method_function(expr, x, args, debug=use_wurlitzer)
|
||||
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
|
||||
compile_kw = {"std": 'c99'}
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton_par.c', ('#include <math.h>\n'
|
||||
'#include <stdio.h>\n') + ccode(func)),
|
||||
('_newton_par.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double, double, double, double)\n"
|
||||
"def py_newton(x, A=1, k=1, p=1):\n"
|
||||
" return newton(x, A, k, p)\n"))
|
||||
], compile_kwargs=compile_kw, build_dir=folder)
|
||||
|
||||
if use_wurlitzer:
|
||||
with wurlitzer.pipes() as (out, err):
|
||||
result = mod.py_newton(0.5)
|
||||
else:
|
||||
result = mod.py_newton(0.5)
|
||||
|
||||
assert abs(result - 0.865474033102) < 1e-12
|
||||
|
||||
if not use_wurlitzer:
|
||||
skip("C-level output only tested when package 'wurlitzer' is available.")
|
||||
|
||||
out, err = out.read(), err.read()
|
||||
assert err == ''
|
||||
assert out == """\
|
||||
x= 0.5
|
||||
x= 1.1121 d_x= 0.61214
|
||||
x= 0.90967 d_x= -0.20247
|
||||
x= 0.86726 d_x= -0.042409
|
||||
x= 0.86548 d_x= -0.0017867
|
||||
x= 0.86547 d_x= -3.1022e-06
|
||||
x= 0.86547 d_x= -9.3421e-12
|
||||
x= 0.86547 d_x= 3.6902e-17
|
||||
""" # try to run tests with LC_ALL=C if this assertion fails
|
||||
|
||||
|
||||
def test_newtons_method_function__rtol_cse_nan():
|
||||
a, b, c, N_geo, N_tot = symbols('a b c N_geo N_tot', real=True, nonnegative=True)
|
||||
i = Symbol('i', integer=True, nonnegative=True)
|
||||
N_ari = N_tot - N_geo - 1
|
||||
delta_ari = (c-b)/N_ari
|
||||
ln_delta_geo = log(b) + log(-expm1((log(a)-log(b))/N_geo))
|
||||
eqb_log = ln_delta_geo - log(delta_ari)
|
||||
|
||||
def _clamp(low, expr, high):
|
||||
return Min(Max(low, expr), high)
|
||||
|
||||
meth_kw = {
|
||||
'clamped_newton': {'delta_fn': lambda e, x: _clamp(
|
||||
(sqrt(a*x)-x)*0.99,
|
||||
-e/e.diff(x),
|
||||
(sqrt(c*x)-x)*0.99
|
||||
)},
|
||||
'halley': {'delta_fn': lambda e, x: (-2*(e*e.diff(x))/(2*e.diff(x)**2 - e*e.diff(x, 2)))},
|
||||
'halley_alt': {'delta_fn': lambda e, x: (-e/e.diff(x)/(1-e/e.diff(x)*e.diff(x,2)/2/e.diff(x)))},
|
||||
}
|
||||
args = eqb_log, b
|
||||
for use_cse in [False, True]:
|
||||
kwargs = {
|
||||
'params': (b, a, c, N_geo, N_tot), 'itermax': 60, 'debug': True, 'cse': use_cse,
|
||||
'counter': i, 'atol': 1e-100, 'rtol': 2e-16, 'bounds': (a,c),
|
||||
'handle_nan': Raise(RuntimeError_(QuotedString("encountered NaN.")))
|
||||
}
|
||||
func = {k: newtons_method_function(*args, func_name=f"{k}_b", **dict(kwargs, **kw)) for k, kw in meth_kw.items()}
|
||||
py_mod = {k: py_module(v) for k, v in func.items()}
|
||||
namespace = {}
|
||||
root_find_b = {}
|
||||
for k, v in py_mod.items():
|
||||
ns = namespace[k] = {}
|
||||
exec(v, ns, ns)
|
||||
root_find_b[k] = ns[f'{k}_b']
|
||||
ref = Float('13.2261515064168768938151923226496')
|
||||
reftol = {'clamped_newton': 2e-16, 'halley': 2e-16, 'halley_alt': 3e-16}
|
||||
guess = 4.0
|
||||
for meth, func in root_find_b.items():
|
||||
result = func(guess, 1e-2, 1e2, 50, 100)
|
||||
req = ref*reftol[meth]
|
||||
if use_cse:
|
||||
req *= 2
|
||||
assert abs(result - ref) < req
|
||||
@@ -0,0 +1,58 @@
|
||||
# This file contains tests that exercise multiple AST nodes
|
||||
|
||||
import tempfile
|
||||
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, skip_under_pyodide
|
||||
from sympy.codegen.ast import (
|
||||
FunctionDefinition, FunctionPrototype, Variable, Pointer, real, Assignment,
|
||||
integer, CodeBlock, While
|
||||
)
|
||||
from sympy.codegen.cnodes import void, PreIncrement
|
||||
from sympy.codegen.cutils import render_as_source_file
|
||||
|
||||
cython = import_module('cython')
|
||||
np = import_module('numpy')
|
||||
|
||||
def _mk_func1():
|
||||
declars = n, inp, out = Variable('n', integer), Pointer('inp', real), Pointer('out', real)
|
||||
i = Variable('i', integer)
|
||||
whl = While(i<n, [Assignment(out[i], inp[i]), PreIncrement(i)])
|
||||
body = CodeBlock(i.as_Declaration(value=0), whl)
|
||||
return FunctionDefinition(void, 'our_test_function', declars, body)
|
||||
|
||||
|
||||
def _render_compile_import(funcdef, build_dir):
|
||||
code_str = render_as_source_file(funcdef, settings={"contract": False})
|
||||
declar = ccode(FunctionPrototype.from_FunctionDefinition(funcdef))
|
||||
return compile_link_import_strings([
|
||||
('our_test_func.c', code_str),
|
||||
('_our_test_func.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern {declar}\n"
|
||||
"def _{fname}({typ}[:] inp, {typ}[:] out):\n"
|
||||
" {fname}(inp.size, &inp[0], &out[0])").format(
|
||||
declar=declar, fname=funcdef.name, typ='double'
|
||||
))
|
||||
], build_dir=build_dir)
|
||||
|
||||
|
||||
@may_xfail
|
||||
@skip_under_pyodide("Emscripten does not support process spawning")
|
||||
def test_copying_function():
|
||||
if not np:
|
||||
skip("numpy not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
if not cython:
|
||||
skip("Cython not found.")
|
||||
|
||||
info = None
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = _render_compile_import(_mk_func1(), build_dir=folder)
|
||||
inp = np.arange(10.0)
|
||||
out = np.empty_like(inp)
|
||||
mod._our_test_function(inp, out)
|
||||
assert np.allclose(inp, out)
|
||||
@@ -0,0 +1,53 @@
|
||||
import math
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.codegen.rewriting import optimize
|
||||
from sympy.codegen.approximations import SumApprox, SeriesApprox
|
||||
|
||||
|
||||
def test_SumApprox_trivial():
|
||||
x = symbols('x')
|
||||
expr1 = 1 + x
|
||||
sum_approx = SumApprox(bounds={x: (-1e-20, 1e-20)}, reltol=1e-16)
|
||||
apx1 = optimize(expr1, [sum_approx])
|
||||
assert apx1 - 1 == 0
|
||||
|
||||
|
||||
def test_SumApprox_monotone_terms():
|
||||
x, y, z = symbols('x y z')
|
||||
expr1 = exp(z)*(x**2 + y**2 + 1)
|
||||
bnds1 = {x: (0, 1e-3), y: (100, 1000)}
|
||||
sum_approx_m2 = SumApprox(bounds=bnds1, reltol=1e-2)
|
||||
sum_approx_m5 = SumApprox(bounds=bnds1, reltol=1e-5)
|
||||
sum_approx_m11 = SumApprox(bounds=bnds1, reltol=1e-11)
|
||||
assert (optimize(expr1, [sum_approx_m2])/exp(z) - (y**2)).simplify() == 0
|
||||
assert (optimize(expr1, [sum_approx_m5])/exp(z) - (y**2 + 1)).simplify() == 0
|
||||
assert (optimize(expr1, [sum_approx_m11])/exp(z) - (y**2 + 1 + x**2)).simplify() == 0
|
||||
|
||||
|
||||
def test_SeriesApprox_trivial():
|
||||
x, z = symbols('x z')
|
||||
for factor in [1, exp(z)]:
|
||||
x = symbols('x')
|
||||
expr1 = exp(x)*factor
|
||||
bnds1 = {x: (-1, 1)}
|
||||
series_approx_50 = SeriesApprox(bounds=bnds1, reltol=0.50)
|
||||
series_approx_10 = SeriesApprox(bounds=bnds1, reltol=0.10)
|
||||
series_approx_05 = SeriesApprox(bounds=bnds1, reltol=0.05)
|
||||
c = (bnds1[x][1] + bnds1[x][0])/2 # 0.0
|
||||
f0 = math.exp(c) # 1.0
|
||||
|
||||
ref_50 = f0 + x + x**2/2
|
||||
ref_10 = f0 + x + x**2/2 + x**3/6
|
||||
ref_05 = f0 + x + x**2/2 + x**3/6 + x**4/24
|
||||
|
||||
res_50 = optimize(expr1, [series_approx_50])
|
||||
res_10 = optimize(expr1, [series_approx_10])
|
||||
res_05 = optimize(expr1, [series_approx_05])
|
||||
|
||||
assert (res_50/factor - ref_50).simplify() == 0
|
||||
assert (res_10/factor - ref_10).simplify() == 0
|
||||
assert (res_05/factor - ref_05).simplify() == 0
|
||||
|
||||
max_ord3 = SeriesApprox(bounds=bnds1, reltol=0.05, max_order=3)
|
||||
assert optimize(expr1, [max_ord3]) == expr1
|
||||
@@ -0,0 +1,661 @@
|
||||
import math
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.numbers import nan, oo, Float, Integer
|
||||
from sympy.core.relational import Lt
|
||||
from sympy.core.symbol import symbols, Symbol
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.sets.fancysets import Range
|
||||
from sympy.tensor.indexed import Idx, IndexedBase
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
|
||||
from sympy.codegen.ast import (
|
||||
Assignment, Attribute, aug_assign, CodeBlock, For, Type, Variable, Pointer, Declaration,
|
||||
AddAugmentedAssignment, SubAugmentedAssignment, MulAugmentedAssignment,
|
||||
DivAugmentedAssignment, ModAugmentedAssignment, value_const, pointer_const,
|
||||
integer, real, complex_, int8, uint8, float16 as f16, float32 as f32,
|
||||
float64 as f64, float80 as f80, float128 as f128, complex64 as c64, complex128 as c128,
|
||||
While, Scope, String, Print, QuotedString, FunctionPrototype, FunctionDefinition, Return,
|
||||
FunctionCall, untyped, IntBaseType, intc, Node, none, NoneToken, Token, Comment
|
||||
)
|
||||
|
||||
x, y, z, t, x0, x1, x2, a, b = symbols("x, y, z, t, x0, x1, x2, a, b")
|
||||
n = symbols("n", integer=True)
|
||||
A = MatrixSymbol('A', 3, 1)
|
||||
mat = Matrix([1, 2, 3])
|
||||
B = IndexedBase('B')
|
||||
i = Idx("i", n)
|
||||
A22 = MatrixSymbol('A22',2,2)
|
||||
B22 = MatrixSymbol('B22',2,2)
|
||||
|
||||
|
||||
def test_Assignment():
|
||||
# Here we just do things to show they don't error
|
||||
Assignment(x, y)
|
||||
Assignment(x, 0)
|
||||
Assignment(A, mat)
|
||||
Assignment(A[1,0], 0)
|
||||
Assignment(A[1,0], x)
|
||||
Assignment(B[i], x)
|
||||
Assignment(B[i], 0)
|
||||
a = Assignment(x, y)
|
||||
assert a.func(*a.args) == a
|
||||
assert a.op == ':='
|
||||
# Here we test things to show that they error
|
||||
# Matrix to scalar
|
||||
raises(ValueError, lambda: Assignment(B[i], A))
|
||||
raises(ValueError, lambda: Assignment(B[i], mat))
|
||||
raises(ValueError, lambda: Assignment(x, mat))
|
||||
raises(ValueError, lambda: Assignment(x, A))
|
||||
raises(ValueError, lambda: Assignment(A[1,0], mat))
|
||||
# Scalar to matrix
|
||||
raises(ValueError, lambda: Assignment(A, x))
|
||||
raises(ValueError, lambda: Assignment(A, 0))
|
||||
# Non-atomic lhs
|
||||
raises(TypeError, lambda: Assignment(mat, A))
|
||||
raises(TypeError, lambda: Assignment(0, x))
|
||||
raises(TypeError, lambda: Assignment(x*x, 1))
|
||||
raises(TypeError, lambda: Assignment(A + A, mat))
|
||||
raises(TypeError, lambda: Assignment(B, 0))
|
||||
|
||||
|
||||
def test_AugAssign():
|
||||
# Here we just do things to show they don't error
|
||||
aug_assign(x, '+', y)
|
||||
aug_assign(x, '+', 0)
|
||||
aug_assign(A, '+', mat)
|
||||
aug_assign(A[1, 0], '+', 0)
|
||||
aug_assign(A[1, 0], '+', x)
|
||||
aug_assign(B[i], '+', x)
|
||||
aug_assign(B[i], '+', 0)
|
||||
|
||||
# Check creation via aug_assign vs constructor
|
||||
for binop, cls in [
|
||||
('+', AddAugmentedAssignment),
|
||||
('-', SubAugmentedAssignment),
|
||||
('*', MulAugmentedAssignment),
|
||||
('/', DivAugmentedAssignment),
|
||||
('%', ModAugmentedAssignment),
|
||||
]:
|
||||
a = aug_assign(x, binop, y)
|
||||
b = cls(x, y)
|
||||
assert a.func(*a.args) == a == b
|
||||
assert a.binop == binop
|
||||
assert a.op == binop + '='
|
||||
|
||||
# Here we test things to show that they error
|
||||
# Matrix to scalar
|
||||
raises(ValueError, lambda: aug_assign(B[i], '+', A))
|
||||
raises(ValueError, lambda: aug_assign(B[i], '+', mat))
|
||||
raises(ValueError, lambda: aug_assign(x, '+', mat))
|
||||
raises(ValueError, lambda: aug_assign(x, '+', A))
|
||||
raises(ValueError, lambda: aug_assign(A[1, 0], '+', mat))
|
||||
# Scalar to matrix
|
||||
raises(ValueError, lambda: aug_assign(A, '+', x))
|
||||
raises(ValueError, lambda: aug_assign(A, '+', 0))
|
||||
# Non-atomic lhs
|
||||
raises(TypeError, lambda: aug_assign(mat, '+', A))
|
||||
raises(TypeError, lambda: aug_assign(0, '+', x))
|
||||
raises(TypeError, lambda: aug_assign(x * x, '+', 1))
|
||||
raises(TypeError, lambda: aug_assign(A + A, '+', mat))
|
||||
raises(TypeError, lambda: aug_assign(B, '+', 0))
|
||||
|
||||
|
||||
def test_Assignment_printing():
|
||||
assignment_classes = [
|
||||
Assignment,
|
||||
AddAugmentedAssignment,
|
||||
SubAugmentedAssignment,
|
||||
MulAugmentedAssignment,
|
||||
DivAugmentedAssignment,
|
||||
ModAugmentedAssignment,
|
||||
]
|
||||
pairs = [
|
||||
(x, 2 * y + 2),
|
||||
(B[i], x),
|
||||
(A22, B22),
|
||||
(A[0, 0], x),
|
||||
]
|
||||
|
||||
for cls in assignment_classes:
|
||||
for lhs, rhs in pairs:
|
||||
a = cls(lhs, rhs)
|
||||
assert repr(a) == '%s(%s, %s)' % (cls.__name__, repr(lhs), repr(rhs))
|
||||
|
||||
|
||||
def test_CodeBlock():
|
||||
c = CodeBlock(Assignment(x, 1), Assignment(y, x + 1))
|
||||
assert c.func(*c.args) == c
|
||||
|
||||
assert c.left_hand_sides == Tuple(x, y)
|
||||
assert c.right_hand_sides == Tuple(1, x + 1)
|
||||
|
||||
def test_CodeBlock_topological_sort():
|
||||
assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(t, x),
|
||||
Assignment(y, 2),
|
||||
]
|
||||
|
||||
ordered_assignments = [
|
||||
# Note that the unrelated z=1 and y=2 are kept in that order
|
||||
Assignment(z, 1),
|
||||
Assignment(y, 2),
|
||||
Assignment(x, y + z),
|
||||
Assignment(t, x),
|
||||
]
|
||||
c1 = CodeBlock.topological_sort(assignments)
|
||||
assert c1 == CodeBlock(*ordered_assignments)
|
||||
|
||||
# Cycle
|
||||
invalid_assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(y, x),
|
||||
Assignment(y, 2),
|
||||
]
|
||||
|
||||
raises(ValueError, lambda: CodeBlock.topological_sort(invalid_assignments))
|
||||
|
||||
# Free symbols
|
||||
free_assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, a * b),
|
||||
Assignment(t, x),
|
||||
Assignment(y, b + 3),
|
||||
]
|
||||
|
||||
free_assignments_ordered = [
|
||||
Assignment(z, a * b),
|
||||
Assignment(y, b + 3),
|
||||
Assignment(x, y + z),
|
||||
Assignment(t, x),
|
||||
]
|
||||
|
||||
c2 = CodeBlock.topological_sort(free_assignments)
|
||||
assert c2 == CodeBlock(*free_assignments_ordered)
|
||||
|
||||
def test_CodeBlock_free_symbols():
|
||||
c1 = CodeBlock(
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(t, x),
|
||||
Assignment(y, 2),
|
||||
)
|
||||
assert c1.free_symbols == set()
|
||||
|
||||
c2 = CodeBlock(
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, a * b),
|
||||
Assignment(t, x),
|
||||
Assignment(y, b + 3),
|
||||
)
|
||||
assert c2.free_symbols == {a, b}
|
||||
|
||||
def test_CodeBlock_cse():
|
||||
c1 = CodeBlock(
|
||||
Assignment(y, 1),
|
||||
Assignment(x, sin(y)),
|
||||
Assignment(z, sin(y)),
|
||||
Assignment(t, x*z),
|
||||
)
|
||||
assert c1.cse() == CodeBlock(
|
||||
Assignment(y, 1),
|
||||
Assignment(x0, sin(y)),
|
||||
Assignment(x, x0),
|
||||
Assignment(z, x0),
|
||||
Assignment(t, x*z),
|
||||
)
|
||||
|
||||
# Multiple assignments to same symbol not supported
|
||||
raises(NotImplementedError, lambda: CodeBlock(
|
||||
Assignment(x, 1),
|
||||
Assignment(y, 1), Assignment(y, 2)
|
||||
).cse())
|
||||
|
||||
# Check auto-generated symbols do not collide with existing ones
|
||||
c2 = CodeBlock(
|
||||
Assignment(x0, sin(y) + 1),
|
||||
Assignment(x1, 2 * sin(y)),
|
||||
Assignment(z, x * y),
|
||||
)
|
||||
assert c2.cse() == CodeBlock(
|
||||
Assignment(x2, sin(y)),
|
||||
Assignment(x0, x2 + 1),
|
||||
Assignment(x1, 2 * x2),
|
||||
Assignment(z, x * y),
|
||||
)
|
||||
|
||||
|
||||
def test_CodeBlock_cse__issue_14118():
|
||||
# see https://github.com/sympy/sympy/issues/14118
|
||||
c = CodeBlock(
|
||||
Assignment(A22, Matrix([[x, sin(y)],[3, 4]])),
|
||||
Assignment(B22, Matrix([[sin(y), 2*sin(y)], [sin(y)**2, 7]]))
|
||||
)
|
||||
assert c.cse() == CodeBlock(
|
||||
Assignment(x0, sin(y)),
|
||||
Assignment(A22, Matrix([[x, x0],[3, 4]])),
|
||||
Assignment(B22, Matrix([[x0, 2*x0], [x0**2, 7]]))
|
||||
)
|
||||
|
||||
def test_For():
|
||||
f = For(n, Range(0, 3), (Assignment(A[n, 0], x + n), aug_assign(x, '+', y)))
|
||||
f = For(n, (1, 2, 3, 4, 5), (Assignment(A[n, 0], x + n),))
|
||||
assert f.func(*f.args) == f
|
||||
raises(TypeError, lambda: For(n, x, (x + y,)))
|
||||
|
||||
|
||||
def test_none():
|
||||
assert none.is_Atom
|
||||
assert none == none
|
||||
class Foo(Token):
|
||||
pass
|
||||
foo = Foo()
|
||||
assert foo != none
|
||||
assert none == None
|
||||
assert none == NoneToken()
|
||||
assert none.func(*none.args) == none
|
||||
|
||||
|
||||
def test_String():
|
||||
st = String('foobar')
|
||||
assert st.is_Atom
|
||||
assert st == String('foobar')
|
||||
assert st.text == 'foobar'
|
||||
assert st.func(**st.kwargs()) == st
|
||||
assert st.func(*st.args) == st
|
||||
|
||||
|
||||
class Signifier(String):
|
||||
pass
|
||||
|
||||
si = Signifier('foobar')
|
||||
assert si != st
|
||||
assert si.text == st.text
|
||||
s = String('foo')
|
||||
assert str(s) == 'foo'
|
||||
assert repr(s) == "String('foo')"
|
||||
|
||||
def test_Comment():
|
||||
c = Comment('foobar')
|
||||
assert c.text == 'foobar'
|
||||
assert str(c) == 'foobar'
|
||||
|
||||
def test_Node():
|
||||
n = Node()
|
||||
assert n == Node()
|
||||
assert n.func(*n.args) == n
|
||||
|
||||
|
||||
def test_Type():
|
||||
t = Type('MyType')
|
||||
assert len(t.args) == 1
|
||||
assert t.name == String('MyType')
|
||||
assert str(t) == 'MyType'
|
||||
assert repr(t) == "Type(String('MyType'))"
|
||||
assert Type(t) == t
|
||||
assert t.func(*t.args) == t
|
||||
t1 = Type('t1')
|
||||
t2 = Type('t2')
|
||||
assert t1 != t2
|
||||
assert t1 == t1 and t2 == t2
|
||||
t1b = Type('t1')
|
||||
assert t1 == t1b
|
||||
assert t2 != t1b
|
||||
|
||||
|
||||
def test_Type__from_expr():
|
||||
assert Type.from_expr(i) == integer
|
||||
u = symbols('u', real=True)
|
||||
assert Type.from_expr(u) == real
|
||||
assert Type.from_expr(n) == integer
|
||||
assert Type.from_expr(3) == integer
|
||||
assert Type.from_expr(3.0) == real
|
||||
assert Type.from_expr(3+1j) == complex_
|
||||
raises(ValueError, lambda: Type.from_expr(sum))
|
||||
|
||||
|
||||
def test_Type__cast_check__integers():
|
||||
# Rounding
|
||||
raises(ValueError, lambda: integer.cast_check(3.5))
|
||||
assert integer.cast_check('3') == 3
|
||||
assert integer.cast_check(Float('3.0000000000000000000')) == 3
|
||||
assert integer.cast_check(Float('3.0000000000000000001')) == 3 # unintuitive maybe?
|
||||
|
||||
# Range
|
||||
assert int8.cast_check(127.0) == 127
|
||||
raises(ValueError, lambda: int8.cast_check(128))
|
||||
assert int8.cast_check(-128) == -128
|
||||
raises(ValueError, lambda: int8.cast_check(-129))
|
||||
|
||||
assert uint8.cast_check(0) == 0
|
||||
assert uint8.cast_check(128) == 128
|
||||
raises(ValueError, lambda: uint8.cast_check(256.0))
|
||||
raises(ValueError, lambda: uint8.cast_check(-1))
|
||||
|
||||
def test_Attribute():
|
||||
noexcept = Attribute('noexcept')
|
||||
assert noexcept == Attribute('noexcept')
|
||||
alignas16 = Attribute('alignas', [16])
|
||||
alignas32 = Attribute('alignas', [32])
|
||||
assert alignas16 != alignas32
|
||||
assert alignas16.func(*alignas16.args) == alignas16
|
||||
|
||||
|
||||
def test_Variable():
|
||||
v = Variable(x, type=real)
|
||||
assert v == Variable(v)
|
||||
assert v == Variable('x', type=real)
|
||||
assert v.symbol == x
|
||||
assert v.type == real
|
||||
assert value_const not in v.attrs
|
||||
assert v.func(*v.args) == v
|
||||
assert str(v) == 'Variable(x, type=real)'
|
||||
|
||||
w = Variable(y, f32, attrs={value_const})
|
||||
assert w.symbol == y
|
||||
assert w.type == f32
|
||||
assert value_const in w.attrs
|
||||
assert w.func(*w.args) == w
|
||||
|
||||
v_n = Variable(n, type=Type.from_expr(n))
|
||||
assert v_n.type == integer
|
||||
assert v_n.func(*v_n.args) == v_n
|
||||
v_i = Variable(i, type=Type.from_expr(n))
|
||||
assert v_i.type == integer
|
||||
assert v_i != v_n
|
||||
|
||||
a_i = Variable.deduced(i)
|
||||
assert a_i.type == integer
|
||||
assert Variable.deduced(Symbol('x', real=True)).type == real
|
||||
assert a_i.func(*a_i.args) == a_i
|
||||
|
||||
v_n2 = Variable.deduced(n, value=3.5, cast_check=False)
|
||||
assert v_n2.func(*v_n2.args) == v_n2
|
||||
assert abs(v_n2.value - 3.5) < 1e-15
|
||||
raises(ValueError, lambda: Variable.deduced(n, value=3.5, cast_check=True))
|
||||
|
||||
v_n3 = Variable.deduced(n)
|
||||
assert v_n3.type == integer
|
||||
assert str(v_n3) == 'Variable(n, type=integer)'
|
||||
assert Variable.deduced(z, value=3).type == integer
|
||||
assert Variable.deduced(z, value=3.0).type == real
|
||||
assert Variable.deduced(z, value=3.0+1j).type == complex_
|
||||
|
||||
|
||||
def test_Pointer():
|
||||
p = Pointer(x)
|
||||
assert p.symbol == x
|
||||
assert p.type == untyped
|
||||
assert value_const not in p.attrs
|
||||
assert pointer_const not in p.attrs
|
||||
assert p.func(*p.args) == p
|
||||
|
||||
u = symbols('u', real=True)
|
||||
pu = Pointer(u, type=Type.from_expr(u), attrs={value_const, pointer_const})
|
||||
assert pu.symbol is u
|
||||
assert pu.type == real
|
||||
assert value_const in pu.attrs
|
||||
assert pointer_const in pu.attrs
|
||||
assert pu.func(*pu.args) == pu
|
||||
|
||||
i = symbols('i', integer=True)
|
||||
deref = pu[i]
|
||||
assert deref.indices == (i,)
|
||||
|
||||
|
||||
def test_Declaration():
|
||||
u = symbols('u', real=True)
|
||||
vu = Variable(u, type=Type.from_expr(u))
|
||||
assert Declaration(vu).variable.type == real
|
||||
vn = Variable(n, type=Type.from_expr(n))
|
||||
assert Declaration(vn).variable.type == integer
|
||||
|
||||
# PR 19107, does not allow comparison between expressions and Basic
|
||||
# lt = StrictLessThan(vu, vn)
|
||||
# assert isinstance(lt, StrictLessThan)
|
||||
|
||||
vuc = Variable(u, Type.from_expr(u), value=3.0, attrs={value_const})
|
||||
assert value_const in vuc.attrs
|
||||
assert pointer_const not in vuc.attrs
|
||||
decl = Declaration(vuc)
|
||||
assert decl.variable == vuc
|
||||
assert isinstance(decl.variable.value, Float)
|
||||
assert decl.variable.value == 3.0
|
||||
assert decl.func(*decl.args) == decl
|
||||
assert vuc.as_Declaration() == decl
|
||||
assert vuc.as_Declaration(value=None, attrs=None) == Declaration(vu)
|
||||
|
||||
vy = Variable(y, type=integer, value=3)
|
||||
decl2 = Declaration(vy)
|
||||
assert decl2.variable == vy
|
||||
assert decl2.variable.value == Integer(3)
|
||||
|
||||
vi = Variable(i, type=Type.from_expr(i), value=3.0)
|
||||
decl3 = Declaration(vi)
|
||||
assert decl3.variable.type == integer
|
||||
assert decl3.variable.value == 3.0
|
||||
|
||||
raises(ValueError, lambda: Declaration(vi, 42))
|
||||
|
||||
|
||||
def test_IntBaseType():
|
||||
assert intc.name == String('intc')
|
||||
assert intc.args == (intc.name,)
|
||||
assert str(IntBaseType('a').name) == 'a'
|
||||
|
||||
|
||||
def test_FloatType():
|
||||
assert f16.dig == 3
|
||||
assert f32.dig == 6
|
||||
assert f64.dig == 15
|
||||
assert f80.dig == 18
|
||||
assert f128.dig == 33
|
||||
|
||||
assert f16.decimal_dig == 5
|
||||
assert f32.decimal_dig == 9
|
||||
assert f64.decimal_dig == 17
|
||||
assert f80.decimal_dig == 21
|
||||
assert f128.decimal_dig == 36
|
||||
|
||||
assert f16.max_exponent == 16
|
||||
assert f32.max_exponent == 128
|
||||
assert f64.max_exponent == 1024
|
||||
assert f80.max_exponent == 16384
|
||||
assert f128.max_exponent == 16384
|
||||
|
||||
assert f16.min_exponent == -13
|
||||
assert f32.min_exponent == -125
|
||||
assert f64.min_exponent == -1021
|
||||
assert f80.min_exponent == -16381
|
||||
assert f128.min_exponent == -16381
|
||||
|
||||
assert abs(f16.eps / Float('0.00097656', precision=16) - 1) < 0.1*10**-f16.dig
|
||||
assert abs(f32.eps / Float('1.1920929e-07', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.eps / Float('2.2204460492503131e-16', precision=64) - 1) < 0.1*10**-f64.dig
|
||||
assert abs(f80.eps / Float('1.08420217248550443401e-19', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.eps / Float(' 1.92592994438723585305597794258492732e-34', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
assert abs(f16.max / Float('65504', precision=16) - 1) < .1*10**-f16.dig
|
||||
assert abs(f32.max / Float('3.40282347e+38', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.max / Float('1.79769313486231571e+308', precision=64) - 1) < 0.1*10**-f64.dig # cf. np.finfo(np.float64).max
|
||||
assert abs(f80.max / Float('1.18973149535723176502e+4932', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.max / Float('1.18973149535723176508575932662800702e+4932', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
# cf. np.finfo(np.float32).tiny
|
||||
assert abs(f16.tiny / Float('6.1035e-05', precision=16) - 1) < 0.1*10**-f16.dig
|
||||
assert abs(f32.tiny / Float('1.17549435e-38', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.tiny / Float('2.22507385850720138e-308', precision=64) - 1) < 0.1*10**-f64.dig
|
||||
assert abs(f80.tiny / Float('3.36210314311209350626e-4932', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.tiny / Float('3.3621031431120935062626778173217526e-4932', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
assert f64.cast_check(0.5) == Float(0.5, 17)
|
||||
assert abs(f64.cast_check(3.7) - 3.7) < 3e-17
|
||||
assert isinstance(f64.cast_check(3), (Float, float))
|
||||
|
||||
assert f64.cast_nocheck(oo) == float('inf')
|
||||
assert f64.cast_nocheck(-oo) == float('-inf')
|
||||
assert f64.cast_nocheck(float(oo)) == float('inf')
|
||||
assert f64.cast_nocheck(float(-oo)) == float('-inf')
|
||||
assert math.isnan(f64.cast_nocheck(nan))
|
||||
|
||||
assert f32 != f64
|
||||
assert f64 == f64.func(*f64.args)
|
||||
|
||||
|
||||
def test_Type__cast_check__floating_point():
|
||||
raises(ValueError, lambda: f32.cast_check(123.45678949))
|
||||
raises(ValueError, lambda: f32.cast_check(12.345678949))
|
||||
raises(ValueError, lambda: f32.cast_check(1.2345678949))
|
||||
raises(ValueError, lambda: f32.cast_check(.12345678949))
|
||||
assert abs(123.456789049 - f32.cast_check(123.456789049) - 4.9e-8) < 1e-8
|
||||
assert abs(0.12345678904 - f32.cast_check(0.12345678904) - 4e-11) < 1e-11
|
||||
|
||||
dcm21 = Float('0.123456789012345670499') # 21 decimals
|
||||
assert abs(dcm21 - f64.cast_check(dcm21) - 4.99e-19) < 1e-19
|
||||
|
||||
f80.cast_check(Float('0.12345678901234567890103', precision=88))
|
||||
raises(ValueError, lambda: f80.cast_check(Float('0.12345678901234567890149', precision=88)))
|
||||
|
||||
v10 = 12345.67894
|
||||
raises(ValueError, lambda: f32.cast_check(v10))
|
||||
assert abs(Float(str(v10), precision=64+8) - f64.cast_check(v10)) < v10*1e-16
|
||||
|
||||
assert abs(f32.cast_check(2147483647) - 2147483650) < 1
|
||||
|
||||
|
||||
def test_Type__cast_check__complex_floating_point():
|
||||
val9_11 = 123.456789049 + 0.123456789049j
|
||||
raises(ValueError, lambda: c64.cast_check(.12345678949 + .12345678949j))
|
||||
assert abs(val9_11 - c64.cast_check(val9_11) - 4.9e-8) < 1e-8
|
||||
|
||||
dcm21 = Float('0.123456789012345670499') + 1e-20j # 21 decimals
|
||||
assert abs(dcm21 - c128.cast_check(dcm21) - 4.99e-19) < 1e-19
|
||||
v19 = Float('0.1234567890123456749') + 1j*Float('0.1234567890123456749')
|
||||
raises(ValueError, lambda: c128.cast_check(v19))
|
||||
|
||||
|
||||
def test_While():
|
||||
xpp = AddAugmentedAssignment(x, 1)
|
||||
whl1 = While(x < 2, [xpp])
|
||||
assert whl1.condition.args[0] == x
|
||||
assert whl1.condition.args[1] == 2
|
||||
assert whl1.condition == Lt(x, 2, evaluate=False)
|
||||
assert whl1.body.args == (xpp,)
|
||||
assert whl1.func(*whl1.args) == whl1
|
||||
|
||||
cblk = CodeBlock(AddAugmentedAssignment(x, 1))
|
||||
whl2 = While(x < 2, cblk)
|
||||
assert whl1 == whl2
|
||||
assert whl1 != While(x < 3, [xpp])
|
||||
|
||||
|
||||
def test_Scope():
|
||||
assign = Assignment(x, y)
|
||||
incr = AddAugmentedAssignment(x, 1)
|
||||
scp = Scope([assign, incr])
|
||||
cblk = CodeBlock(assign, incr)
|
||||
assert scp.body == cblk
|
||||
assert scp == Scope(cblk)
|
||||
assert scp != Scope([incr, assign])
|
||||
assert scp.func(*scp.args) == scp
|
||||
|
||||
|
||||
def test_Print():
|
||||
fmt = "%d %.3f"
|
||||
ps = Print([n, x], fmt)
|
||||
assert str(ps.format_string) == fmt
|
||||
assert ps.print_args == Tuple(n, x)
|
||||
assert ps.args == (Tuple(n, x), QuotedString(fmt), none)
|
||||
assert ps == Print((n, x), fmt)
|
||||
assert ps != Print([x, n], fmt)
|
||||
assert ps.func(*ps.args) == ps
|
||||
|
||||
ps2 = Print([n, x])
|
||||
assert ps2 == Print([n, x])
|
||||
assert ps2 != ps
|
||||
assert ps2.format_string == None
|
||||
|
||||
|
||||
def test_FunctionPrototype_and_FunctionDefinition():
|
||||
vx = Variable(x, type=real)
|
||||
vn = Variable(n, type=integer)
|
||||
fp1 = FunctionPrototype(real, 'power', [vx, vn])
|
||||
assert fp1.return_type == real
|
||||
assert fp1.name == String('power')
|
||||
assert fp1.parameters == Tuple(vx, vn)
|
||||
assert fp1 == FunctionPrototype(real, 'power', [vx, vn])
|
||||
assert fp1 != FunctionPrototype(real, 'power', [vn, vx])
|
||||
assert fp1.func(*fp1.args) == fp1
|
||||
|
||||
|
||||
body = [Assignment(x, x**n), Return(x)]
|
||||
fd1 = FunctionDefinition(real, 'power', [vx, vn], body)
|
||||
assert fd1.return_type == real
|
||||
assert str(fd1.name) == 'power'
|
||||
assert fd1.parameters == Tuple(vx, vn)
|
||||
assert fd1.body == CodeBlock(*body)
|
||||
assert fd1 == FunctionDefinition(real, 'power', [vx, vn], body)
|
||||
assert fd1 != FunctionDefinition(real, 'power', [vx, vn], body[::-1])
|
||||
assert fd1.func(*fd1.args) == fd1
|
||||
|
||||
fp2 = FunctionPrototype.from_FunctionDefinition(fd1)
|
||||
assert fp2 == fp1
|
||||
|
||||
fd2 = FunctionDefinition.from_FunctionPrototype(fp1, body)
|
||||
assert fd2 == fd1
|
||||
|
||||
|
||||
def test_Return():
|
||||
rs = Return(x)
|
||||
assert rs.args == (x,)
|
||||
assert rs == Return(x)
|
||||
assert rs != Return(y)
|
||||
assert rs.func(*rs.args) == rs
|
||||
|
||||
|
||||
def test_FunctionCall():
|
||||
fc = FunctionCall('power', (x, 3))
|
||||
assert fc.function_args[0] == x
|
||||
assert fc.function_args[1] == 3
|
||||
assert len(fc.function_args) == 2
|
||||
assert isinstance(fc.function_args[1], Integer)
|
||||
assert fc == FunctionCall('power', (x, 3))
|
||||
assert fc != FunctionCall('power', (3, x))
|
||||
assert fc != FunctionCall('Power', (x, 3))
|
||||
assert fc.func(*fc.args) == fc
|
||||
|
||||
fc2 = FunctionCall('fma', [2, 3, 4])
|
||||
assert len(fc2.function_args) == 3
|
||||
assert fc2.function_args[0] == 2
|
||||
assert fc2.function_args[1] == 3
|
||||
assert fc2.function_args[2] == 4
|
||||
assert str(fc2) in ( # not sure if QuotedString is a better default...
|
||||
'FunctionCall(fma, function_args=(2, 3, 4))',
|
||||
'FunctionCall("fma", function_args=(2, 3, 4))',
|
||||
)
|
||||
|
||||
def test_ast_replace():
|
||||
x = Variable('x', real)
|
||||
y = Variable('y', real)
|
||||
n = Variable('n', integer)
|
||||
|
||||
pwer = FunctionDefinition(real, 'pwer', [x, n], [pow(x.symbol, n.symbol)])
|
||||
pname = pwer.name
|
||||
pcall = FunctionCall('pwer', [y, 3])
|
||||
|
||||
tree1 = CodeBlock(pwer, pcall)
|
||||
assert str(tree1.args[0].name) == 'pwer'
|
||||
assert str(tree1.args[1].name) == 'pwer'
|
||||
for a, b in zip(tree1, [pwer, pcall]):
|
||||
assert a == b
|
||||
|
||||
tree2 = tree1.replace(pname, String('power'))
|
||||
assert str(tree1.args[0].name) == 'pwer'
|
||||
assert str(tree1.args[1].name) == 'pwer'
|
||||
assert str(tree2.args[0].name) == 'power'
|
||||
assert str(tree2.args[1].name) == 'power'
|
||||
@@ -0,0 +1,186 @@
|
||||
from sympy.core.numbers import (Rational, pi)
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.codegen.cfunctions import (
|
||||
expm1, log1p, exp2, log2, fma, log10, Sqrt, Cbrt, hypot, isnan, isinf
|
||||
)
|
||||
from sympy.core.function import expand_log
|
||||
|
||||
|
||||
def test_expm1():
|
||||
# Eval
|
||||
assert expm1(0) == 0
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand and rewrite
|
||||
assert expm1(x).expand(func=True) - exp(x) == -1
|
||||
assert expm1(x).rewrite('tractable') - exp(x) == -1
|
||||
assert expm1(x).rewrite('exp') - exp(x) == -1
|
||||
|
||||
# Precision
|
||||
assert not ((exp(1e-10).evalf() - 1) - 1e-10 - 5e-21) < 1e-22 # for comparison
|
||||
assert abs(expm1(1e-10).evalf() - 1e-10 - 5e-21) < 1e-22
|
||||
|
||||
# Properties
|
||||
assert expm1(x).is_real
|
||||
assert expm1(x).is_finite
|
||||
|
||||
# Diff
|
||||
assert expm1(42*x).diff(x) - 42*exp(42*x) == 0
|
||||
assert expm1(42*x).diff(x) - expm1(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_log1p():
|
||||
# Eval
|
||||
assert log1p(0) == 0
|
||||
d = S(10)
|
||||
assert expand_log(log1p(d**-1000) - log(d**1000 + 1) + log(d**1000)) == 0
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand and rewrite
|
||||
assert log1p(x).expand(func=True) - log(x + 1) == 0
|
||||
assert log1p(x).rewrite('tractable') - log(x + 1) == 0
|
||||
assert log1p(x).rewrite('log') - log(x + 1) == 0
|
||||
|
||||
# Precision
|
||||
assert not abs(log(1e-99 + 1).evalf() - 1e-99) < 1e-100 # for comparison
|
||||
assert abs(expand_log(log1p(1e-99)).evalf() - 1e-99) < 1e-100
|
||||
|
||||
# Properties
|
||||
assert log1p(-2**Rational(-1, 2)).is_real
|
||||
|
||||
assert not log1p(-1).is_finite
|
||||
assert log1p(pi).is_finite
|
||||
|
||||
assert not log1p(x).is_positive
|
||||
assert log1p(Symbol('y', positive=True)).is_positive
|
||||
|
||||
assert not log1p(x).is_zero
|
||||
assert log1p(Symbol('z', zero=True)).is_zero
|
||||
|
||||
assert not log1p(x).is_nonnegative
|
||||
assert log1p(Symbol('o', nonnegative=True)).is_nonnegative
|
||||
|
||||
# Diff
|
||||
assert log1p(42*x).diff(x) - 42/(42*x + 1) == 0
|
||||
assert log1p(42*x).diff(x) - log1p(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_exp2():
|
||||
# Eval
|
||||
assert exp2(2) == 4
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand
|
||||
assert exp2(x).expand(func=True) - 2**x == 0
|
||||
|
||||
# Diff
|
||||
assert exp2(42*x).diff(x) - 42*exp2(42*x)*log(2) == 0
|
||||
assert exp2(42*x).diff(x) - exp2(42*x).diff(x) == 0
|
||||
|
||||
|
||||
def test_log2():
|
||||
# Eval
|
||||
assert log2(8) == 3
|
||||
assert log2(pi) != log(pi)/log(2) # log2 should *save* (CPU) instructions
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
assert log2(x) != log(x)/log(2)
|
||||
assert log2(2**x) == x
|
||||
|
||||
# Expand
|
||||
assert log2(x).expand(func=True) - log(x)/log(2) == 0
|
||||
|
||||
# Diff
|
||||
assert log2(42*x).diff() - 1/(log(2)*x) == 0
|
||||
assert log2(42*x).diff() - log2(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_fma():
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
# Expand
|
||||
assert fma(x, y, z).expand(func=True) - x*y - z == 0
|
||||
|
||||
expr = fma(17*x, 42*y, 101*z)
|
||||
|
||||
# Diff
|
||||
assert expr.diff(x) - expr.expand(func=True).diff(x) == 0
|
||||
assert expr.diff(y) - expr.expand(func=True).diff(y) == 0
|
||||
assert expr.diff(z) - expr.expand(func=True).diff(z) == 0
|
||||
|
||||
assert expr.diff(x) - 17*42*y == 0
|
||||
assert expr.diff(y) - 17*42*x == 0
|
||||
assert expr.diff(z) - 101 == 0
|
||||
|
||||
|
||||
def test_log10():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert log10(x).expand(func=True) - log(x)/log(10) == 0
|
||||
|
||||
# Diff
|
||||
assert log10(42*x).diff(x) - 1/(log(10)*x) == 0
|
||||
assert log10(42*x).diff(x) - log10(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_Cbrt():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert Cbrt(x).expand(func=True) - x**Rational(1, 3) == 0
|
||||
|
||||
# Diff
|
||||
assert Cbrt(42*x).diff(x) - 42*(42*x)**(Rational(1, 3) - 1)/3 == 0
|
||||
assert Cbrt(42*x).diff(x) - Cbrt(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_Sqrt():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert Sqrt(x).expand(func=True) - x**S.Half == 0
|
||||
|
||||
# Diff
|
||||
assert Sqrt(42*x).diff(x) - 42*(42*x)**(S.Half - 1)/2 == 0
|
||||
assert Sqrt(42*x).diff(x) - Sqrt(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_hypot():
|
||||
x, y = symbols('x y')
|
||||
|
||||
# Expand
|
||||
assert hypot(x, y).expand(func=True) - (x**2 + y**2)**S.Half == 0
|
||||
|
||||
# Diff
|
||||
assert hypot(17*x, 42*y).diff(x).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(x) == 0
|
||||
assert hypot(17*x, 42*y).diff(y).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(y) == 0
|
||||
|
||||
assert hypot(17*x, 42*y).diff(x).expand(func=True) - 2*17*17*x*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0
|
||||
assert hypot(17*x, 42*y).diff(y).expand(func=True) - 2*42*42*y*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0
|
||||
|
||||
|
||||
def test_isnan_isinf():
|
||||
x = Symbol('x')
|
||||
|
||||
# isinf
|
||||
assert isinf(+S.Infinity) == True
|
||||
assert isinf(-S.Infinity) == True
|
||||
assert isinf(S.Pi) == False
|
||||
isinfx = isinf(x)
|
||||
assert isinfx not in (False, True)
|
||||
assert isinfx.func is isinf
|
||||
assert isinfx.args == (x,)
|
||||
|
||||
# isnan
|
||||
assert isnan(S.NaN) == True
|
||||
assert isnan(S.Pi) == False
|
||||
isnanx = isnan(x)
|
||||
assert isnanx not in (False, True)
|
||||
assert isnanx.func is isnan
|
||||
assert isnanx.args == (x,)
|
||||
@@ -0,0 +1,112 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.codegen.ast import Declaration, Variable, float64, int64, String, CodeBlock
|
||||
from sympy.codegen.cnodes import (
|
||||
alignof, CommaOperator, goto, Label, PreDecrement, PostDecrement, PreIncrement, PostIncrement,
|
||||
sizeof, union, struct
|
||||
)
|
||||
|
||||
x, y = symbols('x y')
|
||||
|
||||
|
||||
def test_alignof():
|
||||
ax = alignof(x)
|
||||
assert ccode(ax) == 'alignof(x)'
|
||||
assert ax.func(*ax.args) == ax
|
||||
|
||||
|
||||
def test_CommaOperator():
|
||||
expr = CommaOperator(PreIncrement(x), 2*x)
|
||||
assert ccode(expr) == '(++(x), 2*x)'
|
||||
assert expr.func(*expr.args) == expr
|
||||
|
||||
|
||||
def test_goto_Label():
|
||||
s = 'early_exit'
|
||||
g = goto(s)
|
||||
assert g.func(*g.args) == g
|
||||
assert g != goto('foobar')
|
||||
assert ccode(g) == 'goto early_exit'
|
||||
|
||||
l1 = Label(s)
|
||||
assert ccode(l1) == 'early_exit:'
|
||||
assert l1 == Label('early_exit')
|
||||
assert l1 != Label('foobar')
|
||||
|
||||
body = [PreIncrement(x)]
|
||||
l2 = Label(s, body)
|
||||
assert l2.name == String("early_exit")
|
||||
assert l2.body == CodeBlock(PreIncrement(x))
|
||||
assert ccode(l2) == ("early_exit:\n"
|
||||
"++(x);")
|
||||
|
||||
body = [PreIncrement(x), PreDecrement(y)]
|
||||
l2 = Label(s, body)
|
||||
assert l2.name == String("early_exit")
|
||||
assert l2.body == CodeBlock(PreIncrement(x), PreDecrement(y))
|
||||
assert ccode(l2) == ("early_exit:\n"
|
||||
"{\n ++(x);\n --(y);\n}")
|
||||
|
||||
|
||||
def test_PreDecrement():
|
||||
p = PreDecrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '--(x)'
|
||||
|
||||
|
||||
def test_PostDecrement():
|
||||
p = PostDecrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '(x)--'
|
||||
|
||||
|
||||
def test_PreIncrement():
|
||||
p = PreIncrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '++(x)'
|
||||
|
||||
|
||||
def test_PostIncrement():
|
||||
p = PostIncrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '(x)++'
|
||||
|
||||
|
||||
def test_sizeof():
|
||||
typename = 'unsigned int'
|
||||
sz = sizeof(typename)
|
||||
assert ccode(sz) == 'sizeof(%s)' % typename
|
||||
assert sz.func(*sz.args) == sz
|
||||
assert not sz.is_Atom
|
||||
assert sz.atoms() == {String('unsigned int'), String('sizeof')}
|
||||
|
||||
|
||||
def test_struct():
|
||||
vx, vy = Variable(x, type=float64), Variable(y, type=float64)
|
||||
s = struct('vec2', [vx, vy])
|
||||
assert s.func(*s.args) == s
|
||||
assert s == struct('vec2', (vx, vy))
|
||||
assert s != struct('vec2', (vy, vx))
|
||||
assert str(s.name) == 'vec2'
|
||||
assert len(s.declarations) == 2
|
||||
assert all(isinstance(arg, Declaration) for arg in s.declarations)
|
||||
assert ccode(s) == (
|
||||
"struct vec2 {\n"
|
||||
" double x;\n"
|
||||
" double y;\n"
|
||||
"}")
|
||||
|
||||
|
||||
def test_union():
|
||||
vx, vy = Variable(x, type=float64), Variable(y, type=int64)
|
||||
u = union('dualuse', [vx, vy])
|
||||
assert u.func(*u.args) == u
|
||||
assert u == union('dualuse', (vx, vy))
|
||||
assert str(u.name) == 'dualuse'
|
||||
assert len(u.declarations) == 2
|
||||
assert all(isinstance(arg, Declaration) for arg in u.declarations)
|
||||
assert ccode(u) == (
|
||||
"union dualuse {\n"
|
||||
" double x;\n"
|
||||
" int64_t y;\n"
|
||||
"}")
|
||||
@@ -0,0 +1,14 @@
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.codegen.ast import Type
|
||||
from sympy.codegen.cxxnodes import using
|
||||
from sympy.printing.codeprinter import cxxcode
|
||||
|
||||
x = Symbol('x')
|
||||
|
||||
def test_using():
|
||||
v = Type('std::vector')
|
||||
u1 = using(v)
|
||||
assert cxxcode(u1) == 'using std::vector'
|
||||
|
||||
u2 = using(v, 'vec')
|
||||
assert cxxcode(u2) == 'using vec = std::vector'
|
||||
@@ -0,0 +1,213 @@
|
||||
import os
|
||||
import tempfile
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.codegen.ast import (
|
||||
Assignment, Print, Declaration, FunctionDefinition, Return, real,
|
||||
FunctionCall, Variable, Element, integer
|
||||
)
|
||||
from sympy.codegen.fnodes import (
|
||||
allocatable, ArrayConstructor, isign, dsign, cmplx, kind, literal_dp,
|
||||
Program, Module, use, Subroutine, dimension, assumed_extent, ImpliedDoLoop,
|
||||
intent_out, size, Do, SubroutineCall, sum_, array, bind_C
|
||||
)
|
||||
from sympy.codegen.futils import render_as_module
|
||||
from sympy.core.expr import unchanged
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import fcode
|
||||
from sympy.utilities._compilation import has_fortran, compile_run_strings, compile_link_import_strings
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, XFAIL
|
||||
|
||||
cython = import_module('cython')
|
||||
np = import_module('numpy')
|
||||
|
||||
|
||||
def test_size():
|
||||
x = Symbol('x', real=True)
|
||||
sx = size(x)
|
||||
assert fcode(sx, source_format='free') == 'size(x)'
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_size_assumed_shape():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
a = Symbol('a', real=True)
|
||||
body = [Return((sum_(a**2)/size(a))**.5)]
|
||||
arr = array(a, dim=[':'], intent='in')
|
||||
fd = FunctionDefinition(real, 'rms', [arr], body)
|
||||
render_as_module([fd], 'mod_rms')
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('rms.f90', render_as_module([fd], 'mod_rms')),
|
||||
('main.f90', (
|
||||
'program myprog\n'
|
||||
'use mod_rms, only: rms\n'
|
||||
'real*8, dimension(4), parameter :: x = [4, 2, 2, 2]\n'
|
||||
'print "(f7.5)", dsqrt(7d0) - rms(x)\n'
|
||||
'end program\n'
|
||||
))
|
||||
], clean=True)
|
||||
assert '0.00000' in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@XFAIL # https://github.com/sympy/sympy/issues/20265
|
||||
@may_xfail
|
||||
def test_ImpliedDoLoop():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
a, i = symbols('a i', integer=True)
|
||||
idl = ImpliedDoLoop(i**3, i, -3, 3, 2)
|
||||
ac = ArrayConstructor([-28, idl, 28])
|
||||
a = array(a, dim=[':'], attrs=[allocatable])
|
||||
prog = Program('idlprog', [
|
||||
a.as_Declaration(),
|
||||
Assignment(a, ac),
|
||||
Print([a])
|
||||
])
|
||||
fsrc = fcode(prog, standard=2003, source_format='free')
|
||||
(stdout, stderr), info = compile_run_strings([('main.f90', fsrc)], clean=True)
|
||||
for numstr in '-28 -27 -1 1 27 28'.split():
|
||||
assert numstr in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_Program():
|
||||
x = Symbol('x', real=True)
|
||||
vx = Variable.deduced(x, 42)
|
||||
decl = Declaration(vx)
|
||||
prnt = Print([x, x+1])
|
||||
prog = Program('foo', [decl, prnt])
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([('main.f90', fcode(prog, standard=90))], clean=True)
|
||||
assert '42' in stdout
|
||||
assert '43' in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_Module():
|
||||
x = Symbol('x', real=True)
|
||||
v_x = Variable.deduced(x)
|
||||
sq = FunctionDefinition(real, 'sqr', [v_x], [Return(x**2)])
|
||||
mod_sq = Module('mod_sq', [], [sq])
|
||||
sq_call = FunctionCall('sqr', [42.])
|
||||
prg_sq = Program('foobar', [
|
||||
use('mod_sq', only=['sqr']),
|
||||
Print(['"Square of 42 = "', sq_call])
|
||||
])
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('mod_sq.f90', fcode(mod_sq, standard=90)),
|
||||
('main.f90', fcode(prg_sq, standard=90))
|
||||
], clean=True)
|
||||
assert '42' in stdout
|
||||
assert str(42**2) in stdout
|
||||
assert stderr == ''
|
||||
|
||||
|
||||
@XFAIL # https://github.com/sympy/sympy/issues/20265
|
||||
@may_xfail
|
||||
def test_Subroutine():
|
||||
# Code to generate the subroutine in the example from
|
||||
# http://www.fortran90.org/src/best-practices.html#arrays
|
||||
r = Symbol('r', real=True)
|
||||
i = Symbol('i', integer=True)
|
||||
v_r = Variable.deduced(r, attrs=(dimension(assumed_extent), intent_out))
|
||||
v_i = Variable.deduced(i)
|
||||
v_n = Variable('n', integer)
|
||||
do_loop = Do([
|
||||
Assignment(Element(r, [i]), literal_dp(1)/i**2)
|
||||
], i, 1, v_n)
|
||||
sub = Subroutine("f", [v_r], [
|
||||
Declaration(v_n),
|
||||
Declaration(v_i),
|
||||
Assignment(v_n, size(r)),
|
||||
do_loop
|
||||
])
|
||||
x = Symbol('x', real=True)
|
||||
v_x3 = Variable.deduced(x, attrs=[dimension(3)])
|
||||
mod = Module('mymod', definitions=[sub])
|
||||
prog = Program('foo', [
|
||||
use(mod, only=[sub]),
|
||||
Declaration(v_x3),
|
||||
SubroutineCall(sub, [v_x3]),
|
||||
Print([sum_(v_x3), v_x3])
|
||||
])
|
||||
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('a.f90', fcode(mod, standard=90)),
|
||||
('b.f90', fcode(prog, standard=90))
|
||||
], clean=True)
|
||||
ref = [1.0/i**2 for i in range(1, 4)]
|
||||
assert str(sum(ref))[:-3] in stdout
|
||||
for _ in ref:
|
||||
assert str(_)[:-3] in stdout
|
||||
assert stderr == ''
|
||||
|
||||
|
||||
def test_isign():
|
||||
x = Symbol('x', integer=True)
|
||||
assert unchanged(isign, 1, x)
|
||||
assert fcode(isign(1, x), standard=95, source_format='free') == 'isign(1, x)'
|
||||
|
||||
|
||||
def test_dsign():
|
||||
x = Symbol('x')
|
||||
assert unchanged(dsign, 1, x)
|
||||
assert fcode(dsign(literal_dp(1), x), standard=95, source_format='free') == 'dsign(1d0, x)'
|
||||
|
||||
|
||||
def test_cmplx():
|
||||
x = Symbol('x')
|
||||
assert unchanged(cmplx, 1, x)
|
||||
|
||||
|
||||
def test_kind():
|
||||
x = Symbol('x')
|
||||
assert unchanged(kind, x)
|
||||
|
||||
|
||||
def test_literal_dp():
|
||||
assert fcode(literal_dp(0), source_format='free') == '0d0'
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_bind_C():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
if not cython:
|
||||
skip("Cython not found.")
|
||||
if not np:
|
||||
skip("NumPy not found.")
|
||||
|
||||
a = Symbol('a', real=True)
|
||||
s = Symbol('s', integer=True)
|
||||
body = [Return((sum_(a**2)/s)**.5)]
|
||||
arr = array(a, dim=[s], intent='in')
|
||||
fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')])
|
||||
f_mod = render_as_module([fd], 'mod_rms')
|
||||
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('rms.f90', f_mod),
|
||||
('_rms.pyx', (
|
||||
"#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double rms(double*, int*)\n"
|
||||
"def py_rms(double[::1] x):\n"
|
||||
" cdef int s = x.size\n"
|
||||
" return rms(&x[0], &s)\n"))
|
||||
], build_dir=folder)
|
||||
assert abs(mod.py_rms(np.array([2., 4., 2., 2.])) - 7**0.5) < 1e-14
|
||||
@@ -0,0 +1,50 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.core.function import Function
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.matrices.dense import zeros
|
||||
from sympy.simplify.simplify import simplify
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from sympy.printing.numpy import NumPyPrinter
|
||||
from sympy.testing.pytest import skip
|
||||
from sympy.external import import_module
|
||||
|
||||
|
||||
def test_matrix_solve_issue_24862():
|
||||
A = Matrix(3, 3, symbols('a:9'))
|
||||
b = Matrix(3, 1, symbols('b:3'))
|
||||
hash(MatrixSolve(A, b))
|
||||
|
||||
|
||||
def test_matrix_solve_derivative_exact():
|
||||
q = symbols('q')
|
||||
a11, a12, a21, a22, b1, b2 = (
|
||||
f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
|
||||
A = Matrix([[a11, a12], [a21, a22]])
|
||||
b = Matrix([b1, b2])
|
||||
x_lu = A.LUsolve(b)
|
||||
dxdq_lu = A.LUsolve(b.diff(q) - A.diff(q) * A.LUsolve(b))
|
||||
assert simplify(x_lu.diff(q) - dxdq_lu) == zeros(2, 1)
|
||||
# dxdq_ms is the MatrixSolve equivalent of dxdq_lu
|
||||
dxdq_ms = MatrixSolve(A, b.diff(q) - A.diff(q) * MatrixSolve(A, b))
|
||||
assert MatrixSolve(A, b).diff(q) == dxdq_ms
|
||||
|
||||
|
||||
def test_matrix_solve_derivative_numpy():
|
||||
np = import_module('numpy')
|
||||
if not np:
|
||||
skip("numpy not installed.")
|
||||
q = symbols('q')
|
||||
a11, a12, a21, a22, b1, b2 = (
|
||||
f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
|
||||
A = Matrix([[a11, a12], [a21, a22]])
|
||||
b = Matrix([b1, b2])
|
||||
dx_lu = A.LUsolve(b).diff(q)
|
||||
subs = {a11.diff(q): 0.2, a12.diff(q): 0.3, a21.diff(q): 0.1,
|
||||
a22.diff(q): 0.5, b1.diff(q): 0.4, b2.diff(q): 0.9,
|
||||
a11: 1.3, a12: 0.5, a21: 1.2, a22: 4, b1: 6.2, b2: 3.5}
|
||||
p, p_vals = zip(*subs.items())
|
||||
dx_sm = MatrixSolve(A, b).diff(q)
|
||||
np.testing.assert_allclose(
|
||||
lambdify(p, dx_sm, printer=NumPyPrinter)(*p_vals),
|
||||
lambdify(p, dx_lu, printer=NumPyPrinter)(*p_vals))
|
||||
@@ -0,0 +1,69 @@
|
||||
from itertools import product
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import Max, Min
|
||||
from sympy.printing.repr import srepr
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2, minimum, maximum, amax, amin
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
def test_logaddexp():
|
||||
lae_xy = logaddexp(x, y)
|
||||
ref_xy = log(exp(x) + exp(y))
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
lae_xy.diff(wrt, deriv_order) -
|
||||
ref_xy.diff(wrt, deriv_order)
|
||||
).rewrite(log).simplify() == 0
|
||||
|
||||
one_third_e = 1*exp(1)/3
|
||||
two_thirds_e = 2*exp(1)/3
|
||||
logThirdE = log(one_third_e)
|
||||
logTwoThirdsE = log(two_thirds_e)
|
||||
lae_sum_to_e = logaddexp(logThirdE, logTwoThirdsE)
|
||||
assert lae_sum_to_e.rewrite(log) == 1
|
||||
assert lae_sum_to_e.simplify() == 1
|
||||
was = logaddexp(2, 3)
|
||||
assert srepr(was) == srepr(was.simplify()) # cannot simplify with 2, 3
|
||||
|
||||
|
||||
def test_logaddexp2():
|
||||
lae2_xy = logaddexp2(x, y)
|
||||
ref2_xy = log(2**x + 2**y)/log(2)
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
lae2_xy.diff(wrt, deriv_order) -
|
||||
ref2_xy.diff(wrt, deriv_order)
|
||||
).rewrite(log).cancel() == 0
|
||||
|
||||
def lb(x):
|
||||
return log(x)/log(2)
|
||||
|
||||
two_thirds = S.One*2/3
|
||||
four_thirds = 2*two_thirds
|
||||
lbTwoThirds = lb(two_thirds)
|
||||
lbFourThirds = lb(four_thirds)
|
||||
lae2_sum_to_2 = logaddexp2(lbTwoThirds, lbFourThirds)
|
||||
assert lae2_sum_to_2.rewrite(log) == 1
|
||||
assert lae2_sum_to_2.simplify() == 1
|
||||
was = logaddexp2(x, y)
|
||||
assert srepr(was) == srepr(was.simplify()) # cannot simplify with x, y
|
||||
|
||||
|
||||
def test_minimum_maximum():
|
||||
for MM, mm in zip([Min, Max], [minimum, maximum]):
|
||||
ref = MM(x, y, z)
|
||||
m = mm(x, y, z)
|
||||
assert m != ref
|
||||
assert m.rewrite(MM) == ref
|
||||
|
||||
|
||||
def test_amin_amax():
|
||||
for am in [amin, amax]:
|
||||
assert am(x).array == x
|
||||
assert am(x).axis == None
|
||||
assert am(x, axis=3).axis == 3
|
||||
with raises(ValueError):
|
||||
am(x, y, z)
|
||||
@@ -0,0 +1,13 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.codegen.pynodes import List
|
||||
|
||||
|
||||
def test_List():
|
||||
l = List(2, 3, 4)
|
||||
assert l == List(2, 3, 4)
|
||||
assert str(l) == "[2, 3, 4]"
|
||||
x, y, z = symbols('x y z')
|
||||
l = List(x**2,y**3,z**4)
|
||||
# contrary to python's built-in list, we can call e.g. "replace" on List.
|
||||
m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp)
|
||||
assert m == [x**2, y-3, z-4]
|
||||
@@ -0,0 +1,7 @@
|
||||
from sympy.codegen.ast import Print
|
||||
from sympy.codegen.pyutils import render_as_module
|
||||
|
||||
def test_standard():
|
||||
ast = Print('x y'.split(), r"coordinate: %12.5g %12.5g\n")
|
||||
assert render_as_module(ast, standard='python3') == \
|
||||
'\n\nprint("coordinate: %12.5g %12.5g\\n" % (x, y), end="")'
|
||||
@@ -0,0 +1,479 @@
|
||||
import tempfile
|
||||
from sympy.core.numbers import pi, Rational
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin, sinc)
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.assumptions import assuming, Q
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.codegen.cfunctions import log2, exp2, expm1, log1p
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
|
||||
from sympy.codegen.scipy_nodes import cosm1, powm1
|
||||
from sympy.codegen.rewriting import (
|
||||
optimize, cosm1_opt, log2_opt, exp2_opt, expm1_opt, log1p_opt, powm1_opt, optims_c99,
|
||||
create_expand_pow_optimization, matinv_opt, logaddexp_opt, logaddexp2_opt,
|
||||
optims_numpy, optims_scipy, sinc_opts, FuncMinusOneOptim
|
||||
)
|
||||
from sympy.testing.pytest import XFAIL, skip
|
||||
from sympy.utilities import lambdify
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
|
||||
cython = import_module('cython')
|
||||
numpy = import_module('numpy')
|
||||
scipy = import_module('scipy')
|
||||
|
||||
|
||||
def test_log2_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = 7*log(3*x + 5)/(log(2))
|
||||
opt1 = optimize(expr1, [log2_opt])
|
||||
assert opt1 == 7*log2(3*x + 5)
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
expr2 = 3*log(5*x + 7)/(13*log(2))
|
||||
opt2 = optimize(expr2, [log2_opt])
|
||||
assert opt2 == 3*log2(5*x + 7)/13
|
||||
assert opt2.rewrite(log) == expr2
|
||||
|
||||
expr3 = log(x)/log(2)
|
||||
opt3 = optimize(expr3, [log2_opt])
|
||||
assert opt3 == log2(x)
|
||||
assert opt3.rewrite(log) == expr3
|
||||
|
||||
expr4 = log(x)/log(2) + log(x+1)
|
||||
opt4 = optimize(expr4, [log2_opt])
|
||||
assert opt4 == log2(x) + log(2)*log2(x+1)
|
||||
assert opt4.rewrite(log) == expr4
|
||||
|
||||
expr5 = log(17)
|
||||
opt5 = optimize(expr5, [log2_opt])
|
||||
assert opt5 == expr5
|
||||
|
||||
expr6 = log(x + 3)/log(2)
|
||||
opt6 = optimize(expr6, [log2_opt])
|
||||
assert str(opt6) == 'log2(x + 3)'
|
||||
assert opt6.rewrite(log) == expr6
|
||||
|
||||
|
||||
def test_exp2_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = 1 + 2**x
|
||||
opt1 = optimize(expr1, [exp2_opt])
|
||||
assert opt1 == 1 + exp2(x)
|
||||
assert opt1.rewrite(Pow) == expr1
|
||||
|
||||
expr2 = 1 + 3**x
|
||||
assert expr2 == optimize(expr2, [exp2_opt])
|
||||
|
||||
|
||||
def test_expm1_opt():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = exp(x) - 1
|
||||
opt1 = optimize(expr1, [expm1_opt])
|
||||
assert expm1(x) - opt1 == 0
|
||||
assert opt1.rewrite(exp) == expr1
|
||||
|
||||
expr2 = 3*exp(x) - 3
|
||||
opt2 = optimize(expr2, [expm1_opt])
|
||||
assert 3*expm1(x) == opt2
|
||||
assert opt2.rewrite(exp) == expr2
|
||||
|
||||
expr3 = 3*exp(x) - 5
|
||||
opt3 = optimize(expr3, [expm1_opt])
|
||||
assert 3*expm1(x) - 2 == opt3
|
||||
assert opt3.rewrite(exp) == expr3
|
||||
expm1_opt_non_opportunistic = FuncMinusOneOptim(exp, expm1, opportunistic=False)
|
||||
assert expr3 == optimize(expr3, [expm1_opt_non_opportunistic])
|
||||
assert opt1 == optimize(expr1, [expm1_opt_non_opportunistic])
|
||||
assert opt2 == optimize(expr2, [expm1_opt_non_opportunistic])
|
||||
|
||||
expr4 = 3*exp(x) + log(x) - 3
|
||||
opt4 = optimize(expr4, [expm1_opt])
|
||||
assert 3*expm1(x) + log(x) == opt4
|
||||
assert opt4.rewrite(exp) == expr4
|
||||
|
||||
expr5 = 3*exp(2*x) - 3
|
||||
opt5 = optimize(expr5, [expm1_opt])
|
||||
assert 3*expm1(2*x) == opt5
|
||||
assert opt5.rewrite(exp) == expr5
|
||||
|
||||
expr6 = (2*exp(x) + 1)/(exp(x) + 1) + 1
|
||||
opt6 = optimize(expr6, [expm1_opt])
|
||||
assert opt6.count_ops() <= expr6.count_ops()
|
||||
|
||||
def ev(e):
|
||||
return e.subs(x, 3).evalf()
|
||||
assert abs(ev(expr6) - ev(opt6)) < 1e-15
|
||||
|
||||
y = Symbol('y')
|
||||
expr7 = (2*exp(x) - 1)/(1 - exp(y)) - 1/(1-exp(y))
|
||||
opt7 = optimize(expr7, [expm1_opt])
|
||||
assert -2*expm1(x)/expm1(y) == opt7
|
||||
assert (opt7.rewrite(exp) - expr7).factor() == 0
|
||||
|
||||
expr8 = (1+exp(x))**2 - 4
|
||||
opt8 = optimize(expr8, [expm1_opt])
|
||||
tgt8a = (exp(x) + 3)*expm1(x)
|
||||
tgt8b = 2*expm1(x) + expm1(2*x)
|
||||
# Both tgt8a & tgt8b seem to give full precision (~16 digits for double)
|
||||
# for x=1e-7 (compare with expr8 which only achieves ~8 significant digits).
|
||||
# If we can show that either tgt8a or tgt8b is preferable, we can
|
||||
# change this test to ensure the preferable version is returned.
|
||||
assert (tgt8a - tgt8b).rewrite(exp).factor() == 0
|
||||
assert opt8 in (tgt8a, tgt8b)
|
||||
assert (opt8.rewrite(exp) - expr8).factor() == 0
|
||||
|
||||
expr9 = sin(expr8)
|
||||
opt9 = optimize(expr9, [expm1_opt])
|
||||
tgt9a = sin(tgt8a)
|
||||
tgt9b = sin(tgt8b)
|
||||
assert opt9 in (tgt9a, tgt9b)
|
||||
assert (opt9.rewrite(exp) - expr9.rewrite(exp)).factor().is_zero
|
||||
|
||||
|
||||
def test_expm1_two_exp_terms():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = exp(x) + exp(y) - 2
|
||||
opt1 = optimize(expr1, [expm1_opt])
|
||||
assert opt1 == expm1(x) + expm1(y)
|
||||
|
||||
|
||||
def test_cosm1_opt():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = cos(x) - 1
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert cosm1(x) - opt1 == 0
|
||||
assert opt1.rewrite(cos) == expr1
|
||||
|
||||
expr2 = 3*cos(x) - 3
|
||||
opt2 = optimize(expr2, [cosm1_opt])
|
||||
assert 3*cosm1(x) == opt2
|
||||
assert opt2.rewrite(cos) == expr2
|
||||
|
||||
expr3 = 3*cos(x) - 5
|
||||
opt3 = optimize(expr3, [cosm1_opt])
|
||||
assert 3*cosm1(x) - 2 == opt3
|
||||
assert opt3.rewrite(cos) == expr3
|
||||
cosm1_opt_non_opportunistic = FuncMinusOneOptim(cos, cosm1, opportunistic=False)
|
||||
assert expr3 == optimize(expr3, [cosm1_opt_non_opportunistic])
|
||||
assert opt1 == optimize(expr1, [cosm1_opt_non_opportunistic])
|
||||
assert opt2 == optimize(expr2, [cosm1_opt_non_opportunistic])
|
||||
|
||||
expr4 = 3*cos(x) + log(x) - 3
|
||||
opt4 = optimize(expr4, [cosm1_opt])
|
||||
assert 3*cosm1(x) + log(x) == opt4
|
||||
assert opt4.rewrite(cos) == expr4
|
||||
|
||||
expr5 = 3*cos(2*x) - 3
|
||||
opt5 = optimize(expr5, [cosm1_opt])
|
||||
assert 3*cosm1(2*x) == opt5
|
||||
assert opt5.rewrite(cos) == expr5
|
||||
|
||||
expr6 = 2 - 2*cos(x)
|
||||
opt6 = optimize(expr6, [cosm1_opt])
|
||||
assert -2*cosm1(x) == opt6
|
||||
assert opt6.rewrite(cos) == expr6
|
||||
|
||||
|
||||
def test_cosm1_two_cos_terms():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = cos(x) + cos(y) - 2
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert opt1 == cosm1(x) + cosm1(y)
|
||||
|
||||
|
||||
def test_expm1_cosm1_mixed():
|
||||
x = Symbol('x')
|
||||
expr1 = exp(x) + cos(x) - 2
|
||||
opt1 = optimize(expr1, [expm1_opt, cosm1_opt])
|
||||
assert opt1 == cosm1(x) + expm1(x)
|
||||
|
||||
|
||||
def _check_num_lambdify(expr, opt, val_subs, approx_ref, lambdify_kw=None, poorness=1e10):
|
||||
""" poorness=1e10 signifies that `expr` loses precision of at least ten decimal digits. """
|
||||
num_ref = expr.subs(val_subs).evalf()
|
||||
eps = numpy.finfo(numpy.float64).eps
|
||||
assert abs(num_ref - approx_ref) < approx_ref*eps
|
||||
f1 = lambdify(list(val_subs.keys()), opt, **(lambdify_kw or {}))
|
||||
args_float = tuple(map(float, val_subs.values()))
|
||||
num_err1 = abs(f1(*args_float) - approx_ref)
|
||||
assert num_err1 < abs(num_ref*eps)
|
||||
f2 = lambdify(list(val_subs.keys()), expr, **(lambdify_kw or {}))
|
||||
num_err2 = abs(f2(*args_float) - approx_ref)
|
||||
assert num_err2 > abs(num_ref*eps*poorness) # this only ensures that the *test* works as intended
|
||||
|
||||
|
||||
def test_cosm1_apart():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = 1/cos(x) - 1
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert opt1 == -cosm1(x)/cos(x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr1, opt1, {x: S(10)**-30}, 5e-61, lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
expr2 = 2/cos(x) - 2
|
||||
opt2 = optimize(expr2, optims_scipy)
|
||||
assert opt2 == -2*cosm1(x)/cos(x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr2, opt2, {x: S(10)**-30}, 1e-60, lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
expr3 = pi/cos(3*x) - pi
|
||||
opt3 = optimize(expr3, [cosm1_opt])
|
||||
assert opt3 == -pi*cosm1(3*x)/cos(3*x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr3, opt3, {x: S(10)**-30/3}, float(5e-61*pi), lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
|
||||
def test_powm1():
|
||||
args = x, y = map(Symbol, "xy")
|
||||
|
||||
expr1 = x**y - 1
|
||||
opt1 = optimize(expr1, [powm1_opt])
|
||||
assert opt1 == powm1(x, y)
|
||||
for arg in args:
|
||||
assert expr1.diff(arg) == opt1.diff(arg)
|
||||
if scipy and tuple(map(int, scipy.version.version.split('.')[:3])) >= (1, 10, 0):
|
||||
subs1_a = {x: Rational(*(1.0+1e-13).as_integer_ratio()), y: pi}
|
||||
ref1_f64_a = 3.139081648208105e-13
|
||||
_check_num_lambdify(expr1, opt1, subs1_a, ref1_f64_a, lambdify_kw={"modules": 'scipy'}, poorness=10**11)
|
||||
|
||||
subs1_b = {x: pi, y: Rational(*(1e-10).as_integer_ratio())}
|
||||
ref1_f64_b = 1.1447298859149205e-10
|
||||
_check_num_lambdify(expr1, opt1, subs1_b, ref1_f64_b, lambdify_kw={"modules": 'scipy'}, poorness=10**9)
|
||||
|
||||
|
||||
def test_log1p_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = log(x + 1)
|
||||
opt1 = optimize(expr1, [log1p_opt])
|
||||
assert log1p(x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
expr2 = log(3*x + 3)
|
||||
opt2 = optimize(expr2, [log1p_opt])
|
||||
assert log1p(x) + log(3) == opt2
|
||||
assert (opt2.rewrite(log) - expr2).simplify() == 0
|
||||
|
||||
expr3 = log(2*x + 1)
|
||||
opt3 = optimize(expr3, [log1p_opt])
|
||||
assert log1p(2*x) - opt3 == 0
|
||||
assert opt3.rewrite(log) == expr3
|
||||
|
||||
expr4 = log(x+3)
|
||||
opt4 = optimize(expr4, [log1p_opt])
|
||||
assert str(opt4) == 'log(x + 3)'
|
||||
|
||||
|
||||
def test_optims_c99():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1
|
||||
opt1 = optimize(expr1, optims_c99).simplify()
|
||||
assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)
|
||||
assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1
|
||||
|
||||
expr2 = log(x)/log(2) + log(x + 1)
|
||||
opt2 = optimize(expr2, optims_c99)
|
||||
assert opt2 == log2(x) + log1p(x)
|
||||
assert opt2.rewrite(log) == expr2
|
||||
|
||||
expr3 = log(x)/log(2) + log(17*x + 17)
|
||||
opt3 = optimize(expr3, optims_c99)
|
||||
delta3 = opt3 - (log2(x) + log(17) + log1p(x))
|
||||
assert delta3 == 0
|
||||
assert (opt3.rewrite(log) - expr3).simplify() == 0
|
||||
|
||||
expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17)
|
||||
opt4 = optimize(expr4, optims_c99).simplify()
|
||||
delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x))
|
||||
assert delta4 == 0
|
||||
assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0
|
||||
|
||||
expr5 = 3*exp(2*x) - 3
|
||||
opt5 = optimize(expr5, optims_c99)
|
||||
delta5 = opt5 - 3*expm1(2*x)
|
||||
assert delta5 == 0
|
||||
assert opt5.rewrite(exp) == expr5
|
||||
|
||||
expr6 = exp(2*x) - 3
|
||||
opt6 = optimize(expr6, optims_c99)
|
||||
assert opt6 in (expm1(2*x) - 2, expr6) # expm1(2*x) - 2 is not better or worse
|
||||
|
||||
expr7 = log(3*x + 3)
|
||||
opt7 = optimize(expr7, optims_c99)
|
||||
delta7 = opt7 - (log(3) + log1p(x))
|
||||
assert delta7 == 0
|
||||
assert (opt7.rewrite(log) - expr7).simplify() == 0
|
||||
|
||||
expr8 = log(2*x + 3)
|
||||
opt8 = optimize(expr8, optims_c99)
|
||||
assert opt8 == expr8
|
||||
|
||||
|
||||
def test_create_expand_pow_optimization():
|
||||
cc = lambda x: ccode(
|
||||
optimize(x, [create_expand_pow_optimization(4)]))
|
||||
x = Symbol('x')
|
||||
assert cc(x**4) == 'x*x*x*x'
|
||||
assert cc(x**4 + x**2) == 'x*x + x*x*x*x'
|
||||
assert cc(x**5 + x**4) == 'pow(x, 5) + x*x*x*x'
|
||||
assert cc(sin(x)**4) == 'pow(sin(x), 4)'
|
||||
# gh issue 15335
|
||||
assert cc(x**(-4)) == '1.0/(x*x*x*x)'
|
||||
assert cc(x**(-5)) == 'pow(x, -5)'
|
||||
assert cc(-x**4) == '-(x*x*x*x)'
|
||||
assert cc(x**4 - x**2) == '-(x*x) + x*x*x*x'
|
||||
i = Symbol('i', integer=True)
|
||||
assert cc(x**i - x**2) == 'pow(x, i) - (x*x)'
|
||||
y = Symbol('y', real=True)
|
||||
assert cc(Abs(exp(y**4))) == "exp(y*y*y*y)"
|
||||
|
||||
# gh issue 20753
|
||||
cc2 = lambda x: ccode(optimize(x, [create_expand_pow_optimization(
|
||||
4, base_req=lambda b: b.is_Function)]))
|
||||
assert cc2(x**3 + sin(x)**3) == "pow(x, 3) + sin(x)*sin(x)*sin(x)"
|
||||
|
||||
|
||||
def test_matsolve():
|
||||
n = Symbol('n', integer=True)
|
||||
A = MatrixSymbol('A', n, n)
|
||||
x = MatrixSymbol('x', n, 1)
|
||||
|
||||
with assuming(Q.fullrank(A)):
|
||||
assert optimize(A**(-1) * x, [matinv_opt]) == MatrixSolve(A, x)
|
||||
assert optimize(A**(-1) * x + x, [matinv_opt]) == MatrixSolve(A, x) + x
|
||||
|
||||
|
||||
def test_logaddexp_opt():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = log(exp(x) + exp(y))
|
||||
opt1 = optimize(expr1, [logaddexp_opt])
|
||||
assert logaddexp(x, y) - opt1 == 0
|
||||
assert logaddexp(y, x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
|
||||
def test_logaddexp2_opt():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = log(2**x + 2**y)/log(2)
|
||||
opt1 = optimize(expr1, [logaddexp2_opt])
|
||||
assert logaddexp2(x, y) - opt1 == 0
|
||||
assert logaddexp2(y, x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
|
||||
def test_sinc_opts():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, sinc_opts) == v
|
||||
|
||||
x = Symbol('x')
|
||||
check({
|
||||
sin(x)/x : sinc(x),
|
||||
sin(2*x)/(2*x) : sinc(2*x),
|
||||
sin(3*x)/x : 3*sinc(3*x),
|
||||
x*sin(x) : x*sin(x)
|
||||
})
|
||||
|
||||
y = Symbol('y')
|
||||
check({
|
||||
sin(x*y)/(x*y) : sinc(x*y),
|
||||
y*sin(x/y)/x : sinc(x/y),
|
||||
sin(sin(x))/sin(x) : sinc(sin(x)),
|
||||
sin(3*sin(x))/sin(x) : 3*sinc(3*sin(x)),
|
||||
sin(x)/y : sin(x)/y
|
||||
})
|
||||
|
||||
|
||||
def test_optims_numpy():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, optims_numpy) == v
|
||||
|
||||
x = Symbol('x')
|
||||
check({
|
||||
sin(2*x)/(2*x) + exp(2*x) - 1: sinc(2*x) + expm1(2*x),
|
||||
log(x+3)/log(2) + log(x**2 + 1): log1p(x**2) + log2(x+3)
|
||||
})
|
||||
|
||||
|
||||
@XFAIL # room for improvement, ideally this test case should pass.
|
||||
def test_optims_numpy_TODO():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, optims_numpy) == v
|
||||
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
check({
|
||||
log(x*y)*sin(x*y)*log(x*y+1)/(log(2)*x*y): log2(x*y)*sinc(x*y)*log1p(x*y),
|
||||
exp(x*sin(y)/y) - 1: expm1(x*sinc(y))
|
||||
})
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_compiled_ccode_with_rewriting():
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
|
||||
x = Symbol('x')
|
||||
about_two = 2**(58/S(117))*3**(97/S(117))*5**(4/S(39))*7**(92/S(117))/S(30)*pi
|
||||
# about_two: 1.999999999999581826
|
||||
unchanged = 2*exp(x) - about_two
|
||||
xval = S(10)**-11
|
||||
ref = unchanged.subs(x, xval).n(19) # 2.0418173913673213e-11
|
||||
|
||||
rewritten = optimize(2*exp(x) - about_two, [expm1_opt])
|
||||
|
||||
# Unfortunately, we need to call ``.n()`` on our expressions before we hand them
|
||||
# to ``ccode``, and we need to request a large number of significant digits.
|
||||
# In this test, results converged for double precision when the following number
|
||||
# of significant digits were chosen:
|
||||
NUMBER_OF_DIGITS = 25 # TODO: this should ideally be automatically handled.
|
||||
|
||||
func_c = '''
|
||||
#include <math.h>
|
||||
|
||||
double func_unchanged(double x) {
|
||||
return %(unchanged)s;
|
||||
}
|
||||
double func_rewritten(double x) {
|
||||
return %(rewritten)s;
|
||||
}
|
||||
''' % {"unchanged": ccode(unchanged.n(NUMBER_OF_DIGITS)),
|
||||
"rewritten": ccode(rewritten.n(NUMBER_OF_DIGITS))}
|
||||
|
||||
func_pyx = '''
|
||||
#cython: language_level=3
|
||||
cdef extern double func_unchanged(double)
|
||||
cdef extern double func_rewritten(double)
|
||||
def py_unchanged(x):
|
||||
return func_unchanged(x)
|
||||
def py_rewritten(x):
|
||||
return func_rewritten(x)
|
||||
'''
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings(
|
||||
[('func.c', func_c), ('_func.pyx', func_pyx)],
|
||||
build_dir=folder, compile_kwargs={"std": 'c99'}
|
||||
)
|
||||
err_rewritten = abs(mod.py_rewritten(1e-11) - ref)
|
||||
err_unchanged = abs(mod.py_unchanged(1e-11) - ref)
|
||||
assert 1e-27 < err_rewritten < 1e-25 # highly accurate.
|
||||
assert 1e-19 < err_unchanged < 1e-16 # quite poor.
|
||||
|
||||
# Tolerances used above were determined as follows:
|
||||
# >>> no_opt = unchanged.subs(x, xval.evalf()).evalf()
|
||||
# >>> with_opt = rewritten.n(25).subs(x, 1e-11).evalf()
|
||||
# >>> with_opt - ref, no_opt - ref
|
||||
# (1.1536301877952077e-26, 1.6547074214222335e-18)
|
||||
@@ -0,0 +1,44 @@
|
||||
from itertools import product
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import exp, log
|
||||
from sympy.functions.elementary.trigonometric import cos
|
||||
from sympy.core.numbers import pi
|
||||
from sympy.codegen.scipy_nodes import cosm1, powm1
|
||||
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
|
||||
def test_cosm1():
|
||||
cm1_xy = cosm1(x*y)
|
||||
ref_xy = cos(x*y) - 1
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
cm1_xy.diff(wrt, deriv_order) -
|
||||
ref_xy.diff(wrt, deriv_order)
|
||||
).rewrite(cos).simplify() == 0
|
||||
|
||||
expr_minus2 = cosm1(pi)
|
||||
assert expr_minus2.rewrite(cos) == -2
|
||||
assert cosm1(3.14).simplify() == cosm1(3.14) # cannot simplify with 3.14
|
||||
assert cosm1(pi/2).simplify() == -1
|
||||
assert (1/cos(x) - 1 + cosm1(x)/cos(x)).simplify() == 0
|
||||
|
||||
|
||||
def test_powm1():
|
||||
cases = {
|
||||
powm1(x, y): x**y - 1,
|
||||
powm1(x*y, z): (x*y)**z - 1,
|
||||
powm1(x, y*z): x**(y*z)-1,
|
||||
powm1(x*y*z, x*y*z): (x*y*z)**(x*y*z)-1
|
||||
}
|
||||
for pm1_e, ref_e in cases.items():
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
der = pm1_e.diff(wrt, deriv_order)
|
||||
ref = ref_e.diff(wrt, deriv_order)
|
||||
delta = (der - ref).rewrite(Pow)
|
||||
assert delta.simplify() == 0
|
||||
|
||||
eulers_constant_m1 = powm1(x, 1/log(x))
|
||||
assert eulers_constant_m1.rewrite(Pow) == exp(1) - 1
|
||||
assert eulers_constant_m1.simplify() == exp(1) - 1
|
||||
Reference in New Issue
Block a user