chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
""" Rewrite Rules
|
||||
|
||||
DISCLAIMER: This module is experimental. The interface is subject to change.
|
||||
|
||||
A rule is a function that transforms one expression into another
|
||||
|
||||
Rule :: Expr -> Expr
|
||||
|
||||
A strategy is a function that says how a rule should be applied to a syntax
|
||||
tree. In general strategies take rules and produce a new rule
|
||||
|
||||
Strategy :: [Rules], Other-stuff -> Rule
|
||||
|
||||
This allows developers to separate a mathematical transformation from the
|
||||
algorithmic details of applying that transformation. The goal is to separate
|
||||
the work of mathematical programming from algorithmic programming.
|
||||
|
||||
Submodules
|
||||
|
||||
strategies.rl - some fundamental rules
|
||||
strategies.core - generic non-SymPy specific strategies
|
||||
strategies.traverse - strategies that traverse a SymPy tree
|
||||
strategies.tools - some conglomerate strategies that do depend on SymPy
|
||||
"""
|
||||
|
||||
from . import rl
|
||||
from . import traverse
|
||||
from .rl import rm_id, unpack, flatten, sort, glom, distribute, rebuild
|
||||
from .util import new
|
||||
from .core import (
|
||||
condition, debug, chain, null_safe, do_one, exhaust, minimize, tryit)
|
||||
from .tools import canon, typed
|
||||
from . import branch
|
||||
|
||||
__all__ = [
|
||||
'rl',
|
||||
|
||||
'traverse',
|
||||
|
||||
'rm_id', 'unpack', 'flatten', 'sort', 'glom', 'distribute', 'rebuild',
|
||||
|
||||
'new',
|
||||
|
||||
'condition', 'debug', 'chain', 'null_safe', 'do_one', 'exhaust',
|
||||
'minimize', 'tryit',
|
||||
|
||||
'canon', 'typed',
|
||||
|
||||
'branch',
|
||||
]
|
||||
@@ -0,0 +1,14 @@
|
||||
from . import traverse
|
||||
from .core import (
|
||||
condition, debug, multiplex, exhaust, notempty,
|
||||
chain, onaction, sfilter, yieldify, do_one, identity)
|
||||
from .tools import canon
|
||||
|
||||
__all__ = [
|
||||
'traverse',
|
||||
|
||||
'condition', 'debug', 'multiplex', 'exhaust', 'notempty', 'chain',
|
||||
'onaction', 'sfilter', 'yieldify', 'do_one', 'identity',
|
||||
|
||||
'canon',
|
||||
]
|
||||
@@ -0,0 +1,116 @@
|
||||
""" Generic SymPy-Independent Strategies """
|
||||
|
||||
|
||||
def identity(x):
|
||||
yield x
|
||||
|
||||
|
||||
def exhaust(brule):
|
||||
""" Apply a branching rule repeatedly until it has no effect """
|
||||
def exhaust_brl(expr):
|
||||
seen = {expr}
|
||||
for nexpr in brule(expr):
|
||||
if nexpr not in seen:
|
||||
seen.add(nexpr)
|
||||
yield from exhaust_brl(nexpr)
|
||||
if seen == {expr}:
|
||||
yield expr
|
||||
return exhaust_brl
|
||||
|
||||
|
||||
def onaction(brule, fn):
|
||||
def onaction_brl(expr):
|
||||
for result in brule(expr):
|
||||
if result != expr:
|
||||
fn(brule, expr, result)
|
||||
yield result
|
||||
return onaction_brl
|
||||
|
||||
|
||||
def debug(brule, file=None):
|
||||
""" Print the input and output expressions at each rule application """
|
||||
if not file:
|
||||
from sys import stdout
|
||||
file = stdout
|
||||
|
||||
def write(brl, expr, result):
|
||||
file.write("Rule: %s\n" % brl.__name__)
|
||||
file.write("In: %s\nOut: %s\n\n" % (expr, result))
|
||||
|
||||
return onaction(brule, write)
|
||||
|
||||
|
||||
def multiplex(*brules):
|
||||
""" Multiplex many branching rules into one """
|
||||
def multiplex_brl(expr):
|
||||
seen = set()
|
||||
for brl in brules:
|
||||
for nexpr in brl(expr):
|
||||
if nexpr not in seen:
|
||||
seen.add(nexpr)
|
||||
yield nexpr
|
||||
return multiplex_brl
|
||||
|
||||
|
||||
def condition(cond, brule):
|
||||
""" Only apply branching rule if condition is true """
|
||||
def conditioned_brl(expr):
|
||||
if cond(expr):
|
||||
yield from brule(expr)
|
||||
else:
|
||||
pass
|
||||
return conditioned_brl
|
||||
|
||||
|
||||
def sfilter(pred, brule):
|
||||
""" Yield only those results which satisfy the predicate """
|
||||
def filtered_brl(expr):
|
||||
yield from filter(pred, brule(expr))
|
||||
return filtered_brl
|
||||
|
||||
|
||||
def notempty(brule):
|
||||
def notempty_brl(expr):
|
||||
yielded = False
|
||||
for nexpr in brule(expr):
|
||||
yielded = True
|
||||
yield nexpr
|
||||
if not yielded:
|
||||
yield expr
|
||||
return notempty_brl
|
||||
|
||||
|
||||
def do_one(*brules):
|
||||
""" Execute one of the branching rules """
|
||||
def do_one_brl(expr):
|
||||
yielded = False
|
||||
for brl in brules:
|
||||
for nexpr in brl(expr):
|
||||
yielded = True
|
||||
yield nexpr
|
||||
if yielded:
|
||||
return
|
||||
return do_one_brl
|
||||
|
||||
|
||||
def chain(*brules):
|
||||
"""
|
||||
Compose a sequence of brules so that they apply to the expr sequentially
|
||||
"""
|
||||
def chain_brl(expr):
|
||||
if not brules:
|
||||
yield expr
|
||||
return
|
||||
|
||||
head, tail = brules[0], brules[1:]
|
||||
for nexpr in head(expr):
|
||||
yield from chain(*tail)(nexpr)
|
||||
|
||||
return chain_brl
|
||||
|
||||
|
||||
def yieldify(rl):
|
||||
""" Turn a rule into a branching rule """
|
||||
def brl(expr):
|
||||
yield rl(expr)
|
||||
return brl
|
||||
@@ -0,0 +1,117 @@
|
||||
from sympy.strategies.branch.core import (
|
||||
exhaust, debug, multiplex, condition, notempty, chain, onaction, sfilter,
|
||||
yieldify, do_one, identity)
|
||||
|
||||
|
||||
def posdec(x):
|
||||
if x > 0:
|
||||
yield x - 1
|
||||
else:
|
||||
yield x
|
||||
|
||||
|
||||
def branch5(x):
|
||||
if 0 < x < 5:
|
||||
yield x - 1
|
||||
elif 5 < x < 10:
|
||||
yield x + 1
|
||||
elif x == 5:
|
||||
yield x + 1
|
||||
yield x - 1
|
||||
else:
|
||||
yield x
|
||||
|
||||
|
||||
def even(x):
|
||||
return x % 2 == 0
|
||||
|
||||
|
||||
def inc(x):
|
||||
yield x + 1
|
||||
|
||||
|
||||
def one_to_n(n):
|
||||
yield from range(n)
|
||||
|
||||
|
||||
def test_exhaust():
|
||||
brl = exhaust(branch5)
|
||||
assert set(brl(3)) == {0}
|
||||
assert set(brl(7)) == {10}
|
||||
assert set(brl(5)) == {0, 10}
|
||||
|
||||
|
||||
def test_debug():
|
||||
from io import StringIO
|
||||
file = StringIO()
|
||||
rl = debug(posdec, file)
|
||||
list(rl(5))
|
||||
log = file.getvalue()
|
||||
file.close()
|
||||
|
||||
assert posdec.__name__ in log
|
||||
assert '5' in log
|
||||
assert '4' in log
|
||||
|
||||
|
||||
def test_multiplex():
|
||||
brl = multiplex(posdec, branch5)
|
||||
assert set(brl(3)) == {2}
|
||||
assert set(brl(7)) == {6, 8}
|
||||
assert set(brl(5)) == {4, 6}
|
||||
|
||||
|
||||
def test_condition():
|
||||
brl = condition(even, branch5)
|
||||
assert set(brl(4)) == set(branch5(4))
|
||||
assert set(brl(5)) == set()
|
||||
|
||||
|
||||
def test_sfilter():
|
||||
brl = sfilter(even, one_to_n)
|
||||
assert set(brl(10)) == {0, 2, 4, 6, 8}
|
||||
|
||||
|
||||
def test_notempty():
|
||||
def ident_if_even(x):
|
||||
if even(x):
|
||||
yield x
|
||||
|
||||
brl = notempty(ident_if_even)
|
||||
assert set(brl(4)) == {4}
|
||||
assert set(brl(5)) == {5}
|
||||
|
||||
|
||||
def test_chain():
|
||||
assert list(chain()(2)) == [2] # identity
|
||||
assert list(chain(inc, inc)(2)) == [4]
|
||||
assert list(chain(branch5, inc)(4)) == [4]
|
||||
assert set(chain(branch5, inc)(5)) == {5, 7}
|
||||
assert list(chain(inc, branch5)(5)) == [7]
|
||||
|
||||
|
||||
def test_onaction():
|
||||
L = []
|
||||
|
||||
def record(fn, input, output):
|
||||
L.append((input, output))
|
||||
|
||||
list(onaction(inc, record)(2))
|
||||
assert L == [(2, 3)]
|
||||
|
||||
list(onaction(identity, record)(2))
|
||||
assert L == [(2, 3)]
|
||||
|
||||
|
||||
def test_yieldify():
|
||||
yinc = yieldify(lambda x: x + 1)
|
||||
assert list(yinc(3)) == [4]
|
||||
|
||||
|
||||
def test_do_one():
|
||||
def bad(expr):
|
||||
raise ValueError
|
||||
|
||||
assert list(do_one(inc)(3)) == [4]
|
||||
assert list(do_one(inc, bad)(3)) == [4]
|
||||
assert list(do_one(inc, posdec)(3)) == [4]
|
||||
@@ -0,0 +1,42 @@
|
||||
from sympy.strategies.branch.tools import canon
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.numbers import Integer
|
||||
from sympy.core.singleton import S
|
||||
|
||||
|
||||
def posdec(x):
|
||||
if isinstance(x, Integer) and x > 0:
|
||||
yield x - 1
|
||||
else:
|
||||
yield x
|
||||
|
||||
|
||||
def branch5(x):
|
||||
if isinstance(x, Integer):
|
||||
if 0 < x < 5:
|
||||
yield x - 1
|
||||
elif 5 < x < 10:
|
||||
yield x + 1
|
||||
elif x == 5:
|
||||
yield x + 1
|
||||
yield x - 1
|
||||
else:
|
||||
yield x
|
||||
|
||||
|
||||
def test_zero_ints():
|
||||
expr = Basic(S(2), Basic(S(5), S(3)), S(8))
|
||||
expected = {Basic(S(0), Basic(S(0), S(0)), S(0))}
|
||||
|
||||
brl = canon(posdec)
|
||||
assert set(brl(expr)) == expected
|
||||
|
||||
|
||||
def test_split5():
|
||||
expr = Basic(S(2), Basic(S(5), S(3)), S(8))
|
||||
expected = {
|
||||
Basic(S(0), Basic(S(0), S(0)), S(10)),
|
||||
Basic(S(0), Basic(S(10), S(0)), S(10))}
|
||||
|
||||
brl = canon(branch5)
|
||||
assert set(brl(expr)) == expected
|
||||
@@ -0,0 +1,53 @@
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.numbers import Integer
|
||||
from sympy.core.singleton import S
|
||||
from sympy.strategies.branch.traverse import top_down, sall
|
||||
from sympy.strategies.branch.core import do_one, identity
|
||||
|
||||
|
||||
def inc(x):
|
||||
if isinstance(x, Integer):
|
||||
yield x + 1
|
||||
|
||||
|
||||
def test_top_down_easy():
|
||||
expr = Basic(S(1), S(2))
|
||||
expected = Basic(S(2), S(3))
|
||||
brl = top_down(inc)
|
||||
|
||||
assert set(brl(expr)) == {expected}
|
||||
|
||||
|
||||
def test_top_down_big_tree():
|
||||
expr = Basic(S(1), Basic(S(2)), Basic(S(3), Basic(S(4)), S(5)))
|
||||
expected = Basic(S(2), Basic(S(3)), Basic(S(4), Basic(S(5)), S(6)))
|
||||
brl = top_down(inc)
|
||||
|
||||
assert set(brl(expr)) == {expected}
|
||||
|
||||
|
||||
def test_top_down_harder_function():
|
||||
def split5(x):
|
||||
if x == 5:
|
||||
yield x - 1
|
||||
yield x + 1
|
||||
|
||||
expr = Basic(Basic(S(5), S(6)), S(1))
|
||||
expected = {Basic(Basic(S(4), S(6)), S(1)), Basic(Basic(S(6), S(6)), S(1))}
|
||||
brl = top_down(split5)
|
||||
|
||||
assert set(brl(expr)) == expected
|
||||
|
||||
|
||||
def test_sall():
|
||||
expr = Basic(S(1), S(2))
|
||||
expected = Basic(S(2), S(3))
|
||||
brl = sall(inc)
|
||||
|
||||
assert list(brl(expr)) == [expected]
|
||||
|
||||
expr = Basic(S(1), S(2), Basic(S(3), S(4)))
|
||||
expected = Basic(S(2), S(3), Basic(S(3), S(4)))
|
||||
brl = sall(do_one(inc, identity))
|
||||
|
||||
assert list(brl(expr)) == [expected]
|
||||
@@ -0,0 +1,12 @@
|
||||
from .core import exhaust, multiplex
|
||||
from .traverse import top_down
|
||||
|
||||
|
||||
def canon(*rules):
|
||||
""" Strategy for canonicalization
|
||||
|
||||
Apply each branching rule in a top-down fashion through the tree.
|
||||
Multiplex through all branching rule traversals
|
||||
Keep doing this until there is no change.
|
||||
"""
|
||||
return exhaust(multiplex(*map(top_down, rules)))
|
||||
@@ -0,0 +1,25 @@
|
||||
""" Branching Strategies to Traverse a Tree """
|
||||
from itertools import product
|
||||
from sympy.strategies.util import basic_fns
|
||||
from .core import chain, identity, do_one
|
||||
|
||||
|
||||
def top_down(brule, fns=basic_fns):
|
||||
""" Apply a rule down a tree running it on the top nodes first """
|
||||
return chain(do_one(brule, identity),
|
||||
lambda expr: sall(top_down(brule, fns), fns)(expr))
|
||||
|
||||
|
||||
def sall(brule, fns=basic_fns):
|
||||
""" Strategic all - apply rule to args """
|
||||
op, new, children, leaf = map(fns.get, ('op', 'new', 'children', 'leaf'))
|
||||
|
||||
def all_rl(expr):
|
||||
if leaf(expr):
|
||||
yield expr
|
||||
else:
|
||||
myop = op(expr)
|
||||
argss = product(*map(brule, children(expr)))
|
||||
for args in argss:
|
||||
yield new(myop, *args)
|
||||
return all_rl
|
||||
@@ -0,0 +1,151 @@
|
||||
""" Generic SymPy-Independent Strategies """
|
||||
from __future__ import annotations
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import TypeVar
|
||||
from sys import stdout
|
||||
|
||||
|
||||
_S = TypeVar('_S')
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
def identity(x: _T) -> _T:
|
||||
return x
|
||||
|
||||
|
||||
def exhaust(rule: Callable[[_T], _T]) -> Callable[[_T], _T]:
|
||||
""" Apply a rule repeatedly until it has no effect """
|
||||
def exhaustive_rl(expr: _T) -> _T:
|
||||
new, old = rule(expr), expr
|
||||
while new != old:
|
||||
new, old = rule(new), new
|
||||
return new
|
||||
return exhaustive_rl
|
||||
|
||||
|
||||
def memoize(rule: Callable[[_S], _T]) -> Callable[[_S], _T]:
|
||||
"""Memoized version of a rule
|
||||
|
||||
Notes
|
||||
=====
|
||||
|
||||
This cache can grow infinitely, so it is not recommended to use this
|
||||
than ``functools.lru_cache`` unless you need very heavy computation.
|
||||
"""
|
||||
cache: dict[_S, _T] = {}
|
||||
|
||||
def memoized_rl(expr: _S) -> _T:
|
||||
if expr in cache:
|
||||
return cache[expr]
|
||||
else:
|
||||
result = rule(expr)
|
||||
cache[expr] = result
|
||||
return result
|
||||
return memoized_rl
|
||||
|
||||
|
||||
def condition(
|
||||
cond: Callable[[_T], bool], rule: Callable[[_T], _T]
|
||||
) -> Callable[[_T], _T]:
|
||||
""" Only apply rule if condition is true """
|
||||
def conditioned_rl(expr: _T) -> _T:
|
||||
if cond(expr):
|
||||
return rule(expr)
|
||||
return expr
|
||||
return conditioned_rl
|
||||
|
||||
|
||||
def chain(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]:
|
||||
"""
|
||||
Compose a sequence of rules so that they apply to the expr sequentially
|
||||
"""
|
||||
def chain_rl(expr: _T) -> _T:
|
||||
for rule in rules:
|
||||
expr = rule(expr)
|
||||
return expr
|
||||
return chain_rl
|
||||
|
||||
|
||||
def debug(rule, file=None):
|
||||
""" Print out before and after expressions each time rule is used """
|
||||
if file is None:
|
||||
file = stdout
|
||||
|
||||
def debug_rl(*args, **kwargs):
|
||||
expr = args[0]
|
||||
result = rule(*args, **kwargs)
|
||||
if result != expr:
|
||||
file.write("Rule: %s\n" % rule.__name__)
|
||||
file.write("In: %s\nOut: %s\n\n" % (expr, result))
|
||||
return result
|
||||
return debug_rl
|
||||
|
||||
|
||||
def null_safe(rule: Callable[[_T], _T | None]) -> Callable[[_T], _T]:
|
||||
""" Return original expr if rule returns None """
|
||||
def null_safe_rl(expr: _T) -> _T:
|
||||
result = rule(expr)
|
||||
if result is None:
|
||||
return expr
|
||||
return result
|
||||
return null_safe_rl
|
||||
|
||||
|
||||
def tryit(rule: Callable[[_T], _T], exception) -> Callable[[_T], _T]:
|
||||
""" Return original expr if rule raises exception """
|
||||
def try_rl(expr: _T) -> _T:
|
||||
try:
|
||||
return rule(expr)
|
||||
except exception:
|
||||
return expr
|
||||
return try_rl
|
||||
|
||||
|
||||
def do_one(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]:
|
||||
""" Try each of the rules until one works. Then stop. """
|
||||
def do_one_rl(expr: _T) -> _T:
|
||||
for rl in rules:
|
||||
result = rl(expr)
|
||||
if result != expr:
|
||||
return result
|
||||
return expr
|
||||
return do_one_rl
|
||||
|
||||
|
||||
def switch(
|
||||
key: Callable[[_S], _T],
|
||||
ruledict: Mapping[_T, Callable[[_S], _S]]
|
||||
) -> Callable[[_S], _S]:
|
||||
""" Select a rule based on the result of key called on the function """
|
||||
def switch_rl(expr: _S) -> _S:
|
||||
rl = ruledict.get(key(expr), identity)
|
||||
return rl(expr)
|
||||
return switch_rl
|
||||
|
||||
|
||||
# XXX Untyped default argument for minimize function
|
||||
# where python requires SupportsRichComparison type
|
||||
def _identity(x):
|
||||
return x
|
||||
|
||||
|
||||
def minimize(
|
||||
*rules: Callable[[_S], _T],
|
||||
objective=_identity
|
||||
) -> Callable[[_S], _T]:
|
||||
""" Select result of rules that minimizes objective
|
||||
|
||||
>>> from sympy.strategies import minimize
|
||||
>>> inc = lambda x: x + 1
|
||||
>>> dec = lambda x: x - 1
|
||||
>>> rl = minimize(inc, dec)
|
||||
>>> rl(4)
|
||||
3
|
||||
|
||||
>>> rl = minimize(inc, dec, objective=lambda x: -x) # maximize
|
||||
>>> rl(4)
|
||||
5
|
||||
"""
|
||||
def minrule(expr: _S) -> _T:
|
||||
return min([rule(expr) for rule in rules], key=objective)
|
||||
return minrule
|
||||
@@ -0,0 +1,176 @@
|
||||
""" Generic Rules for SymPy
|
||||
|
||||
This file assumes knowledge of Basic and little else.
|
||||
"""
|
||||
from sympy.utilities.iterables import sift
|
||||
from .util import new
|
||||
|
||||
|
||||
# Functions that create rules
|
||||
def rm_id(isid, new=new):
|
||||
""" Create a rule to remove identities.
|
||||
|
||||
isid - fn :: x -> Bool --- whether or not this element is an identity.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.strategies import rm_id
|
||||
>>> from sympy import Basic, S
|
||||
>>> remove_zeros = rm_id(lambda x: x==0)
|
||||
>>> remove_zeros(Basic(S(1), S(0), S(2)))
|
||||
Basic(1, 2)
|
||||
>>> remove_zeros(Basic(S(0), S(0))) # If only identities then we keep one
|
||||
Basic(0)
|
||||
|
||||
See Also:
|
||||
unpack
|
||||
"""
|
||||
def ident_remove(expr):
|
||||
""" Remove identities """
|
||||
ids = list(map(isid, expr.args))
|
||||
if sum(ids) == 0: # No identities. Common case
|
||||
return expr
|
||||
elif sum(ids) != len(ids): # there is at least one non-identity
|
||||
return new(expr.__class__,
|
||||
*[arg for arg, x in zip(expr.args, ids) if not x])
|
||||
else:
|
||||
return new(expr.__class__, expr.args[0])
|
||||
|
||||
return ident_remove
|
||||
|
||||
|
||||
def glom(key, count, combine):
|
||||
""" Create a rule to conglomerate identical args.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.strategies import glom
|
||||
>>> from sympy import Add
|
||||
>>> from sympy.abc import x
|
||||
|
||||
>>> key = lambda x: x.as_coeff_Mul()[1]
|
||||
>>> count = lambda x: x.as_coeff_Mul()[0]
|
||||
>>> combine = lambda cnt, arg: cnt * arg
|
||||
>>> rl = glom(key, count, combine)
|
||||
|
||||
>>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False))
|
||||
3*x + 5
|
||||
|
||||
Wait, how are key, count and combine supposed to work?
|
||||
|
||||
>>> key(2*x)
|
||||
x
|
||||
>>> count(2*x)
|
||||
2
|
||||
>>> combine(2, x)
|
||||
2*x
|
||||
"""
|
||||
def conglomerate(expr):
|
||||
""" Conglomerate together identical args x + x -> 2x """
|
||||
groups = sift(expr.args, key)
|
||||
counts = {k: sum(map(count, args)) for k, args in groups.items()}
|
||||
newargs = [combine(cnt, mat) for mat, cnt in counts.items()]
|
||||
if set(newargs) != set(expr.args):
|
||||
return new(type(expr), *newargs)
|
||||
else:
|
||||
return expr
|
||||
|
||||
return conglomerate
|
||||
|
||||
|
||||
def sort(key, new=new):
|
||||
""" Create a rule to sort by a key function.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.strategies import sort
|
||||
>>> from sympy import Basic, S
|
||||
>>> sort_rl = sort(str)
|
||||
>>> sort_rl(Basic(S(3), S(1), S(2)))
|
||||
Basic(1, 2, 3)
|
||||
"""
|
||||
|
||||
def sort_rl(expr):
|
||||
return new(expr.__class__, *sorted(expr.args, key=key))
|
||||
return sort_rl
|
||||
|
||||
|
||||
def distribute(A, B):
|
||||
""" Turns an A containing Bs into a B of As
|
||||
|
||||
where A, B are container types
|
||||
|
||||
>>> from sympy.strategies import distribute
|
||||
>>> from sympy import Add, Mul, symbols
|
||||
>>> x, y = symbols('x,y')
|
||||
>>> dist = distribute(Mul, Add)
|
||||
>>> expr = Mul(2, x+y, evaluate=False)
|
||||
>>> expr
|
||||
2*(x + y)
|
||||
>>> dist(expr)
|
||||
2*x + 2*y
|
||||
"""
|
||||
|
||||
def distribute_rl(expr):
|
||||
for i, arg in enumerate(expr.args):
|
||||
if isinstance(arg, B):
|
||||
first, b, tail = expr.args[:i], expr.args[i], expr.args[i + 1:]
|
||||
return B(*[A(*(first + (arg,) + tail)) for arg in b.args])
|
||||
return expr
|
||||
return distribute_rl
|
||||
|
||||
|
||||
def subs(a, b):
|
||||
""" Replace expressions exactly """
|
||||
def subs_rl(expr):
|
||||
if expr == a:
|
||||
return b
|
||||
else:
|
||||
return expr
|
||||
return subs_rl
|
||||
|
||||
|
||||
# Functions that are rules
|
||||
def unpack(expr):
|
||||
""" Rule to unpack singleton args
|
||||
|
||||
>>> from sympy.strategies import unpack
|
||||
>>> from sympy import Basic, S
|
||||
>>> unpack(Basic(S(2)))
|
||||
2
|
||||
"""
|
||||
if len(expr.args) == 1:
|
||||
return expr.args[0]
|
||||
else:
|
||||
return expr
|
||||
|
||||
|
||||
def flatten(expr, new=new):
|
||||
""" Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """
|
||||
cls = expr.__class__
|
||||
args = []
|
||||
for arg in expr.args:
|
||||
if arg.__class__ == cls:
|
||||
args.extend(arg.args)
|
||||
else:
|
||||
args.append(arg)
|
||||
return new(expr.__class__, *args)
|
||||
|
||||
|
||||
def rebuild(expr):
|
||||
""" Rebuild a SymPy tree.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
This function recursively calls constructors in the expression tree.
|
||||
This forces canonicalization and removes ugliness introduced by the use of
|
||||
Basic.__new__
|
||||
"""
|
||||
if expr.is_Atom:
|
||||
return expr
|
||||
else:
|
||||
return expr.func(*list(map(rebuild, expr.args)))
|
||||
@@ -0,0 +1,118 @@
|
||||
from __future__ import annotations
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.strategies.core import (
|
||||
null_safe, exhaust, memoize, condition,
|
||||
chain, tryit, do_one, debug, switch, minimize)
|
||||
from io import StringIO
|
||||
|
||||
|
||||
def posdec(x: int) -> int:
|
||||
if x > 0:
|
||||
return x - 1
|
||||
return x
|
||||
|
||||
|
||||
def inc(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
|
||||
def dec(x: int) -> int:
|
||||
return x - 1
|
||||
|
||||
|
||||
def test_null_safe():
|
||||
def rl(expr: int) -> int | None:
|
||||
if expr == 1:
|
||||
return 2
|
||||
return None
|
||||
|
||||
safe_rl = null_safe(rl)
|
||||
assert rl(1) == safe_rl(1)
|
||||
assert rl(3) is None
|
||||
assert safe_rl(3) == 3
|
||||
|
||||
|
||||
def test_exhaust():
|
||||
sink = exhaust(posdec)
|
||||
assert sink(5) == 0
|
||||
assert sink(10) == 0
|
||||
|
||||
|
||||
def test_memoize():
|
||||
rl = memoize(posdec)
|
||||
assert rl(5) == posdec(5)
|
||||
assert rl(5) == posdec(5)
|
||||
assert rl(-2) == posdec(-2)
|
||||
|
||||
|
||||
def test_condition():
|
||||
rl = condition(lambda x: x % 2 == 0, posdec)
|
||||
assert rl(5) == 5
|
||||
assert rl(4) == 3
|
||||
|
||||
|
||||
def test_chain():
|
||||
rl = chain(posdec, posdec)
|
||||
assert rl(5) == 3
|
||||
assert rl(1) == 0
|
||||
|
||||
|
||||
def test_tryit():
|
||||
def rl(expr: Basic) -> Basic:
|
||||
assert False
|
||||
|
||||
safe_rl = tryit(rl, AssertionError)
|
||||
assert safe_rl(S(1)) == S(1)
|
||||
|
||||
|
||||
def test_do_one():
|
||||
rl = do_one(posdec, posdec)
|
||||
assert rl(5) == 4
|
||||
|
||||
def rl1(x: int) -> int:
|
||||
if x == 1:
|
||||
return 2
|
||||
return x
|
||||
|
||||
def rl2(x: int) -> int:
|
||||
if x == 2:
|
||||
return 3
|
||||
return x
|
||||
|
||||
rule = do_one(rl1, rl2)
|
||||
assert rule(1) == 2
|
||||
assert rule(rule(1)) == 3
|
||||
|
||||
|
||||
def test_debug():
|
||||
file = StringIO()
|
||||
rl = debug(posdec, file)
|
||||
rl(5)
|
||||
log = file.getvalue()
|
||||
file.close()
|
||||
|
||||
assert posdec.__name__ in log
|
||||
assert '5' in log
|
||||
assert '4' in log
|
||||
|
||||
|
||||
def test_switch():
|
||||
def key(x: int) -> int:
|
||||
return x % 3
|
||||
|
||||
rl = switch(key, {0: inc, 1: dec})
|
||||
assert rl(3) == 4
|
||||
assert rl(4) == 3
|
||||
assert rl(5) == 5
|
||||
|
||||
|
||||
def test_minimize():
|
||||
def key(x: int) -> int:
|
||||
return -x
|
||||
|
||||
rl = minimize(inc, dec)
|
||||
assert rl(4) == 3
|
||||
|
||||
rl = minimize(inc, dec, objective=key)
|
||||
assert rl(4) == 5
|
||||
@@ -0,0 +1,78 @@
|
||||
from sympy.core.singleton import S
|
||||
from sympy.strategies.rl import (
|
||||
rm_id, glom, flatten, unpack, sort, distribute, subs, rebuild)
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.abc import x
|
||||
|
||||
|
||||
def test_rm_id():
|
||||
rmzeros = rm_id(lambda x: x == 0)
|
||||
assert rmzeros(Basic(S(0), S(1))) == Basic(S(1))
|
||||
assert rmzeros(Basic(S(0), S(0))) == Basic(S(0))
|
||||
assert rmzeros(Basic(S(2), S(1))) == Basic(S(2), S(1))
|
||||
|
||||
|
||||
def test_glom():
|
||||
def key(x):
|
||||
return x.as_coeff_Mul()[1]
|
||||
|
||||
def count(x):
|
||||
return x.as_coeff_Mul()[0]
|
||||
|
||||
def newargs(cnt, arg):
|
||||
return cnt * arg
|
||||
|
||||
rl = glom(key, count, newargs)
|
||||
|
||||
result = rl(Add(x, -x, 3 * x, 2, 3, evaluate=False))
|
||||
expected = Add(3 * x, 5)
|
||||
assert set(result.args) == set(expected.args)
|
||||
|
||||
|
||||
def test_flatten():
|
||||
assert flatten(Basic(S(1), S(2), Basic(S(3), S(4)))) == \
|
||||
Basic(S(1), S(2), S(3), S(4))
|
||||
|
||||
|
||||
def test_unpack():
|
||||
assert unpack(Basic(S(2))) == 2
|
||||
assert unpack(Basic(S(2), S(3))) == Basic(S(2), S(3))
|
||||
|
||||
|
||||
def test_sort():
|
||||
assert sort(str)(Basic(S(3), S(1), S(2))) == Basic(S(1), S(2), S(3))
|
||||
|
||||
|
||||
def test_distribute():
|
||||
class T1(Basic):
|
||||
pass
|
||||
|
||||
class T2(Basic):
|
||||
pass
|
||||
|
||||
distribute_t12 = distribute(T1, T2)
|
||||
assert distribute_t12(T1(S(1), S(2), T2(S(3), S(4)), S(5))) == \
|
||||
T2(T1(S(1), S(2), S(3), S(5)), T1(S(1), S(2), S(4), S(5)))
|
||||
assert distribute_t12(T1(S(1), S(2), S(3))) == T1(S(1), S(2), S(3))
|
||||
|
||||
|
||||
def test_distribute_add_mul():
|
||||
x, y = symbols('x, y')
|
||||
expr = Mul(2, Add(x, y), evaluate=False)
|
||||
expected = Add(Mul(2, x), Mul(2, y))
|
||||
distribute_mul = distribute(Mul, Add)
|
||||
assert distribute_mul(expr) == expected
|
||||
|
||||
|
||||
def test_subs():
|
||||
rl = subs(1, 2)
|
||||
assert rl(1) == 2
|
||||
assert rl(3) == 3
|
||||
|
||||
|
||||
def test_rebuild():
|
||||
expr = Basic.__new__(Add, S(1), S(2))
|
||||
assert rebuild(expr) == 3
|
||||
@@ -0,0 +1,32 @@
|
||||
from sympy.strategies.tools import subs, typed
|
||||
from sympy.strategies.rl import rm_id
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.singleton import S
|
||||
|
||||
|
||||
def test_subs():
|
||||
from sympy.core.symbol import symbols
|
||||
a, b, c, d, e, f = symbols('a,b,c,d,e,f')
|
||||
mapping = {a: d, d: a, Basic(e): Basic(f)}
|
||||
expr = Basic(a, Basic(b, c), Basic(d, Basic(e)))
|
||||
result = Basic(d, Basic(b, c), Basic(a, Basic(f)))
|
||||
assert subs(mapping)(expr) == result
|
||||
|
||||
|
||||
def test_subs_empty():
|
||||
assert subs({})(Basic(S(1), S(2))) == Basic(S(1), S(2))
|
||||
|
||||
|
||||
def test_typed():
|
||||
class A(Basic):
|
||||
pass
|
||||
|
||||
class B(Basic):
|
||||
pass
|
||||
|
||||
rmzeros = rm_id(lambda x: x == S(0))
|
||||
rmones = rm_id(lambda x: x == S(1))
|
||||
remove_something = typed({A: rmzeros, B: rmones})
|
||||
|
||||
assert remove_something(A(S(0), S(1))) == A(S(1))
|
||||
assert remove_something(B(S(0), S(1))) == B(S(0))
|
||||
@@ -0,0 +1,84 @@
|
||||
from sympy.strategies.traverse import (
|
||||
top_down, bottom_up, sall, top_down_once, bottom_up_once, basic_fns)
|
||||
from sympy.strategies.rl import rebuild
|
||||
from sympy.strategies.util import expr_fns
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.numbers import Integer
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import Str, Symbol
|
||||
from sympy.abc import x, y, z
|
||||
|
||||
|
||||
def zero_symbols(expression):
|
||||
return S.Zero if isinstance(expression, Symbol) else expression
|
||||
|
||||
|
||||
def test_sall():
|
||||
zero_onelevel = sall(zero_symbols)
|
||||
|
||||
assert zero_onelevel(Basic(x, y, Basic(x, z))) == \
|
||||
Basic(S(0), S(0), Basic(x, z))
|
||||
|
||||
|
||||
def test_bottom_up():
|
||||
_test_global_traversal(bottom_up)
|
||||
_test_stop_on_non_basics(bottom_up)
|
||||
|
||||
|
||||
def test_top_down():
|
||||
_test_global_traversal(top_down)
|
||||
_test_stop_on_non_basics(top_down)
|
||||
|
||||
|
||||
def _test_global_traversal(trav):
|
||||
zero_all_symbols = trav(zero_symbols)
|
||||
|
||||
assert zero_all_symbols(Basic(x, y, Basic(x, z))) == \
|
||||
Basic(S(0), S(0), Basic(S(0), S(0)))
|
||||
|
||||
|
||||
def _test_stop_on_non_basics(trav):
|
||||
def add_one_if_can(expr):
|
||||
try:
|
||||
return expr + 1
|
||||
except TypeError:
|
||||
return expr
|
||||
|
||||
expr = Basic(S(1), Str('a'), Basic(S(2), Str('b')))
|
||||
expected = Basic(S(2), Str('a'), Basic(S(3), Str('b')))
|
||||
rl = trav(add_one_if_can)
|
||||
|
||||
assert rl(expr) == expected
|
||||
|
||||
|
||||
class Basic2(Basic):
|
||||
pass
|
||||
|
||||
|
||||
def rl(x):
|
||||
if x.args and not isinstance(x.args[0], Integer):
|
||||
return Basic2(*x.args)
|
||||
return x
|
||||
|
||||
|
||||
def test_top_down_once():
|
||||
top_rl = top_down_once(rl)
|
||||
|
||||
assert top_rl(Basic(S(1.0), S(2.0), Basic(S(3), S(4)))) == \
|
||||
Basic2(S(1.0), S(2.0), Basic(S(3), S(4)))
|
||||
|
||||
|
||||
def test_bottom_up_once():
|
||||
bottom_rl = bottom_up_once(rl)
|
||||
|
||||
assert bottom_rl(Basic(S(1), S(2), Basic(S(3.0), S(4.0)))) == \
|
||||
Basic(S(1), S(2), Basic2(S(3.0), S(4.0)))
|
||||
|
||||
|
||||
def test_expr_fns():
|
||||
expr = x + y**3
|
||||
e = bottom_up(lambda v: v + 1, expr_fns)(expr)
|
||||
b = bottom_up(lambda v: Basic.__new__(Add, v, S(1)), basic_fns)(expr)
|
||||
|
||||
assert rebuild(b) == e
|
||||
@@ -0,0 +1,92 @@
|
||||
from sympy.strategies.tree import treeapply, greedy, allresults, brute
|
||||
from functools import partial, reduce
|
||||
|
||||
|
||||
def inc(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
def dec(x):
|
||||
return x - 1
|
||||
|
||||
|
||||
def double(x):
|
||||
return 2 * x
|
||||
|
||||
|
||||
def square(x):
|
||||
return x**2
|
||||
|
||||
|
||||
def add(*args):
|
||||
return sum(args)
|
||||
|
||||
|
||||
def mul(*args):
|
||||
return reduce(lambda a, b: a * b, args, 1)
|
||||
|
||||
|
||||
def test_treeapply():
|
||||
tree = ([3, 3], [4, 1], 2)
|
||||
assert treeapply(tree, {list: min, tuple: max}) == 3
|
||||
assert treeapply(tree, {list: add, tuple: mul}) == 60
|
||||
|
||||
|
||||
def test_treeapply_leaf():
|
||||
assert treeapply(3, {}, leaf=lambda x: x**2) == 9
|
||||
tree = ([3, 3], [4, 1], 2)
|
||||
treep1 = ([4, 4], [5, 2], 3)
|
||||
assert treeapply(tree, {list: min, tuple: max}, leaf=lambda x: x + 1) == \
|
||||
treeapply(treep1, {list: min, tuple: max})
|
||||
|
||||
|
||||
def test_treeapply_strategies():
|
||||
from sympy.strategies import chain, minimize
|
||||
join = {list: chain, tuple: minimize}
|
||||
|
||||
assert treeapply(inc, join) == inc
|
||||
assert treeapply((inc, dec), join)(5) == minimize(inc, dec)(5)
|
||||
assert treeapply([inc, dec], join)(5) == chain(inc, dec)(5)
|
||||
tree = (inc, [dec, double]) # either inc or dec-then-double
|
||||
assert treeapply(tree, join)(5) == 6
|
||||
assert treeapply(tree, join)(1) == 0
|
||||
|
||||
maximize = partial(minimize, objective=lambda x: -x)
|
||||
join = {list: chain, tuple: maximize}
|
||||
fn = treeapply(tree, join)
|
||||
assert fn(4) == 6 # highest value comes from the dec then double
|
||||
assert fn(1) == 2 # highest value comes from the inc
|
||||
|
||||
|
||||
def test_greedy():
|
||||
tree = [inc, (dec, double)] # either inc or dec-then-double
|
||||
|
||||
fn = greedy(tree, objective=lambda x: -x)
|
||||
assert fn(4) == 6 # highest value comes from the dec then double
|
||||
assert fn(1) == 2 # highest value comes from the inc
|
||||
|
||||
tree = [inc, dec, [inc, dec, [(inc, inc), (dec, dec)]]]
|
||||
lowest = greedy(tree)
|
||||
assert lowest(10) == 8
|
||||
|
||||
highest = greedy(tree, objective=lambda x: -x)
|
||||
assert highest(10) == 12
|
||||
|
||||
|
||||
def test_allresults():
|
||||
# square = lambda x: x**2
|
||||
|
||||
assert set(allresults(inc)(3)) == {inc(3)}
|
||||
assert set(allresults([inc, dec])(3)) == {2, 4}
|
||||
assert set(allresults((inc, dec))(3)) == {3}
|
||||
assert set(allresults([inc, (dec, double)])(4)) == {5, 6}
|
||||
|
||||
|
||||
def test_brute():
|
||||
tree = ([inc, dec], square)
|
||||
fn = brute(tree, lambda x: -x)
|
||||
|
||||
assert fn(2) == (2 + 1)**2
|
||||
assert fn(-2) == (-2 - 1)**2
|
||||
|
||||
assert brute(inc)(1) == 2
|
||||
@@ -0,0 +1,53 @@
|
||||
from . import rl
|
||||
from .core import do_one, exhaust, switch
|
||||
from .traverse import top_down
|
||||
|
||||
|
||||
def subs(d, **kwargs):
|
||||
""" Full simultaneous exact substitution.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.strategies.tools import subs
|
||||
>>> from sympy import Basic, S
|
||||
>>> mapping = {S(1): S(4), S(4): S(1), Basic(S(5)): Basic(S(6), S(7))}
|
||||
>>> expr = Basic(S(1), Basic(S(2), S(3)), Basic(S(4), Basic(S(5))))
|
||||
>>> subs(mapping)(expr)
|
||||
Basic(4, Basic(2, 3), Basic(1, Basic(6, 7)))
|
||||
"""
|
||||
if d:
|
||||
return top_down(do_one(*map(rl.subs, *zip(*d.items()))), **kwargs)
|
||||
else:
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def canon(*rules, **kwargs):
|
||||
""" Strategy for canonicalization.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Apply each rule in a bottom_up fashion through the tree.
|
||||
Do each one in turn.
|
||||
Keep doing this until there is no change.
|
||||
"""
|
||||
return exhaust(top_down(exhaust(do_one(*rules)), **kwargs))
|
||||
|
||||
|
||||
def typed(ruletypes):
|
||||
""" Apply rules based on the expression type
|
||||
|
||||
inputs:
|
||||
ruletypes -- a dict mapping {Type: rule}
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.strategies import rm_id, typed
|
||||
>>> from sympy import Add, Mul
|
||||
>>> rm_zeros = rm_id(lambda x: x==0)
|
||||
>>> rm_ones = rm_id(lambda x: x==1)
|
||||
>>> remove_idents = typed({Add: rm_zeros, Mul: rm_ones})
|
||||
"""
|
||||
return switch(type, ruletypes)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Strategies to Traverse a Tree."""
|
||||
from sympy.strategies.util import basic_fns
|
||||
from sympy.strategies.core import chain, do_one
|
||||
|
||||
|
||||
def top_down(rule, fns=basic_fns):
|
||||
"""Apply a rule down a tree running it on the top nodes first."""
|
||||
return chain(rule, lambda expr: sall(top_down(rule, fns), fns)(expr))
|
||||
|
||||
|
||||
def bottom_up(rule, fns=basic_fns):
|
||||
"""Apply a rule down a tree running it on the bottom nodes first."""
|
||||
return chain(lambda expr: sall(bottom_up(rule, fns), fns)(expr), rule)
|
||||
|
||||
|
||||
def top_down_once(rule, fns=basic_fns):
|
||||
"""Apply a rule down a tree - stop on success."""
|
||||
return do_one(rule, lambda expr: sall(top_down(rule, fns), fns)(expr))
|
||||
|
||||
|
||||
def bottom_up_once(rule, fns=basic_fns):
|
||||
"""Apply a rule up a tree - stop on success."""
|
||||
return do_one(lambda expr: sall(bottom_up(rule, fns), fns)(expr), rule)
|
||||
|
||||
|
||||
def sall(rule, fns=basic_fns):
|
||||
"""Strategic all - apply rule to args."""
|
||||
op, new, children, leaf = map(fns.get, ('op', 'new', 'children', 'leaf'))
|
||||
|
||||
def all_rl(expr):
|
||||
if leaf(expr):
|
||||
return expr
|
||||
else:
|
||||
args = map(rule, children(expr))
|
||||
return new(op(expr), *args)
|
||||
|
||||
return all_rl
|
||||
@@ -0,0 +1,139 @@
|
||||
from functools import partial
|
||||
from sympy.strategies import chain, minimize
|
||||
from sympy.strategies.core import identity
|
||||
import sympy.strategies.branch as branch
|
||||
from sympy.strategies.branch import yieldify
|
||||
|
||||
|
||||
def treeapply(tree, join, leaf=identity):
|
||||
""" Apply functions onto recursive containers (tree).
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
join - a dictionary mapping container types to functions
|
||||
e.g. ``{list: minimize, tuple: chain}``
|
||||
|
||||
Keys are containers/iterables. Values are functions [a] -> a.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.strategies.tree import treeapply
|
||||
>>> tree = [(3, 2), (4, 1)]
|
||||
>>> treeapply(tree, {list: max, tuple: min})
|
||||
2
|
||||
|
||||
>>> add = lambda *args: sum(args)
|
||||
>>> def mul(*args):
|
||||
... total = 1
|
||||
... for arg in args:
|
||||
... total *= arg
|
||||
... return total
|
||||
>>> treeapply(tree, {list: mul, tuple: add})
|
||||
25
|
||||
"""
|
||||
for typ in join:
|
||||
if isinstance(tree, typ):
|
||||
return join[typ](*map(partial(treeapply, join=join, leaf=leaf),
|
||||
tree))
|
||||
return leaf(tree)
|
||||
|
||||
|
||||
def greedy(tree, objective=identity, **kwargs):
|
||||
""" Execute a strategic tree. Select alternatives greedily
|
||||
|
||||
Trees
|
||||
-----
|
||||
|
||||
Nodes in a tree can be either
|
||||
|
||||
function - a leaf
|
||||
list - a selection among operations
|
||||
tuple - a sequence of chained operations
|
||||
|
||||
Textual examples
|
||||
----------------
|
||||
|
||||
Text: Run f, then run g, e.g. ``lambda x: g(f(x))``
|
||||
Code: ``(f, g)``
|
||||
|
||||
Text: Run either f or g, whichever minimizes the objective
|
||||
Code: ``[f, g]``
|
||||
|
||||
Textx: Run either f or g, whichever is better, then run h
|
||||
Code: ``([f, g], h)``
|
||||
|
||||
Text: Either expand then simplify or try factor then foosimp. Finally print
|
||||
Code: ``([(expand, simplify), (factor, foosimp)], print)``
|
||||
|
||||
Objective
|
||||
---------
|
||||
|
||||
"Better" is determined by the objective keyword. This function makes
|
||||
choices to minimize the objective. It defaults to the identity.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.strategies.tree import greedy
|
||||
>>> inc = lambda x: x + 1
|
||||
>>> dec = lambda x: x - 1
|
||||
>>> double = lambda x: 2*x
|
||||
|
||||
>>> tree = [inc, (dec, double)] # either inc or dec-then-double
|
||||
>>> fn = greedy(tree)
|
||||
>>> fn(4) # lowest value comes from the inc
|
||||
5
|
||||
>>> fn(1) # lowest value comes from dec then double
|
||||
0
|
||||
|
||||
This function selects between options in a tuple. The result is chosen
|
||||
that minimizes the objective function.
|
||||
|
||||
>>> fn = greedy(tree, objective=lambda x: -x) # maximize
|
||||
>>> fn(4) # highest value comes from the dec then double
|
||||
6
|
||||
>>> fn(1) # highest value comes from the inc
|
||||
2
|
||||
|
||||
Greediness
|
||||
----------
|
||||
|
||||
This is a greedy algorithm. In the example:
|
||||
|
||||
([a, b], c) # do either a or b, then do c
|
||||
|
||||
the choice between running ``a`` or ``b`` is made without foresight to c
|
||||
"""
|
||||
optimize = partial(minimize, objective=objective)
|
||||
return treeapply(tree, {list: optimize, tuple: chain}, **kwargs)
|
||||
|
||||
|
||||
def allresults(tree, leaf=yieldify):
|
||||
""" Execute a strategic tree. Return all possibilities.
|
||||
|
||||
Returns a lazy iterator of all possible results
|
||||
|
||||
Exhaustiveness
|
||||
--------------
|
||||
|
||||
This is an exhaustive algorithm. In the example
|
||||
|
||||
([a, b], [c, d])
|
||||
|
||||
All of the results from
|
||||
|
||||
(a, c), (b, c), (a, d), (b, d)
|
||||
|
||||
are returned. This can lead to combinatorial blowup.
|
||||
|
||||
See sympy.strategies.greedy for details on input
|
||||
"""
|
||||
return treeapply(tree, {list: branch.multiplex, tuple: branch.chain},
|
||||
leaf=leaf)
|
||||
|
||||
|
||||
def brute(tree, objective=identity, **kwargs):
|
||||
return lambda expr: min(tuple(allresults(tree, **kwargs)(expr)),
|
||||
key=objective)
|
||||
@@ -0,0 +1,17 @@
|
||||
from sympy.core.basic import Basic
|
||||
|
||||
new = Basic.__new__
|
||||
|
||||
|
||||
def assoc(d, k, v):
|
||||
d = d.copy()
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
|
||||
basic_fns = {'op': type,
|
||||
'new': Basic.__new__,
|
||||
'leaf': lambda x: not isinstance(x, Basic) or x.is_Atom,
|
||||
'children': lambda x: x.args}
|
||||
|
||||
expr_fns = assoc(basic_fns, 'new', lambda op, *args: op(*args))
|
||||
Reference in New Issue
Block a user