chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,490 @@
|
||||
'''
|
||||
Use llvmlite to create executable functions from SymPy expressions
|
||||
|
||||
This module requires llvmlite (https://github.com/numba/llvmlite).
|
||||
'''
|
||||
|
||||
import ctypes
|
||||
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.printer import Printer
|
||||
from sympy.core.singleton import S
|
||||
from sympy.tensor.indexed import IndexedBase
|
||||
from sympy.utilities.decorator import doctest_depends_on
|
||||
|
||||
llvmlite = import_module('llvmlite')
|
||||
if llvmlite:
|
||||
ll = import_module('llvmlite.ir').ir
|
||||
llvm = import_module('llvmlite.binding').binding
|
||||
llvm.initialize()
|
||||
llvm.initialize_native_target()
|
||||
llvm.initialize_native_asmprinter()
|
||||
|
||||
|
||||
__doctest_requires__ = {('llvm_callable'): ['llvmlite']}
|
||||
|
||||
|
||||
class LLVMJitPrinter(Printer):
|
||||
'''Convert expressions to LLVM IR'''
|
||||
def __init__(self, module, builder, fn, *args, **kwargs):
|
||||
self.func_arg_map = kwargs.pop("func_arg_map", {})
|
||||
if not llvmlite:
|
||||
raise ImportError("llvmlite is required for LLVMJITPrinter")
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fp_type = ll.DoubleType()
|
||||
self.module = module
|
||||
self.builder = builder
|
||||
self.fn = fn
|
||||
self.ext_fn = {} # keep track of wrappers to external functions
|
||||
self.tmp_var = {}
|
||||
|
||||
def _add_tmp_var(self, name, value):
|
||||
self.tmp_var[name] = value
|
||||
|
||||
def _print_Number(self, n):
|
||||
return ll.Constant(self.fp_type, float(n))
|
||||
|
||||
def _print_Integer(self, expr):
|
||||
return ll.Constant(self.fp_type, float(expr.p))
|
||||
|
||||
def _print_Symbol(self, s):
|
||||
val = self.tmp_var.get(s)
|
||||
if not val:
|
||||
# look up parameter with name s
|
||||
val = self.func_arg_map.get(s)
|
||||
if not val:
|
||||
raise LookupError("Symbol not found: %s" % s)
|
||||
return val
|
||||
|
||||
def _print_Pow(self, expr):
|
||||
base0 = self._print(expr.base)
|
||||
if expr.exp == S.NegativeOne:
|
||||
return self.builder.fdiv(ll.Constant(self.fp_type, 1.0), base0)
|
||||
if expr.exp == S.Half:
|
||||
fn = self.ext_fn.get("sqrt")
|
||||
if not fn:
|
||||
fn_type = ll.FunctionType(self.fp_type, [self.fp_type])
|
||||
fn = ll.Function(self.module, fn_type, "sqrt")
|
||||
self.ext_fn["sqrt"] = fn
|
||||
return self.builder.call(fn, [base0], "sqrt")
|
||||
if expr.exp == 2:
|
||||
return self.builder.fmul(base0, base0)
|
||||
|
||||
exp0 = self._print(expr.exp)
|
||||
fn = self.ext_fn.get("pow")
|
||||
if not fn:
|
||||
fn_type = ll.FunctionType(self.fp_type, [self.fp_type, self.fp_type])
|
||||
fn = ll.Function(self.module, fn_type, "pow")
|
||||
self.ext_fn["pow"] = fn
|
||||
return self.builder.call(fn, [base0, exp0], "pow")
|
||||
|
||||
def _print_Mul(self, expr):
|
||||
nodes = [self._print(a) for a in expr.args]
|
||||
e = nodes[0]
|
||||
for node in nodes[1:]:
|
||||
e = self.builder.fmul(e, node)
|
||||
return e
|
||||
|
||||
def _print_Add(self, expr):
|
||||
nodes = [self._print(a) for a in expr.args]
|
||||
e = nodes[0]
|
||||
for node in nodes[1:]:
|
||||
e = self.builder.fadd(e, node)
|
||||
return e
|
||||
|
||||
# TODO - assumes all called functions take one double precision argument.
|
||||
# Should have a list of math library functions to validate this.
|
||||
def _print_Function(self, expr):
|
||||
name = expr.func.__name__
|
||||
e0 = self._print(expr.args[0])
|
||||
fn = self.ext_fn.get(name)
|
||||
if not fn:
|
||||
fn_type = ll.FunctionType(self.fp_type, [self.fp_type])
|
||||
fn = ll.Function(self.module, fn_type, name)
|
||||
self.ext_fn[name] = fn
|
||||
return self.builder.call(fn, [e0], name)
|
||||
|
||||
def emptyPrinter(self, expr):
|
||||
raise TypeError("Unsupported type for LLVM JIT conversion: %s"
|
||||
% type(expr))
|
||||
|
||||
|
||||
# Used when parameters are passed by array. Often used in callbacks to
|
||||
# handle a variable number of parameters.
|
||||
class LLVMJitCallbackPrinter(LLVMJitPrinter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _print_Indexed(self, expr):
|
||||
array, idx = self.func_arg_map[expr.base]
|
||||
offset = int(expr.indices[0].evalf())
|
||||
array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), offset)])
|
||||
fp_array_ptr = self.builder.bitcast(array_ptr, ll.PointerType(self.fp_type))
|
||||
value = self.builder.load(fp_array_ptr)
|
||||
return value
|
||||
|
||||
def _print_Symbol(self, s):
|
||||
val = self.tmp_var.get(s)
|
||||
if val:
|
||||
return val
|
||||
|
||||
array, idx = self.func_arg_map.get(s, [None, 0])
|
||||
if not array:
|
||||
raise LookupError("Symbol not found: %s" % s)
|
||||
array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), idx)])
|
||||
fp_array_ptr = self.builder.bitcast(array_ptr,
|
||||
ll.PointerType(self.fp_type))
|
||||
value = self.builder.load(fp_array_ptr)
|
||||
return value
|
||||
|
||||
|
||||
# ensure lifetime of the execution engine persists (else call to compiled
|
||||
# function will seg fault)
|
||||
exe_engines = []
|
||||
|
||||
# ensure names for generated functions are unique
|
||||
link_names = set()
|
||||
current_link_suffix = 0
|
||||
|
||||
|
||||
class LLVMJitCode:
|
||||
def __init__(self, signature):
|
||||
self.signature = signature
|
||||
self.fp_type = ll.DoubleType()
|
||||
self.module = ll.Module('mod1')
|
||||
self.fn = None
|
||||
self.llvm_arg_types = []
|
||||
self.llvm_ret_type = self.fp_type
|
||||
self.param_dict = {} # map symbol name to LLVM function argument
|
||||
self.link_name = ''
|
||||
|
||||
def _from_ctype(self, ctype):
|
||||
if ctype == ctypes.c_int:
|
||||
return ll.IntType(32)
|
||||
if ctype == ctypes.c_double:
|
||||
return self.fp_type
|
||||
if ctype == ctypes.POINTER(ctypes.c_double):
|
||||
return ll.PointerType(self.fp_type)
|
||||
if ctype == ctypes.c_void_p:
|
||||
return ll.PointerType(ll.IntType(32))
|
||||
if ctype == ctypes.py_object:
|
||||
return ll.PointerType(ll.IntType(32))
|
||||
|
||||
print("Unhandled ctype = %s" % str(ctype))
|
||||
|
||||
def _create_args(self, func_args):
|
||||
"""Create types for function arguments"""
|
||||
self.llvm_ret_type = self._from_ctype(self.signature.ret_type)
|
||||
self.llvm_arg_types = \
|
||||
[self._from_ctype(a) for a in self.signature.arg_ctypes]
|
||||
|
||||
def _create_function_base(self):
|
||||
"""Create function with name and type signature"""
|
||||
global current_link_suffix
|
||||
default_link_name = 'jit_func'
|
||||
current_link_suffix += 1
|
||||
self.link_name = default_link_name + str(current_link_suffix)
|
||||
link_names.add(self.link_name)
|
||||
|
||||
fn_type = ll.FunctionType(self.llvm_ret_type, self.llvm_arg_types)
|
||||
self.fn = ll.Function(self.module, fn_type, name=self.link_name)
|
||||
|
||||
def _create_param_dict(self, func_args):
|
||||
"""Mapping of symbolic values to function arguments"""
|
||||
for i, a in enumerate(func_args):
|
||||
self.fn.args[i].name = str(a)
|
||||
self.param_dict[a] = self.fn.args[i]
|
||||
|
||||
def _create_function(self, expr):
|
||||
"""Create function body and return LLVM IR"""
|
||||
bb_entry = self.fn.append_basic_block('entry')
|
||||
builder = ll.IRBuilder(bb_entry)
|
||||
|
||||
lj = LLVMJitPrinter(self.module, builder, self.fn,
|
||||
func_arg_map=self.param_dict)
|
||||
|
||||
ret = self._convert_expr(lj, expr)
|
||||
lj.builder.ret(self._wrap_return(lj, ret))
|
||||
|
||||
strmod = str(self.module)
|
||||
return strmod
|
||||
|
||||
def _wrap_return(self, lj, vals):
|
||||
# Return a single double if there is one return value,
|
||||
# else return a tuple of doubles.
|
||||
|
||||
# Don't wrap return value in this case
|
||||
if self.signature.ret_type == ctypes.c_double:
|
||||
return vals[0]
|
||||
|
||||
# Use this instead of a real PyObject*
|
||||
void_ptr = ll.PointerType(ll.IntType(32))
|
||||
|
||||
# Create a wrapped double: PyObject* PyFloat_FromDouble(double v)
|
||||
wrap_type = ll.FunctionType(void_ptr, [self.fp_type])
|
||||
wrap_fn = ll.Function(lj.module, wrap_type, "PyFloat_FromDouble")
|
||||
|
||||
wrapped_vals = [lj.builder.call(wrap_fn, [v]) for v in vals]
|
||||
if len(vals) == 1:
|
||||
final_val = wrapped_vals[0]
|
||||
else:
|
||||
# Create a tuple: PyObject* PyTuple_Pack(Py_ssize_t n, ...)
|
||||
|
||||
# This should be Py_ssize_t
|
||||
tuple_arg_types = [ll.IntType(32)]
|
||||
|
||||
tuple_arg_types.extend([void_ptr]*len(vals))
|
||||
tuple_type = ll.FunctionType(void_ptr, tuple_arg_types)
|
||||
tuple_fn = ll.Function(lj.module, tuple_type, "PyTuple_Pack")
|
||||
|
||||
tuple_args = [ll.Constant(ll.IntType(32), len(wrapped_vals))]
|
||||
tuple_args.extend(wrapped_vals)
|
||||
|
||||
final_val = lj.builder.call(tuple_fn, tuple_args)
|
||||
|
||||
return final_val
|
||||
|
||||
def _convert_expr(self, lj, expr):
|
||||
try:
|
||||
# Match CSE return data structure.
|
||||
if len(expr) == 2:
|
||||
tmp_exprs = expr[0]
|
||||
final_exprs = expr[1]
|
||||
if len(final_exprs) != 1 and self.signature.ret_type == ctypes.c_double:
|
||||
raise NotImplementedError("Return of multiple expressions not supported for this callback")
|
||||
for name, e in tmp_exprs:
|
||||
val = lj._print(e)
|
||||
lj._add_tmp_var(name, val)
|
||||
except TypeError:
|
||||
final_exprs = [expr]
|
||||
|
||||
vals = [lj._print(e) for e in final_exprs]
|
||||
|
||||
return vals
|
||||
|
||||
def _compile_function(self, strmod):
|
||||
llmod = llvm.parse_assembly(strmod)
|
||||
|
||||
pmb = llvm.create_pass_manager_builder()
|
||||
pmb.opt_level = 2
|
||||
pass_manager = llvm.create_module_pass_manager()
|
||||
pmb.populate(pass_manager)
|
||||
|
||||
pass_manager.run(llmod)
|
||||
|
||||
target_machine = \
|
||||
llvm.Target.from_default_triple().create_target_machine()
|
||||
exe_eng = llvm.create_mcjit_compiler(llmod, target_machine)
|
||||
exe_eng.finalize_object()
|
||||
exe_engines.append(exe_eng)
|
||||
|
||||
if False:
|
||||
print("Assembly")
|
||||
print(target_machine.emit_assembly(llmod))
|
||||
|
||||
fptr = exe_eng.get_function_address(self.link_name)
|
||||
|
||||
return fptr
|
||||
|
||||
|
||||
class LLVMJitCodeCallback(LLVMJitCode):
|
||||
def __init__(self, signature):
|
||||
super().__init__(signature)
|
||||
|
||||
def _create_param_dict(self, func_args):
|
||||
for i, a in enumerate(func_args):
|
||||
if isinstance(a, IndexedBase):
|
||||
self.param_dict[a] = (self.fn.args[i], i)
|
||||
self.fn.args[i].name = str(a)
|
||||
else:
|
||||
self.param_dict[a] = (self.fn.args[self.signature.input_arg],
|
||||
i)
|
||||
|
||||
def _create_function(self, expr):
|
||||
"""Create function body and return LLVM IR"""
|
||||
bb_entry = self.fn.append_basic_block('entry')
|
||||
builder = ll.IRBuilder(bb_entry)
|
||||
|
||||
lj = LLVMJitCallbackPrinter(self.module, builder, self.fn,
|
||||
func_arg_map=self.param_dict)
|
||||
|
||||
ret = self._convert_expr(lj, expr)
|
||||
|
||||
if self.signature.ret_arg:
|
||||
output_fp_ptr = builder.bitcast(self.fn.args[self.signature.ret_arg],
|
||||
ll.PointerType(self.fp_type))
|
||||
for i, val in enumerate(ret):
|
||||
index = ll.Constant(ll.IntType(32), i)
|
||||
output_array_ptr = builder.gep(output_fp_ptr, [index])
|
||||
builder.store(val, output_array_ptr)
|
||||
builder.ret(ll.Constant(ll.IntType(32), 0)) # return success
|
||||
else:
|
||||
lj.builder.ret(self._wrap_return(lj, ret))
|
||||
|
||||
strmod = str(self.module)
|
||||
return strmod
|
||||
|
||||
|
||||
class CodeSignature:
|
||||
def __init__(self, ret_type):
|
||||
self.ret_type = ret_type
|
||||
self.arg_ctypes = []
|
||||
|
||||
# Input argument array element index
|
||||
self.input_arg = 0
|
||||
|
||||
# For the case output value is referenced through a parameter rather
|
||||
# than the return value
|
||||
self.ret_arg = None
|
||||
|
||||
|
||||
def _llvm_jit_code(args, expr, signature, callback_type):
|
||||
"""Create a native code function from a SymPy expression"""
|
||||
if callback_type is None:
|
||||
jit = LLVMJitCode(signature)
|
||||
else:
|
||||
jit = LLVMJitCodeCallback(signature)
|
||||
|
||||
jit._create_args(args)
|
||||
jit._create_function_base()
|
||||
jit._create_param_dict(args)
|
||||
strmod = jit._create_function(expr)
|
||||
if False:
|
||||
print("LLVM IR")
|
||||
print(strmod)
|
||||
fptr = jit._compile_function(strmod)
|
||||
return fptr
|
||||
|
||||
|
||||
@doctest_depends_on(modules=('llvmlite', 'scipy'))
|
||||
def llvm_callable(args, expr, callback_type=None):
|
||||
'''Compile function from a SymPy expression
|
||||
|
||||
Expressions are evaluated using double precision arithmetic.
|
||||
Some single argument math functions (exp, sin, cos, etc.) are supported
|
||||
in expressions.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
args : List of Symbol
|
||||
Arguments to the generated function. Usually the free symbols in
|
||||
the expression. Currently each one is assumed to convert to
|
||||
a double precision scalar.
|
||||
expr : Expr, or (Replacements, Expr) as returned from 'cse'
|
||||
Expression to compile.
|
||||
callback_type : string
|
||||
Create function with signature appropriate to use as a callback.
|
||||
Currently supported:
|
||||
'scipy.integrate'
|
||||
'scipy.integrate.test'
|
||||
'cubature'
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
Compiled function that can evaluate the expression.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> import sympy.printing.llvmjitcode as jit
|
||||
>>> from sympy.abc import a
|
||||
>>> e = a*a + a + 1
|
||||
>>> e1 = jit.llvm_callable([a], e)
|
||||
>>> e.subs(a, 1.1) # Evaluate via substitution
|
||||
3.31000000000000
|
||||
>>> e1(1.1) # Evaluate using JIT-compiled code
|
||||
3.3100000000000005
|
||||
|
||||
|
||||
Callbacks for integration functions can be JIT compiled.
|
||||
|
||||
>>> import sympy.printing.llvmjitcode as jit
|
||||
>>> from sympy.abc import a
|
||||
>>> from sympy import integrate
|
||||
>>> from scipy.integrate import quad
|
||||
>>> e = a*a
|
||||
>>> e1 = jit.llvm_callable([a], e, callback_type='scipy.integrate')
|
||||
>>> integrate(e, (a, 0.0, 2.0))
|
||||
2.66666666666667
|
||||
>>> quad(e1, 0.0, 2.0)[0]
|
||||
2.66666666666667
|
||||
|
||||
The 'cubature' callback is for the Python wrapper around the
|
||||
cubature package ( https://github.com/saullocastro/cubature )
|
||||
and ( http://ab-initio.mit.edu/wiki/index.php/Cubature )
|
||||
|
||||
There are two signatures for the SciPy integration callbacks.
|
||||
The first ('scipy.integrate') is the function to be passed to the
|
||||
integration routine, and will pass the signature checks.
|
||||
The second ('scipy.integrate.test') is only useful for directly calling
|
||||
the function using ctypes variables. It will not pass the signature checks
|
||||
for scipy.integrate.
|
||||
|
||||
The return value from the cse module can also be compiled. This
|
||||
can improve the performance of the compiled function. If multiple
|
||||
expressions are given to cse, the compiled function returns a tuple.
|
||||
The 'cubature' callback handles multiple expressions (set `fdim`
|
||||
to match in the integration call.)
|
||||
|
||||
>>> import sympy.printing.llvmjitcode as jit
|
||||
>>> from sympy import cse
|
||||
>>> from sympy.abc import x,y
|
||||
>>> e1 = x*x + y*y
|
||||
>>> e2 = 4*(x*x + y*y) + 8.0
|
||||
>>> after_cse = cse([e1,e2])
|
||||
>>> after_cse
|
||||
([(x0, x**2), (x1, y**2)], [x0 + x1, 4*x0 + 4*x1 + 8.0])
|
||||
>>> j1 = jit.llvm_callable([x,y], after_cse)
|
||||
>>> j1(1.0, 2.0)
|
||||
(5.0, 28.0)
|
||||
'''
|
||||
|
||||
if not llvmlite:
|
||||
raise ImportError("llvmlite is required for llvmjitcode")
|
||||
|
||||
signature = CodeSignature(ctypes.py_object)
|
||||
|
||||
arg_ctypes = []
|
||||
if callback_type is None:
|
||||
for _ in args:
|
||||
arg_ctype = ctypes.c_double
|
||||
arg_ctypes.append(arg_ctype)
|
||||
elif callback_type in ('scipy.integrate', 'scipy.integrate.test'):
|
||||
signature.ret_type = ctypes.c_double
|
||||
arg_ctypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_double)]
|
||||
arg_ctypes_formal = [ctypes.c_int, ctypes.c_double]
|
||||
signature.input_arg = 1
|
||||
elif callback_type == 'cubature':
|
||||
arg_ctypes = [ctypes.c_int,
|
||||
ctypes.POINTER(ctypes.c_double),
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(ctypes.c_double)
|
||||
]
|
||||
signature.ret_type = ctypes.c_int
|
||||
signature.input_arg = 1
|
||||
signature.ret_arg = 4
|
||||
else:
|
||||
raise ValueError("Unknown callback type: %s" % callback_type)
|
||||
|
||||
signature.arg_ctypes = arg_ctypes
|
||||
|
||||
fptr = _llvm_jit_code(args, expr, signature, callback_type)
|
||||
|
||||
if callback_type and callback_type == 'scipy.integrate':
|
||||
arg_ctypes = arg_ctypes_formal
|
||||
|
||||
# PYFUNCTYPE holds the GIL which is needed to prevent a segfault when
|
||||
# calling PyFloat_FromDouble on Python 3.10. Probably it is better to use
|
||||
# ctypes.c_double when returning a float rather than using ctypes.py_object
|
||||
# and returning a PyFloat from inside the jitted function (i.e. let ctypes
|
||||
# handle the conversion from double to PyFloat).
|
||||
if signature.ret_type == ctypes.py_object:
|
||||
FUNCTYPE = ctypes.PYFUNCTYPE
|
||||
else:
|
||||
FUNCTYPE = ctypes.CFUNCTYPE
|
||||
|
||||
cfunc = FUNCTYPE(signature.ret_type, *arg_ctypes)(fptr)
|
||||
return cfunc
|
||||
Reference in New Issue
Block a user