chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,24 @@
|
||||
""" The ``sympy.codegen`` module contains classes and functions for building
|
||||
abstract syntax trees of algorithms. These trees may then be printed by the
|
||||
code-printers in ``sympy.printing``.
|
||||
|
||||
There are several submodules available:
|
||||
- ``sympy.codegen.ast``: AST nodes useful across multiple languages.
|
||||
- ``sympy.codegen.cnodes``: AST nodes useful for the C family of languages.
|
||||
- ``sympy.codegen.fnodes``: AST nodes useful for Fortran.
|
||||
- ``sympy.codegen.cfunctions``: functions specific to C (C99 math functions)
|
||||
- ``sympy.codegen.ffunctions``: functions specific to Fortran (e.g. ``kind``).
|
||||
|
||||
|
||||
|
||||
"""
|
||||
from .ast import (
|
||||
Assignment, aug_assign, CodeBlock, For, Attribute, Variable, Declaration,
|
||||
While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Assignment', 'aug_assign', 'CodeBlock', 'For', 'Attribute', 'Variable',
|
||||
'Declaration', 'While', 'Scope', 'Print', 'FunctionPrototype',
|
||||
'FunctionDefinition', 'FunctionCall',
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
"""This module provides containers for python objects that are valid
|
||||
printing targets but are not a subclass of SymPy's Printable.
|
||||
"""
|
||||
|
||||
|
||||
from sympy.core.containers import Tuple
|
||||
|
||||
|
||||
class List(Tuple):
|
||||
"""Represents a (frozen) (Python) list (for code printing purposes)."""
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, list):
|
||||
return self == List(*other)
|
||||
else:
|
||||
return self.args == other
|
||||
|
||||
def __hash__(self):
|
||||
return super().__hash__()
|
||||
@@ -0,0 +1,180 @@
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.numbers import oo
|
||||
from sympy.core.relational import (Gt, Lt)
|
||||
from sympy.core.symbol import (Dummy, Symbol)
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.miscellaneous import Min, Max
|
||||
from sympy.logic.boolalg import And
|
||||
from sympy.codegen.ast import (
|
||||
Assignment, AddAugmentedAssignment, break_, CodeBlock, Declaration, FunctionDefinition,
|
||||
Print, Return, Scope, While, Variable, Pointer, real
|
||||
)
|
||||
from sympy.codegen.cfunctions import isnan
|
||||
|
||||
""" This module collects functions for constructing ASTs representing algorithms. """
|
||||
|
||||
def newtons_method(expr, wrt, atol=1e-12, delta=None, *, rtol=4e-16, debug=False,
|
||||
itermax=None, counter=None, delta_fn=lambda e, x: -e/e.diff(x),
|
||||
cse=False, handle_nan=None,
|
||||
bounds=None):
|
||||
""" Generates an AST for Newton-Raphson method (a root-finding algorithm).
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Returns an abstract syntax tree (AST) based on ``sympy.codegen.ast`` for Netwon's
|
||||
method of root-finding.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : expression
|
||||
wrt : Symbol
|
||||
With respect to, i.e. what is the variable.
|
||||
atol : number or expression
|
||||
Absolute tolerance (stopping criterion)
|
||||
rtol : number or expression
|
||||
Relative tolerance (stopping criterion)
|
||||
delta : Symbol
|
||||
Will be a ``Dummy`` if ``None``.
|
||||
debug : bool
|
||||
Whether to print convergence information during iterations
|
||||
itermax : number or expr
|
||||
Maximum number of iterations.
|
||||
counter : Symbol
|
||||
Will be a ``Dummy`` if ``None``.
|
||||
delta_fn: Callable[[Expr, Symbol], Expr]
|
||||
computes the step, default is newtons method. For e.g. Halley's method
|
||||
use delta_fn=lambda e, x: -2*e*e.diff(x)/(2*e.diff(x)**2 - e*e.diff(x, 2))
|
||||
cse: bool
|
||||
Perform common sub-expression elimination on delta expression
|
||||
handle_nan: Token
|
||||
How to handle occurrence of not-a-number (NaN).
|
||||
bounds: Optional[tuple[Expr, Expr]]
|
||||
Perform optimization within bounds
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import symbols, cos
|
||||
>>> from sympy.codegen.ast import Assignment
|
||||
>>> from sympy.codegen.algorithms import newtons_method
|
||||
>>> x, dx, atol = symbols('x dx atol')
|
||||
>>> expr = cos(x) - x**3
|
||||
>>> algo = newtons_method(expr, x, atol=atol, delta=dx)
|
||||
>>> algo.has(Assignment(dx, -expr/expr.diff(x)))
|
||||
True
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [1] https://en.wikipedia.org/wiki/Newton%27s_method
|
||||
|
||||
"""
|
||||
|
||||
if delta is None:
|
||||
delta = Dummy()
|
||||
Wrapper = Scope
|
||||
name_d = 'delta'
|
||||
else:
|
||||
Wrapper = lambda x: x
|
||||
name_d = delta.name
|
||||
|
||||
delta_expr = delta_fn(expr, wrt)
|
||||
if cse:
|
||||
from sympy.simplify.cse_main import cse
|
||||
cses, (red,) = cse([delta_expr.factor()])
|
||||
whl_bdy = [Assignment(dum, sub_e) for dum, sub_e in cses]
|
||||
whl_bdy += [Assignment(delta, red)]
|
||||
else:
|
||||
whl_bdy = [Assignment(delta, delta_expr)]
|
||||
if handle_nan is not None:
|
||||
whl_bdy += [While(isnan(delta), CodeBlock(handle_nan, break_))]
|
||||
whl_bdy += [AddAugmentedAssignment(wrt, delta)]
|
||||
if bounds is not None:
|
||||
whl_bdy += [Assignment(wrt, Min(Max(wrt, bounds[0]), bounds[1]))]
|
||||
if debug:
|
||||
prnt = Print([wrt, delta], r"{}=%12.5g {}=%12.5g\n".format(wrt.name, name_d))
|
||||
whl_bdy += [prnt]
|
||||
req = Gt(Abs(delta), atol + rtol*Abs(wrt))
|
||||
declars = [Declaration(Variable(delta, type=real, value=oo))]
|
||||
if itermax is not None:
|
||||
counter = counter or Dummy(integer=True)
|
||||
v_counter = Variable.deduced(counter, 0)
|
||||
declars.append(Declaration(v_counter))
|
||||
whl_bdy.append(AddAugmentedAssignment(counter, 1))
|
||||
req = And(req, Lt(counter, itermax))
|
||||
whl = While(req, CodeBlock(*whl_bdy))
|
||||
blck = declars
|
||||
if debug:
|
||||
blck.append(Print([wrt], r"{}=%12.5g\n".format(wrt.name)))
|
||||
blck += [whl]
|
||||
return Wrapper(CodeBlock(*blck))
|
||||
|
||||
|
||||
def _symbol_of(arg):
|
||||
if isinstance(arg, Declaration):
|
||||
arg = arg.variable.symbol
|
||||
elif isinstance(arg, Variable):
|
||||
arg = arg.symbol
|
||||
return arg
|
||||
|
||||
|
||||
def newtons_method_function(expr, wrt, params=None, func_name="newton", attrs=Tuple(), *, delta=None, **kwargs):
|
||||
""" Generates an AST for a function implementing the Newton-Raphson method.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : expression
|
||||
wrt : Symbol
|
||||
With respect to, i.e. what is the variable
|
||||
params : iterable of symbols
|
||||
Symbols appearing in expr that are taken as constants during the iterations
|
||||
(these will be accepted as parameters to the generated function).
|
||||
func_name : str
|
||||
Name of the generated function.
|
||||
attrs : Tuple
|
||||
Attribute instances passed as ``attrs`` to ``FunctionDefinition``.
|
||||
\\*\\*kwargs :
|
||||
Keyword arguments passed to :func:`sympy.codegen.algorithms.newtons_method`.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import symbols, cos
|
||||
>>> from sympy.codegen.algorithms import newtons_method_function
|
||||
>>> from sympy.codegen.pyutils import render_as_module
|
||||
>>> x = symbols('x')
|
||||
>>> expr = cos(x) - x**3
|
||||
>>> func = newtons_method_function(expr, x)
|
||||
>>> py_mod = render_as_module(func) # source code as string
|
||||
>>> namespace = {}
|
||||
>>> exec(py_mod, namespace, namespace)
|
||||
>>> res = eval('newton(0.5)', namespace)
|
||||
>>> abs(res - 0.865474033102) < 1e-12
|
||||
True
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
sympy.codegen.algorithms.newtons_method
|
||||
|
||||
"""
|
||||
if params is None:
|
||||
params = (wrt,)
|
||||
pointer_subs = {p.symbol: Symbol('(*%s)' % p.symbol.name)
|
||||
for p in params if isinstance(p, Pointer)}
|
||||
if delta is None:
|
||||
delta = Symbol('d_' + wrt.name)
|
||||
if expr.has(delta):
|
||||
delta = None # will use Dummy
|
||||
algo = newtons_method(expr, wrt, delta=delta, **kwargs).xreplace(pointer_subs)
|
||||
if isinstance(algo, Scope):
|
||||
algo = algo.body
|
||||
not_in_params = expr.free_symbols.difference({_symbol_of(p) for p in params})
|
||||
if not_in_params:
|
||||
raise ValueError("Missing symbols in params: %s" % ', '.join(map(str, not_in_params)))
|
||||
declars = tuple(Variable(p, real) for p in params)
|
||||
body = CodeBlock(algo, Return(wrt))
|
||||
return FunctionDefinition(real, func_name, declars, body, attrs=attrs)
|
||||
@@ -0,0 +1,187 @@
|
||||
import math
|
||||
from sympy.sets.sets import Interval
|
||||
from sympy.calculus.singularities import is_increasing, is_decreasing
|
||||
from sympy.codegen.rewriting import Optimization
|
||||
from sympy.core.function import UndefinedFunction
|
||||
|
||||
"""
|
||||
This module collects classes useful for approximate rewriting of expressions.
|
||||
This can be beneficial when generating numeric code for which performance is
|
||||
of greater importance than precision (e.g. for preconditioners used in iterative
|
||||
methods).
|
||||
"""
|
||||
|
||||
class SumApprox(Optimization):
|
||||
"""
|
||||
Approximates sum by neglecting small terms.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
If terms are expressions which can be determined to be monotonic, then
|
||||
bounds for those expressions are added.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
bounds : dict
|
||||
Mapping expressions to length 2 tuple of bounds (low, high).
|
||||
reltol : number
|
||||
Threshold for when to ignore a term. Taken relative to the largest
|
||||
lower bound among bounds.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import exp
|
||||
>>> from sympy.abc import x, y, z
|
||||
>>> from sympy.codegen.rewriting import optimize
|
||||
>>> from sympy.codegen.approximations import SumApprox
|
||||
>>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)}
|
||||
>>> sum_approx3 = SumApprox(bounds, reltol=1e-3)
|
||||
>>> sum_approx2 = SumApprox(bounds, reltol=1e-2)
|
||||
>>> sum_approx1 = SumApprox(bounds, reltol=1e-1)
|
||||
>>> expr = 3*(x + y + exp(z))
|
||||
>>> optimize(expr, [sum_approx3])
|
||||
3*(x + y + exp(z))
|
||||
>>> optimize(expr, [sum_approx2])
|
||||
3*y + 3*exp(z)
|
||||
>>> optimize(expr, [sum_approx1])
|
||||
3*y
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, bounds, reltol, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.bounds = bounds
|
||||
self.reltol = reltol
|
||||
|
||||
def __call__(self, expr):
|
||||
return expr.factor().replace(self.query, lambda arg: self.value(arg))
|
||||
|
||||
def query(self, expr):
|
||||
return expr.is_Add
|
||||
|
||||
def value(self, add):
|
||||
for term in add.args:
|
||||
if term.is_number or term in self.bounds or len(term.free_symbols) != 1:
|
||||
continue
|
||||
fs, = term.free_symbols
|
||||
if fs not in self.bounds:
|
||||
continue
|
||||
intrvl = Interval(*self.bounds[fs])
|
||||
if is_increasing(term, intrvl, fs):
|
||||
self.bounds[term] = (
|
||||
term.subs({fs: self.bounds[fs][0]}),
|
||||
term.subs({fs: self.bounds[fs][1]})
|
||||
)
|
||||
elif is_decreasing(term, intrvl, fs):
|
||||
self.bounds[term] = (
|
||||
term.subs({fs: self.bounds[fs][1]}),
|
||||
term.subs({fs: self.bounds[fs][0]})
|
||||
)
|
||||
else:
|
||||
return add
|
||||
|
||||
if all(term.is_number or term in self.bounds for term in add.args):
|
||||
bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args]
|
||||
largest_abs_guarantee = 0
|
||||
for lo, hi in bounds:
|
||||
if lo <= 0 <= hi:
|
||||
continue
|
||||
largest_abs_guarantee = max(largest_abs_guarantee,
|
||||
min(abs(lo), abs(hi)))
|
||||
new_terms = []
|
||||
for term, (lo, hi) in zip(add.args, bounds):
|
||||
if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol:
|
||||
new_terms.append(term)
|
||||
return add.func(*new_terms)
|
||||
else:
|
||||
return add
|
||||
|
||||
|
||||
class SeriesApprox(Optimization):
|
||||
""" Approximates functions by expanding them as a series.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
bounds : dict
|
||||
Mapping expressions to length 2 tuple of bounds (low, high).
|
||||
reltol : number
|
||||
Threshold for when to ignore a term. Taken relative to the largest
|
||||
lower bound among bounds.
|
||||
max_order : int
|
||||
Largest order to include in series expansion
|
||||
n_point_checks : int (even)
|
||||
The validity of an expansion (with respect to reltol) is checked at
|
||||
discrete points (linearly spaced over the bounds of the variable). The
|
||||
number of points used in this numerical check is given by this number.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import sin, pi
|
||||
>>> from sympy.abc import x, y
|
||||
>>> from sympy.codegen.rewriting import optimize
|
||||
>>> from sympy.codegen.approximations import SeriesApprox
|
||||
>>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)}
|
||||
>>> series_approx2 = SeriesApprox(bounds, reltol=1e-2)
|
||||
>>> series_approx3 = SeriesApprox(bounds, reltol=1e-3)
|
||||
>>> series_approx8 = SeriesApprox(bounds, reltol=1e-8)
|
||||
>>> expr = sin(x)*sin(y)
|
||||
>>> optimize(expr, [series_approx2])
|
||||
x*(-y + (y - pi)**3/6 + pi)
|
||||
>>> optimize(expr, [series_approx3])
|
||||
(-x**3/6 + x)*sin(y)
|
||||
>>> optimize(expr, [series_approx8])
|
||||
sin(x)*sin(y)
|
||||
|
||||
"""
|
||||
def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.bounds = bounds
|
||||
self.reltol = reltol
|
||||
self.max_order = max_order
|
||||
if n_point_checks % 2 == 1:
|
||||
raise ValueError("Checking the solution at expansion point is not helpful")
|
||||
self.n_point_checks = n_point_checks
|
||||
self._prec = math.ceil(-math.log10(self.reltol))
|
||||
|
||||
def __call__(self, expr):
|
||||
return expr.factor().replace(self.query, lambda arg: self.value(arg))
|
||||
|
||||
def query(self, expr):
|
||||
return (expr.is_Function and not isinstance(expr, UndefinedFunction)
|
||||
and len(expr.args) == 1)
|
||||
|
||||
def value(self, fexpr):
|
||||
free_symbols = fexpr.free_symbols
|
||||
if len(free_symbols) != 1:
|
||||
return fexpr
|
||||
symb, = free_symbols
|
||||
if symb not in self.bounds:
|
||||
return fexpr
|
||||
lo, hi = self.bounds[symb]
|
||||
x0 = (lo + hi)/2
|
||||
cheapest = None
|
||||
for n in range(self.max_order+1, 0, -1):
|
||||
fseri = fexpr.series(symb, x0=x0, n=n).removeO()
|
||||
n_ok = True
|
||||
for idx in range(self.n_point_checks):
|
||||
x = lo + idx*(hi - lo)/(self.n_point_checks - 1)
|
||||
val = fseri.xreplace({symb: x})
|
||||
ref = fexpr.xreplace({symb: x})
|
||||
if abs((1 - val/ref).evalf(self._prec)) > self.reltol:
|
||||
n_ok = False
|
||||
break
|
||||
|
||||
if n_ok:
|
||||
cheapest = fseri
|
||||
else:
|
||||
break
|
||||
|
||||
if cheapest is None:
|
||||
return fexpr
|
||||
else:
|
||||
return cheapest
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,558 @@
|
||||
"""
|
||||
This module contains SymPy functions mathcin corresponding to special math functions in the
|
||||
C standard library (since C99, also available in C++11).
|
||||
|
||||
The functions defined in this module allows the user to express functions such as ``expm1``
|
||||
as a SymPy function for symbolic manipulation.
|
||||
|
||||
"""
|
||||
from sympy.core.function import ArgumentIndexError, Function
|
||||
from sympy.core.numbers import Rational
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.functions.elementary.exponential import exp, log
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.logic.boolalg import BooleanFunction, true, false
|
||||
|
||||
def _expm1(x):
|
||||
return exp(x) - S.One
|
||||
|
||||
|
||||
class expm1(Function):
|
||||
"""
|
||||
Represents the exponential function minus one.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The benefit of using ``expm1(x)`` over ``exp(x) - 1``
|
||||
is that the latter is prone to cancellation under finite precision
|
||||
arithmetic when x is close to zero.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import expm1
|
||||
>>> '%.0e' % expm1(1e-99).evalf()
|
||||
'1e-99'
|
||||
>>> from math import exp
|
||||
>>> exp(1e-99) - 1
|
||||
0.0
|
||||
>>> expm1(x).diff(x)
|
||||
exp(x)
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
log1p
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return exp(*self.args)
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _expm1(*self.args)
|
||||
|
||||
def _eval_rewrite_as_exp(self, arg, **kwargs):
|
||||
return exp(arg) - S.One
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_exp
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
exp_arg = exp.eval(arg)
|
||||
if exp_arg is not None:
|
||||
return exp_arg - S.One
|
||||
|
||||
def _eval_is_real(self):
|
||||
return self.args[0].is_real
|
||||
|
||||
def _eval_is_finite(self):
|
||||
return self.args[0].is_finite
|
||||
|
||||
|
||||
def _log1p(x):
|
||||
return log(x + S.One)
|
||||
|
||||
|
||||
class log1p(Function):
|
||||
"""
|
||||
Represents the natural logarithm of a number plus one.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The benefit of using ``log1p(x)`` over ``log(x + 1)``
|
||||
is that the latter is prone to cancellation under finite precision
|
||||
arithmetic when x is close to zero.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import log1p
|
||||
>>> from sympy import expand_log
|
||||
>>> '%.0e' % expand_log(log1p(1e-99)).evalf()
|
||||
'1e-99'
|
||||
>>> from math import log
|
||||
>>> log(1 + 1e-99)
|
||||
0.0
|
||||
>>> log1p(x).diff(x)
|
||||
1/(x + 1)
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
expm1
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return S.One/(self.args[0] + S.One)
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _log1p(*self.args)
|
||||
|
||||
def _eval_rewrite_as_log(self, arg, **kwargs):
|
||||
return _log1p(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_log
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg.is_Rational:
|
||||
return log(arg + S.One)
|
||||
elif not arg.is_Float: # not safe to add 1 to Float
|
||||
return log.eval(arg + S.One)
|
||||
elif arg.is_number:
|
||||
return log(Rational(arg) + S.One)
|
||||
|
||||
def _eval_is_real(self):
|
||||
return (self.args[0] + S.One).is_nonnegative
|
||||
|
||||
def _eval_is_finite(self):
|
||||
if (self.args[0] + S.One).is_zero:
|
||||
return False
|
||||
return self.args[0].is_finite
|
||||
|
||||
def _eval_is_positive(self):
|
||||
return self.args[0].is_positive
|
||||
|
||||
def _eval_is_zero(self):
|
||||
return self.args[0].is_zero
|
||||
|
||||
def _eval_is_nonnegative(self):
|
||||
return self.args[0].is_nonnegative
|
||||
|
||||
_Two = S(2)
|
||||
|
||||
def _exp2(x):
|
||||
return Pow(_Two, x)
|
||||
|
||||
class exp2(Function):
|
||||
"""
|
||||
Represents the exponential function with base two.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The benefit of using ``exp2(x)`` over ``2**x``
|
||||
is that the latter is not as efficient under finite precision
|
||||
arithmetic.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import exp2
|
||||
>>> exp2(2).evalf() == 4.0
|
||||
True
|
||||
>>> exp2(x).diff(x)
|
||||
log(2)*exp2(x)
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
log2
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return self*log(_Two)
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
def _eval_rewrite_as_Pow(self, arg, **kwargs):
|
||||
return _exp2(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_Pow
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _exp2(*self.args)
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg.is_number:
|
||||
return _exp2(arg)
|
||||
|
||||
|
||||
def _log2(x):
|
||||
return log(x)/log(_Two)
|
||||
|
||||
|
||||
class log2(Function):
|
||||
"""
|
||||
Represents the logarithm function with base two.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The benefit of using ``log2(x)`` over ``log(x)/log(2)``
|
||||
is that the latter is not as efficient under finite precision
|
||||
arithmetic.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import log2
|
||||
>>> log2(4).evalf() == 2.0
|
||||
True
|
||||
>>> log2(x).diff(x)
|
||||
1/(x*log(2))
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
exp2
|
||||
log10
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return S.One/(log(_Two)*self.args[0])
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg.is_number:
|
||||
result = log.eval(arg, base=_Two)
|
||||
if result.is_Atom:
|
||||
return result
|
||||
elif arg.is_Pow and arg.base == _Two:
|
||||
return arg.exp
|
||||
|
||||
def _eval_evalf(self, *args, **kwargs):
|
||||
return self.rewrite(log).evalf(*args, **kwargs)
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _log2(*self.args)
|
||||
|
||||
def _eval_rewrite_as_log(self, arg, **kwargs):
|
||||
return _log2(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_log
|
||||
|
||||
|
||||
def _fma(x, y, z):
|
||||
return x*y + z
|
||||
|
||||
|
||||
class fma(Function):
|
||||
"""
|
||||
Represents "fused multiply add".
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The benefit of using ``fma(x, y, z)`` over ``x*y + z``
|
||||
is that, under finite precision arithmetic, the former is
|
||||
supported by special instructions on some CPUs.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x, y, z
|
||||
>>> from sympy.codegen.cfunctions import fma
|
||||
>>> fma(x, y, z).diff(x)
|
||||
y
|
||||
|
||||
"""
|
||||
nargs = 3
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex in (1, 2):
|
||||
return self.args[2 - argindex]
|
||||
elif argindex == 3:
|
||||
return S.One
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _fma(*self.args)
|
||||
|
||||
def _eval_rewrite_as_tractable(self, arg, limitvar=None, **kwargs):
|
||||
return _fma(arg)
|
||||
|
||||
|
||||
_Ten = S(10)
|
||||
|
||||
|
||||
def _log10(x):
|
||||
return log(x)/log(_Ten)
|
||||
|
||||
|
||||
class log10(Function):
|
||||
"""
|
||||
Represents the logarithm function with base ten.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import log10
|
||||
>>> log10(100).evalf() == 2.0
|
||||
True
|
||||
>>> log10(x).diff(x)
|
||||
1/(x*log(10))
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
log2
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return S.One/(log(_Ten)*self.args[0])
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg.is_number:
|
||||
result = log.eval(arg, base=_Ten)
|
||||
if result.is_Atom:
|
||||
return result
|
||||
elif arg.is_Pow and arg.base == _Ten:
|
||||
return arg.exp
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _log10(*self.args)
|
||||
|
||||
def _eval_rewrite_as_log(self, arg, **kwargs):
|
||||
return _log10(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_log
|
||||
|
||||
|
||||
def _Sqrt(x):
|
||||
return Pow(x, S.Half)
|
||||
|
||||
|
||||
class Sqrt(Function): # 'sqrt' already defined in sympy.functions.elementary.miscellaneous
|
||||
"""
|
||||
Represents the square root function.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The reason why one would use ``Sqrt(x)`` over ``sqrt(x)``
|
||||
is that the latter is internally represented as ``Pow(x, S.Half)`` which
|
||||
may not be what one wants when doing code-generation.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import Sqrt
|
||||
>>> Sqrt(x)
|
||||
Sqrt(x)
|
||||
>>> Sqrt(x).diff(x)
|
||||
1/(2*sqrt(x))
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
Cbrt
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return Pow(self.args[0], Rational(-1, 2))/_Two
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _Sqrt(*self.args)
|
||||
|
||||
def _eval_rewrite_as_Pow(self, arg, **kwargs):
|
||||
return _Sqrt(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_Pow
|
||||
|
||||
|
||||
def _Cbrt(x):
|
||||
return Pow(x, Rational(1, 3))
|
||||
|
||||
|
||||
class Cbrt(Function): # 'cbrt' already defined in sympy.functions.elementary.miscellaneous
|
||||
"""
|
||||
Represents the cube root function.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The reason why one would use ``Cbrt(x)`` over ``cbrt(x)``
|
||||
is that the latter is internally represented as ``Pow(x, Rational(1, 3))`` which
|
||||
may not be what one wants when doing code-generation.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import Cbrt
|
||||
>>> Cbrt(x)
|
||||
Cbrt(x)
|
||||
>>> Cbrt(x).diff(x)
|
||||
1/(3*x**(2/3))
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
Sqrt
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return Pow(self.args[0], Rational(-_Two/3))/3
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _Cbrt(*self.args)
|
||||
|
||||
def _eval_rewrite_as_Pow(self, arg, **kwargs):
|
||||
return _Cbrt(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_Pow
|
||||
|
||||
|
||||
def _hypot(x, y):
|
||||
return sqrt(Pow(x, 2) + Pow(y, 2))
|
||||
|
||||
|
||||
class hypot(Function):
|
||||
"""
|
||||
Represents the hypotenuse function.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The hypotenuse function is provided by e.g. the math library
|
||||
in the C99 standard, hence one may want to represent the function
|
||||
symbolically when doing code-generation.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x, y
|
||||
>>> from sympy.codegen.cfunctions import hypot
|
||||
>>> hypot(3, 4).evalf() == 5.0
|
||||
True
|
||||
>>> hypot(x, y)
|
||||
hypot(x, y)
|
||||
>>> hypot(x, y).diff(x)
|
||||
x/hypot(x, y)
|
||||
|
||||
"""
|
||||
nargs = 2
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex in (1, 2):
|
||||
return 2*self.args[argindex-1]/(_Two*self.func(*self.args))
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _hypot(*self.args)
|
||||
|
||||
def _eval_rewrite_as_Pow(self, arg, **kwargs):
|
||||
return _hypot(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_Pow
|
||||
|
||||
|
||||
class isnan(BooleanFunction):
|
||||
nargs = 1
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg is S.NaN:
|
||||
return true
|
||||
elif arg.is_number:
|
||||
return false
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class isinf(BooleanFunction):
|
||||
nargs = 1
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg.is_infinite:
|
||||
return true
|
||||
elif arg.is_finite:
|
||||
return false
|
||||
else:
|
||||
return None
|
||||
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
AST nodes specific to the C family of languages
|
||||
"""
|
||||
|
||||
from sympy.codegen.ast import (
|
||||
Attribute, Declaration, Node, String, Token, Type, none,
|
||||
FunctionCall, CodeBlock
|
||||
)
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.sympify import sympify
|
||||
|
||||
void = Type('void')
|
||||
|
||||
restrict = Attribute('restrict') # guarantees no pointer aliasing
|
||||
volatile = Attribute('volatile')
|
||||
static = Attribute('static')
|
||||
|
||||
|
||||
def alignof(arg):
|
||||
""" Generate of FunctionCall instance for calling 'alignof' """
|
||||
return FunctionCall('alignof', [String(arg) if isinstance(arg, str) else arg])
|
||||
|
||||
|
||||
def sizeof(arg):
|
||||
""" Generate of FunctionCall instance for calling 'sizeof'
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.ast import real
|
||||
>>> from sympy.codegen.cnodes import sizeof
|
||||
>>> from sympy import ccode
|
||||
>>> ccode(sizeof(real))
|
||||
'sizeof(double)'
|
||||
"""
|
||||
return FunctionCall('sizeof', [String(arg) if isinstance(arg, str) else arg])
|
||||
|
||||
|
||||
class CommaOperator(Basic):
|
||||
""" Represents the comma operator in C """
|
||||
def __new__(cls, *args):
|
||||
return Basic.__new__(cls, *[sympify(arg) for arg in args])
|
||||
|
||||
|
||||
class Label(Node):
|
||||
""" Label for use with e.g. goto statement.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import ccode, Symbol
|
||||
>>> from sympy.codegen.cnodes import Label, PreIncrement
|
||||
>>> print(ccode(Label('foo')))
|
||||
foo:
|
||||
>>> print(ccode(Label('bar', [PreIncrement(Symbol('a'))])))
|
||||
bar:
|
||||
++(a);
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('name', 'body')
|
||||
defaults = {'body': none}
|
||||
_construct_name = String
|
||||
|
||||
@classmethod
|
||||
def _construct_body(cls, itr):
|
||||
if isinstance(itr, CodeBlock):
|
||||
return itr
|
||||
else:
|
||||
return CodeBlock(*itr)
|
||||
|
||||
|
||||
class goto(Token):
|
||||
""" Represents goto in C """
|
||||
__slots__ = _fields = ('label',)
|
||||
_construct_label = Label
|
||||
|
||||
|
||||
class PreDecrement(Basic):
|
||||
""" Represents the pre-decrement operator
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cnodes import PreDecrement
|
||||
>>> from sympy import ccode
|
||||
>>> ccode(PreDecrement(x))
|
||||
'--(x)'
|
||||
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
class PostDecrement(Basic):
|
||||
""" Represents the post-decrement operator
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cnodes import PostDecrement
|
||||
>>> from sympy import ccode
|
||||
>>> ccode(PostDecrement(x))
|
||||
'(x)--'
|
||||
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
class PreIncrement(Basic):
|
||||
""" Represents the pre-increment operator
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cnodes import PreIncrement
|
||||
>>> from sympy import ccode
|
||||
>>> ccode(PreIncrement(x))
|
||||
'++(x)'
|
||||
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
class PostIncrement(Basic):
|
||||
""" Represents the post-increment operator
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cnodes import PostIncrement
|
||||
>>> from sympy import ccode
|
||||
>>> ccode(PostIncrement(x))
|
||||
'(x)++'
|
||||
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
class struct(Node):
|
||||
""" Represents a struct in C """
|
||||
__slots__ = _fields = ('name', 'declarations')
|
||||
defaults = {'name': none}
|
||||
_construct_name = String
|
||||
|
||||
@classmethod
|
||||
def _construct_declarations(cls, args):
|
||||
return Tuple(*[Declaration(arg) for arg in args])
|
||||
|
||||
|
||||
class union(struct):
|
||||
""" Represents a union in C """
|
||||
__slots__ = ()
|
||||
@@ -0,0 +1,8 @@
|
||||
from sympy.printing.c import C99CodePrinter
|
||||
|
||||
def render_as_source_file(content, Printer=C99CodePrinter, settings=None):
|
||||
""" Renders a C source file (with required #include statements) """
|
||||
printer = Printer(settings or {})
|
||||
code_str = printer.doprint(content)
|
||||
includes = '\n'.join(['#include <%s>' % h for h in printer.headers])
|
||||
return includes + '\n\n' + code_str
|
||||
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
AST nodes specific to C++.
|
||||
"""
|
||||
|
||||
from sympy.codegen.ast import Attribute, String, Token, Type, none
|
||||
|
||||
class using(Token):
|
||||
""" Represents a 'using' statement in C++ """
|
||||
__slots__ = _fields = ('type', 'alias')
|
||||
defaults = {'alias': none}
|
||||
_construct_type = Type
|
||||
_construct_alias = String
|
||||
|
||||
constexpr = Attribute('constexpr')
|
||||
@@ -0,0 +1,658 @@
|
||||
"""
|
||||
AST nodes specific to Fortran.
|
||||
|
||||
The functions defined in this module allows the user to express functions such as ``dsign``
|
||||
as a SymPy function for symbolic manipulation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from sympy.codegen.ast import (
|
||||
Attribute, CodeBlock, FunctionCall, Node, none, String,
|
||||
Token, _mk_Tuple, Variable
|
||||
)
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.expr import Expr
|
||||
from sympy.core.function import Function
|
||||
from sympy.core.numbers import Float, Integer
|
||||
from sympy.core.symbol import Str
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.logic import true, false
|
||||
from sympy.utilities.iterables import iterable
|
||||
|
||||
|
||||
|
||||
pure = Attribute('pure')
|
||||
elemental = Attribute('elemental') # (all elemental procedures are also pure)
|
||||
|
||||
intent_in = Attribute('intent_in')
|
||||
intent_out = Attribute('intent_out')
|
||||
intent_inout = Attribute('intent_inout')
|
||||
|
||||
allocatable = Attribute('allocatable')
|
||||
|
||||
class Program(Token):
|
||||
""" Represents a 'program' block in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.ast import Print
|
||||
>>> from sympy.codegen.fnodes import Program
|
||||
>>> prog = Program('myprogram', [Print([42])])
|
||||
>>> from sympy import fcode
|
||||
>>> print(fcode(prog, source_format='free'))
|
||||
program myprogram
|
||||
print *, 42
|
||||
end program
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('name', 'body')
|
||||
_construct_name = String
|
||||
_construct_body = staticmethod(lambda body: CodeBlock(*body))
|
||||
|
||||
|
||||
class use_rename(Token):
|
||||
""" Represents a renaming in a use statement in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import use_rename, use
|
||||
>>> from sympy import fcode
|
||||
>>> ren = use_rename("thingy", "convolution2d")
|
||||
>>> print(fcode(ren, source_format='free'))
|
||||
thingy => convolution2d
|
||||
>>> full = use('signallib', only=['snr', ren])
|
||||
>>> print(fcode(full, source_format='free'))
|
||||
use signallib, only: snr, thingy => convolution2d
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('local', 'original')
|
||||
_construct_local = String
|
||||
_construct_original = String
|
||||
|
||||
def _name(arg):
|
||||
if hasattr(arg, 'name'):
|
||||
return arg.name
|
||||
else:
|
||||
return String(arg)
|
||||
|
||||
class use(Token):
|
||||
""" Represents a use statement in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import use
|
||||
>>> from sympy import fcode
|
||||
>>> fcode(use('signallib'), source_format='free')
|
||||
'use signallib'
|
||||
>>> fcode(use('signallib', [('metric', 'snr')]), source_format='free')
|
||||
'use signallib, metric => snr'
|
||||
>>> fcode(use('signallib', only=['snr', 'convolution2d']), source_format='free')
|
||||
'use signallib, only: snr, convolution2d'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('namespace', 'rename', 'only')
|
||||
defaults = {'rename': none, 'only': none}
|
||||
_construct_namespace = staticmethod(_name)
|
||||
_construct_rename = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else use_rename(*arg) for arg in args]))
|
||||
_construct_only = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else _name(arg) for arg in args]))
|
||||
|
||||
|
||||
class Module(Token):
|
||||
""" Represents a module in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import Module
|
||||
>>> from sympy import fcode
|
||||
>>> print(fcode(Module('signallib', ['implicit none'], []), source_format='free'))
|
||||
module signallib
|
||||
implicit none
|
||||
<BLANKLINE>
|
||||
contains
|
||||
<BLANKLINE>
|
||||
<BLANKLINE>
|
||||
end module
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('name', 'declarations', 'definitions')
|
||||
defaults = {'declarations': Tuple()}
|
||||
_construct_name = String
|
||||
|
||||
@classmethod
|
||||
def _construct_declarations(cls, args):
|
||||
args = [Str(arg) if isinstance(arg, str) else arg for arg in args]
|
||||
return CodeBlock(*args)
|
||||
|
||||
_construct_definitions = staticmethod(lambda arg: CodeBlock(*arg))
|
||||
|
||||
|
||||
class Subroutine(Node):
|
||||
""" Represents a subroutine in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode, symbols
|
||||
>>> from sympy.codegen.ast import Print
|
||||
>>> from sympy.codegen.fnodes import Subroutine
|
||||
>>> x, y = symbols('x y', real=True)
|
||||
>>> sub = Subroutine('mysub', [x, y], [Print([x**2 + y**2, x*y])])
|
||||
>>> print(fcode(sub, source_format='free', standard=2003))
|
||||
subroutine mysub(x, y)
|
||||
real*8 :: x
|
||||
real*8 :: y
|
||||
print *, x**2 + y**2, x*y
|
||||
end subroutine
|
||||
|
||||
"""
|
||||
__slots__ = ('name', 'parameters', 'body')
|
||||
_fields = __slots__ + Node._fields
|
||||
_construct_name = String
|
||||
_construct_parameters = staticmethod(lambda params: Tuple(*map(Variable.deduced, params)))
|
||||
|
||||
@classmethod
|
||||
def _construct_body(cls, itr):
|
||||
if isinstance(itr, CodeBlock):
|
||||
return itr
|
||||
else:
|
||||
return CodeBlock(*itr)
|
||||
|
||||
class SubroutineCall(Token):
|
||||
""" Represents a call to a subroutine in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import SubroutineCall
|
||||
>>> from sympy import fcode
|
||||
>>> fcode(SubroutineCall('mysub', 'x y'.split()))
|
||||
' call mysub(x, y)'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('name', 'subroutine_args')
|
||||
_construct_name = staticmethod(_name)
|
||||
_construct_subroutine_args = staticmethod(_mk_Tuple)
|
||||
|
||||
|
||||
class Do(Token):
|
||||
""" Represents a Do loop in in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode, symbols
|
||||
>>> from sympy.codegen.ast import aug_assign, Print
|
||||
>>> from sympy.codegen.fnodes import Do
|
||||
>>> i, n = symbols('i n', integer=True)
|
||||
>>> r = symbols('r', real=True)
|
||||
>>> body = [aug_assign(r, '+', 1/i), Print([i, r])]
|
||||
>>> do1 = Do(body, i, 1, n)
|
||||
>>> print(fcode(do1, source_format='free'))
|
||||
do i = 1, n
|
||||
r = r + 1d0/i
|
||||
print *, i, r
|
||||
end do
|
||||
>>> do2 = Do(body, i, 1, n, 2)
|
||||
>>> print(fcode(do2, source_format='free'))
|
||||
do i = 1, n, 2
|
||||
r = r + 1d0/i
|
||||
print *, i, r
|
||||
end do
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = _fields = ('body', 'counter', 'first', 'last', 'step', 'concurrent')
|
||||
defaults = {'step': Integer(1), 'concurrent': false}
|
||||
_construct_body = staticmethod(lambda body: CodeBlock(*body))
|
||||
_construct_counter = staticmethod(sympify)
|
||||
_construct_first = staticmethod(sympify)
|
||||
_construct_last = staticmethod(sympify)
|
||||
_construct_step = staticmethod(sympify)
|
||||
_construct_concurrent = staticmethod(lambda arg: true if arg else false)
|
||||
|
||||
|
||||
class ArrayConstructor(Token):
|
||||
""" Represents an array constructor.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.fnodes import ArrayConstructor
|
||||
>>> ac = ArrayConstructor([1, 2, 3])
|
||||
>>> fcode(ac, standard=95, source_format='free')
|
||||
'(/1, 2, 3/)'
|
||||
>>> fcode(ac, standard=2003, source_format='free')
|
||||
'[1, 2, 3]'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('elements',)
|
||||
_construct_elements = staticmethod(_mk_Tuple)
|
||||
|
||||
|
||||
class ImpliedDoLoop(Token):
|
||||
""" Represents an implied do loop in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Symbol, fcode
|
||||
>>> from sympy.codegen.fnodes import ImpliedDoLoop, ArrayConstructor
|
||||
>>> i = Symbol('i', integer=True)
|
||||
>>> idl = ImpliedDoLoop(i**3, i, -3, 3, 2) # -27, -1, 1, 27
|
||||
>>> ac = ArrayConstructor([-28, idl, 28]) # -28, -27, -1, 1, 27, 28
|
||||
>>> fcode(ac, standard=2003, source_format='free')
|
||||
'[-28, (i**3, i = -3, 3, 2), 28]'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('expr', 'counter', 'first', 'last', 'step')
|
||||
defaults = {'step': Integer(1)}
|
||||
_construct_expr = staticmethod(sympify)
|
||||
_construct_counter = staticmethod(sympify)
|
||||
_construct_first = staticmethod(sympify)
|
||||
_construct_last = staticmethod(sympify)
|
||||
_construct_step = staticmethod(sympify)
|
||||
|
||||
|
||||
class Extent(Basic):
|
||||
""" Represents a dimension extent.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import Extent
|
||||
>>> e = Extent(-3, 3) # -3, -2, -1, 0, 1, 2, 3
|
||||
>>> from sympy import fcode
|
||||
>>> fcode(e, source_format='free')
|
||||
'-3:3'
|
||||
>>> from sympy.codegen.ast import Variable, real
|
||||
>>> from sympy.codegen.fnodes import dimension, intent_out
|
||||
>>> dim = dimension(e, e)
|
||||
>>> arr = Variable('x', real, attrs=[dim, intent_out])
|
||||
>>> fcode(arr.as_Declaration(), source_format='free', standard=2003)
|
||||
'real*8, dimension(-3:3, -3:3), intent(out) :: x'
|
||||
|
||||
"""
|
||||
def __new__(cls, *args):
|
||||
if len(args) == 2:
|
||||
low, high = args
|
||||
return Basic.__new__(cls, sympify(low), sympify(high))
|
||||
elif len(args) == 0 or (len(args) == 1 and args[0] in (':', None)):
|
||||
return Basic.__new__(cls) # assumed shape
|
||||
else:
|
||||
raise ValueError("Expected 0 or 2 args (or one argument == None or ':')")
|
||||
|
||||
def _sympystr(self, printer):
|
||||
if len(self.args) == 0:
|
||||
return ':'
|
||||
return ":".join(str(arg) for arg in self.args)
|
||||
|
||||
assumed_extent = Extent() # or Extent(':'), Extent(None)
|
||||
|
||||
|
||||
def dimension(*args):
|
||||
""" Creates a 'dimension' Attribute with (up to 7) extents.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.fnodes import dimension, intent_in
|
||||
>>> dim = dimension('2', ':') # 2 rows, runtime determined number of columns
|
||||
>>> from sympy.codegen.ast import Variable, integer
|
||||
>>> arr = Variable('a', integer, attrs=[dim, intent_in])
|
||||
>>> fcode(arr.as_Declaration(), source_format='free', standard=2003)
|
||||
'integer*4, dimension(2, :), intent(in) :: a'
|
||||
|
||||
"""
|
||||
if len(args) > 7:
|
||||
raise ValueError("Fortran only supports up to 7 dimensional arrays")
|
||||
parameters = []
|
||||
for arg in args:
|
||||
if isinstance(arg, Extent):
|
||||
parameters.append(arg)
|
||||
elif isinstance(arg, str):
|
||||
if arg == ':':
|
||||
parameters.append(Extent())
|
||||
else:
|
||||
parameters.append(String(arg))
|
||||
elif iterable(arg):
|
||||
parameters.append(Extent(*arg))
|
||||
else:
|
||||
parameters.append(sympify(arg))
|
||||
if len(args) == 0:
|
||||
raise ValueError("Need at least one dimension")
|
||||
return Attribute('dimension', parameters)
|
||||
|
||||
|
||||
assumed_size = dimension('*')
|
||||
|
||||
def array(symbol, dim, intent=None, *, attrs=(), value=None, type=None):
|
||||
""" Convenience function for creating a Variable instance for a Fortran array.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
symbol : symbol
|
||||
dim : Attribute or iterable
|
||||
If dim is an ``Attribute`` it need to have the name 'dimension'. If it is
|
||||
not an ``Attribute``, then it is passed to :func:`dimension` as ``*dim``
|
||||
intent : str
|
||||
One of: 'in', 'out', 'inout' or None
|
||||
\\*\\*kwargs:
|
||||
Keyword arguments for ``Variable`` ('type' & 'value')
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.ast import integer, real
|
||||
>>> from sympy.codegen.fnodes import array
|
||||
>>> arr = array('a', '*', 'in', type=integer)
|
||||
>>> print(fcode(arr.as_Declaration(), source_format='free', standard=2003))
|
||||
integer*4, dimension(*), intent(in) :: a
|
||||
>>> x = array('x', [3, ':', ':'], intent='out', type=real)
|
||||
>>> print(fcode(x.as_Declaration(value=1), source_format='free', standard=2003))
|
||||
real*8, dimension(3, :, :), intent(out) :: x = 1
|
||||
|
||||
"""
|
||||
if isinstance(dim, Attribute):
|
||||
if str(dim.name) != 'dimension':
|
||||
raise ValueError("Got an unexpected Attribute argument as dim: %s" % str(dim))
|
||||
else:
|
||||
dim = dimension(*dim)
|
||||
|
||||
attrs = list(attrs) + [dim]
|
||||
if intent is not None:
|
||||
if intent not in (intent_in, intent_out, intent_inout):
|
||||
intent = {'in': intent_in, 'out': intent_out, 'inout': intent_inout}[intent]
|
||||
attrs.append(intent)
|
||||
if type is None:
|
||||
return Variable.deduced(symbol, value=value, attrs=attrs)
|
||||
else:
|
||||
return Variable(symbol, type, value=value, attrs=attrs)
|
||||
|
||||
def _printable(arg):
|
||||
return String(arg) if isinstance(arg, str) else sympify(arg)
|
||||
|
||||
|
||||
def allocated(array):
|
||||
""" Creates an AST node for a function call to Fortran's "allocated(...)"
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.fnodes import allocated
|
||||
>>> alloc = allocated('x')
|
||||
>>> fcode(alloc, source_format='free')
|
||||
'allocated(x)'
|
||||
|
||||
"""
|
||||
return FunctionCall('allocated', [_printable(array)])
|
||||
|
||||
|
||||
def lbound(array, dim=None, kind=None):
|
||||
""" Creates an AST node for a function call to Fortran's "lbound(...)"
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
array : Symbol or String
|
||||
dim : expr
|
||||
kind : expr
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.fnodes import lbound
|
||||
>>> lb = lbound('arr', dim=2)
|
||||
>>> fcode(lb, source_format='free')
|
||||
'lbound(arr, 2)'
|
||||
|
||||
"""
|
||||
return FunctionCall(
|
||||
'lbound',
|
||||
[_printable(array)] +
|
||||
([_printable(dim)] if dim else []) +
|
||||
([_printable(kind)] if kind else [])
|
||||
)
|
||||
|
||||
|
||||
def ubound(array, dim=None, kind=None):
|
||||
return FunctionCall(
|
||||
'ubound',
|
||||
[_printable(array)] +
|
||||
([_printable(dim)] if dim else []) +
|
||||
([_printable(kind)] if kind else [])
|
||||
)
|
||||
|
||||
|
||||
def shape(source, kind=None):
|
||||
""" Creates an AST node for a function call to Fortran's "shape(...)"
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
source : Symbol or String
|
||||
kind : expr
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.fnodes import shape
|
||||
>>> shp = shape('x')
|
||||
>>> fcode(shp, source_format='free')
|
||||
'shape(x)'
|
||||
|
||||
"""
|
||||
return FunctionCall(
|
||||
'shape',
|
||||
[_printable(source)] +
|
||||
([_printable(kind)] if kind else [])
|
||||
)
|
||||
|
||||
|
||||
def size(array, dim=None, kind=None):
|
||||
""" Creates an AST node for a function call to Fortran's "size(...)"
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode, Symbol
|
||||
>>> from sympy.codegen.ast import FunctionDefinition, real, Return
|
||||
>>> from sympy.codegen.fnodes import array, sum_, size
|
||||
>>> a = Symbol('a', real=True)
|
||||
>>> body = [Return((sum_(a**2)/size(a))**.5)]
|
||||
>>> arr = array(a, dim=[':'], intent='in')
|
||||
>>> fd = FunctionDefinition(real, 'rms', [arr], body)
|
||||
>>> print(fcode(fd, source_format='free', standard=2003))
|
||||
real*8 function rms(a)
|
||||
real*8, dimension(:), intent(in) :: a
|
||||
rms = sqrt(sum(a**2)*1d0/size(a))
|
||||
end function
|
||||
|
||||
"""
|
||||
return FunctionCall(
|
||||
'size',
|
||||
[_printable(array)] +
|
||||
([_printable(dim)] if dim else []) +
|
||||
([_printable(kind)] if kind else [])
|
||||
)
|
||||
|
||||
|
||||
def reshape(source, shape, pad=None, order=None):
|
||||
""" Creates an AST node for a function call to Fortran's "reshape(...)"
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
source : Symbol or String
|
||||
shape : ArrayExpr
|
||||
|
||||
"""
|
||||
return FunctionCall(
|
||||
'reshape',
|
||||
[_printable(source), _printable(shape)] +
|
||||
([_printable(pad)] if pad else []) +
|
||||
([_printable(order)] if pad else [])
|
||||
)
|
||||
|
||||
|
||||
def bind_C(name=None):
|
||||
""" Creates an Attribute ``bind_C`` with a name.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
name : str
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode, Symbol
|
||||
>>> from sympy.codegen.ast import FunctionDefinition, real, Return
|
||||
>>> from sympy.codegen.fnodes import array, sum_, bind_C
|
||||
>>> a = Symbol('a', real=True)
|
||||
>>> s = Symbol('s', integer=True)
|
||||
>>> arr = array(a, dim=[s], intent='in')
|
||||
>>> body = [Return((sum_(a**2)/s)**.5)]
|
||||
>>> fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')])
|
||||
>>> print(fcode(fd, source_format='free', standard=2003))
|
||||
real*8 function rms(a, s) bind(C, name="rms")
|
||||
real*8, dimension(s), intent(in) :: a
|
||||
integer*4 :: s
|
||||
rms = sqrt(sum(a**2)/s)
|
||||
end function
|
||||
|
||||
"""
|
||||
return Attribute('bind_C', [String(name)] if name else [])
|
||||
|
||||
class GoTo(Token):
|
||||
""" Represents a goto statement in Fortran
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import GoTo
|
||||
>>> go = GoTo([10, 20, 30], 'i')
|
||||
>>> from sympy import fcode
|
||||
>>> fcode(go, source_format='free')
|
||||
'go to (10, 20, 30), i'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('labels', 'expr')
|
||||
defaults = {'expr': none}
|
||||
_construct_labels = staticmethod(_mk_Tuple)
|
||||
_construct_expr = staticmethod(sympify)
|
||||
|
||||
|
||||
class FortranReturn(Token):
|
||||
""" AST node explicitly mapped to a fortran "return".
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Because a return statement in fortran is different from C, and
|
||||
in order to aid reuse of our codegen ASTs the ordinary
|
||||
``.codegen.ast.Return`` is interpreted as assignment to
|
||||
the result variable of the function. If one for some reason needs
|
||||
to generate a fortran RETURN statement, this node should be used.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import FortranReturn
|
||||
>>> from sympy import fcode
|
||||
>>> fcode(FortranReturn('x'))
|
||||
' return x'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('return_value',)
|
||||
defaults = {'return_value': none}
|
||||
_construct_return_value = staticmethod(sympify)
|
||||
|
||||
|
||||
class FFunction(Function):
|
||||
_required_standard = 77
|
||||
|
||||
def _fcode(self, printer):
|
||||
name = self.__class__.__name__
|
||||
if printer._settings['standard'] < self._required_standard:
|
||||
raise NotImplementedError("%s requires Fortran %d or newer" %
|
||||
(name, self._required_standard))
|
||||
return '{}({})'.format(name, ', '.join(map(printer._print, self.args)))
|
||||
|
||||
|
||||
class F95Function(FFunction):
|
||||
_required_standard = 95
|
||||
|
||||
|
||||
class isign(FFunction):
|
||||
""" Fortran sign intrinsic for integer arguments. """
|
||||
nargs = 2
|
||||
|
||||
|
||||
class dsign(FFunction):
|
||||
""" Fortran sign intrinsic for double precision arguments. """
|
||||
nargs = 2
|
||||
|
||||
|
||||
class cmplx(FFunction):
|
||||
""" Fortran complex conversion function. """
|
||||
nargs = 2 # may be extended to (2, 3) at a later point
|
||||
|
||||
|
||||
class kind(FFunction):
|
||||
""" Fortran kind function. """
|
||||
nargs = 1
|
||||
|
||||
|
||||
class merge(F95Function):
|
||||
""" Fortran merge function """
|
||||
nargs = 3
|
||||
|
||||
|
||||
class _literal(Float):
|
||||
_token: str
|
||||
_decimals: int
|
||||
|
||||
def _fcode(self, printer, *args, **kwargs):
|
||||
mantissa, sgnd_ex = ('%.{}e'.format(self._decimals) % self).split('e')
|
||||
mantissa = mantissa.strip('0').rstrip('.')
|
||||
ex_sgn, ex_num = sgnd_ex[0], sgnd_ex[1:].lstrip('0')
|
||||
ex_sgn = '' if ex_sgn == '+' else ex_sgn
|
||||
return (mantissa or '0') + self._token + ex_sgn + (ex_num or '0')
|
||||
|
||||
|
||||
class literal_sp(_literal):
|
||||
""" Fortran single precision real literal """
|
||||
_token = 'e'
|
||||
_decimals = 9
|
||||
|
||||
|
||||
class literal_dp(_literal):
|
||||
""" Fortran double precision real literal """
|
||||
_token = 'd'
|
||||
_decimals = 17
|
||||
|
||||
|
||||
class sum_(Token, Expr):
|
||||
__slots__ = _fields = ('array', 'dim', 'mask')
|
||||
defaults = {'dim': none, 'mask': none}
|
||||
_construct_array = staticmethod(sympify)
|
||||
_construct_dim = staticmethod(sympify)
|
||||
|
||||
|
||||
class product_(Token, Expr):
|
||||
__slots__ = _fields = ('array', 'dim', 'mask')
|
||||
defaults = {'dim': none, 'mask': none}
|
||||
_construct_array = staticmethod(sympify)
|
||||
_construct_dim = staticmethod(sympify)
|
||||
@@ -0,0 +1,40 @@
|
||||
from itertools import chain
|
||||
from sympy.codegen.fnodes import Module
|
||||
from sympy.core.symbol import Dummy
|
||||
from sympy.printing.fortran import FCodePrinter
|
||||
|
||||
""" This module collects utilities for rendering Fortran code. """
|
||||
|
||||
|
||||
def render_as_module(definitions, name, declarations=(), printer_settings=None):
|
||||
""" Creates a ``Module`` instance and renders it as a string.
|
||||
|
||||
This generates Fortran source code for a module with the correct ``use`` statements.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
definitions : iterable
|
||||
Passed to :class:`sympy.codegen.fnodes.Module`.
|
||||
name : str
|
||||
Passed to :class:`sympy.codegen.fnodes.Module`.
|
||||
declarations : iterable
|
||||
Passed to :class:`sympy.codegen.fnodes.Module`. It will be extended with
|
||||
use statements, 'implicit none' and public list generated from ``definitions``.
|
||||
printer_settings : dict
|
||||
Passed to ``FCodePrinter`` (default: ``{'standard': 2003, 'source_format': 'free'}``).
|
||||
|
||||
"""
|
||||
printer_settings = printer_settings or {'standard': 2003, 'source_format': 'free'}
|
||||
printer = FCodePrinter(printer_settings)
|
||||
dummy = Dummy()
|
||||
if isinstance(definitions, Module):
|
||||
raise ValueError("This function expects to construct a module on its own.")
|
||||
mod = Module(name, chain(declarations, [dummy]), definitions)
|
||||
fstr = printer.doprint(mod)
|
||||
module_use_str = ' %s\n' % ' \n'.join(['use %s, only: %s' % (k, ', '.join(v)) for
|
||||
k, v in printer.module_uses.items()])
|
||||
module_use_str += ' implicit none\n'
|
||||
module_use_str += ' private\n'
|
||||
module_use_str += ' public %s\n' % ', '.join([str(node.name) for node in definitions if getattr(node, 'name', None)])
|
||||
return fstr.replace(printer.doprint(dummy), module_use_str)
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Additional AST nodes for operations on matrices. The nodes in this module
|
||||
are meant to represent optimization of matrix expressions within codegen's
|
||||
target languages that cannot be represented by SymPy expressions.
|
||||
|
||||
As an example, we can use :meth:`sympy.codegen.rewriting.optimize` and the
|
||||
``matin_opt`` optimization provided in :mod:`sympy.codegen.rewriting` to
|
||||
transform matrix multiplication under certain assumptions:
|
||||
|
||||
>>> from sympy import symbols, MatrixSymbol
|
||||
>>> n = symbols('n', integer=True)
|
||||
>>> A = MatrixSymbol('A', n, n)
|
||||
>>> x = MatrixSymbol('x', n, 1)
|
||||
>>> expr = A**(-1) * x
|
||||
>>> from sympy import assuming, Q
|
||||
>>> from sympy.codegen.rewriting import matinv_opt, optimize
|
||||
>>> with assuming(Q.fullrank(A)):
|
||||
... optimize(expr, [matinv_opt])
|
||||
MatrixSolve(A, vector=x)
|
||||
"""
|
||||
|
||||
from .ast import Token
|
||||
from sympy.matrices import MatrixExpr
|
||||
from sympy.core.sympify import sympify
|
||||
|
||||
|
||||
class MatrixSolve(Token, MatrixExpr):
|
||||
"""Represents an operation to solve a linear matrix equation.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
matrix : MatrixSymbol
|
||||
|
||||
Matrix representing the coefficients of variables in the linear
|
||||
equation. This matrix must be square and full-rank (i.e. all columns must
|
||||
be linearly independent) for the solving operation to be valid.
|
||||
|
||||
vector : MatrixSymbol
|
||||
|
||||
One-column matrix representing the solutions to the equations
|
||||
represented in ``matrix``.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import symbols, MatrixSymbol
|
||||
>>> from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
>>> n = symbols('n', integer=True)
|
||||
>>> A = MatrixSymbol('A', n, n)
|
||||
>>> x = MatrixSymbol('x', n, 1)
|
||||
>>> from sympy.printing.numpy import NumPyPrinter
|
||||
>>> NumPyPrinter().doprint(MatrixSolve(A, x))
|
||||
'numpy.linalg.solve(A, x)'
|
||||
>>> from sympy import octave_code
|
||||
>>> octave_code(MatrixSolve(A, x))
|
||||
'A \\\\ x'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('matrix', 'vector')
|
||||
|
||||
_construct_matrix = staticmethod(sympify)
|
||||
_construct_vector = staticmethod(sympify)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.vector.shape
|
||||
|
||||
def _eval_derivative(self, x):
|
||||
A, b = self.matrix, self.vector
|
||||
return MatrixSolve(A, b.diff(x) - A.diff(x) * MatrixSolve(A, b))
|
||||
@@ -0,0 +1,177 @@
|
||||
from sympy.core.function import Add, ArgumentIndexError, Function
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.sorting import default_sort_key
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.functions.elementary.exponential import exp, log
|
||||
from sympy.functions.elementary.miscellaneous import Max, Min
|
||||
from .ast import Token, none
|
||||
|
||||
|
||||
def _logaddexp(x1, x2, *, evaluate=True):
|
||||
return log(Add(exp(x1, evaluate=evaluate), exp(x2, evaluate=evaluate), evaluate=evaluate))
|
||||
|
||||
|
||||
_two = S.One*2
|
||||
_ln2 = log(_two)
|
||||
|
||||
|
||||
def _lb(x, *, evaluate=True):
|
||||
return log(x, evaluate=evaluate)/_ln2
|
||||
|
||||
|
||||
def _exp2(x, *, evaluate=True):
|
||||
return Pow(_two, x, evaluate=evaluate)
|
||||
|
||||
|
||||
def _logaddexp2(x1, x2, *, evaluate=True):
|
||||
return _lb(Add(_exp2(x1, evaluate=evaluate),
|
||||
_exp2(x2, evaluate=evaluate), evaluate=evaluate))
|
||||
|
||||
|
||||
class logaddexp(Function):
|
||||
""" Logarithm of the sum of exponentiations of the inputs.
|
||||
|
||||
Helper class for use with e.g. numpy.logaddexp
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html
|
||||
"""
|
||||
nargs = 2
|
||||
|
||||
def __new__(cls, *args):
|
||||
return Function.__new__(cls, *sorted(args, key=default_sort_key))
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
wrt, other = self.args
|
||||
elif argindex == 2:
|
||||
other, wrt = self.args
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
return S.One/(S.One + exp(other-wrt))
|
||||
|
||||
def _eval_rewrite_as_log(self, x1, x2, **kwargs):
|
||||
return _logaddexp(x1, x2)
|
||||
|
||||
def _eval_evalf(self, *args, **kwargs):
|
||||
return self.rewrite(log).evalf(*args, **kwargs)
|
||||
|
||||
def _eval_simplify(self, *args, **kwargs):
|
||||
a, b = (x.simplify(**kwargs) for x in self.args)
|
||||
candidate = _logaddexp(a, b)
|
||||
if candidate != _logaddexp(a, b, evaluate=False):
|
||||
return candidate
|
||||
else:
|
||||
return logaddexp(a, b)
|
||||
|
||||
|
||||
class logaddexp2(Function):
|
||||
""" Logarithm of the sum of exponentiations of the inputs in base-2.
|
||||
|
||||
Helper class for use with e.g. numpy.logaddexp2
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.logaddexp2.html
|
||||
"""
|
||||
nargs = 2
|
||||
|
||||
def __new__(cls, *args):
|
||||
return Function.__new__(cls, *sorted(args, key=default_sort_key))
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
wrt, other = self.args
|
||||
elif argindex == 2:
|
||||
other, wrt = self.args
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
return S.One/(S.One + _exp2(other-wrt))
|
||||
|
||||
def _eval_rewrite_as_log(self, x1, x2, **kwargs):
|
||||
return _logaddexp2(x1, x2)
|
||||
|
||||
def _eval_evalf(self, *args, **kwargs):
|
||||
return self.rewrite(log).evalf(*args, **kwargs)
|
||||
|
||||
def _eval_simplify(self, *args, **kwargs):
|
||||
a, b = (x.simplify(**kwargs).factor() for x in self.args)
|
||||
candidate = _logaddexp2(a, b)
|
||||
if candidate != _logaddexp2(a, b, evaluate=False):
|
||||
return candidate
|
||||
else:
|
||||
return logaddexp2(a, b)
|
||||
|
||||
|
||||
class amin(Token):
|
||||
""" Minimum value along an axis.
|
||||
|
||||
Helper class for use with e.g. numpy.amin
|
||||
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.amin.html
|
||||
"""
|
||||
__slots__ = _fields = ('array', 'axis')
|
||||
defaults = {'axis': none}
|
||||
_construct_axis = staticmethod(sympify)
|
||||
|
||||
|
||||
class amax(Token):
|
||||
""" Maximum value along an axis.
|
||||
|
||||
Helper class for use with e.g. numpy.amax
|
||||
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.amax.html
|
||||
"""
|
||||
__slots__ = _fields = ('array', 'axis')
|
||||
defaults = {'axis': none}
|
||||
_construct_axis = staticmethod(sympify)
|
||||
|
||||
|
||||
class maximum(Function):
|
||||
""" Element-wise maximum of array elements.
|
||||
|
||||
Helper class for use with e.g. numpy.maximum
|
||||
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.maximum.html
|
||||
"""
|
||||
|
||||
def _eval_rewrite_as_Max(self, *args):
|
||||
return Max(*self.args)
|
||||
|
||||
|
||||
class minimum(Function):
|
||||
""" Element-wise minimum of array elements.
|
||||
|
||||
Helper class for use with e.g. numpy.minimum
|
||||
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.minimum.html
|
||||
"""
|
||||
|
||||
def _eval_rewrite_as_Min(self, *args):
|
||||
return Min(*self.args)
|
||||
@@ -0,0 +1,11 @@
|
||||
from .abstract_nodes import List as AbstractList
|
||||
from .ast import Token
|
||||
|
||||
|
||||
class List(AbstractList):
|
||||
pass
|
||||
|
||||
|
||||
class NumExprEvaluate(Token):
|
||||
"""represents a call to :class:`numexpr`s :func:`evaluate`"""
|
||||
__slots__ = _fields = ('expr',)
|
||||
@@ -0,0 +1,24 @@
|
||||
from sympy.printing.pycode import PythonCodePrinter
|
||||
|
||||
""" This module collects utilities for rendering Python code. """
|
||||
|
||||
|
||||
def render_as_module(content, standard='python3'):
|
||||
"""Renders Python code as a module (with the required imports).
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
standard :
|
||||
See the parameter ``standard`` in
|
||||
:meth:`sympy.printing.pycode.pycode`
|
||||
"""
|
||||
|
||||
printer = PythonCodePrinter({'standard':standard})
|
||||
pystr = printer.doprint(content)
|
||||
if printer._settings['fully_qualified_modules']:
|
||||
module_imports_str = '\n'.join('import %s' % k for k in printer.module_imports)
|
||||
else:
|
||||
module_imports_str = '\n'.join(['from %s import %s' % (k, ', '.join(v)) for
|
||||
k, v in printer.module_imports.items()])
|
||||
return module_imports_str + '\n\n' + pystr
|
||||
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
Classes and functions useful for rewriting expressions for optimized code
|
||||
generation. Some languages (or standards thereof), e.g. C99, offer specialized
|
||||
math functions for better performance and/or precision.
|
||||
|
||||
Using the ``optimize`` function in this module, together with a collection of
|
||||
rules (represented as instances of ``Optimization``), one can rewrite the
|
||||
expressions for this purpose::
|
||||
|
||||
>>> from sympy import Symbol, exp, log
|
||||
>>> from sympy.codegen.rewriting import optimize, optims_c99
|
||||
>>> x = Symbol('x')
|
||||
>>> optimize(3*exp(2*x) - 3, optims_c99)
|
||||
3*expm1(2*x)
|
||||
>>> optimize(exp(2*x) - 1 - exp(-33), optims_c99)
|
||||
expm1(2*x) - exp(-33)
|
||||
>>> optimize(log(3*x + 3), optims_c99)
|
||||
log1p(x) + log(3)
|
||||
>>> optimize(log(2*x + 3), optims_c99)
|
||||
log(2*x + 3)
|
||||
|
||||
The ``optims_c99`` imported above is tuple containing the following instances
|
||||
(which may be imported from ``sympy.codegen.rewriting``):
|
||||
|
||||
- ``expm1_opt``
|
||||
- ``log1p_opt``
|
||||
- ``exp2_opt``
|
||||
- ``log2_opt``
|
||||
- ``log2const_opt``
|
||||
|
||||
|
||||
"""
|
||||
from sympy.core.function import expand_log
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import Wild
|
||||
from sympy.functions.elementary.complexes import sign
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import (Max, Min)
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin, sinc)
|
||||
from sympy.assumptions import Q, ask
|
||||
from sympy.codegen.cfunctions import log1p, log2, exp2, expm1
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.core.expr import UnevaluatedExpr
|
||||
from sympy.core.power import Pow
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
|
||||
from sympy.codegen.scipy_nodes import cosm1, powm1
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.utilities.iterables import sift
|
||||
|
||||
|
||||
class Optimization:
|
||||
""" Abstract base class for rewriting optimization.
|
||||
|
||||
Subclasses should implement ``__call__`` taking an expression
|
||||
as argument.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
cost_function : callable returning number
|
||||
priority : number
|
||||
|
||||
"""
|
||||
def __init__(self, cost_function=None, priority=1):
|
||||
self.cost_function = cost_function
|
||||
self.priority=priority
|
||||
|
||||
def cheapest(self, *args):
|
||||
return min(args, key=self.cost_function)
|
||||
|
||||
|
||||
class ReplaceOptim(Optimization):
|
||||
""" Rewriting optimization calling replace on expressions.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The instance can be used as a function on expressions for which
|
||||
it will apply the ``replace`` method (see
|
||||
:meth:`sympy.core.basic.Basic.replace`).
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
query :
|
||||
First argument passed to replace.
|
||||
value :
|
||||
Second argument passed to replace.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Symbol
|
||||
>>> from sympy.codegen.rewriting import ReplaceOptim
|
||||
>>> from sympy.codegen.cfunctions import exp2
|
||||
>>> x = Symbol('x')
|
||||
>>> exp2_opt = ReplaceOptim(lambda p: p.is_Pow and p.base == 2,
|
||||
... lambda p: exp2(p.exp))
|
||||
>>> exp2_opt(2**x)
|
||||
exp2(x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, query, value, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.query = query
|
||||
self.value = value
|
||||
|
||||
def __call__(self, expr):
|
||||
return expr.replace(self.query, self.value)
|
||||
|
||||
|
||||
def optimize(expr, optimizations):
|
||||
""" Apply optimizations to an expression.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : expression
|
||||
optimizations : iterable of ``Optimization`` instances
|
||||
The optimizations will be sorted with respect to ``priority`` (highest first).
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import log, Symbol
|
||||
>>> from sympy.codegen.rewriting import optims_c99, optimize
|
||||
>>> x = Symbol('x')
|
||||
>>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99)
|
||||
log1p(x**2) + log2(x + 3)
|
||||
|
||||
"""
|
||||
|
||||
for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True):
|
||||
new_expr = optim(expr)
|
||||
if optim.cost_function is None:
|
||||
expr = new_expr
|
||||
else:
|
||||
expr = optim.cheapest(expr, new_expr)
|
||||
return expr
|
||||
|
||||
|
||||
exp2_opt = ReplaceOptim(
|
||||
lambda p: p.is_Pow and p.base == 2,
|
||||
lambda p: exp2(p.exp)
|
||||
)
|
||||
|
||||
|
||||
_d = Wild('d', properties=[lambda x: x.is_Dummy])
|
||||
_u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add])
|
||||
_v = Wild('v')
|
||||
_w = Wild('w')
|
||||
_n = Wild('n', properties=[lambda x: x.is_number])
|
||||
|
||||
sinc_opt1 = ReplaceOptim(
|
||||
sin(_w)/_w, sinc(_w)
|
||||
)
|
||||
sinc_opt2 = ReplaceOptim(
|
||||
sin(_n*_w)/_w, _n*sinc(_n*_w)
|
||||
)
|
||||
sinc_opts = (sinc_opt1, sinc_opt2)
|
||||
|
||||
log2_opt = ReplaceOptim(_v*log(_w)/log(2), _v*log2(_w), cost_function=lambda expr: expr.count(
|
||||
lambda e: ( # division & eval of transcendentals are expensive floating point operations...
|
||||
e.is_Pow and e.exp.is_negative # division
|
||||
or (isinstance(e, (log, log2)) and not e.args[0].is_number)) # transcendental
|
||||
)
|
||||
)
|
||||
|
||||
log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w))
|
||||
|
||||
logsumexp_2terms_opt = ReplaceOptim(
|
||||
lambda l: (isinstance(l, log)
|
||||
and l.args[0].is_Add
|
||||
and len(l.args[0].args) == 2
|
||||
and all(isinstance(t, exp) for t in l.args[0].args)),
|
||||
lambda l: (
|
||||
Max(*[e.args[0] for e in l.args[0].args]) +
|
||||
log1p(exp(Min(*[e.args[0] for e in l.args[0].args])))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class FuncMinusOneOptim(ReplaceOptim):
|
||||
"""Specialization of ReplaceOptim for functions evaluating "f(x) - 1".
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Numerical functions which go toward one as x go toward zero is often best
|
||||
implemented by a dedicated function in order to avoid catastrophic
|
||||
cancellation. One such example is ``expm1(x)`` in the C standard library
|
||||
which evaluates ``exp(x) - 1``. Such functions preserves many more
|
||||
significant digits when its argument is much smaller than one, compared
|
||||
to subtracting one afterwards.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
func :
|
||||
The function which is subtracted by one.
|
||||
func_m_1 :
|
||||
The specialized function evaluating ``func(x) - 1``.
|
||||
opportunistic : bool
|
||||
When ``True``, apply the transformation as long as the magnitude of the
|
||||
remaining number terms decreases. When ``False``, only apply the
|
||||
transformation if it completely eliminates the number term.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import symbols, exp
|
||||
>>> from sympy.codegen.rewriting import FuncMinusOneOptim
|
||||
>>> from sympy.codegen.cfunctions import expm1
|
||||
>>> x, y = symbols('x y')
|
||||
>>> expm1_opt = FuncMinusOneOptim(exp, expm1)
|
||||
>>> expm1_opt(exp(x) + 2*exp(5*y) - 3)
|
||||
expm1(x) + 2*expm1(5*y)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, func, func_m_1, opportunistic=True):
|
||||
weight = 10 # <-- this is an arbitrary number (heuristic)
|
||||
super().__init__(lambda e: e.is_Add, self.replace_in_Add,
|
||||
cost_function=lambda expr: expr.count_ops() - weight*expr.count(func_m_1))
|
||||
self.func = func
|
||||
self.func_m_1 = func_m_1
|
||||
self.opportunistic = opportunistic
|
||||
|
||||
def _group_Add_terms(self, add):
|
||||
numbers, non_num = sift(add.args, lambda arg: arg.is_number, binary=True)
|
||||
numsum = sum(numbers)
|
||||
terms_with_func, other = sift(non_num, lambda arg: arg.has(self.func), binary=True)
|
||||
return numsum, terms_with_func, other
|
||||
|
||||
def replace_in_Add(self, e):
|
||||
""" passed as second argument to Basic.replace(...) """
|
||||
numsum, terms_with_func, other_non_num_terms = self._group_Add_terms(e)
|
||||
if numsum == 0:
|
||||
return e
|
||||
substituted, untouched = [], []
|
||||
for with_func in terms_with_func:
|
||||
if with_func.is_Mul:
|
||||
func, coeff = sift(with_func.args, lambda arg: arg.func == self.func, binary=True)
|
||||
if len(func) == 1 and len(coeff) == 1:
|
||||
func, coeff = func[0], coeff[0]
|
||||
else:
|
||||
coeff = None
|
||||
elif with_func.func == self.func:
|
||||
func, coeff = with_func, S.One
|
||||
else:
|
||||
coeff = None
|
||||
|
||||
if coeff is not None and coeff.is_number and sign(coeff) == -sign(numsum):
|
||||
if self.opportunistic:
|
||||
do_substitute = abs(coeff+numsum) < abs(numsum)
|
||||
else:
|
||||
do_substitute = coeff+numsum == 0
|
||||
|
||||
if do_substitute: # advantageous substitution
|
||||
numsum += coeff
|
||||
substituted.append(coeff*self.func_m_1(*func.args))
|
||||
continue
|
||||
untouched.append(with_func)
|
||||
|
||||
return e.func(numsum, *substituted, *untouched, *other_non_num_terms)
|
||||
|
||||
def __call__(self, expr):
|
||||
alt1 = super().__call__(expr)
|
||||
alt2 = super().__call__(expr.factor())
|
||||
return self.cheapest(alt1, alt2)
|
||||
|
||||
|
||||
expm1_opt = FuncMinusOneOptim(exp, expm1)
|
||||
cosm1_opt = FuncMinusOneOptim(cos, cosm1)
|
||||
powm1_opt = FuncMinusOneOptim(Pow, powm1)
|
||||
|
||||
log1p_opt = ReplaceOptim(
|
||||
lambda e: isinstance(e, log),
|
||||
lambda l: expand_log(l.replace(
|
||||
log, lambda arg: log(arg.factor())
|
||||
)).replace(log(_u+1), log1p(_u))
|
||||
)
|
||||
|
||||
def create_expand_pow_optimization(limit, *, base_req=lambda b: b.is_symbol):
|
||||
""" Creates an instance of :class:`ReplaceOptim` for expanding ``Pow``.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The requirements for expansions are that the base needs to be a symbol
|
||||
and the exponent needs to be an Integer (and be less than or equal to
|
||||
``limit``).
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
limit : int
|
||||
The highest power which is expanded into multiplication.
|
||||
base_req : function returning bool
|
||||
Requirement on base for expansion to happen, default is to return
|
||||
the ``is_symbol`` attribute of the base.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Symbol, sin
|
||||
>>> from sympy.codegen.rewriting import create_expand_pow_optimization
|
||||
>>> x = Symbol('x')
|
||||
>>> expand_opt = create_expand_pow_optimization(3)
|
||||
>>> expand_opt(x**5 + x**3)
|
||||
x**5 + x*x*x
|
||||
>>> expand_opt(x**5 + x**3 + sin(x)**3)
|
||||
x**5 + sin(x)**3 + x*x*x
|
||||
>>> opt2 = create_expand_pow_optimization(3, base_req=lambda b: not b.is_Function)
|
||||
>>> opt2((x+1)**2 + sin(x)**2)
|
||||
sin(x)**2 + (x + 1)*(x + 1)
|
||||
|
||||
"""
|
||||
return ReplaceOptim(
|
||||
lambda e: e.is_Pow and base_req(e.base) and e.exp.is_Integer and abs(e.exp) <= limit,
|
||||
lambda p: (
|
||||
UnevaluatedExpr(Mul(*([p.base]*+p.exp), evaluate=False)) if p.exp > 0 else
|
||||
1/UnevaluatedExpr(Mul(*([p.base]*-p.exp), evaluate=False))
|
||||
))
|
||||
|
||||
# Optimization procedures for turning A**(-1) * x into MatrixSolve(A, x)
|
||||
def _matinv_predicate(expr):
|
||||
# TODO: We should be able to support more than 2 elements
|
||||
if expr.is_MatMul and len(expr.args) == 2:
|
||||
left, right = expr.args
|
||||
if left.is_Inverse and right.shape[1] == 1:
|
||||
inv_arg = left.arg
|
||||
if isinstance(inv_arg, MatrixSymbol):
|
||||
return bool(ask(Q.fullrank(left.arg)))
|
||||
|
||||
return False
|
||||
|
||||
def _matinv_transform(expr):
|
||||
left, right = expr.args
|
||||
inv_arg = left.arg
|
||||
return MatrixSolve(inv_arg, right)
|
||||
|
||||
|
||||
matinv_opt = ReplaceOptim(_matinv_predicate, _matinv_transform)
|
||||
|
||||
|
||||
logaddexp_opt = ReplaceOptim(log(exp(_v)+exp(_w)), logaddexp(_v, _w))
|
||||
logaddexp2_opt = ReplaceOptim(log(Pow(2, _v)+Pow(2, _w)), logaddexp2(_v, _w)*log(2))
|
||||
|
||||
# Collections of optimizations:
|
||||
optims_c99 = (expm1_opt, log1p_opt, exp2_opt, log2_opt, log2const_opt)
|
||||
|
||||
optims_numpy = optims_c99 + (logaddexp_opt, logaddexp2_opt,) + sinc_opts
|
||||
|
||||
optims_scipy = (cosm1_opt, powm1_opt)
|
||||
@@ -0,0 +1,79 @@
|
||||
from sympy.core.function import Add, ArgumentIndexError, Function
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.functions.elementary.exponential import log
|
||||
from sympy.functions.elementary.trigonometric import cos, sin
|
||||
|
||||
|
||||
def _cosm1(x, *, evaluate=True):
|
||||
return Add(cos(x, evaluate=evaluate), -S.One, evaluate=evaluate)
|
||||
|
||||
|
||||
class cosm1(Function):
|
||||
""" Minus one plus cosine of x, i.e. cos(x) - 1. For use when x is close to zero.
|
||||
|
||||
Helper class for use with e.g. scipy.special.cosm1
|
||||
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.cosm1.html
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return -sin(*self.args)
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
def _eval_rewrite_as_cos(self, x, **kwargs):
|
||||
return _cosm1(x)
|
||||
|
||||
def _eval_evalf(self, *args, **kwargs):
|
||||
return self.rewrite(cos).evalf(*args, **kwargs)
|
||||
|
||||
def _eval_simplify(self, **kwargs):
|
||||
x, = self.args
|
||||
candidate = _cosm1(x.simplify(**kwargs))
|
||||
if candidate != _cosm1(x, evaluate=False):
|
||||
return candidate
|
||||
else:
|
||||
return cosm1(x)
|
||||
|
||||
|
||||
def _powm1(x, y, *, evaluate=True):
|
||||
return Add(Pow(x, y, evaluate=evaluate), -S.One, evaluate=evaluate)
|
||||
|
||||
|
||||
class powm1(Function):
|
||||
""" Minus one plus x to the power of y, i.e. x**y - 1. For use when x is close to one or y is close to zero.
|
||||
|
||||
Helper class for use with e.g. scipy.special.powm1
|
||||
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.powm1.html
|
||||
"""
|
||||
nargs = 2
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return Pow(self.args[0], self.args[1])*self.args[1]/self.args[0]
|
||||
elif argindex == 2:
|
||||
return log(self.args[0])*Pow(*self.args)
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
def _eval_rewrite_as_Pow(self, x, y, **kwargs):
|
||||
return _powm1(x, y)
|
||||
|
||||
def _eval_evalf(self, *args, **kwargs):
|
||||
return self.rewrite(Pow).evalf(*args, **kwargs)
|
||||
|
||||
def _eval_simplify(self, **kwargs):
|
||||
x, y = self.args
|
||||
candidate = _powm1(x.simplify(**kwargs), y.simplify(**kwargs))
|
||||
if candidate != _powm1(x, y, evaluate=False):
|
||||
return candidate
|
||||
else:
|
||||
return powm1(x, y)
|
||||
@@ -0,0 +1,14 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.codegen.abstract_nodes import List
|
||||
|
||||
|
||||
def test_List():
|
||||
l = List(2, 3, 4)
|
||||
assert l == List(2, 3, 4)
|
||||
assert str(l) == "[2, 3, 4]"
|
||||
x, y, z = symbols('x y z')
|
||||
l = List(x**2,y**3,z**4)
|
||||
# contrary to python's built-in list, we can call e.g. "replace" on List.
|
||||
m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp)
|
||||
assert m == [x**2, y-3, z-4]
|
||||
hash(m)
|
||||
@@ -0,0 +1,180 @@
|
||||
import tempfile
|
||||
from sympy import log, Min, Max, sqrt
|
||||
from sympy.core.numbers import Float
|
||||
from sympy.core.symbol import Symbol, symbols
|
||||
from sympy.functions.elementary.trigonometric import cos
|
||||
from sympy.codegen.ast import Assignment, Raise, RuntimeError_, QuotedString
|
||||
from sympy.codegen.algorithms import newtons_method, newtons_method_function
|
||||
from sympy.codegen.cfunctions import expm1
|
||||
from sympy.codegen.fnodes import bind_C
|
||||
from sympy.codegen.futils import render_as_module as f_module
|
||||
from sympy.codegen.pyutils import render_as_module as py_module
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c, has_fortran
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, raises, skip_under_pyodide
|
||||
|
||||
cython = import_module('cython')
|
||||
wurlitzer = import_module('wurlitzer')
|
||||
|
||||
def test_newtons_method():
|
||||
x, dx, atol = symbols('x dx atol')
|
||||
expr = cos(x) - x**3
|
||||
algo = newtons_method(expr, x, atol, dx)
|
||||
assert algo.has(Assignment(dx, -expr/expr.diff(x)))
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_newtons_method_function__ccode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x)
|
||||
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
|
||||
compile_kw = {"std": 'c99'}
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton.c', ('#include <math.h>\n'
|
||||
'#include <stdio.h>\n') + ccode(func)),
|
||||
('_newton.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double)\n"
|
||||
"def py_newton(x):\n"
|
||||
" return newton(x)\n"))
|
||||
], build_dir=folder, compile_kwargs=compile_kw)
|
||||
assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_newtons_method_function__fcode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x, attrs=[bind_C(name='newton')])
|
||||
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_fortran():
|
||||
skip("No Fortran compiler found.")
|
||||
|
||||
f_mod = f_module([func], 'mod_newton')
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton.f90', f_mod),
|
||||
('_newton.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double*)\n"
|
||||
"def py_newton(double x):\n"
|
||||
" return newton(&x)\n"))
|
||||
], build_dir=folder)
|
||||
assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
def test_newtons_method_function__pycode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x)
|
||||
py_mod = py_module(func)
|
||||
namespace = {}
|
||||
exec(py_mod, namespace, namespace)
|
||||
res = eval('newton(0.5)', namespace)
|
||||
assert abs(res - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
@may_xfail
|
||||
@skip_under_pyodide("Emscripten does not support process spawning")
|
||||
def test_newtons_method_function__ccode_parameters():
|
||||
args = x, A, k, p = symbols('x A k p')
|
||||
expr = A*cos(k*x) - p*x**3
|
||||
raises(ValueError, lambda: newtons_method_function(expr, x))
|
||||
use_wurlitzer = wurlitzer
|
||||
|
||||
func = newtons_method_function(expr, x, args, debug=use_wurlitzer)
|
||||
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
|
||||
compile_kw = {"std": 'c99'}
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton_par.c', ('#include <math.h>\n'
|
||||
'#include <stdio.h>\n') + ccode(func)),
|
||||
('_newton_par.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double, double, double, double)\n"
|
||||
"def py_newton(x, A=1, k=1, p=1):\n"
|
||||
" return newton(x, A, k, p)\n"))
|
||||
], compile_kwargs=compile_kw, build_dir=folder)
|
||||
|
||||
if use_wurlitzer:
|
||||
with wurlitzer.pipes() as (out, err):
|
||||
result = mod.py_newton(0.5)
|
||||
else:
|
||||
result = mod.py_newton(0.5)
|
||||
|
||||
assert abs(result - 0.865474033102) < 1e-12
|
||||
|
||||
if not use_wurlitzer:
|
||||
skip("C-level output only tested when package 'wurlitzer' is available.")
|
||||
|
||||
out, err = out.read(), err.read()
|
||||
assert err == ''
|
||||
assert out == """\
|
||||
x= 0.5
|
||||
x= 1.1121 d_x= 0.61214
|
||||
x= 0.90967 d_x= -0.20247
|
||||
x= 0.86726 d_x= -0.042409
|
||||
x= 0.86548 d_x= -0.0017867
|
||||
x= 0.86547 d_x= -3.1022e-06
|
||||
x= 0.86547 d_x= -9.3421e-12
|
||||
x= 0.86547 d_x= 3.6902e-17
|
||||
""" # try to run tests with LC_ALL=C if this assertion fails
|
||||
|
||||
|
||||
def test_newtons_method_function__rtol_cse_nan():
|
||||
a, b, c, N_geo, N_tot = symbols('a b c N_geo N_tot', real=True, nonnegative=True)
|
||||
i = Symbol('i', integer=True, nonnegative=True)
|
||||
N_ari = N_tot - N_geo - 1
|
||||
delta_ari = (c-b)/N_ari
|
||||
ln_delta_geo = log(b) + log(-expm1((log(a)-log(b))/N_geo))
|
||||
eqb_log = ln_delta_geo - log(delta_ari)
|
||||
|
||||
def _clamp(low, expr, high):
|
||||
return Min(Max(low, expr), high)
|
||||
|
||||
meth_kw = {
|
||||
'clamped_newton': {'delta_fn': lambda e, x: _clamp(
|
||||
(sqrt(a*x)-x)*0.99,
|
||||
-e/e.diff(x),
|
||||
(sqrt(c*x)-x)*0.99
|
||||
)},
|
||||
'halley': {'delta_fn': lambda e, x: (-2*(e*e.diff(x))/(2*e.diff(x)**2 - e*e.diff(x, 2)))},
|
||||
'halley_alt': {'delta_fn': lambda e, x: (-e/e.diff(x)/(1-e/e.diff(x)*e.diff(x,2)/2/e.diff(x)))},
|
||||
}
|
||||
args = eqb_log, b
|
||||
for use_cse in [False, True]:
|
||||
kwargs = {
|
||||
'params': (b, a, c, N_geo, N_tot), 'itermax': 60, 'debug': True, 'cse': use_cse,
|
||||
'counter': i, 'atol': 1e-100, 'rtol': 2e-16, 'bounds': (a,c),
|
||||
'handle_nan': Raise(RuntimeError_(QuotedString("encountered NaN.")))
|
||||
}
|
||||
func = {k: newtons_method_function(*args, func_name=f"{k}_b", **dict(kwargs, **kw)) for k, kw in meth_kw.items()}
|
||||
py_mod = {k: py_module(v) for k, v in func.items()}
|
||||
namespace = {}
|
||||
root_find_b = {}
|
||||
for k, v in py_mod.items():
|
||||
ns = namespace[k] = {}
|
||||
exec(v, ns, ns)
|
||||
root_find_b[k] = ns[f'{k}_b']
|
||||
ref = Float('13.2261515064168768938151923226496')
|
||||
reftol = {'clamped_newton': 2e-16, 'halley': 2e-16, 'halley_alt': 3e-16}
|
||||
guess = 4.0
|
||||
for meth, func in root_find_b.items():
|
||||
result = func(guess, 1e-2, 1e2, 50, 100)
|
||||
req = ref*reftol[meth]
|
||||
if use_cse:
|
||||
req *= 2
|
||||
assert abs(result - ref) < req
|
||||
@@ -0,0 +1,58 @@
|
||||
# This file contains tests that exercise multiple AST nodes
|
||||
|
||||
import tempfile
|
||||
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, skip_under_pyodide
|
||||
from sympy.codegen.ast import (
|
||||
FunctionDefinition, FunctionPrototype, Variable, Pointer, real, Assignment,
|
||||
integer, CodeBlock, While
|
||||
)
|
||||
from sympy.codegen.cnodes import void, PreIncrement
|
||||
from sympy.codegen.cutils import render_as_source_file
|
||||
|
||||
cython = import_module('cython')
|
||||
np = import_module('numpy')
|
||||
|
||||
def _mk_func1():
|
||||
declars = n, inp, out = Variable('n', integer), Pointer('inp', real), Pointer('out', real)
|
||||
i = Variable('i', integer)
|
||||
whl = While(i<n, [Assignment(out[i], inp[i]), PreIncrement(i)])
|
||||
body = CodeBlock(i.as_Declaration(value=0), whl)
|
||||
return FunctionDefinition(void, 'our_test_function', declars, body)
|
||||
|
||||
|
||||
def _render_compile_import(funcdef, build_dir):
|
||||
code_str = render_as_source_file(funcdef, settings={"contract": False})
|
||||
declar = ccode(FunctionPrototype.from_FunctionDefinition(funcdef))
|
||||
return compile_link_import_strings([
|
||||
('our_test_func.c', code_str),
|
||||
('_our_test_func.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern {declar}\n"
|
||||
"def _{fname}({typ}[:] inp, {typ}[:] out):\n"
|
||||
" {fname}(inp.size, &inp[0], &out[0])").format(
|
||||
declar=declar, fname=funcdef.name, typ='double'
|
||||
))
|
||||
], build_dir=build_dir)
|
||||
|
||||
|
||||
@may_xfail
|
||||
@skip_under_pyodide("Emscripten does not support process spawning")
|
||||
def test_copying_function():
|
||||
if not np:
|
||||
skip("numpy not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
if not cython:
|
||||
skip("Cython not found.")
|
||||
|
||||
info = None
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = _render_compile_import(_mk_func1(), build_dir=folder)
|
||||
inp = np.arange(10.0)
|
||||
out = np.empty_like(inp)
|
||||
mod._our_test_function(inp, out)
|
||||
assert np.allclose(inp, out)
|
||||
@@ -0,0 +1,53 @@
|
||||
import math
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.codegen.rewriting import optimize
|
||||
from sympy.codegen.approximations import SumApprox, SeriesApprox
|
||||
|
||||
|
||||
def test_SumApprox_trivial():
|
||||
x = symbols('x')
|
||||
expr1 = 1 + x
|
||||
sum_approx = SumApprox(bounds={x: (-1e-20, 1e-20)}, reltol=1e-16)
|
||||
apx1 = optimize(expr1, [sum_approx])
|
||||
assert apx1 - 1 == 0
|
||||
|
||||
|
||||
def test_SumApprox_monotone_terms():
|
||||
x, y, z = symbols('x y z')
|
||||
expr1 = exp(z)*(x**2 + y**2 + 1)
|
||||
bnds1 = {x: (0, 1e-3), y: (100, 1000)}
|
||||
sum_approx_m2 = SumApprox(bounds=bnds1, reltol=1e-2)
|
||||
sum_approx_m5 = SumApprox(bounds=bnds1, reltol=1e-5)
|
||||
sum_approx_m11 = SumApprox(bounds=bnds1, reltol=1e-11)
|
||||
assert (optimize(expr1, [sum_approx_m2])/exp(z) - (y**2)).simplify() == 0
|
||||
assert (optimize(expr1, [sum_approx_m5])/exp(z) - (y**2 + 1)).simplify() == 0
|
||||
assert (optimize(expr1, [sum_approx_m11])/exp(z) - (y**2 + 1 + x**2)).simplify() == 0
|
||||
|
||||
|
||||
def test_SeriesApprox_trivial():
|
||||
x, z = symbols('x z')
|
||||
for factor in [1, exp(z)]:
|
||||
x = symbols('x')
|
||||
expr1 = exp(x)*factor
|
||||
bnds1 = {x: (-1, 1)}
|
||||
series_approx_50 = SeriesApprox(bounds=bnds1, reltol=0.50)
|
||||
series_approx_10 = SeriesApprox(bounds=bnds1, reltol=0.10)
|
||||
series_approx_05 = SeriesApprox(bounds=bnds1, reltol=0.05)
|
||||
c = (bnds1[x][1] + bnds1[x][0])/2 # 0.0
|
||||
f0 = math.exp(c) # 1.0
|
||||
|
||||
ref_50 = f0 + x + x**2/2
|
||||
ref_10 = f0 + x + x**2/2 + x**3/6
|
||||
ref_05 = f0 + x + x**2/2 + x**3/6 + x**4/24
|
||||
|
||||
res_50 = optimize(expr1, [series_approx_50])
|
||||
res_10 = optimize(expr1, [series_approx_10])
|
||||
res_05 = optimize(expr1, [series_approx_05])
|
||||
|
||||
assert (res_50/factor - ref_50).simplify() == 0
|
||||
assert (res_10/factor - ref_10).simplify() == 0
|
||||
assert (res_05/factor - ref_05).simplify() == 0
|
||||
|
||||
max_ord3 = SeriesApprox(bounds=bnds1, reltol=0.05, max_order=3)
|
||||
assert optimize(expr1, [max_ord3]) == expr1
|
||||
@@ -0,0 +1,661 @@
|
||||
import math
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.numbers import nan, oo, Float, Integer
|
||||
from sympy.core.relational import Lt
|
||||
from sympy.core.symbol import symbols, Symbol
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.sets.fancysets import Range
|
||||
from sympy.tensor.indexed import Idx, IndexedBase
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
|
||||
from sympy.codegen.ast import (
|
||||
Assignment, Attribute, aug_assign, CodeBlock, For, Type, Variable, Pointer, Declaration,
|
||||
AddAugmentedAssignment, SubAugmentedAssignment, MulAugmentedAssignment,
|
||||
DivAugmentedAssignment, ModAugmentedAssignment, value_const, pointer_const,
|
||||
integer, real, complex_, int8, uint8, float16 as f16, float32 as f32,
|
||||
float64 as f64, float80 as f80, float128 as f128, complex64 as c64, complex128 as c128,
|
||||
While, Scope, String, Print, QuotedString, FunctionPrototype, FunctionDefinition, Return,
|
||||
FunctionCall, untyped, IntBaseType, intc, Node, none, NoneToken, Token, Comment
|
||||
)
|
||||
|
||||
x, y, z, t, x0, x1, x2, a, b = symbols("x, y, z, t, x0, x1, x2, a, b")
|
||||
n = symbols("n", integer=True)
|
||||
A = MatrixSymbol('A', 3, 1)
|
||||
mat = Matrix([1, 2, 3])
|
||||
B = IndexedBase('B')
|
||||
i = Idx("i", n)
|
||||
A22 = MatrixSymbol('A22',2,2)
|
||||
B22 = MatrixSymbol('B22',2,2)
|
||||
|
||||
|
||||
def test_Assignment():
|
||||
# Here we just do things to show they don't error
|
||||
Assignment(x, y)
|
||||
Assignment(x, 0)
|
||||
Assignment(A, mat)
|
||||
Assignment(A[1,0], 0)
|
||||
Assignment(A[1,0], x)
|
||||
Assignment(B[i], x)
|
||||
Assignment(B[i], 0)
|
||||
a = Assignment(x, y)
|
||||
assert a.func(*a.args) == a
|
||||
assert a.op == ':='
|
||||
# Here we test things to show that they error
|
||||
# Matrix to scalar
|
||||
raises(ValueError, lambda: Assignment(B[i], A))
|
||||
raises(ValueError, lambda: Assignment(B[i], mat))
|
||||
raises(ValueError, lambda: Assignment(x, mat))
|
||||
raises(ValueError, lambda: Assignment(x, A))
|
||||
raises(ValueError, lambda: Assignment(A[1,0], mat))
|
||||
# Scalar to matrix
|
||||
raises(ValueError, lambda: Assignment(A, x))
|
||||
raises(ValueError, lambda: Assignment(A, 0))
|
||||
# Non-atomic lhs
|
||||
raises(TypeError, lambda: Assignment(mat, A))
|
||||
raises(TypeError, lambda: Assignment(0, x))
|
||||
raises(TypeError, lambda: Assignment(x*x, 1))
|
||||
raises(TypeError, lambda: Assignment(A + A, mat))
|
||||
raises(TypeError, lambda: Assignment(B, 0))
|
||||
|
||||
|
||||
def test_AugAssign():
|
||||
# Here we just do things to show they don't error
|
||||
aug_assign(x, '+', y)
|
||||
aug_assign(x, '+', 0)
|
||||
aug_assign(A, '+', mat)
|
||||
aug_assign(A[1, 0], '+', 0)
|
||||
aug_assign(A[1, 0], '+', x)
|
||||
aug_assign(B[i], '+', x)
|
||||
aug_assign(B[i], '+', 0)
|
||||
|
||||
# Check creation via aug_assign vs constructor
|
||||
for binop, cls in [
|
||||
('+', AddAugmentedAssignment),
|
||||
('-', SubAugmentedAssignment),
|
||||
('*', MulAugmentedAssignment),
|
||||
('/', DivAugmentedAssignment),
|
||||
('%', ModAugmentedAssignment),
|
||||
]:
|
||||
a = aug_assign(x, binop, y)
|
||||
b = cls(x, y)
|
||||
assert a.func(*a.args) == a == b
|
||||
assert a.binop == binop
|
||||
assert a.op == binop + '='
|
||||
|
||||
# Here we test things to show that they error
|
||||
# Matrix to scalar
|
||||
raises(ValueError, lambda: aug_assign(B[i], '+', A))
|
||||
raises(ValueError, lambda: aug_assign(B[i], '+', mat))
|
||||
raises(ValueError, lambda: aug_assign(x, '+', mat))
|
||||
raises(ValueError, lambda: aug_assign(x, '+', A))
|
||||
raises(ValueError, lambda: aug_assign(A[1, 0], '+', mat))
|
||||
# Scalar to matrix
|
||||
raises(ValueError, lambda: aug_assign(A, '+', x))
|
||||
raises(ValueError, lambda: aug_assign(A, '+', 0))
|
||||
# Non-atomic lhs
|
||||
raises(TypeError, lambda: aug_assign(mat, '+', A))
|
||||
raises(TypeError, lambda: aug_assign(0, '+', x))
|
||||
raises(TypeError, lambda: aug_assign(x * x, '+', 1))
|
||||
raises(TypeError, lambda: aug_assign(A + A, '+', mat))
|
||||
raises(TypeError, lambda: aug_assign(B, '+', 0))
|
||||
|
||||
|
||||
def test_Assignment_printing():
|
||||
assignment_classes = [
|
||||
Assignment,
|
||||
AddAugmentedAssignment,
|
||||
SubAugmentedAssignment,
|
||||
MulAugmentedAssignment,
|
||||
DivAugmentedAssignment,
|
||||
ModAugmentedAssignment,
|
||||
]
|
||||
pairs = [
|
||||
(x, 2 * y + 2),
|
||||
(B[i], x),
|
||||
(A22, B22),
|
||||
(A[0, 0], x),
|
||||
]
|
||||
|
||||
for cls in assignment_classes:
|
||||
for lhs, rhs in pairs:
|
||||
a = cls(lhs, rhs)
|
||||
assert repr(a) == '%s(%s, %s)' % (cls.__name__, repr(lhs), repr(rhs))
|
||||
|
||||
|
||||
def test_CodeBlock():
|
||||
c = CodeBlock(Assignment(x, 1), Assignment(y, x + 1))
|
||||
assert c.func(*c.args) == c
|
||||
|
||||
assert c.left_hand_sides == Tuple(x, y)
|
||||
assert c.right_hand_sides == Tuple(1, x + 1)
|
||||
|
||||
def test_CodeBlock_topological_sort():
|
||||
assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(t, x),
|
||||
Assignment(y, 2),
|
||||
]
|
||||
|
||||
ordered_assignments = [
|
||||
# Note that the unrelated z=1 and y=2 are kept in that order
|
||||
Assignment(z, 1),
|
||||
Assignment(y, 2),
|
||||
Assignment(x, y + z),
|
||||
Assignment(t, x),
|
||||
]
|
||||
c1 = CodeBlock.topological_sort(assignments)
|
||||
assert c1 == CodeBlock(*ordered_assignments)
|
||||
|
||||
# Cycle
|
||||
invalid_assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(y, x),
|
||||
Assignment(y, 2),
|
||||
]
|
||||
|
||||
raises(ValueError, lambda: CodeBlock.topological_sort(invalid_assignments))
|
||||
|
||||
# Free symbols
|
||||
free_assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, a * b),
|
||||
Assignment(t, x),
|
||||
Assignment(y, b + 3),
|
||||
]
|
||||
|
||||
free_assignments_ordered = [
|
||||
Assignment(z, a * b),
|
||||
Assignment(y, b + 3),
|
||||
Assignment(x, y + z),
|
||||
Assignment(t, x),
|
||||
]
|
||||
|
||||
c2 = CodeBlock.topological_sort(free_assignments)
|
||||
assert c2 == CodeBlock(*free_assignments_ordered)
|
||||
|
||||
def test_CodeBlock_free_symbols():
|
||||
c1 = CodeBlock(
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(t, x),
|
||||
Assignment(y, 2),
|
||||
)
|
||||
assert c1.free_symbols == set()
|
||||
|
||||
c2 = CodeBlock(
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, a * b),
|
||||
Assignment(t, x),
|
||||
Assignment(y, b + 3),
|
||||
)
|
||||
assert c2.free_symbols == {a, b}
|
||||
|
||||
def test_CodeBlock_cse():
|
||||
c1 = CodeBlock(
|
||||
Assignment(y, 1),
|
||||
Assignment(x, sin(y)),
|
||||
Assignment(z, sin(y)),
|
||||
Assignment(t, x*z),
|
||||
)
|
||||
assert c1.cse() == CodeBlock(
|
||||
Assignment(y, 1),
|
||||
Assignment(x0, sin(y)),
|
||||
Assignment(x, x0),
|
||||
Assignment(z, x0),
|
||||
Assignment(t, x*z),
|
||||
)
|
||||
|
||||
# Multiple assignments to same symbol not supported
|
||||
raises(NotImplementedError, lambda: CodeBlock(
|
||||
Assignment(x, 1),
|
||||
Assignment(y, 1), Assignment(y, 2)
|
||||
).cse())
|
||||
|
||||
# Check auto-generated symbols do not collide with existing ones
|
||||
c2 = CodeBlock(
|
||||
Assignment(x0, sin(y) + 1),
|
||||
Assignment(x1, 2 * sin(y)),
|
||||
Assignment(z, x * y),
|
||||
)
|
||||
assert c2.cse() == CodeBlock(
|
||||
Assignment(x2, sin(y)),
|
||||
Assignment(x0, x2 + 1),
|
||||
Assignment(x1, 2 * x2),
|
||||
Assignment(z, x * y),
|
||||
)
|
||||
|
||||
|
||||
def test_CodeBlock_cse__issue_14118():
|
||||
# see https://github.com/sympy/sympy/issues/14118
|
||||
c = CodeBlock(
|
||||
Assignment(A22, Matrix([[x, sin(y)],[3, 4]])),
|
||||
Assignment(B22, Matrix([[sin(y), 2*sin(y)], [sin(y)**2, 7]]))
|
||||
)
|
||||
assert c.cse() == CodeBlock(
|
||||
Assignment(x0, sin(y)),
|
||||
Assignment(A22, Matrix([[x, x0],[3, 4]])),
|
||||
Assignment(B22, Matrix([[x0, 2*x0], [x0**2, 7]]))
|
||||
)
|
||||
|
||||
def test_For():
|
||||
f = For(n, Range(0, 3), (Assignment(A[n, 0], x + n), aug_assign(x, '+', y)))
|
||||
f = For(n, (1, 2, 3, 4, 5), (Assignment(A[n, 0], x + n),))
|
||||
assert f.func(*f.args) == f
|
||||
raises(TypeError, lambda: For(n, x, (x + y,)))
|
||||
|
||||
|
||||
def test_none():
|
||||
assert none.is_Atom
|
||||
assert none == none
|
||||
class Foo(Token):
|
||||
pass
|
||||
foo = Foo()
|
||||
assert foo != none
|
||||
assert none == None
|
||||
assert none == NoneToken()
|
||||
assert none.func(*none.args) == none
|
||||
|
||||
|
||||
def test_String():
|
||||
st = String('foobar')
|
||||
assert st.is_Atom
|
||||
assert st == String('foobar')
|
||||
assert st.text == 'foobar'
|
||||
assert st.func(**st.kwargs()) == st
|
||||
assert st.func(*st.args) == st
|
||||
|
||||
|
||||
class Signifier(String):
|
||||
pass
|
||||
|
||||
si = Signifier('foobar')
|
||||
assert si != st
|
||||
assert si.text == st.text
|
||||
s = String('foo')
|
||||
assert str(s) == 'foo'
|
||||
assert repr(s) == "String('foo')"
|
||||
|
||||
def test_Comment():
|
||||
c = Comment('foobar')
|
||||
assert c.text == 'foobar'
|
||||
assert str(c) == 'foobar'
|
||||
|
||||
def test_Node():
|
||||
n = Node()
|
||||
assert n == Node()
|
||||
assert n.func(*n.args) == n
|
||||
|
||||
|
||||
def test_Type():
|
||||
t = Type('MyType')
|
||||
assert len(t.args) == 1
|
||||
assert t.name == String('MyType')
|
||||
assert str(t) == 'MyType'
|
||||
assert repr(t) == "Type(String('MyType'))"
|
||||
assert Type(t) == t
|
||||
assert t.func(*t.args) == t
|
||||
t1 = Type('t1')
|
||||
t2 = Type('t2')
|
||||
assert t1 != t2
|
||||
assert t1 == t1 and t2 == t2
|
||||
t1b = Type('t1')
|
||||
assert t1 == t1b
|
||||
assert t2 != t1b
|
||||
|
||||
|
||||
def test_Type__from_expr():
|
||||
assert Type.from_expr(i) == integer
|
||||
u = symbols('u', real=True)
|
||||
assert Type.from_expr(u) == real
|
||||
assert Type.from_expr(n) == integer
|
||||
assert Type.from_expr(3) == integer
|
||||
assert Type.from_expr(3.0) == real
|
||||
assert Type.from_expr(3+1j) == complex_
|
||||
raises(ValueError, lambda: Type.from_expr(sum))
|
||||
|
||||
|
||||
def test_Type__cast_check__integers():
|
||||
# Rounding
|
||||
raises(ValueError, lambda: integer.cast_check(3.5))
|
||||
assert integer.cast_check('3') == 3
|
||||
assert integer.cast_check(Float('3.0000000000000000000')) == 3
|
||||
assert integer.cast_check(Float('3.0000000000000000001')) == 3 # unintuitive maybe?
|
||||
|
||||
# Range
|
||||
assert int8.cast_check(127.0) == 127
|
||||
raises(ValueError, lambda: int8.cast_check(128))
|
||||
assert int8.cast_check(-128) == -128
|
||||
raises(ValueError, lambda: int8.cast_check(-129))
|
||||
|
||||
assert uint8.cast_check(0) == 0
|
||||
assert uint8.cast_check(128) == 128
|
||||
raises(ValueError, lambda: uint8.cast_check(256.0))
|
||||
raises(ValueError, lambda: uint8.cast_check(-1))
|
||||
|
||||
def test_Attribute():
|
||||
noexcept = Attribute('noexcept')
|
||||
assert noexcept == Attribute('noexcept')
|
||||
alignas16 = Attribute('alignas', [16])
|
||||
alignas32 = Attribute('alignas', [32])
|
||||
assert alignas16 != alignas32
|
||||
assert alignas16.func(*alignas16.args) == alignas16
|
||||
|
||||
|
||||
def test_Variable():
|
||||
v = Variable(x, type=real)
|
||||
assert v == Variable(v)
|
||||
assert v == Variable('x', type=real)
|
||||
assert v.symbol == x
|
||||
assert v.type == real
|
||||
assert value_const not in v.attrs
|
||||
assert v.func(*v.args) == v
|
||||
assert str(v) == 'Variable(x, type=real)'
|
||||
|
||||
w = Variable(y, f32, attrs={value_const})
|
||||
assert w.symbol == y
|
||||
assert w.type == f32
|
||||
assert value_const in w.attrs
|
||||
assert w.func(*w.args) == w
|
||||
|
||||
v_n = Variable(n, type=Type.from_expr(n))
|
||||
assert v_n.type == integer
|
||||
assert v_n.func(*v_n.args) == v_n
|
||||
v_i = Variable(i, type=Type.from_expr(n))
|
||||
assert v_i.type == integer
|
||||
assert v_i != v_n
|
||||
|
||||
a_i = Variable.deduced(i)
|
||||
assert a_i.type == integer
|
||||
assert Variable.deduced(Symbol('x', real=True)).type == real
|
||||
assert a_i.func(*a_i.args) == a_i
|
||||
|
||||
v_n2 = Variable.deduced(n, value=3.5, cast_check=False)
|
||||
assert v_n2.func(*v_n2.args) == v_n2
|
||||
assert abs(v_n2.value - 3.5) < 1e-15
|
||||
raises(ValueError, lambda: Variable.deduced(n, value=3.5, cast_check=True))
|
||||
|
||||
v_n3 = Variable.deduced(n)
|
||||
assert v_n3.type == integer
|
||||
assert str(v_n3) == 'Variable(n, type=integer)'
|
||||
assert Variable.deduced(z, value=3).type == integer
|
||||
assert Variable.deduced(z, value=3.0).type == real
|
||||
assert Variable.deduced(z, value=3.0+1j).type == complex_
|
||||
|
||||
|
||||
def test_Pointer():
|
||||
p = Pointer(x)
|
||||
assert p.symbol == x
|
||||
assert p.type == untyped
|
||||
assert value_const not in p.attrs
|
||||
assert pointer_const not in p.attrs
|
||||
assert p.func(*p.args) == p
|
||||
|
||||
u = symbols('u', real=True)
|
||||
pu = Pointer(u, type=Type.from_expr(u), attrs={value_const, pointer_const})
|
||||
assert pu.symbol is u
|
||||
assert pu.type == real
|
||||
assert value_const in pu.attrs
|
||||
assert pointer_const in pu.attrs
|
||||
assert pu.func(*pu.args) == pu
|
||||
|
||||
i = symbols('i', integer=True)
|
||||
deref = pu[i]
|
||||
assert deref.indices == (i,)
|
||||
|
||||
|
||||
def test_Declaration():
|
||||
u = symbols('u', real=True)
|
||||
vu = Variable(u, type=Type.from_expr(u))
|
||||
assert Declaration(vu).variable.type == real
|
||||
vn = Variable(n, type=Type.from_expr(n))
|
||||
assert Declaration(vn).variable.type == integer
|
||||
|
||||
# PR 19107, does not allow comparison between expressions and Basic
|
||||
# lt = StrictLessThan(vu, vn)
|
||||
# assert isinstance(lt, StrictLessThan)
|
||||
|
||||
vuc = Variable(u, Type.from_expr(u), value=3.0, attrs={value_const})
|
||||
assert value_const in vuc.attrs
|
||||
assert pointer_const not in vuc.attrs
|
||||
decl = Declaration(vuc)
|
||||
assert decl.variable == vuc
|
||||
assert isinstance(decl.variable.value, Float)
|
||||
assert decl.variable.value == 3.0
|
||||
assert decl.func(*decl.args) == decl
|
||||
assert vuc.as_Declaration() == decl
|
||||
assert vuc.as_Declaration(value=None, attrs=None) == Declaration(vu)
|
||||
|
||||
vy = Variable(y, type=integer, value=3)
|
||||
decl2 = Declaration(vy)
|
||||
assert decl2.variable == vy
|
||||
assert decl2.variable.value == Integer(3)
|
||||
|
||||
vi = Variable(i, type=Type.from_expr(i), value=3.0)
|
||||
decl3 = Declaration(vi)
|
||||
assert decl3.variable.type == integer
|
||||
assert decl3.variable.value == 3.0
|
||||
|
||||
raises(ValueError, lambda: Declaration(vi, 42))
|
||||
|
||||
|
||||
def test_IntBaseType():
|
||||
assert intc.name == String('intc')
|
||||
assert intc.args == (intc.name,)
|
||||
assert str(IntBaseType('a').name) == 'a'
|
||||
|
||||
|
||||
def test_FloatType():
|
||||
assert f16.dig == 3
|
||||
assert f32.dig == 6
|
||||
assert f64.dig == 15
|
||||
assert f80.dig == 18
|
||||
assert f128.dig == 33
|
||||
|
||||
assert f16.decimal_dig == 5
|
||||
assert f32.decimal_dig == 9
|
||||
assert f64.decimal_dig == 17
|
||||
assert f80.decimal_dig == 21
|
||||
assert f128.decimal_dig == 36
|
||||
|
||||
assert f16.max_exponent == 16
|
||||
assert f32.max_exponent == 128
|
||||
assert f64.max_exponent == 1024
|
||||
assert f80.max_exponent == 16384
|
||||
assert f128.max_exponent == 16384
|
||||
|
||||
assert f16.min_exponent == -13
|
||||
assert f32.min_exponent == -125
|
||||
assert f64.min_exponent == -1021
|
||||
assert f80.min_exponent == -16381
|
||||
assert f128.min_exponent == -16381
|
||||
|
||||
assert abs(f16.eps / Float('0.00097656', precision=16) - 1) < 0.1*10**-f16.dig
|
||||
assert abs(f32.eps / Float('1.1920929e-07', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.eps / Float('2.2204460492503131e-16', precision=64) - 1) < 0.1*10**-f64.dig
|
||||
assert abs(f80.eps / Float('1.08420217248550443401e-19', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.eps / Float(' 1.92592994438723585305597794258492732e-34', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
assert abs(f16.max / Float('65504', precision=16) - 1) < .1*10**-f16.dig
|
||||
assert abs(f32.max / Float('3.40282347e+38', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.max / Float('1.79769313486231571e+308', precision=64) - 1) < 0.1*10**-f64.dig # cf. np.finfo(np.float64).max
|
||||
assert abs(f80.max / Float('1.18973149535723176502e+4932', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.max / Float('1.18973149535723176508575932662800702e+4932', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
# cf. np.finfo(np.float32).tiny
|
||||
assert abs(f16.tiny / Float('6.1035e-05', precision=16) - 1) < 0.1*10**-f16.dig
|
||||
assert abs(f32.tiny / Float('1.17549435e-38', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.tiny / Float('2.22507385850720138e-308', precision=64) - 1) < 0.1*10**-f64.dig
|
||||
assert abs(f80.tiny / Float('3.36210314311209350626e-4932', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.tiny / Float('3.3621031431120935062626778173217526e-4932', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
assert f64.cast_check(0.5) == Float(0.5, 17)
|
||||
assert abs(f64.cast_check(3.7) - 3.7) < 3e-17
|
||||
assert isinstance(f64.cast_check(3), (Float, float))
|
||||
|
||||
assert f64.cast_nocheck(oo) == float('inf')
|
||||
assert f64.cast_nocheck(-oo) == float('-inf')
|
||||
assert f64.cast_nocheck(float(oo)) == float('inf')
|
||||
assert f64.cast_nocheck(float(-oo)) == float('-inf')
|
||||
assert math.isnan(f64.cast_nocheck(nan))
|
||||
|
||||
assert f32 != f64
|
||||
assert f64 == f64.func(*f64.args)
|
||||
|
||||
|
||||
def test_Type__cast_check__floating_point():
|
||||
raises(ValueError, lambda: f32.cast_check(123.45678949))
|
||||
raises(ValueError, lambda: f32.cast_check(12.345678949))
|
||||
raises(ValueError, lambda: f32.cast_check(1.2345678949))
|
||||
raises(ValueError, lambda: f32.cast_check(.12345678949))
|
||||
assert abs(123.456789049 - f32.cast_check(123.456789049) - 4.9e-8) < 1e-8
|
||||
assert abs(0.12345678904 - f32.cast_check(0.12345678904) - 4e-11) < 1e-11
|
||||
|
||||
dcm21 = Float('0.123456789012345670499') # 21 decimals
|
||||
assert abs(dcm21 - f64.cast_check(dcm21) - 4.99e-19) < 1e-19
|
||||
|
||||
f80.cast_check(Float('0.12345678901234567890103', precision=88))
|
||||
raises(ValueError, lambda: f80.cast_check(Float('0.12345678901234567890149', precision=88)))
|
||||
|
||||
v10 = 12345.67894
|
||||
raises(ValueError, lambda: f32.cast_check(v10))
|
||||
assert abs(Float(str(v10), precision=64+8) - f64.cast_check(v10)) < v10*1e-16
|
||||
|
||||
assert abs(f32.cast_check(2147483647) - 2147483650) < 1
|
||||
|
||||
|
||||
def test_Type__cast_check__complex_floating_point():
|
||||
val9_11 = 123.456789049 + 0.123456789049j
|
||||
raises(ValueError, lambda: c64.cast_check(.12345678949 + .12345678949j))
|
||||
assert abs(val9_11 - c64.cast_check(val9_11) - 4.9e-8) < 1e-8
|
||||
|
||||
dcm21 = Float('0.123456789012345670499') + 1e-20j # 21 decimals
|
||||
assert abs(dcm21 - c128.cast_check(dcm21) - 4.99e-19) < 1e-19
|
||||
v19 = Float('0.1234567890123456749') + 1j*Float('0.1234567890123456749')
|
||||
raises(ValueError, lambda: c128.cast_check(v19))
|
||||
|
||||
|
||||
def test_While():
|
||||
xpp = AddAugmentedAssignment(x, 1)
|
||||
whl1 = While(x < 2, [xpp])
|
||||
assert whl1.condition.args[0] == x
|
||||
assert whl1.condition.args[1] == 2
|
||||
assert whl1.condition == Lt(x, 2, evaluate=False)
|
||||
assert whl1.body.args == (xpp,)
|
||||
assert whl1.func(*whl1.args) == whl1
|
||||
|
||||
cblk = CodeBlock(AddAugmentedAssignment(x, 1))
|
||||
whl2 = While(x < 2, cblk)
|
||||
assert whl1 == whl2
|
||||
assert whl1 != While(x < 3, [xpp])
|
||||
|
||||
|
||||
def test_Scope():
|
||||
assign = Assignment(x, y)
|
||||
incr = AddAugmentedAssignment(x, 1)
|
||||
scp = Scope([assign, incr])
|
||||
cblk = CodeBlock(assign, incr)
|
||||
assert scp.body == cblk
|
||||
assert scp == Scope(cblk)
|
||||
assert scp != Scope([incr, assign])
|
||||
assert scp.func(*scp.args) == scp
|
||||
|
||||
|
||||
def test_Print():
|
||||
fmt = "%d %.3f"
|
||||
ps = Print([n, x], fmt)
|
||||
assert str(ps.format_string) == fmt
|
||||
assert ps.print_args == Tuple(n, x)
|
||||
assert ps.args == (Tuple(n, x), QuotedString(fmt), none)
|
||||
assert ps == Print((n, x), fmt)
|
||||
assert ps != Print([x, n], fmt)
|
||||
assert ps.func(*ps.args) == ps
|
||||
|
||||
ps2 = Print([n, x])
|
||||
assert ps2 == Print([n, x])
|
||||
assert ps2 != ps
|
||||
assert ps2.format_string == None
|
||||
|
||||
|
||||
def test_FunctionPrototype_and_FunctionDefinition():
|
||||
vx = Variable(x, type=real)
|
||||
vn = Variable(n, type=integer)
|
||||
fp1 = FunctionPrototype(real, 'power', [vx, vn])
|
||||
assert fp1.return_type == real
|
||||
assert fp1.name == String('power')
|
||||
assert fp1.parameters == Tuple(vx, vn)
|
||||
assert fp1 == FunctionPrototype(real, 'power', [vx, vn])
|
||||
assert fp1 != FunctionPrototype(real, 'power', [vn, vx])
|
||||
assert fp1.func(*fp1.args) == fp1
|
||||
|
||||
|
||||
body = [Assignment(x, x**n), Return(x)]
|
||||
fd1 = FunctionDefinition(real, 'power', [vx, vn], body)
|
||||
assert fd1.return_type == real
|
||||
assert str(fd1.name) == 'power'
|
||||
assert fd1.parameters == Tuple(vx, vn)
|
||||
assert fd1.body == CodeBlock(*body)
|
||||
assert fd1 == FunctionDefinition(real, 'power', [vx, vn], body)
|
||||
assert fd1 != FunctionDefinition(real, 'power', [vx, vn], body[::-1])
|
||||
assert fd1.func(*fd1.args) == fd1
|
||||
|
||||
fp2 = FunctionPrototype.from_FunctionDefinition(fd1)
|
||||
assert fp2 == fp1
|
||||
|
||||
fd2 = FunctionDefinition.from_FunctionPrototype(fp1, body)
|
||||
assert fd2 == fd1
|
||||
|
||||
|
||||
def test_Return():
|
||||
rs = Return(x)
|
||||
assert rs.args == (x,)
|
||||
assert rs == Return(x)
|
||||
assert rs != Return(y)
|
||||
assert rs.func(*rs.args) == rs
|
||||
|
||||
|
||||
def test_FunctionCall():
|
||||
fc = FunctionCall('power', (x, 3))
|
||||
assert fc.function_args[0] == x
|
||||
assert fc.function_args[1] == 3
|
||||
assert len(fc.function_args) == 2
|
||||
assert isinstance(fc.function_args[1], Integer)
|
||||
assert fc == FunctionCall('power', (x, 3))
|
||||
assert fc != FunctionCall('power', (3, x))
|
||||
assert fc != FunctionCall('Power', (x, 3))
|
||||
assert fc.func(*fc.args) == fc
|
||||
|
||||
fc2 = FunctionCall('fma', [2, 3, 4])
|
||||
assert len(fc2.function_args) == 3
|
||||
assert fc2.function_args[0] == 2
|
||||
assert fc2.function_args[1] == 3
|
||||
assert fc2.function_args[2] == 4
|
||||
assert str(fc2) in ( # not sure if QuotedString is a better default...
|
||||
'FunctionCall(fma, function_args=(2, 3, 4))',
|
||||
'FunctionCall("fma", function_args=(2, 3, 4))',
|
||||
)
|
||||
|
||||
def test_ast_replace():
|
||||
x = Variable('x', real)
|
||||
y = Variable('y', real)
|
||||
n = Variable('n', integer)
|
||||
|
||||
pwer = FunctionDefinition(real, 'pwer', [x, n], [pow(x.symbol, n.symbol)])
|
||||
pname = pwer.name
|
||||
pcall = FunctionCall('pwer', [y, 3])
|
||||
|
||||
tree1 = CodeBlock(pwer, pcall)
|
||||
assert str(tree1.args[0].name) == 'pwer'
|
||||
assert str(tree1.args[1].name) == 'pwer'
|
||||
for a, b in zip(tree1, [pwer, pcall]):
|
||||
assert a == b
|
||||
|
||||
tree2 = tree1.replace(pname, String('power'))
|
||||
assert str(tree1.args[0].name) == 'pwer'
|
||||
assert str(tree1.args[1].name) == 'pwer'
|
||||
assert str(tree2.args[0].name) == 'power'
|
||||
assert str(tree2.args[1].name) == 'power'
|
||||
@@ -0,0 +1,186 @@
|
||||
from sympy.core.numbers import (Rational, pi)
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.codegen.cfunctions import (
|
||||
expm1, log1p, exp2, log2, fma, log10, Sqrt, Cbrt, hypot, isnan, isinf
|
||||
)
|
||||
from sympy.core.function import expand_log
|
||||
|
||||
|
||||
def test_expm1():
|
||||
# Eval
|
||||
assert expm1(0) == 0
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand and rewrite
|
||||
assert expm1(x).expand(func=True) - exp(x) == -1
|
||||
assert expm1(x).rewrite('tractable') - exp(x) == -1
|
||||
assert expm1(x).rewrite('exp') - exp(x) == -1
|
||||
|
||||
# Precision
|
||||
assert not ((exp(1e-10).evalf() - 1) - 1e-10 - 5e-21) < 1e-22 # for comparison
|
||||
assert abs(expm1(1e-10).evalf() - 1e-10 - 5e-21) < 1e-22
|
||||
|
||||
# Properties
|
||||
assert expm1(x).is_real
|
||||
assert expm1(x).is_finite
|
||||
|
||||
# Diff
|
||||
assert expm1(42*x).diff(x) - 42*exp(42*x) == 0
|
||||
assert expm1(42*x).diff(x) - expm1(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_log1p():
|
||||
# Eval
|
||||
assert log1p(0) == 0
|
||||
d = S(10)
|
||||
assert expand_log(log1p(d**-1000) - log(d**1000 + 1) + log(d**1000)) == 0
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand and rewrite
|
||||
assert log1p(x).expand(func=True) - log(x + 1) == 0
|
||||
assert log1p(x).rewrite('tractable') - log(x + 1) == 0
|
||||
assert log1p(x).rewrite('log') - log(x + 1) == 0
|
||||
|
||||
# Precision
|
||||
assert not abs(log(1e-99 + 1).evalf() - 1e-99) < 1e-100 # for comparison
|
||||
assert abs(expand_log(log1p(1e-99)).evalf() - 1e-99) < 1e-100
|
||||
|
||||
# Properties
|
||||
assert log1p(-2**Rational(-1, 2)).is_real
|
||||
|
||||
assert not log1p(-1).is_finite
|
||||
assert log1p(pi).is_finite
|
||||
|
||||
assert not log1p(x).is_positive
|
||||
assert log1p(Symbol('y', positive=True)).is_positive
|
||||
|
||||
assert not log1p(x).is_zero
|
||||
assert log1p(Symbol('z', zero=True)).is_zero
|
||||
|
||||
assert not log1p(x).is_nonnegative
|
||||
assert log1p(Symbol('o', nonnegative=True)).is_nonnegative
|
||||
|
||||
# Diff
|
||||
assert log1p(42*x).diff(x) - 42/(42*x + 1) == 0
|
||||
assert log1p(42*x).diff(x) - log1p(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_exp2():
|
||||
# Eval
|
||||
assert exp2(2) == 4
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand
|
||||
assert exp2(x).expand(func=True) - 2**x == 0
|
||||
|
||||
# Diff
|
||||
assert exp2(42*x).diff(x) - 42*exp2(42*x)*log(2) == 0
|
||||
assert exp2(42*x).diff(x) - exp2(42*x).diff(x) == 0
|
||||
|
||||
|
||||
def test_log2():
|
||||
# Eval
|
||||
assert log2(8) == 3
|
||||
assert log2(pi) != log(pi)/log(2) # log2 should *save* (CPU) instructions
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
assert log2(x) != log(x)/log(2)
|
||||
assert log2(2**x) == x
|
||||
|
||||
# Expand
|
||||
assert log2(x).expand(func=True) - log(x)/log(2) == 0
|
||||
|
||||
# Diff
|
||||
assert log2(42*x).diff() - 1/(log(2)*x) == 0
|
||||
assert log2(42*x).diff() - log2(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_fma():
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
# Expand
|
||||
assert fma(x, y, z).expand(func=True) - x*y - z == 0
|
||||
|
||||
expr = fma(17*x, 42*y, 101*z)
|
||||
|
||||
# Diff
|
||||
assert expr.diff(x) - expr.expand(func=True).diff(x) == 0
|
||||
assert expr.diff(y) - expr.expand(func=True).diff(y) == 0
|
||||
assert expr.diff(z) - expr.expand(func=True).diff(z) == 0
|
||||
|
||||
assert expr.diff(x) - 17*42*y == 0
|
||||
assert expr.diff(y) - 17*42*x == 0
|
||||
assert expr.diff(z) - 101 == 0
|
||||
|
||||
|
||||
def test_log10():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert log10(x).expand(func=True) - log(x)/log(10) == 0
|
||||
|
||||
# Diff
|
||||
assert log10(42*x).diff(x) - 1/(log(10)*x) == 0
|
||||
assert log10(42*x).diff(x) - log10(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_Cbrt():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert Cbrt(x).expand(func=True) - x**Rational(1, 3) == 0
|
||||
|
||||
# Diff
|
||||
assert Cbrt(42*x).diff(x) - 42*(42*x)**(Rational(1, 3) - 1)/3 == 0
|
||||
assert Cbrt(42*x).diff(x) - Cbrt(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_Sqrt():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert Sqrt(x).expand(func=True) - x**S.Half == 0
|
||||
|
||||
# Diff
|
||||
assert Sqrt(42*x).diff(x) - 42*(42*x)**(S.Half - 1)/2 == 0
|
||||
assert Sqrt(42*x).diff(x) - Sqrt(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_hypot():
|
||||
x, y = symbols('x y')
|
||||
|
||||
# Expand
|
||||
assert hypot(x, y).expand(func=True) - (x**2 + y**2)**S.Half == 0
|
||||
|
||||
# Diff
|
||||
assert hypot(17*x, 42*y).diff(x).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(x) == 0
|
||||
assert hypot(17*x, 42*y).diff(y).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(y) == 0
|
||||
|
||||
assert hypot(17*x, 42*y).diff(x).expand(func=True) - 2*17*17*x*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0
|
||||
assert hypot(17*x, 42*y).diff(y).expand(func=True) - 2*42*42*y*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0
|
||||
|
||||
|
||||
def test_isnan_isinf():
|
||||
x = Symbol('x')
|
||||
|
||||
# isinf
|
||||
assert isinf(+S.Infinity) == True
|
||||
assert isinf(-S.Infinity) == True
|
||||
assert isinf(S.Pi) == False
|
||||
isinfx = isinf(x)
|
||||
assert isinfx not in (False, True)
|
||||
assert isinfx.func is isinf
|
||||
assert isinfx.args == (x,)
|
||||
|
||||
# isnan
|
||||
assert isnan(S.NaN) == True
|
||||
assert isnan(S.Pi) == False
|
||||
isnanx = isnan(x)
|
||||
assert isnanx not in (False, True)
|
||||
assert isnanx.func is isnan
|
||||
assert isnanx.args == (x,)
|
||||
@@ -0,0 +1,112 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.codegen.ast import Declaration, Variable, float64, int64, String, CodeBlock
|
||||
from sympy.codegen.cnodes import (
|
||||
alignof, CommaOperator, goto, Label, PreDecrement, PostDecrement, PreIncrement, PostIncrement,
|
||||
sizeof, union, struct
|
||||
)
|
||||
|
||||
x, y = symbols('x y')
|
||||
|
||||
|
||||
def test_alignof():
|
||||
ax = alignof(x)
|
||||
assert ccode(ax) == 'alignof(x)'
|
||||
assert ax.func(*ax.args) == ax
|
||||
|
||||
|
||||
def test_CommaOperator():
|
||||
expr = CommaOperator(PreIncrement(x), 2*x)
|
||||
assert ccode(expr) == '(++(x), 2*x)'
|
||||
assert expr.func(*expr.args) == expr
|
||||
|
||||
|
||||
def test_goto_Label():
|
||||
s = 'early_exit'
|
||||
g = goto(s)
|
||||
assert g.func(*g.args) == g
|
||||
assert g != goto('foobar')
|
||||
assert ccode(g) == 'goto early_exit'
|
||||
|
||||
l1 = Label(s)
|
||||
assert ccode(l1) == 'early_exit:'
|
||||
assert l1 == Label('early_exit')
|
||||
assert l1 != Label('foobar')
|
||||
|
||||
body = [PreIncrement(x)]
|
||||
l2 = Label(s, body)
|
||||
assert l2.name == String("early_exit")
|
||||
assert l2.body == CodeBlock(PreIncrement(x))
|
||||
assert ccode(l2) == ("early_exit:\n"
|
||||
"++(x);")
|
||||
|
||||
body = [PreIncrement(x), PreDecrement(y)]
|
||||
l2 = Label(s, body)
|
||||
assert l2.name == String("early_exit")
|
||||
assert l2.body == CodeBlock(PreIncrement(x), PreDecrement(y))
|
||||
assert ccode(l2) == ("early_exit:\n"
|
||||
"{\n ++(x);\n --(y);\n}")
|
||||
|
||||
|
||||
def test_PreDecrement():
|
||||
p = PreDecrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '--(x)'
|
||||
|
||||
|
||||
def test_PostDecrement():
|
||||
p = PostDecrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '(x)--'
|
||||
|
||||
|
||||
def test_PreIncrement():
|
||||
p = PreIncrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '++(x)'
|
||||
|
||||
|
||||
def test_PostIncrement():
|
||||
p = PostIncrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '(x)++'
|
||||
|
||||
|
||||
def test_sizeof():
|
||||
typename = 'unsigned int'
|
||||
sz = sizeof(typename)
|
||||
assert ccode(sz) == 'sizeof(%s)' % typename
|
||||
assert sz.func(*sz.args) == sz
|
||||
assert not sz.is_Atom
|
||||
assert sz.atoms() == {String('unsigned int'), String('sizeof')}
|
||||
|
||||
|
||||
def test_struct():
|
||||
vx, vy = Variable(x, type=float64), Variable(y, type=float64)
|
||||
s = struct('vec2', [vx, vy])
|
||||
assert s.func(*s.args) == s
|
||||
assert s == struct('vec2', (vx, vy))
|
||||
assert s != struct('vec2', (vy, vx))
|
||||
assert str(s.name) == 'vec2'
|
||||
assert len(s.declarations) == 2
|
||||
assert all(isinstance(arg, Declaration) for arg in s.declarations)
|
||||
assert ccode(s) == (
|
||||
"struct vec2 {\n"
|
||||
" double x;\n"
|
||||
" double y;\n"
|
||||
"}")
|
||||
|
||||
|
||||
def test_union():
|
||||
vx, vy = Variable(x, type=float64), Variable(y, type=int64)
|
||||
u = union('dualuse', [vx, vy])
|
||||
assert u.func(*u.args) == u
|
||||
assert u == union('dualuse', (vx, vy))
|
||||
assert str(u.name) == 'dualuse'
|
||||
assert len(u.declarations) == 2
|
||||
assert all(isinstance(arg, Declaration) for arg in u.declarations)
|
||||
assert ccode(u) == (
|
||||
"union dualuse {\n"
|
||||
" double x;\n"
|
||||
" int64_t y;\n"
|
||||
"}")
|
||||
@@ -0,0 +1,14 @@
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.codegen.ast import Type
|
||||
from sympy.codegen.cxxnodes import using
|
||||
from sympy.printing.codeprinter import cxxcode
|
||||
|
||||
x = Symbol('x')
|
||||
|
||||
def test_using():
|
||||
v = Type('std::vector')
|
||||
u1 = using(v)
|
||||
assert cxxcode(u1) == 'using std::vector'
|
||||
|
||||
u2 = using(v, 'vec')
|
||||
assert cxxcode(u2) == 'using vec = std::vector'
|
||||
@@ -0,0 +1,213 @@
|
||||
import os
|
||||
import tempfile
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.codegen.ast import (
|
||||
Assignment, Print, Declaration, FunctionDefinition, Return, real,
|
||||
FunctionCall, Variable, Element, integer
|
||||
)
|
||||
from sympy.codegen.fnodes import (
|
||||
allocatable, ArrayConstructor, isign, dsign, cmplx, kind, literal_dp,
|
||||
Program, Module, use, Subroutine, dimension, assumed_extent, ImpliedDoLoop,
|
||||
intent_out, size, Do, SubroutineCall, sum_, array, bind_C
|
||||
)
|
||||
from sympy.codegen.futils import render_as_module
|
||||
from sympy.core.expr import unchanged
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import fcode
|
||||
from sympy.utilities._compilation import has_fortran, compile_run_strings, compile_link_import_strings
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, XFAIL
|
||||
|
||||
cython = import_module('cython')
|
||||
np = import_module('numpy')
|
||||
|
||||
|
||||
def test_size():
|
||||
x = Symbol('x', real=True)
|
||||
sx = size(x)
|
||||
assert fcode(sx, source_format='free') == 'size(x)'
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_size_assumed_shape():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
a = Symbol('a', real=True)
|
||||
body = [Return((sum_(a**2)/size(a))**.5)]
|
||||
arr = array(a, dim=[':'], intent='in')
|
||||
fd = FunctionDefinition(real, 'rms', [arr], body)
|
||||
render_as_module([fd], 'mod_rms')
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('rms.f90', render_as_module([fd], 'mod_rms')),
|
||||
('main.f90', (
|
||||
'program myprog\n'
|
||||
'use mod_rms, only: rms\n'
|
||||
'real*8, dimension(4), parameter :: x = [4, 2, 2, 2]\n'
|
||||
'print "(f7.5)", dsqrt(7d0) - rms(x)\n'
|
||||
'end program\n'
|
||||
))
|
||||
], clean=True)
|
||||
assert '0.00000' in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@XFAIL # https://github.com/sympy/sympy/issues/20265
|
||||
@may_xfail
|
||||
def test_ImpliedDoLoop():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
a, i = symbols('a i', integer=True)
|
||||
idl = ImpliedDoLoop(i**3, i, -3, 3, 2)
|
||||
ac = ArrayConstructor([-28, idl, 28])
|
||||
a = array(a, dim=[':'], attrs=[allocatable])
|
||||
prog = Program('idlprog', [
|
||||
a.as_Declaration(),
|
||||
Assignment(a, ac),
|
||||
Print([a])
|
||||
])
|
||||
fsrc = fcode(prog, standard=2003, source_format='free')
|
||||
(stdout, stderr), info = compile_run_strings([('main.f90', fsrc)], clean=True)
|
||||
for numstr in '-28 -27 -1 1 27 28'.split():
|
||||
assert numstr in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_Program():
|
||||
x = Symbol('x', real=True)
|
||||
vx = Variable.deduced(x, 42)
|
||||
decl = Declaration(vx)
|
||||
prnt = Print([x, x+1])
|
||||
prog = Program('foo', [decl, prnt])
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([('main.f90', fcode(prog, standard=90))], clean=True)
|
||||
assert '42' in stdout
|
||||
assert '43' in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_Module():
|
||||
x = Symbol('x', real=True)
|
||||
v_x = Variable.deduced(x)
|
||||
sq = FunctionDefinition(real, 'sqr', [v_x], [Return(x**2)])
|
||||
mod_sq = Module('mod_sq', [], [sq])
|
||||
sq_call = FunctionCall('sqr', [42.])
|
||||
prg_sq = Program('foobar', [
|
||||
use('mod_sq', only=['sqr']),
|
||||
Print(['"Square of 42 = "', sq_call])
|
||||
])
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('mod_sq.f90', fcode(mod_sq, standard=90)),
|
||||
('main.f90', fcode(prg_sq, standard=90))
|
||||
], clean=True)
|
||||
assert '42' in stdout
|
||||
assert str(42**2) in stdout
|
||||
assert stderr == ''
|
||||
|
||||
|
||||
@XFAIL # https://github.com/sympy/sympy/issues/20265
|
||||
@may_xfail
|
||||
def test_Subroutine():
|
||||
# Code to generate the subroutine in the example from
|
||||
# http://www.fortran90.org/src/best-practices.html#arrays
|
||||
r = Symbol('r', real=True)
|
||||
i = Symbol('i', integer=True)
|
||||
v_r = Variable.deduced(r, attrs=(dimension(assumed_extent), intent_out))
|
||||
v_i = Variable.deduced(i)
|
||||
v_n = Variable('n', integer)
|
||||
do_loop = Do([
|
||||
Assignment(Element(r, [i]), literal_dp(1)/i**2)
|
||||
], i, 1, v_n)
|
||||
sub = Subroutine("f", [v_r], [
|
||||
Declaration(v_n),
|
||||
Declaration(v_i),
|
||||
Assignment(v_n, size(r)),
|
||||
do_loop
|
||||
])
|
||||
x = Symbol('x', real=True)
|
||||
v_x3 = Variable.deduced(x, attrs=[dimension(3)])
|
||||
mod = Module('mymod', definitions=[sub])
|
||||
prog = Program('foo', [
|
||||
use(mod, only=[sub]),
|
||||
Declaration(v_x3),
|
||||
SubroutineCall(sub, [v_x3]),
|
||||
Print([sum_(v_x3), v_x3])
|
||||
])
|
||||
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('a.f90', fcode(mod, standard=90)),
|
||||
('b.f90', fcode(prog, standard=90))
|
||||
], clean=True)
|
||||
ref = [1.0/i**2 for i in range(1, 4)]
|
||||
assert str(sum(ref))[:-3] in stdout
|
||||
for _ in ref:
|
||||
assert str(_)[:-3] in stdout
|
||||
assert stderr == ''
|
||||
|
||||
|
||||
def test_isign():
|
||||
x = Symbol('x', integer=True)
|
||||
assert unchanged(isign, 1, x)
|
||||
assert fcode(isign(1, x), standard=95, source_format='free') == 'isign(1, x)'
|
||||
|
||||
|
||||
def test_dsign():
|
||||
x = Symbol('x')
|
||||
assert unchanged(dsign, 1, x)
|
||||
assert fcode(dsign(literal_dp(1), x), standard=95, source_format='free') == 'dsign(1d0, x)'
|
||||
|
||||
|
||||
def test_cmplx():
|
||||
x = Symbol('x')
|
||||
assert unchanged(cmplx, 1, x)
|
||||
|
||||
|
||||
def test_kind():
|
||||
x = Symbol('x')
|
||||
assert unchanged(kind, x)
|
||||
|
||||
|
||||
def test_literal_dp():
|
||||
assert fcode(literal_dp(0), source_format='free') == '0d0'
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_bind_C():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
if not cython:
|
||||
skip("Cython not found.")
|
||||
if not np:
|
||||
skip("NumPy not found.")
|
||||
|
||||
a = Symbol('a', real=True)
|
||||
s = Symbol('s', integer=True)
|
||||
body = [Return((sum_(a**2)/s)**.5)]
|
||||
arr = array(a, dim=[s], intent='in')
|
||||
fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')])
|
||||
f_mod = render_as_module([fd], 'mod_rms')
|
||||
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('rms.f90', f_mod),
|
||||
('_rms.pyx', (
|
||||
"#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double rms(double*, int*)\n"
|
||||
"def py_rms(double[::1] x):\n"
|
||||
" cdef int s = x.size\n"
|
||||
" return rms(&x[0], &s)\n"))
|
||||
], build_dir=folder)
|
||||
assert abs(mod.py_rms(np.array([2., 4., 2., 2.])) - 7**0.5) < 1e-14
|
||||
@@ -0,0 +1,50 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.core.function import Function
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.matrices.dense import zeros
|
||||
from sympy.simplify.simplify import simplify
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from sympy.printing.numpy import NumPyPrinter
|
||||
from sympy.testing.pytest import skip
|
||||
from sympy.external import import_module
|
||||
|
||||
|
||||
def test_matrix_solve_issue_24862():
|
||||
A = Matrix(3, 3, symbols('a:9'))
|
||||
b = Matrix(3, 1, symbols('b:3'))
|
||||
hash(MatrixSolve(A, b))
|
||||
|
||||
|
||||
def test_matrix_solve_derivative_exact():
|
||||
q = symbols('q')
|
||||
a11, a12, a21, a22, b1, b2 = (
|
||||
f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
|
||||
A = Matrix([[a11, a12], [a21, a22]])
|
||||
b = Matrix([b1, b2])
|
||||
x_lu = A.LUsolve(b)
|
||||
dxdq_lu = A.LUsolve(b.diff(q) - A.diff(q) * A.LUsolve(b))
|
||||
assert simplify(x_lu.diff(q) - dxdq_lu) == zeros(2, 1)
|
||||
# dxdq_ms is the MatrixSolve equivalent of dxdq_lu
|
||||
dxdq_ms = MatrixSolve(A, b.diff(q) - A.diff(q) * MatrixSolve(A, b))
|
||||
assert MatrixSolve(A, b).diff(q) == dxdq_ms
|
||||
|
||||
|
||||
def test_matrix_solve_derivative_numpy():
|
||||
np = import_module('numpy')
|
||||
if not np:
|
||||
skip("numpy not installed.")
|
||||
q = symbols('q')
|
||||
a11, a12, a21, a22, b1, b2 = (
|
||||
f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
|
||||
A = Matrix([[a11, a12], [a21, a22]])
|
||||
b = Matrix([b1, b2])
|
||||
dx_lu = A.LUsolve(b).diff(q)
|
||||
subs = {a11.diff(q): 0.2, a12.diff(q): 0.3, a21.diff(q): 0.1,
|
||||
a22.diff(q): 0.5, b1.diff(q): 0.4, b2.diff(q): 0.9,
|
||||
a11: 1.3, a12: 0.5, a21: 1.2, a22: 4, b1: 6.2, b2: 3.5}
|
||||
p, p_vals = zip(*subs.items())
|
||||
dx_sm = MatrixSolve(A, b).diff(q)
|
||||
np.testing.assert_allclose(
|
||||
lambdify(p, dx_sm, printer=NumPyPrinter)(*p_vals),
|
||||
lambdify(p, dx_lu, printer=NumPyPrinter)(*p_vals))
|
||||
@@ -0,0 +1,69 @@
|
||||
from itertools import product
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import Max, Min
|
||||
from sympy.printing.repr import srepr
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2, minimum, maximum, amax, amin
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
def test_logaddexp():
|
||||
lae_xy = logaddexp(x, y)
|
||||
ref_xy = log(exp(x) + exp(y))
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
lae_xy.diff(wrt, deriv_order) -
|
||||
ref_xy.diff(wrt, deriv_order)
|
||||
).rewrite(log).simplify() == 0
|
||||
|
||||
one_third_e = 1*exp(1)/3
|
||||
two_thirds_e = 2*exp(1)/3
|
||||
logThirdE = log(one_third_e)
|
||||
logTwoThirdsE = log(two_thirds_e)
|
||||
lae_sum_to_e = logaddexp(logThirdE, logTwoThirdsE)
|
||||
assert lae_sum_to_e.rewrite(log) == 1
|
||||
assert lae_sum_to_e.simplify() == 1
|
||||
was = logaddexp(2, 3)
|
||||
assert srepr(was) == srepr(was.simplify()) # cannot simplify with 2, 3
|
||||
|
||||
|
||||
def test_logaddexp2():
|
||||
lae2_xy = logaddexp2(x, y)
|
||||
ref2_xy = log(2**x + 2**y)/log(2)
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
lae2_xy.diff(wrt, deriv_order) -
|
||||
ref2_xy.diff(wrt, deriv_order)
|
||||
).rewrite(log).cancel() == 0
|
||||
|
||||
def lb(x):
|
||||
return log(x)/log(2)
|
||||
|
||||
two_thirds = S.One*2/3
|
||||
four_thirds = 2*two_thirds
|
||||
lbTwoThirds = lb(two_thirds)
|
||||
lbFourThirds = lb(four_thirds)
|
||||
lae2_sum_to_2 = logaddexp2(lbTwoThirds, lbFourThirds)
|
||||
assert lae2_sum_to_2.rewrite(log) == 1
|
||||
assert lae2_sum_to_2.simplify() == 1
|
||||
was = logaddexp2(x, y)
|
||||
assert srepr(was) == srepr(was.simplify()) # cannot simplify with x, y
|
||||
|
||||
|
||||
def test_minimum_maximum():
|
||||
for MM, mm in zip([Min, Max], [minimum, maximum]):
|
||||
ref = MM(x, y, z)
|
||||
m = mm(x, y, z)
|
||||
assert m != ref
|
||||
assert m.rewrite(MM) == ref
|
||||
|
||||
|
||||
def test_amin_amax():
|
||||
for am in [amin, amax]:
|
||||
assert am(x).array == x
|
||||
assert am(x).axis == None
|
||||
assert am(x, axis=3).axis == 3
|
||||
with raises(ValueError):
|
||||
am(x, y, z)
|
||||
@@ -0,0 +1,13 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.codegen.pynodes import List
|
||||
|
||||
|
||||
def test_List():
|
||||
l = List(2, 3, 4)
|
||||
assert l == List(2, 3, 4)
|
||||
assert str(l) == "[2, 3, 4]"
|
||||
x, y, z = symbols('x y z')
|
||||
l = List(x**2,y**3,z**4)
|
||||
# contrary to python's built-in list, we can call e.g. "replace" on List.
|
||||
m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp)
|
||||
assert m == [x**2, y-3, z-4]
|
||||
@@ -0,0 +1,7 @@
|
||||
from sympy.codegen.ast import Print
|
||||
from sympy.codegen.pyutils import render_as_module
|
||||
|
||||
def test_standard():
|
||||
ast = Print('x y'.split(), r"coordinate: %12.5g %12.5g\n")
|
||||
assert render_as_module(ast, standard='python3') == \
|
||||
'\n\nprint("coordinate: %12.5g %12.5g\\n" % (x, y), end="")'
|
||||
@@ -0,0 +1,479 @@
|
||||
import tempfile
|
||||
from sympy.core.numbers import pi, Rational
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin, sinc)
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.assumptions import assuming, Q
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.codegen.cfunctions import log2, exp2, expm1, log1p
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
|
||||
from sympy.codegen.scipy_nodes import cosm1, powm1
|
||||
from sympy.codegen.rewriting import (
|
||||
optimize, cosm1_opt, log2_opt, exp2_opt, expm1_opt, log1p_opt, powm1_opt, optims_c99,
|
||||
create_expand_pow_optimization, matinv_opt, logaddexp_opt, logaddexp2_opt,
|
||||
optims_numpy, optims_scipy, sinc_opts, FuncMinusOneOptim
|
||||
)
|
||||
from sympy.testing.pytest import XFAIL, skip
|
||||
from sympy.utilities import lambdify
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
|
||||
cython = import_module('cython')
|
||||
numpy = import_module('numpy')
|
||||
scipy = import_module('scipy')
|
||||
|
||||
|
||||
def test_log2_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = 7*log(3*x + 5)/(log(2))
|
||||
opt1 = optimize(expr1, [log2_opt])
|
||||
assert opt1 == 7*log2(3*x + 5)
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
expr2 = 3*log(5*x + 7)/(13*log(2))
|
||||
opt2 = optimize(expr2, [log2_opt])
|
||||
assert opt2 == 3*log2(5*x + 7)/13
|
||||
assert opt2.rewrite(log) == expr2
|
||||
|
||||
expr3 = log(x)/log(2)
|
||||
opt3 = optimize(expr3, [log2_opt])
|
||||
assert opt3 == log2(x)
|
||||
assert opt3.rewrite(log) == expr3
|
||||
|
||||
expr4 = log(x)/log(2) + log(x+1)
|
||||
opt4 = optimize(expr4, [log2_opt])
|
||||
assert opt4 == log2(x) + log(2)*log2(x+1)
|
||||
assert opt4.rewrite(log) == expr4
|
||||
|
||||
expr5 = log(17)
|
||||
opt5 = optimize(expr5, [log2_opt])
|
||||
assert opt5 == expr5
|
||||
|
||||
expr6 = log(x + 3)/log(2)
|
||||
opt6 = optimize(expr6, [log2_opt])
|
||||
assert str(opt6) == 'log2(x + 3)'
|
||||
assert opt6.rewrite(log) == expr6
|
||||
|
||||
|
||||
def test_exp2_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = 1 + 2**x
|
||||
opt1 = optimize(expr1, [exp2_opt])
|
||||
assert opt1 == 1 + exp2(x)
|
||||
assert opt1.rewrite(Pow) == expr1
|
||||
|
||||
expr2 = 1 + 3**x
|
||||
assert expr2 == optimize(expr2, [exp2_opt])
|
||||
|
||||
|
||||
def test_expm1_opt():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = exp(x) - 1
|
||||
opt1 = optimize(expr1, [expm1_opt])
|
||||
assert expm1(x) - opt1 == 0
|
||||
assert opt1.rewrite(exp) == expr1
|
||||
|
||||
expr2 = 3*exp(x) - 3
|
||||
opt2 = optimize(expr2, [expm1_opt])
|
||||
assert 3*expm1(x) == opt2
|
||||
assert opt2.rewrite(exp) == expr2
|
||||
|
||||
expr3 = 3*exp(x) - 5
|
||||
opt3 = optimize(expr3, [expm1_opt])
|
||||
assert 3*expm1(x) - 2 == opt3
|
||||
assert opt3.rewrite(exp) == expr3
|
||||
expm1_opt_non_opportunistic = FuncMinusOneOptim(exp, expm1, opportunistic=False)
|
||||
assert expr3 == optimize(expr3, [expm1_opt_non_opportunistic])
|
||||
assert opt1 == optimize(expr1, [expm1_opt_non_opportunistic])
|
||||
assert opt2 == optimize(expr2, [expm1_opt_non_opportunistic])
|
||||
|
||||
expr4 = 3*exp(x) + log(x) - 3
|
||||
opt4 = optimize(expr4, [expm1_opt])
|
||||
assert 3*expm1(x) + log(x) == opt4
|
||||
assert opt4.rewrite(exp) == expr4
|
||||
|
||||
expr5 = 3*exp(2*x) - 3
|
||||
opt5 = optimize(expr5, [expm1_opt])
|
||||
assert 3*expm1(2*x) == opt5
|
||||
assert opt5.rewrite(exp) == expr5
|
||||
|
||||
expr6 = (2*exp(x) + 1)/(exp(x) + 1) + 1
|
||||
opt6 = optimize(expr6, [expm1_opt])
|
||||
assert opt6.count_ops() <= expr6.count_ops()
|
||||
|
||||
def ev(e):
|
||||
return e.subs(x, 3).evalf()
|
||||
assert abs(ev(expr6) - ev(opt6)) < 1e-15
|
||||
|
||||
y = Symbol('y')
|
||||
expr7 = (2*exp(x) - 1)/(1 - exp(y)) - 1/(1-exp(y))
|
||||
opt7 = optimize(expr7, [expm1_opt])
|
||||
assert -2*expm1(x)/expm1(y) == opt7
|
||||
assert (opt7.rewrite(exp) - expr7).factor() == 0
|
||||
|
||||
expr8 = (1+exp(x))**2 - 4
|
||||
opt8 = optimize(expr8, [expm1_opt])
|
||||
tgt8a = (exp(x) + 3)*expm1(x)
|
||||
tgt8b = 2*expm1(x) + expm1(2*x)
|
||||
# Both tgt8a & tgt8b seem to give full precision (~16 digits for double)
|
||||
# for x=1e-7 (compare with expr8 which only achieves ~8 significant digits).
|
||||
# If we can show that either tgt8a or tgt8b is preferable, we can
|
||||
# change this test to ensure the preferable version is returned.
|
||||
assert (tgt8a - tgt8b).rewrite(exp).factor() == 0
|
||||
assert opt8 in (tgt8a, tgt8b)
|
||||
assert (opt8.rewrite(exp) - expr8).factor() == 0
|
||||
|
||||
expr9 = sin(expr8)
|
||||
opt9 = optimize(expr9, [expm1_opt])
|
||||
tgt9a = sin(tgt8a)
|
||||
tgt9b = sin(tgt8b)
|
||||
assert opt9 in (tgt9a, tgt9b)
|
||||
assert (opt9.rewrite(exp) - expr9.rewrite(exp)).factor().is_zero
|
||||
|
||||
|
||||
def test_expm1_two_exp_terms():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = exp(x) + exp(y) - 2
|
||||
opt1 = optimize(expr1, [expm1_opt])
|
||||
assert opt1 == expm1(x) + expm1(y)
|
||||
|
||||
|
||||
def test_cosm1_opt():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = cos(x) - 1
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert cosm1(x) - opt1 == 0
|
||||
assert opt1.rewrite(cos) == expr1
|
||||
|
||||
expr2 = 3*cos(x) - 3
|
||||
opt2 = optimize(expr2, [cosm1_opt])
|
||||
assert 3*cosm1(x) == opt2
|
||||
assert opt2.rewrite(cos) == expr2
|
||||
|
||||
expr3 = 3*cos(x) - 5
|
||||
opt3 = optimize(expr3, [cosm1_opt])
|
||||
assert 3*cosm1(x) - 2 == opt3
|
||||
assert opt3.rewrite(cos) == expr3
|
||||
cosm1_opt_non_opportunistic = FuncMinusOneOptim(cos, cosm1, opportunistic=False)
|
||||
assert expr3 == optimize(expr3, [cosm1_opt_non_opportunistic])
|
||||
assert opt1 == optimize(expr1, [cosm1_opt_non_opportunistic])
|
||||
assert opt2 == optimize(expr2, [cosm1_opt_non_opportunistic])
|
||||
|
||||
expr4 = 3*cos(x) + log(x) - 3
|
||||
opt4 = optimize(expr4, [cosm1_opt])
|
||||
assert 3*cosm1(x) + log(x) == opt4
|
||||
assert opt4.rewrite(cos) == expr4
|
||||
|
||||
expr5 = 3*cos(2*x) - 3
|
||||
opt5 = optimize(expr5, [cosm1_opt])
|
||||
assert 3*cosm1(2*x) == opt5
|
||||
assert opt5.rewrite(cos) == expr5
|
||||
|
||||
expr6 = 2 - 2*cos(x)
|
||||
opt6 = optimize(expr6, [cosm1_opt])
|
||||
assert -2*cosm1(x) == opt6
|
||||
assert opt6.rewrite(cos) == expr6
|
||||
|
||||
|
||||
def test_cosm1_two_cos_terms():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = cos(x) + cos(y) - 2
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert opt1 == cosm1(x) + cosm1(y)
|
||||
|
||||
|
||||
def test_expm1_cosm1_mixed():
|
||||
x = Symbol('x')
|
||||
expr1 = exp(x) + cos(x) - 2
|
||||
opt1 = optimize(expr1, [expm1_opt, cosm1_opt])
|
||||
assert opt1 == cosm1(x) + expm1(x)
|
||||
|
||||
|
||||
def _check_num_lambdify(expr, opt, val_subs, approx_ref, lambdify_kw=None, poorness=1e10):
|
||||
""" poorness=1e10 signifies that `expr` loses precision of at least ten decimal digits. """
|
||||
num_ref = expr.subs(val_subs).evalf()
|
||||
eps = numpy.finfo(numpy.float64).eps
|
||||
assert abs(num_ref - approx_ref) < approx_ref*eps
|
||||
f1 = lambdify(list(val_subs.keys()), opt, **(lambdify_kw or {}))
|
||||
args_float = tuple(map(float, val_subs.values()))
|
||||
num_err1 = abs(f1(*args_float) - approx_ref)
|
||||
assert num_err1 < abs(num_ref*eps)
|
||||
f2 = lambdify(list(val_subs.keys()), expr, **(lambdify_kw or {}))
|
||||
num_err2 = abs(f2(*args_float) - approx_ref)
|
||||
assert num_err2 > abs(num_ref*eps*poorness) # this only ensures that the *test* works as intended
|
||||
|
||||
|
||||
def test_cosm1_apart():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = 1/cos(x) - 1
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert opt1 == -cosm1(x)/cos(x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr1, opt1, {x: S(10)**-30}, 5e-61, lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
expr2 = 2/cos(x) - 2
|
||||
opt2 = optimize(expr2, optims_scipy)
|
||||
assert opt2 == -2*cosm1(x)/cos(x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr2, opt2, {x: S(10)**-30}, 1e-60, lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
expr3 = pi/cos(3*x) - pi
|
||||
opt3 = optimize(expr3, [cosm1_opt])
|
||||
assert opt3 == -pi*cosm1(3*x)/cos(3*x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr3, opt3, {x: S(10)**-30/3}, float(5e-61*pi), lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
|
||||
def test_powm1():
|
||||
args = x, y = map(Symbol, "xy")
|
||||
|
||||
expr1 = x**y - 1
|
||||
opt1 = optimize(expr1, [powm1_opt])
|
||||
assert opt1 == powm1(x, y)
|
||||
for arg in args:
|
||||
assert expr1.diff(arg) == opt1.diff(arg)
|
||||
if scipy and tuple(map(int, scipy.version.version.split('.')[:3])) >= (1, 10, 0):
|
||||
subs1_a = {x: Rational(*(1.0+1e-13).as_integer_ratio()), y: pi}
|
||||
ref1_f64_a = 3.139081648208105e-13
|
||||
_check_num_lambdify(expr1, opt1, subs1_a, ref1_f64_a, lambdify_kw={"modules": 'scipy'}, poorness=10**11)
|
||||
|
||||
subs1_b = {x: pi, y: Rational(*(1e-10).as_integer_ratio())}
|
||||
ref1_f64_b = 1.1447298859149205e-10
|
||||
_check_num_lambdify(expr1, opt1, subs1_b, ref1_f64_b, lambdify_kw={"modules": 'scipy'}, poorness=10**9)
|
||||
|
||||
|
||||
def test_log1p_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = log(x + 1)
|
||||
opt1 = optimize(expr1, [log1p_opt])
|
||||
assert log1p(x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
expr2 = log(3*x + 3)
|
||||
opt2 = optimize(expr2, [log1p_opt])
|
||||
assert log1p(x) + log(3) == opt2
|
||||
assert (opt2.rewrite(log) - expr2).simplify() == 0
|
||||
|
||||
expr3 = log(2*x + 1)
|
||||
opt3 = optimize(expr3, [log1p_opt])
|
||||
assert log1p(2*x) - opt3 == 0
|
||||
assert opt3.rewrite(log) == expr3
|
||||
|
||||
expr4 = log(x+3)
|
||||
opt4 = optimize(expr4, [log1p_opt])
|
||||
assert str(opt4) == 'log(x + 3)'
|
||||
|
||||
|
||||
def test_optims_c99():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1
|
||||
opt1 = optimize(expr1, optims_c99).simplify()
|
||||
assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)
|
||||
assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1
|
||||
|
||||
expr2 = log(x)/log(2) + log(x + 1)
|
||||
opt2 = optimize(expr2, optims_c99)
|
||||
assert opt2 == log2(x) + log1p(x)
|
||||
assert opt2.rewrite(log) == expr2
|
||||
|
||||
expr3 = log(x)/log(2) + log(17*x + 17)
|
||||
opt3 = optimize(expr3, optims_c99)
|
||||
delta3 = opt3 - (log2(x) + log(17) + log1p(x))
|
||||
assert delta3 == 0
|
||||
assert (opt3.rewrite(log) - expr3).simplify() == 0
|
||||
|
||||
expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17)
|
||||
opt4 = optimize(expr4, optims_c99).simplify()
|
||||
delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x))
|
||||
assert delta4 == 0
|
||||
assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0
|
||||
|
||||
expr5 = 3*exp(2*x) - 3
|
||||
opt5 = optimize(expr5, optims_c99)
|
||||
delta5 = opt5 - 3*expm1(2*x)
|
||||
assert delta5 == 0
|
||||
assert opt5.rewrite(exp) == expr5
|
||||
|
||||
expr6 = exp(2*x) - 3
|
||||
opt6 = optimize(expr6, optims_c99)
|
||||
assert opt6 in (expm1(2*x) - 2, expr6) # expm1(2*x) - 2 is not better or worse
|
||||
|
||||
expr7 = log(3*x + 3)
|
||||
opt7 = optimize(expr7, optims_c99)
|
||||
delta7 = opt7 - (log(3) + log1p(x))
|
||||
assert delta7 == 0
|
||||
assert (opt7.rewrite(log) - expr7).simplify() == 0
|
||||
|
||||
expr8 = log(2*x + 3)
|
||||
opt8 = optimize(expr8, optims_c99)
|
||||
assert opt8 == expr8
|
||||
|
||||
|
||||
def test_create_expand_pow_optimization():
|
||||
cc = lambda x: ccode(
|
||||
optimize(x, [create_expand_pow_optimization(4)]))
|
||||
x = Symbol('x')
|
||||
assert cc(x**4) == 'x*x*x*x'
|
||||
assert cc(x**4 + x**2) == 'x*x + x*x*x*x'
|
||||
assert cc(x**5 + x**4) == 'pow(x, 5) + x*x*x*x'
|
||||
assert cc(sin(x)**4) == 'pow(sin(x), 4)'
|
||||
# gh issue 15335
|
||||
assert cc(x**(-4)) == '1.0/(x*x*x*x)'
|
||||
assert cc(x**(-5)) == 'pow(x, -5)'
|
||||
assert cc(-x**4) == '-(x*x*x*x)'
|
||||
assert cc(x**4 - x**2) == '-(x*x) + x*x*x*x'
|
||||
i = Symbol('i', integer=True)
|
||||
assert cc(x**i - x**2) == 'pow(x, i) - (x*x)'
|
||||
y = Symbol('y', real=True)
|
||||
assert cc(Abs(exp(y**4))) == "exp(y*y*y*y)"
|
||||
|
||||
# gh issue 20753
|
||||
cc2 = lambda x: ccode(optimize(x, [create_expand_pow_optimization(
|
||||
4, base_req=lambda b: b.is_Function)]))
|
||||
assert cc2(x**3 + sin(x)**3) == "pow(x, 3) + sin(x)*sin(x)*sin(x)"
|
||||
|
||||
|
||||
def test_matsolve():
|
||||
n = Symbol('n', integer=True)
|
||||
A = MatrixSymbol('A', n, n)
|
||||
x = MatrixSymbol('x', n, 1)
|
||||
|
||||
with assuming(Q.fullrank(A)):
|
||||
assert optimize(A**(-1) * x, [matinv_opt]) == MatrixSolve(A, x)
|
||||
assert optimize(A**(-1) * x + x, [matinv_opt]) == MatrixSolve(A, x) + x
|
||||
|
||||
|
||||
def test_logaddexp_opt():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = log(exp(x) + exp(y))
|
||||
opt1 = optimize(expr1, [logaddexp_opt])
|
||||
assert logaddexp(x, y) - opt1 == 0
|
||||
assert logaddexp(y, x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
|
||||
def test_logaddexp2_opt():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = log(2**x + 2**y)/log(2)
|
||||
opt1 = optimize(expr1, [logaddexp2_opt])
|
||||
assert logaddexp2(x, y) - opt1 == 0
|
||||
assert logaddexp2(y, x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
|
||||
def test_sinc_opts():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, sinc_opts) == v
|
||||
|
||||
x = Symbol('x')
|
||||
check({
|
||||
sin(x)/x : sinc(x),
|
||||
sin(2*x)/(2*x) : sinc(2*x),
|
||||
sin(3*x)/x : 3*sinc(3*x),
|
||||
x*sin(x) : x*sin(x)
|
||||
})
|
||||
|
||||
y = Symbol('y')
|
||||
check({
|
||||
sin(x*y)/(x*y) : sinc(x*y),
|
||||
y*sin(x/y)/x : sinc(x/y),
|
||||
sin(sin(x))/sin(x) : sinc(sin(x)),
|
||||
sin(3*sin(x))/sin(x) : 3*sinc(3*sin(x)),
|
||||
sin(x)/y : sin(x)/y
|
||||
})
|
||||
|
||||
|
||||
def test_optims_numpy():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, optims_numpy) == v
|
||||
|
||||
x = Symbol('x')
|
||||
check({
|
||||
sin(2*x)/(2*x) + exp(2*x) - 1: sinc(2*x) + expm1(2*x),
|
||||
log(x+3)/log(2) + log(x**2 + 1): log1p(x**2) + log2(x+3)
|
||||
})
|
||||
|
||||
|
||||
@XFAIL # room for improvement, ideally this test case should pass.
|
||||
def test_optims_numpy_TODO():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, optims_numpy) == v
|
||||
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
check({
|
||||
log(x*y)*sin(x*y)*log(x*y+1)/(log(2)*x*y): log2(x*y)*sinc(x*y)*log1p(x*y),
|
||||
exp(x*sin(y)/y) - 1: expm1(x*sinc(y))
|
||||
})
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_compiled_ccode_with_rewriting():
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
|
||||
x = Symbol('x')
|
||||
about_two = 2**(58/S(117))*3**(97/S(117))*5**(4/S(39))*7**(92/S(117))/S(30)*pi
|
||||
# about_two: 1.999999999999581826
|
||||
unchanged = 2*exp(x) - about_two
|
||||
xval = S(10)**-11
|
||||
ref = unchanged.subs(x, xval).n(19) # 2.0418173913673213e-11
|
||||
|
||||
rewritten = optimize(2*exp(x) - about_two, [expm1_opt])
|
||||
|
||||
# Unfortunately, we need to call ``.n()`` on our expressions before we hand them
|
||||
# to ``ccode``, and we need to request a large number of significant digits.
|
||||
# In this test, results converged for double precision when the following number
|
||||
# of significant digits were chosen:
|
||||
NUMBER_OF_DIGITS = 25 # TODO: this should ideally be automatically handled.
|
||||
|
||||
func_c = '''
|
||||
#include <math.h>
|
||||
|
||||
double func_unchanged(double x) {
|
||||
return %(unchanged)s;
|
||||
}
|
||||
double func_rewritten(double x) {
|
||||
return %(rewritten)s;
|
||||
}
|
||||
''' % {"unchanged": ccode(unchanged.n(NUMBER_OF_DIGITS)),
|
||||
"rewritten": ccode(rewritten.n(NUMBER_OF_DIGITS))}
|
||||
|
||||
func_pyx = '''
|
||||
#cython: language_level=3
|
||||
cdef extern double func_unchanged(double)
|
||||
cdef extern double func_rewritten(double)
|
||||
def py_unchanged(x):
|
||||
return func_unchanged(x)
|
||||
def py_rewritten(x):
|
||||
return func_rewritten(x)
|
||||
'''
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings(
|
||||
[('func.c', func_c), ('_func.pyx', func_pyx)],
|
||||
build_dir=folder, compile_kwargs={"std": 'c99'}
|
||||
)
|
||||
err_rewritten = abs(mod.py_rewritten(1e-11) - ref)
|
||||
err_unchanged = abs(mod.py_unchanged(1e-11) - ref)
|
||||
assert 1e-27 < err_rewritten < 1e-25 # highly accurate.
|
||||
assert 1e-19 < err_unchanged < 1e-16 # quite poor.
|
||||
|
||||
# Tolerances used above were determined as follows:
|
||||
# >>> no_opt = unchanged.subs(x, xval.evalf()).evalf()
|
||||
# >>> with_opt = rewritten.n(25).subs(x, 1e-11).evalf()
|
||||
# >>> with_opt - ref, no_opt - ref
|
||||
# (1.1536301877952077e-26, 1.6547074214222335e-18)
|
||||
@@ -0,0 +1,44 @@
|
||||
from itertools import product
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import exp, log
|
||||
from sympy.functions.elementary.trigonometric import cos
|
||||
from sympy.core.numbers import pi
|
||||
from sympy.codegen.scipy_nodes import cosm1, powm1
|
||||
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
|
||||
def test_cosm1():
|
||||
cm1_xy = cosm1(x*y)
|
||||
ref_xy = cos(x*y) - 1
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
cm1_xy.diff(wrt, deriv_order) -
|
||||
ref_xy.diff(wrt, deriv_order)
|
||||
).rewrite(cos).simplify() == 0
|
||||
|
||||
expr_minus2 = cosm1(pi)
|
||||
assert expr_minus2.rewrite(cos) == -2
|
||||
assert cosm1(3.14).simplify() == cosm1(3.14) # cannot simplify with 3.14
|
||||
assert cosm1(pi/2).simplify() == -1
|
||||
assert (1/cos(x) - 1 + cosm1(x)/cos(x)).simplify() == 0
|
||||
|
||||
|
||||
def test_powm1():
|
||||
cases = {
|
||||
powm1(x, y): x**y - 1,
|
||||
powm1(x*y, z): (x*y)**z - 1,
|
||||
powm1(x, y*z): x**(y*z)-1,
|
||||
powm1(x*y*z, x*y*z): (x*y*z)**(x*y*z)-1
|
||||
}
|
||||
for pm1_e, ref_e in cases.items():
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
der = pm1_e.diff(wrt, deriv_order)
|
||||
ref = ref_e.diff(wrt, deriv_order)
|
||||
delta = (der - ref).rewrite(Pow)
|
||||
assert delta.simplify() == 0
|
||||
|
||||
eulers_constant_m1 = powm1(x, 1/log(x))
|
||||
assert eulers_constant_m1.rewrite(Pow) == exp(1) - 1
|
||||
assert eulers_constant_m1.simplify() == exp(1) - 1
|
||||
Reference in New Issue
Block a user