chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
from .plot import plot_backends
|
||||
from .plot_implicit import plot_implicit
|
||||
from .textplot import textplot
|
||||
from .pygletplot import PygletPlot
|
||||
from .plot import PlotGrid
|
||||
from .plot import (plot, plot_parametric, plot3d, plot3d_parametric_surface,
|
||||
plot3d_parametric_line, plot_contour)
|
||||
|
||||
__all__ = [
|
||||
'plot_backends',
|
||||
|
||||
'plot_implicit',
|
||||
|
||||
'textplot',
|
||||
|
||||
'PygletPlot',
|
||||
|
||||
'PlotGrid',
|
||||
|
||||
'plot', 'plot_parametric', 'plot3d', 'plot3d_parametric_surface',
|
||||
'plot3d_parametric_line', 'plot_contour'
|
||||
]
|
||||
@@ -0,0 +1,419 @@
|
||||
from sympy.plotting.series import BaseSeries, GenericDataSeries
|
||||
from sympy.utilities.exceptions import sympy_deprecation_warning
|
||||
from sympy.utilities.iterables import is_sequence
|
||||
|
||||
|
||||
__doctest_requires__ = {
|
||||
('Plot.append', 'Plot.extend'): ['matplotlib'],
|
||||
}
|
||||
|
||||
|
||||
# Global variable
|
||||
# Set to False when running tests / doctests so that the plots don't show.
|
||||
_show = True
|
||||
|
||||
def unset_show():
|
||||
"""
|
||||
Disable show(). For use in the tests.
|
||||
"""
|
||||
global _show
|
||||
_show = False
|
||||
|
||||
|
||||
def _deprecation_msg_m_a_r_f(attr):
|
||||
sympy_deprecation_warning(
|
||||
f"The `{attr}` property is deprecated. The `{attr}` keyword "
|
||||
"argument should be passed to a plotting function, which generates "
|
||||
"the appropriate data series. If needed, index the plot object to "
|
||||
"retrieve a specific data series.",
|
||||
deprecated_since_version="1.13",
|
||||
active_deprecations_target="deprecated-markers-annotations-fill-rectangles",
|
||||
stacklevel=4)
|
||||
|
||||
|
||||
def _create_generic_data_series(**kwargs):
|
||||
keywords = ["annotations", "markers", "fill", "rectangles"]
|
||||
series = []
|
||||
for kw in keywords:
|
||||
dictionaries = kwargs.pop(kw, [])
|
||||
if dictionaries is None:
|
||||
dictionaries = []
|
||||
if isinstance(dictionaries, dict):
|
||||
dictionaries = [dictionaries]
|
||||
for d in dictionaries:
|
||||
args = d.pop("args", [])
|
||||
series.append(GenericDataSeries(kw, *args, **d))
|
||||
return series
|
||||
|
||||
|
||||
class Plot:
|
||||
"""Base class for all backends. A backend represents the plotting library,
|
||||
which implements the necessary functionalities in order to use SymPy
|
||||
plotting functions.
|
||||
|
||||
For interactive work the function :func:`plot` is better suited.
|
||||
|
||||
This class permits the plotting of SymPy expressions using numerous
|
||||
backends (:external:mod:`matplotlib`, textplot, the old pyglet module for SymPy, Google
|
||||
charts api, etc).
|
||||
|
||||
The figure can contain an arbitrary number of plots of SymPy expressions,
|
||||
lists of coordinates of points, etc. Plot has a private attribute _series that
|
||||
contains all data series to be plotted (expressions for lines or surfaces,
|
||||
lists of points, etc (all subclasses of BaseSeries)). Those data series are
|
||||
instances of classes not imported by ``from sympy import *``.
|
||||
|
||||
The customization of the figure is on two levels. Global options that
|
||||
concern the figure as a whole (e.g. title, xlabel, scale, etc) and
|
||||
per-data series options (e.g. name) and aesthetics (e.g. color, point shape,
|
||||
line type, etc.).
|
||||
|
||||
The difference between options and aesthetics is that an aesthetic can be
|
||||
a function of the coordinates (or parameters in a parametric plot). The
|
||||
supported values for an aesthetic are:
|
||||
|
||||
- None (the backend uses default values)
|
||||
- a constant
|
||||
- a function of one variable (the first coordinate or parameter)
|
||||
- a function of two variables (the first and second coordinate or parameters)
|
||||
- a function of three variables (only in nonparametric 3D plots)
|
||||
|
||||
Their implementation depends on the backend so they may not work in some
|
||||
backends.
|
||||
|
||||
If the plot is parametric and the arity of the aesthetic function permits
|
||||
it the aesthetic is calculated over parameters and not over coordinates.
|
||||
If the arity does not permit calculation over parameters the calculation is
|
||||
done over coordinates.
|
||||
|
||||
Only cartesian coordinates are supported for the moment, but you can use
|
||||
the parametric plots to plot in polar, spherical and cylindrical
|
||||
coordinates.
|
||||
|
||||
The arguments for the constructor Plot must be subclasses of BaseSeries.
|
||||
|
||||
Any global option can be specified as a keyword argument.
|
||||
|
||||
The global options for a figure are:
|
||||
|
||||
- title : str
|
||||
- xlabel : str or Symbol
|
||||
- ylabel : str or Symbol
|
||||
- zlabel : str or Symbol
|
||||
- legend : bool
|
||||
- xscale : {'linear', 'log'}
|
||||
- yscale : {'linear', 'log'}
|
||||
- axis : bool
|
||||
- axis_center : tuple of two floats or {'center', 'auto'}
|
||||
- xlim : tuple of two floats
|
||||
- ylim : tuple of two floats
|
||||
- aspect_ratio : tuple of two floats or {'auto'}
|
||||
- autoscale : bool
|
||||
- margin : float in [0, 1]
|
||||
- backend : {'default', 'matplotlib', 'text'} or a subclass of BaseBackend
|
||||
- size : optional tuple of two floats, (width, height); default: None
|
||||
|
||||
The per data series options and aesthetics are:
|
||||
There are none in the base series. See below for options for subclasses.
|
||||
|
||||
Some data series support additional aesthetics or options:
|
||||
|
||||
:class:`~.LineOver1DRangeSeries`, :class:`~.Parametric2DLineSeries`, and
|
||||
:class:`~.Parametric3DLineSeries` support the following:
|
||||
|
||||
Aesthetics:
|
||||
|
||||
- line_color : string, or float, or function, optional
|
||||
Specifies the color for the plot, which depends on the backend being
|
||||
used.
|
||||
|
||||
For example, if ``MatplotlibBackend`` is being used, then
|
||||
Matplotlib string colors are acceptable (``"red"``, ``"r"``,
|
||||
``"cyan"``, ``"c"``, ...).
|
||||
Alternatively, we can use a float number, 0 < color < 1, wrapped in a
|
||||
string (for example, ``line_color="0.5"``) to specify grayscale colors.
|
||||
Alternatively, We can specify a function returning a single
|
||||
float value: this will be used to apply a color-loop (for example,
|
||||
``line_color=lambda x: math.cos(x)``).
|
||||
|
||||
Note that by setting line_color, it would be applied simultaneously
|
||||
to all the series.
|
||||
|
||||
Options:
|
||||
|
||||
- label : str
|
||||
- steps : bool
|
||||
- integers_only : bool
|
||||
|
||||
:class:`~.SurfaceOver2DRangeSeries` and :class:`~.ParametricSurfaceSeries`
|
||||
support the following:
|
||||
|
||||
Aesthetics:
|
||||
|
||||
- surface_color : function which returns a float.
|
||||
|
||||
Notes
|
||||
=====
|
||||
|
||||
How the plotting module works:
|
||||
|
||||
1. Whenever a plotting function is called, the provided expressions are
|
||||
processed and a list of instances of the
|
||||
:class:`~sympy.plotting.series.BaseSeries` class is created, containing
|
||||
the necessary information to plot the expressions
|
||||
(e.g. the expression, ranges, series name, ...). Eventually, these
|
||||
objects will generate the numerical data to be plotted.
|
||||
2. A subclass of :class:`~.Plot` class is instantiaed (referred to as
|
||||
backend, from now on), which stores the list of series and the main
|
||||
attributes of the plot (e.g. axis labels, title, ...).
|
||||
The backend implements the logic to generate the actual figure with
|
||||
some plotting library.
|
||||
3. When the ``show`` command is executed, series are processed one by one
|
||||
to generate numerical data and add it to the figure. The backend is also
|
||||
going to set the axis labels, title, ..., according to the values stored
|
||||
in the Plot instance.
|
||||
|
||||
The backend should check if it supports the data series that it is given
|
||||
(e.g. :class:`TextBackend` supports only
|
||||
:class:`~sympy.plotting.series.LineOver1DRangeSeries`).
|
||||
|
||||
It is the backend responsibility to know how to use the class of data series
|
||||
that it's given. Note that the current implementation of the ``*Series``
|
||||
classes is "matplotlib-centric": the numerical data returned by the
|
||||
``get_points`` and ``get_meshes`` methods is meant to be used directly by
|
||||
Matplotlib. Therefore, the new backend will have to pre-process the
|
||||
numerical data to make it compatible with the chosen plotting library.
|
||||
Keep in mind that future SymPy versions may improve the ``*Series`` classes
|
||||
in order to return numerical data "non-matplotlib-centric", hence if you code
|
||||
a new backend you have the responsibility to check if its working on each
|
||||
SymPy release.
|
||||
|
||||
Please explore the :class:`MatplotlibBackend` source code to understand
|
||||
how a backend should be coded.
|
||||
|
||||
In order to be used by SymPy plotting functions, a backend must implement
|
||||
the following methods:
|
||||
|
||||
* show(self): used to loop over the data series, generate the numerical
|
||||
data, plot it and set the axis labels, title, ...
|
||||
* save(self, path): used to save the current plot to the specified file
|
||||
path.
|
||||
* close(self): used to close the current plot backend (note: some plotting
|
||||
library does not support this functionality. In that case, just raise a
|
||||
warning).
|
||||
"""
|
||||
|
||||
def __init__(self, *args,
|
||||
title=None, xlabel=None, ylabel=None, zlabel=None, aspect_ratio='auto',
|
||||
xlim=None, ylim=None, axis_center='auto', axis=True,
|
||||
xscale='linear', yscale='linear', legend=False, autoscale=True,
|
||||
margin=0, annotations=None, markers=None, rectangles=None,
|
||||
fill=None, backend='default', size=None, **kwargs):
|
||||
|
||||
# Options for the graph as a whole.
|
||||
# The possible values for each option are described in the docstring of
|
||||
# Plot. They are based purely on convention, no checking is done.
|
||||
self.title = title
|
||||
self.xlabel = xlabel
|
||||
self.ylabel = ylabel
|
||||
self.zlabel = zlabel
|
||||
self.aspect_ratio = aspect_ratio
|
||||
self.axis_center = axis_center
|
||||
self.axis = axis
|
||||
self.xscale = xscale
|
||||
self.yscale = yscale
|
||||
self.legend = legend
|
||||
self.autoscale = autoscale
|
||||
self.margin = margin
|
||||
self._annotations = annotations
|
||||
self._markers = markers
|
||||
self._rectangles = rectangles
|
||||
self._fill = fill
|
||||
|
||||
# Contains the data objects to be plotted. The backend should be smart
|
||||
# enough to iterate over this list.
|
||||
self._series = []
|
||||
self._series.extend(args)
|
||||
self._series.extend(_create_generic_data_series(
|
||||
annotations=annotations, markers=markers, rectangles=rectangles,
|
||||
fill=fill))
|
||||
|
||||
is_real = \
|
||||
lambda lim: all(getattr(i, 'is_real', True) for i in lim)
|
||||
is_finite = \
|
||||
lambda lim: all(getattr(i, 'is_finite', True) for i in lim)
|
||||
|
||||
# reduce code repetition
|
||||
def check_and_set(t_name, t):
|
||||
if t:
|
||||
if not is_real(t):
|
||||
raise ValueError(
|
||||
"All numbers from {}={} must be real".format(t_name, t))
|
||||
if not is_finite(t):
|
||||
raise ValueError(
|
||||
"All numbers from {}={} must be finite".format(t_name, t))
|
||||
setattr(self, t_name, (float(t[0]), float(t[1])))
|
||||
|
||||
self.xlim = None
|
||||
check_and_set("xlim", xlim)
|
||||
self.ylim = None
|
||||
check_and_set("ylim", ylim)
|
||||
self.size = None
|
||||
check_and_set("size", size)
|
||||
|
||||
@property
|
||||
def _backend(self):
|
||||
return self
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
return type(self)
|
||||
|
||||
def __str__(self):
|
||||
series_strs = [('[%d]: ' % i) + str(s)
|
||||
for i, s in enumerate(self._series)]
|
||||
return 'Plot object containing:\n' + '\n'.join(series_strs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._series[index]
|
||||
|
||||
def __setitem__(self, index, *args):
|
||||
if len(args) == 1 and isinstance(args[0], BaseSeries):
|
||||
self._series[index] = args
|
||||
|
||||
def __delitem__(self, index):
|
||||
del self._series[index]
|
||||
|
||||
def append(self, arg):
|
||||
"""Adds an element from a plot's series to an existing plot.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the
|
||||
second plot's first series object to the first, use the
|
||||
``append`` method, like so:
|
||||
|
||||
.. plot::
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> from sympy import symbols
|
||||
>>> from sympy.plotting import plot
|
||||
>>> x = symbols('x')
|
||||
>>> p1 = plot(x*x, show=False)
|
||||
>>> p2 = plot(x, show=False)
|
||||
>>> p1.append(p2[0])
|
||||
>>> p1
|
||||
Plot object containing:
|
||||
[0]: cartesian line: x**2 for x over (-10.0, 10.0)
|
||||
[1]: cartesian line: x for x over (-10.0, 10.0)
|
||||
>>> p1.show()
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
extend
|
||||
|
||||
"""
|
||||
if isinstance(arg, BaseSeries):
|
||||
self._series.append(arg)
|
||||
else:
|
||||
raise TypeError('Must specify element of plot to append.')
|
||||
|
||||
def extend(self, arg):
|
||||
"""Adds all series from another plot.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the
|
||||
second plot to the first, use the ``extend`` method, like so:
|
||||
|
||||
.. plot::
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> from sympy import symbols
|
||||
>>> from sympy.plotting import plot
|
||||
>>> x = symbols('x')
|
||||
>>> p1 = plot(x**2, show=False)
|
||||
>>> p2 = plot(x, -x, show=False)
|
||||
>>> p1.extend(p2)
|
||||
>>> p1
|
||||
Plot object containing:
|
||||
[0]: cartesian line: x**2 for x over (-10.0, 10.0)
|
||||
[1]: cartesian line: x for x over (-10.0, 10.0)
|
||||
[2]: cartesian line: -x for x over (-10.0, 10.0)
|
||||
>>> p1.show()
|
||||
|
||||
"""
|
||||
if isinstance(arg, Plot):
|
||||
self._series.extend(arg._series)
|
||||
elif is_sequence(arg):
|
||||
self._series.extend(arg)
|
||||
else:
|
||||
raise TypeError('Expecting Plot or sequence of BaseSeries')
|
||||
|
||||
def show(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, path):
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# deprecations
|
||||
|
||||
@property
|
||||
def markers(self):
|
||||
""".. deprecated:: 1.13"""
|
||||
_deprecation_msg_m_a_r_f("markers")
|
||||
return self._markers
|
||||
|
||||
@markers.setter
|
||||
def markers(self, v):
|
||||
""".. deprecated:: 1.13"""
|
||||
_deprecation_msg_m_a_r_f("markers")
|
||||
self._series.extend(_create_generic_data_series(markers=v))
|
||||
self._markers = v
|
||||
|
||||
@property
|
||||
def annotations(self):
|
||||
""".. deprecated:: 1.13"""
|
||||
_deprecation_msg_m_a_r_f("annotations")
|
||||
return self._annotations
|
||||
|
||||
@annotations.setter
|
||||
def annotations(self, v):
|
||||
""".. deprecated:: 1.13"""
|
||||
_deprecation_msg_m_a_r_f("annotations")
|
||||
self._series.extend(_create_generic_data_series(annotations=v))
|
||||
self._annotations = v
|
||||
|
||||
@property
|
||||
def rectangles(self):
|
||||
""".. deprecated:: 1.13"""
|
||||
_deprecation_msg_m_a_r_f("rectangles")
|
||||
return self._rectangles
|
||||
|
||||
@rectangles.setter
|
||||
def rectangles(self, v):
|
||||
""".. deprecated:: 1.13"""
|
||||
_deprecation_msg_m_a_r_f("rectangles")
|
||||
self._series.extend(_create_generic_data_series(rectangles=v))
|
||||
self._rectangles = v
|
||||
|
||||
@property
|
||||
def fill(self):
|
||||
""".. deprecated:: 1.13"""
|
||||
_deprecation_msg_m_a_r_f("fill")
|
||||
return self._fill
|
||||
|
||||
@fill.setter
|
||||
def fill(self, v):
|
||||
""".. deprecated:: 1.13"""
|
||||
_deprecation_msg_m_a_r_f("fill")
|
||||
self._series.extend(_create_generic_data_series(fill=v))
|
||||
self._fill = v
|
||||
@@ -0,0 +1,5 @@
|
||||
from sympy.plotting.backends.matplotlibbackend.matplotlib import (
|
||||
MatplotlibBackend, _matplotlib_list
|
||||
)
|
||||
|
||||
__all__ = ["MatplotlibBackend", "_matplotlib_list"]
|
||||
@@ -0,0 +1,318 @@
|
||||
from collections.abc import Callable
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.external import import_module
|
||||
import sympy.plotting.backends.base_backend as base_backend
|
||||
from sympy.printing.latex import latex
|
||||
|
||||
|
||||
# N.B.
|
||||
# When changing the minimum module version for matplotlib, please change
|
||||
# the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py`
|
||||
|
||||
|
||||
def _str_or_latex(label):
|
||||
if isinstance(label, Basic):
|
||||
return latex(label, mode='inline')
|
||||
return str(label)
|
||||
|
||||
|
||||
def _matplotlib_list(interval_list):
|
||||
"""
|
||||
Returns lists for matplotlib ``fill`` command from a list of bounding
|
||||
rectangular intervals
|
||||
"""
|
||||
xlist = []
|
||||
ylist = []
|
||||
if len(interval_list):
|
||||
for intervals in interval_list:
|
||||
intervalx = intervals[0]
|
||||
intervaly = intervals[1]
|
||||
xlist.extend([intervalx.start, intervalx.start,
|
||||
intervalx.end, intervalx.end, None])
|
||||
ylist.extend([intervaly.start, intervaly.end,
|
||||
intervaly.end, intervaly.start, None])
|
||||
else:
|
||||
#XXX Ugly hack. Matplotlib does not accept empty lists for ``fill``
|
||||
xlist.extend((None, None, None, None))
|
||||
ylist.extend((None, None, None, None))
|
||||
return xlist, ylist
|
||||
|
||||
|
||||
# Don't have to check for the success of importing matplotlib in each case;
|
||||
# we will only be using this backend if we can successfully import matploblib
|
||||
class MatplotlibBackend(base_backend.Plot):
|
||||
""" This class implements the functionalities to use Matplotlib with SymPy
|
||||
plotting functions.
|
||||
"""
|
||||
|
||||
def __init__(self, *series, **kwargs):
|
||||
super().__init__(*series, **kwargs)
|
||||
self.matplotlib = import_module('matplotlib',
|
||||
import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']},
|
||||
min_module_version='1.1.0', catch=(RuntimeError,))
|
||||
self.plt = self.matplotlib.pyplot
|
||||
self.cm = self.matplotlib.cm
|
||||
self.LineCollection = self.matplotlib.collections.LineCollection
|
||||
self.aspect = kwargs.get('aspect_ratio', 'auto')
|
||||
if self.aspect != 'auto':
|
||||
self.aspect = float(self.aspect[1]) / self.aspect[0]
|
||||
# PlotGrid can provide its figure and axes to be populated with
|
||||
# the data from the series.
|
||||
self._plotgrid_fig = kwargs.pop("fig", None)
|
||||
self._plotgrid_ax = kwargs.pop("ax", None)
|
||||
|
||||
def _create_figure(self):
|
||||
def set_spines(ax):
|
||||
ax.spines['left'].set_position('zero')
|
||||
ax.spines['right'].set_color('none')
|
||||
ax.spines['bottom'].set_position('zero')
|
||||
ax.spines['top'].set_color('none')
|
||||
ax.xaxis.set_ticks_position('bottom')
|
||||
ax.yaxis.set_ticks_position('left')
|
||||
|
||||
if self._plotgrid_fig is not None:
|
||||
self.fig = self._plotgrid_fig
|
||||
self.ax = self._plotgrid_ax
|
||||
if not any(s.is_3D for s in self._series):
|
||||
set_spines(self.ax)
|
||||
else:
|
||||
self.fig = self.plt.figure(figsize=self.size)
|
||||
if any(s.is_3D for s in self._series):
|
||||
self.ax = self.fig.add_subplot(1, 1, 1, projection="3d")
|
||||
else:
|
||||
self.ax = self.fig.add_subplot(1, 1, 1)
|
||||
set_spines(self.ax)
|
||||
|
||||
@staticmethod
|
||||
def get_segments(x, y, z=None):
|
||||
""" Convert two list of coordinates to a list of segments to be used
|
||||
with Matplotlib's :external:class:`~matplotlib.collections.LineCollection`.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
x : list
|
||||
List of x-coordinates
|
||||
|
||||
y : list
|
||||
List of y-coordinates
|
||||
|
||||
z : list
|
||||
List of z-coordinates for a 3D line.
|
||||
"""
|
||||
np = import_module('numpy')
|
||||
if z is not None:
|
||||
dim = 3
|
||||
points = (x, y, z)
|
||||
else:
|
||||
dim = 2
|
||||
points = (x, y)
|
||||
points = np.ma.array(points).T.reshape(-1, 1, dim)
|
||||
return np.ma.concatenate([points[:-1], points[1:]], axis=1)
|
||||
|
||||
def _process_series(self, series, ax):
|
||||
np = import_module('numpy')
|
||||
mpl_toolkits = import_module(
|
||||
'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']})
|
||||
|
||||
# XXX Workaround for matplotlib issue
|
||||
# https://github.com/matplotlib/matplotlib/issues/17130
|
||||
xlims, ylims, zlims = [], [], []
|
||||
|
||||
for s in series:
|
||||
# Create the collections
|
||||
if s.is_2Dline:
|
||||
if s.is_parametric:
|
||||
x, y, param = s.get_data()
|
||||
else:
|
||||
x, y = s.get_data()
|
||||
if (isinstance(s.line_color, (int, float)) or
|
||||
callable(s.line_color)):
|
||||
segments = self.get_segments(x, y)
|
||||
collection = self.LineCollection(segments)
|
||||
collection.set_array(s.get_color_array())
|
||||
ax.add_collection(collection)
|
||||
else:
|
||||
lbl = _str_or_latex(s.label)
|
||||
line, = ax.plot(x, y, label=lbl, color=s.line_color)
|
||||
elif s.is_contour:
|
||||
ax.contour(*s.get_data())
|
||||
elif s.is_3Dline:
|
||||
x, y, z, param = s.get_data()
|
||||
if (isinstance(s.line_color, (int, float)) or
|
||||
callable(s.line_color)):
|
||||
art3d = mpl_toolkits.mplot3d.art3d
|
||||
segments = self.get_segments(x, y, z)
|
||||
collection = art3d.Line3DCollection(segments)
|
||||
collection.set_array(s.get_color_array())
|
||||
ax.add_collection(collection)
|
||||
else:
|
||||
lbl = _str_or_latex(s.label)
|
||||
ax.plot(x, y, z, label=lbl, color=s.line_color)
|
||||
|
||||
xlims.append(s._xlim)
|
||||
ylims.append(s._ylim)
|
||||
zlims.append(s._zlim)
|
||||
elif s.is_3Dsurface:
|
||||
if s.is_parametric:
|
||||
x, y, z, u, v = s.get_data()
|
||||
else:
|
||||
x, y, z = s.get_data()
|
||||
collection = ax.plot_surface(x, y, z,
|
||||
cmap=getattr(self.cm, 'viridis', self.cm.jet),
|
||||
rstride=1, cstride=1, linewidth=0.1)
|
||||
if isinstance(s.surface_color, (float, int, Callable)):
|
||||
color_array = s.get_color_array()
|
||||
color_array = color_array.reshape(color_array.size)
|
||||
collection.set_array(color_array)
|
||||
else:
|
||||
collection.set_color(s.surface_color)
|
||||
|
||||
xlims.append(s._xlim)
|
||||
ylims.append(s._ylim)
|
||||
zlims.append(s._zlim)
|
||||
elif s.is_implicit:
|
||||
points = s.get_data()
|
||||
if len(points) == 2:
|
||||
# interval math plotting
|
||||
x, y = _matplotlib_list(points[0])
|
||||
ax.fill(x, y, facecolor=s.line_color, edgecolor='None')
|
||||
else:
|
||||
# use contourf or contour depending on whether it is
|
||||
# an inequality or equality.
|
||||
# XXX: ``contour`` plots multiple lines. Should be fixed.
|
||||
ListedColormap = self.matplotlib.colors.ListedColormap
|
||||
colormap = ListedColormap(["white", s.line_color])
|
||||
xarray, yarray, zarray, plot_type = points
|
||||
if plot_type == 'contour':
|
||||
ax.contour(xarray, yarray, zarray, cmap=colormap)
|
||||
else:
|
||||
ax.contourf(xarray, yarray, zarray, cmap=colormap)
|
||||
elif s.is_generic:
|
||||
if s.type == "markers":
|
||||
# s.rendering_kw["color"] = s.line_color
|
||||
ax.plot(*s.args, **s.rendering_kw)
|
||||
elif s.type == "annotations":
|
||||
ax.annotate(*s.args, **s.rendering_kw)
|
||||
elif s.type == "fill":
|
||||
# s.rendering_kw["color"] = s.line_color
|
||||
ax.fill_between(*s.args, **s.rendering_kw)
|
||||
elif s.type == "rectangles":
|
||||
# s.rendering_kw["color"] = s.line_color
|
||||
ax.add_patch(
|
||||
self.matplotlib.patches.Rectangle(
|
||||
*s.args, **s.rendering_kw))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'{} is not supported in the SymPy plotting module '
|
||||
'with matplotlib backend. Please report this issue.'
|
||||
.format(ax))
|
||||
|
||||
Axes3D = mpl_toolkits.mplot3d.Axes3D
|
||||
if not isinstance(ax, Axes3D):
|
||||
ax.autoscale_view(
|
||||
scalex=ax.get_autoscalex_on(),
|
||||
scaley=ax.get_autoscaley_on())
|
||||
else:
|
||||
# XXX Workaround for matplotlib issue
|
||||
# https://github.com/matplotlib/matplotlib/issues/17130
|
||||
if xlims:
|
||||
xlims = np.array(xlims)
|
||||
xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1]))
|
||||
ax.set_xlim(xlim)
|
||||
else:
|
||||
ax.set_xlim([0, 1])
|
||||
|
||||
if ylims:
|
||||
ylims = np.array(ylims)
|
||||
ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1]))
|
||||
ax.set_ylim(ylim)
|
||||
else:
|
||||
ax.set_ylim([0, 1])
|
||||
|
||||
if zlims:
|
||||
zlims = np.array(zlims)
|
||||
zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1]))
|
||||
ax.set_zlim(zlim)
|
||||
else:
|
||||
ax.set_zlim([0, 1])
|
||||
|
||||
# Set global options.
|
||||
# TODO The 3D stuff
|
||||
# XXX The order of those is important.
|
||||
if self.xscale and not isinstance(ax, Axes3D):
|
||||
ax.set_xscale(self.xscale)
|
||||
if self.yscale and not isinstance(ax, Axes3D):
|
||||
ax.set_yscale(self.yscale)
|
||||
if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check
|
||||
ax.set_autoscale_on(self.autoscale)
|
||||
if self.axis_center:
|
||||
val = self.axis_center
|
||||
if isinstance(ax, Axes3D):
|
||||
pass
|
||||
elif val == 'center':
|
||||
ax.spines['left'].set_position('center')
|
||||
ax.spines['bottom'].set_position('center')
|
||||
elif val == 'auto':
|
||||
xl, xh = ax.get_xlim()
|
||||
yl, yh = ax.get_ylim()
|
||||
pos_left = ('data', 0) if xl*xh <= 0 else 'center'
|
||||
pos_bottom = ('data', 0) if yl*yh <= 0 else 'center'
|
||||
ax.spines['left'].set_position(pos_left)
|
||||
ax.spines['bottom'].set_position(pos_bottom)
|
||||
else:
|
||||
ax.spines['left'].set_position(('data', val[0]))
|
||||
ax.spines['bottom'].set_position(('data', val[1]))
|
||||
if not self.axis:
|
||||
ax.set_axis_off()
|
||||
if self.legend:
|
||||
if ax.legend():
|
||||
ax.legend_.set_visible(self.legend)
|
||||
if self.margin:
|
||||
ax.set_xmargin(self.margin)
|
||||
ax.set_ymargin(self.margin)
|
||||
if self.title:
|
||||
ax.set_title(self.title)
|
||||
if self.xlabel:
|
||||
xlbl = _str_or_latex(self.xlabel)
|
||||
ax.set_xlabel(xlbl, position=(1, 0))
|
||||
if self.ylabel:
|
||||
ylbl = _str_or_latex(self.ylabel)
|
||||
ax.set_ylabel(ylbl, position=(0, 1))
|
||||
if isinstance(ax, Axes3D) and self.zlabel:
|
||||
zlbl = _str_or_latex(self.zlabel)
|
||||
ax.set_zlabel(zlbl, position=(0, 1))
|
||||
|
||||
# xlim and ylim should always be set at last so that plot limits
|
||||
# doesn't get altered during the process.
|
||||
if self.xlim:
|
||||
ax.set_xlim(self.xlim)
|
||||
if self.ylim:
|
||||
ax.set_ylim(self.ylim)
|
||||
self.ax.set_aspect(self.aspect)
|
||||
|
||||
|
||||
def process_series(self):
|
||||
"""
|
||||
Iterates over every ``Plot`` object and further calls
|
||||
_process_series()
|
||||
"""
|
||||
self._create_figure()
|
||||
self._process_series(self._series, self.ax)
|
||||
|
||||
def show(self):
|
||||
self.process_series()
|
||||
#TODO after fixing https://github.com/ipython/ipython/issues/1255
|
||||
# you can uncomment the next line and remove the pyplot.show() call
|
||||
#self.fig.show()
|
||||
if base_backend._show:
|
||||
self.fig.tight_layout()
|
||||
self.plt.show()
|
||||
else:
|
||||
self.close()
|
||||
|
||||
def save(self, path):
|
||||
self.process_series()
|
||||
self.fig.savefig(path)
|
||||
|
||||
def close(self):
|
||||
self.plt.close(self.fig)
|
||||
@@ -0,0 +1,3 @@
|
||||
from sympy.plotting.backends.textbackend.text import TextBackend
|
||||
|
||||
__all__ = ["TextBackend"]
|
||||
@@ -0,0 +1,24 @@
|
||||
import sympy.plotting.backends.base_backend as base_backend
|
||||
from sympy.plotting.series import LineOver1DRangeSeries
|
||||
from sympy.plotting.textplot import textplot
|
||||
|
||||
|
||||
class TextBackend(base_backend.Plot):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def show(self):
|
||||
if not base_backend._show:
|
||||
return
|
||||
if len(self._series) != 1:
|
||||
raise ValueError(
|
||||
'The TextBackend supports only one graph per Plot.')
|
||||
elif not isinstance(self._series[0], LineOver1DRangeSeries):
|
||||
raise ValueError(
|
||||
'The TextBackend supports only expressions over a 1D range')
|
||||
else:
|
||||
ser = self._series[0]
|
||||
textplot(ser.expr, ser.start, ser.end)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
@@ -0,0 +1,641 @@
|
||||
""" rewrite of lambdify - This stuff is not stable at all.
|
||||
|
||||
It is for internal use in the new plotting module.
|
||||
It may (will! see the Q'n'A in the source) be rewritten.
|
||||
|
||||
It's completely self contained. Especially it does not use lambdarepr.
|
||||
|
||||
It does not aim to replace the current lambdify. Most importantly it will never
|
||||
ever support anything else than SymPy expressions (no Matrices, dictionaries
|
||||
and so on).
|
||||
"""
|
||||
|
||||
|
||||
import re
|
||||
from sympy.core.numbers import (I, NumberSymbol, oo, zoo)
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.utilities.iterables import numbered_symbols
|
||||
|
||||
# We parse the expression string into a tree that identifies functions. Then
|
||||
# we translate the names of the functions and we translate also some strings
|
||||
# that are not names of functions (all this according to translation
|
||||
# dictionaries).
|
||||
# If the translation goes to another module (like numpy) the
|
||||
# module is imported and 'func' is translated to 'module.func'.
|
||||
# If a function can not be translated, the inner nodes of that part of the
|
||||
# tree are not translated. So if we have Integral(sqrt(x)), sqrt is not
|
||||
# translated to np.sqrt and the Integral does not crash.
|
||||
# A namespace for all this is generated by crawling the (func, args) tree of
|
||||
# the expression. The creation of this namespace involves many ugly
|
||||
# workarounds.
|
||||
# The namespace consists of all the names needed for the SymPy expression and
|
||||
# all the name of modules used for translation. Those modules are imported only
|
||||
# as a name (import numpy as np) in order to keep the namespace small and
|
||||
# manageable.
|
||||
|
||||
# Please, if there is a bug, do not try to fix it here! Rewrite this by using
|
||||
# the method proposed in the last Q'n'A below. That way the new function will
|
||||
# work just as well, be just as simple, but it wont need any new workarounds.
|
||||
# If you insist on fixing it here, look at the workarounds in the function
|
||||
# sympy_expression_namespace and in lambdify.
|
||||
|
||||
# Q: Why are you not using Python abstract syntax tree?
|
||||
# A: Because it is more complicated and not much more powerful in this case.
|
||||
|
||||
# Q: What if I have Symbol('sin') or g=Function('f')?
|
||||
# A: You will break the algorithm. We should use srepr to defend against this?
|
||||
# The problem with Symbol('sin') is that it will be printed as 'sin'. The
|
||||
# parser will distinguish it from the function 'sin' because functions are
|
||||
# detected thanks to the opening parenthesis, but the lambda expression won't
|
||||
# understand the difference if we have also the sin function.
|
||||
# The solution (complicated) is to use srepr and maybe ast.
|
||||
# The problem with the g=Function('f') is that it will be printed as 'f' but in
|
||||
# the global namespace we have only 'g'. But as the same printer is used in the
|
||||
# constructor of the namespace there will be no problem.
|
||||
|
||||
# Q: What if some of the printers are not printing as expected?
|
||||
# A: The algorithm wont work. You must use srepr for those cases. But even
|
||||
# srepr may not print well. All problems with printers should be considered
|
||||
# bugs.
|
||||
|
||||
# Q: What about _imp_ functions?
|
||||
# A: Those are taken care for by evalf. A special case treatment will work
|
||||
# faster but it's not worth the code complexity.
|
||||
|
||||
# Q: Will ast fix all possible problems?
|
||||
# A: No. You will always have to use some printer. Even srepr may not work in
|
||||
# some cases. But if the printer does not work, that should be considered a
|
||||
# bug.
|
||||
|
||||
# Q: Is there same way to fix all possible problems?
|
||||
# A: Probably by constructing our strings ourself by traversing the (func,
|
||||
# args) tree and creating the namespace at the same time. That actually sounds
|
||||
# good.
|
||||
|
||||
from sympy.external import import_module
|
||||
import warnings
|
||||
|
||||
#TODO debugging output
|
||||
|
||||
|
||||
class vectorized_lambdify:
|
||||
""" Return a sufficiently smart, vectorized and lambdified function.
|
||||
|
||||
Returns only reals.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
This function uses experimental_lambdify to created a lambdified
|
||||
expression ready to be used with numpy. Many of the functions in SymPy
|
||||
are not implemented in numpy so in some cases we resort to Python cmath or
|
||||
even to evalf.
|
||||
|
||||
The following translations are tried:
|
||||
only numpy complex
|
||||
- on errors raised by SymPy trying to work with ndarray:
|
||||
only Python cmath and then vectorize complex128
|
||||
|
||||
When using Python cmath there is no need for evalf or float/complex
|
||||
because Python cmath calls those.
|
||||
|
||||
This function never tries to mix numpy directly with evalf because numpy
|
||||
does not understand SymPy Float. If this is needed one can use the
|
||||
float_wrap_evalf/complex_wrap_evalf options of experimental_lambdify or
|
||||
better one can be explicit about the dtypes that numpy works with.
|
||||
Check numpy bug http://projects.scipy.org/numpy/ticket/1013 to know what
|
||||
types of errors to expect.
|
||||
"""
|
||||
def __init__(self, args, expr):
|
||||
self.args = args
|
||||
self.expr = expr
|
||||
self.np = import_module('numpy')
|
||||
|
||||
self.lambda_func_1 = experimental_lambdify(
|
||||
args, expr, use_np=True)
|
||||
self.vector_func_1 = self.lambda_func_1
|
||||
|
||||
self.lambda_func_2 = experimental_lambdify(
|
||||
args, expr, use_python_cmath=True)
|
||||
self.vector_func_2 = self.np.vectorize(
|
||||
self.lambda_func_2, otypes=[complex])
|
||||
|
||||
self.vector_func = self.vector_func_1
|
||||
self.failure = False
|
||||
|
||||
def __call__(self, *args):
|
||||
np = self.np
|
||||
|
||||
try:
|
||||
temp_args = (np.array(a, dtype=complex) for a in args)
|
||||
results = self.vector_func(*temp_args)
|
||||
results = np.ma.masked_where(
|
||||
np.abs(results.imag) > 1e-7 * np.abs(results),
|
||||
results.real, copy=False)
|
||||
return results
|
||||
except ValueError:
|
||||
if self.failure:
|
||||
raise
|
||||
|
||||
self.failure = True
|
||||
self.vector_func = self.vector_func_2
|
||||
warnings.warn(
|
||||
'The evaluation of the expression is problematic. '
|
||||
'We are trying a failback method that may still work. '
|
||||
'Please report this as a bug.')
|
||||
return self.__call__(*args)
|
||||
|
||||
|
||||
class lambdify:
|
||||
"""Returns the lambdified function.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
This function uses experimental_lambdify to create a lambdified
|
||||
expression. It uses cmath to lambdify the expression. If the function
|
||||
is not implemented in Python cmath, Python cmath calls evalf on those
|
||||
functions.
|
||||
"""
|
||||
|
||||
def __init__(self, args, expr):
|
||||
self.args = args
|
||||
self.expr = expr
|
||||
self.lambda_func_1 = experimental_lambdify(
|
||||
args, expr, use_python_cmath=True, use_evalf=True)
|
||||
self.lambda_func_2 = experimental_lambdify(
|
||||
args, expr, use_python_math=True, use_evalf=True)
|
||||
self.lambda_func_3 = experimental_lambdify(
|
||||
args, expr, use_evalf=True, complex_wrap_evalf=True)
|
||||
self.lambda_func = self.lambda_func_1
|
||||
self.failure = False
|
||||
|
||||
def __call__(self, args):
|
||||
try:
|
||||
#The result can be sympy.Float. Hence wrap it with complex type.
|
||||
result = complex(self.lambda_func(args))
|
||||
if abs(result.imag) > 1e-7 * abs(result):
|
||||
return None
|
||||
return result.real
|
||||
except (ZeroDivisionError, OverflowError):
|
||||
return None
|
||||
except TypeError as e:
|
||||
if self.failure:
|
||||
raise e
|
||||
|
||||
if self.lambda_func == self.lambda_func_1:
|
||||
self.lambda_func = self.lambda_func_2
|
||||
return self.__call__(args)
|
||||
|
||||
self.failure = True
|
||||
self.lambda_func = self.lambda_func_3
|
||||
warnings.warn(
|
||||
'The evaluation of the expression is problematic. '
|
||||
'We are trying a failback method that may still work. '
|
||||
'Please report this as a bug.', stacklevel=2)
|
||||
return self.__call__(args)
|
||||
|
||||
|
||||
def experimental_lambdify(*args, **kwargs):
|
||||
l = Lambdifier(*args, **kwargs)
|
||||
return l
|
||||
|
||||
|
||||
class Lambdifier:
|
||||
def __init__(self, args, expr, print_lambda=False, use_evalf=False,
|
||||
float_wrap_evalf=False, complex_wrap_evalf=False,
|
||||
use_np=False, use_python_math=False, use_python_cmath=False,
|
||||
use_interval=False):
|
||||
|
||||
self.print_lambda = print_lambda
|
||||
self.use_evalf = use_evalf
|
||||
self.float_wrap_evalf = float_wrap_evalf
|
||||
self.complex_wrap_evalf = complex_wrap_evalf
|
||||
self.use_np = use_np
|
||||
self.use_python_math = use_python_math
|
||||
self.use_python_cmath = use_python_cmath
|
||||
self.use_interval = use_interval
|
||||
|
||||
# Constructing the argument string
|
||||
# - check
|
||||
if not all(isinstance(a, Symbol) for a in args):
|
||||
raise ValueError('The arguments must be Symbols.')
|
||||
# - use numbered symbols
|
||||
syms = numbered_symbols(exclude=expr.free_symbols)
|
||||
newargs = [next(syms) for _ in args]
|
||||
expr = expr.xreplace(dict(zip(args, newargs)))
|
||||
argstr = ', '.join([str(a) for a in newargs])
|
||||
del syms, newargs, args
|
||||
|
||||
# Constructing the translation dictionaries and making the translation
|
||||
self.dict_str = self.get_dict_str()
|
||||
self.dict_fun = self.get_dict_fun()
|
||||
exprstr = str(expr)
|
||||
newexpr = self.tree2str_translate(self.str2tree(exprstr))
|
||||
|
||||
# Constructing the namespaces
|
||||
namespace = {}
|
||||
namespace.update(self.sympy_atoms_namespace(expr))
|
||||
namespace.update(self.sympy_expression_namespace(expr))
|
||||
# XXX Workaround
|
||||
# Ugly workaround because Pow(a,Half) prints as sqrt(a)
|
||||
# and sympy_expression_namespace can not catch it.
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
namespace.update({'sqrt': sqrt})
|
||||
namespace.update({'Eq': lambda x, y: x == y})
|
||||
namespace.update({'Ne': lambda x, y: x != y})
|
||||
# End workaround.
|
||||
if use_python_math:
|
||||
namespace.update({'math': __import__('math')})
|
||||
if use_python_cmath:
|
||||
namespace.update({'cmath': __import__('cmath')})
|
||||
if use_np:
|
||||
try:
|
||||
namespace.update({'np': __import__('numpy')})
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'experimental_lambdify failed to import numpy.')
|
||||
if use_interval:
|
||||
namespace.update({'imath': __import__(
|
||||
'sympy.plotting.intervalmath', fromlist=['intervalmath'])})
|
||||
namespace.update({'math': __import__('math')})
|
||||
|
||||
# Construct the lambda
|
||||
if self.print_lambda:
|
||||
print(newexpr)
|
||||
eval_str = 'lambda %s : ( %s )' % (argstr, newexpr)
|
||||
self.eval_str = eval_str
|
||||
exec("MYNEWLAMBDA = %s" % eval_str, namespace)
|
||||
self.lambda_func = namespace['MYNEWLAMBDA']
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.lambda_func(*args, **kwargs)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Dicts for translating from SymPy to other modules
|
||||
##############################################################################
|
||||
###
|
||||
# builtins
|
||||
###
|
||||
# Functions with different names in builtins
|
||||
builtin_functions_different = {
|
||||
'Min': 'min',
|
||||
'Max': 'max',
|
||||
'Abs': 'abs',
|
||||
}
|
||||
|
||||
# Strings that should be translated
|
||||
builtin_not_functions = {
|
||||
'I': '1j',
|
||||
# 'oo': '1e400',
|
||||
}
|
||||
|
||||
###
|
||||
# numpy
|
||||
###
|
||||
|
||||
# Functions that are the same in numpy
|
||||
numpy_functions_same = [
|
||||
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'exp', 'log',
|
||||
'sqrt', 'floor', 'conjugate', 'sign',
|
||||
]
|
||||
|
||||
# Functions with different names in numpy
|
||||
numpy_functions_different = {
|
||||
"acos": "arccos",
|
||||
"acosh": "arccosh",
|
||||
"arg": "angle",
|
||||
"asin": "arcsin",
|
||||
"asinh": "arcsinh",
|
||||
"atan": "arctan",
|
||||
"atan2": "arctan2",
|
||||
"atanh": "arctanh",
|
||||
"ceiling": "ceil",
|
||||
"im": "imag",
|
||||
"ln": "log",
|
||||
"Max": "amax",
|
||||
"Min": "amin",
|
||||
"re": "real",
|
||||
"Abs": "abs",
|
||||
}
|
||||
|
||||
# Strings that should be translated
|
||||
numpy_not_functions = {
|
||||
'pi': 'np.pi',
|
||||
'oo': 'np.inf',
|
||||
'E': 'np.e',
|
||||
}
|
||||
|
||||
###
|
||||
# Python math
|
||||
###
|
||||
|
||||
# Functions that are the same in math
|
||||
math_functions_same = [
|
||||
'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2',
|
||||
'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh',
|
||||
'exp', 'log', 'erf', 'sqrt', 'floor', 'factorial', 'gamma',
|
||||
]
|
||||
|
||||
# Functions with different names in math
|
||||
math_functions_different = {
|
||||
'ceiling': 'ceil',
|
||||
'ln': 'log',
|
||||
'loggamma': 'lgamma'
|
||||
}
|
||||
|
||||
# Strings that should be translated
|
||||
math_not_functions = {
|
||||
'pi': 'math.pi',
|
||||
'E': 'math.e',
|
||||
}
|
||||
|
||||
###
|
||||
# Python cmath
|
||||
###
|
||||
|
||||
# Functions that are the same in cmath
|
||||
cmath_functions_same = [
|
||||
'sin', 'cos', 'tan', 'asin', 'acos', 'atan',
|
||||
'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh',
|
||||
'exp', 'log', 'sqrt',
|
||||
]
|
||||
|
||||
# Functions with different names in cmath
|
||||
cmath_functions_different = {
|
||||
'ln': 'log',
|
||||
'arg': 'phase',
|
||||
}
|
||||
|
||||
# Strings that should be translated
|
||||
cmath_not_functions = {
|
||||
'pi': 'cmath.pi',
|
||||
'E': 'cmath.e',
|
||||
}
|
||||
|
||||
###
|
||||
# intervalmath
|
||||
###
|
||||
|
||||
interval_not_functions = {
|
||||
'pi': 'math.pi',
|
||||
'E': 'math.e'
|
||||
}
|
||||
|
||||
interval_functions_same = [
|
||||
'sin', 'cos', 'exp', 'tan', 'atan', 'log',
|
||||
'sqrt', 'cosh', 'sinh', 'tanh', 'floor',
|
||||
'acos', 'asin', 'acosh', 'asinh', 'atanh',
|
||||
'Abs', 'And', 'Or'
|
||||
]
|
||||
|
||||
interval_functions_different = {
|
||||
'Min': 'imin',
|
||||
'Max': 'imax',
|
||||
'ceiling': 'ceil',
|
||||
|
||||
}
|
||||
|
||||
###
|
||||
# mpmath, etc
|
||||
###
|
||||
#TODO
|
||||
|
||||
###
|
||||
# Create the final ordered tuples of dictionaries
|
||||
###
|
||||
|
||||
# For strings
|
||||
def get_dict_str(self):
|
||||
dict_str = dict(self.builtin_not_functions)
|
||||
if self.use_np:
|
||||
dict_str.update(self.numpy_not_functions)
|
||||
if self.use_python_math:
|
||||
dict_str.update(self.math_not_functions)
|
||||
if self.use_python_cmath:
|
||||
dict_str.update(self.cmath_not_functions)
|
||||
if self.use_interval:
|
||||
dict_str.update(self.interval_not_functions)
|
||||
return dict_str
|
||||
|
||||
# For functions
|
||||
def get_dict_fun(self):
|
||||
dict_fun = dict(self.builtin_functions_different)
|
||||
if self.use_np:
|
||||
for s in self.numpy_functions_same:
|
||||
dict_fun[s] = 'np.' + s
|
||||
for k, v in self.numpy_functions_different.items():
|
||||
dict_fun[k] = 'np.' + v
|
||||
if self.use_python_math:
|
||||
for s in self.math_functions_same:
|
||||
dict_fun[s] = 'math.' + s
|
||||
for k, v in self.math_functions_different.items():
|
||||
dict_fun[k] = 'math.' + v
|
||||
if self.use_python_cmath:
|
||||
for s in self.cmath_functions_same:
|
||||
dict_fun[s] = 'cmath.' + s
|
||||
for k, v in self.cmath_functions_different.items():
|
||||
dict_fun[k] = 'cmath.' + v
|
||||
if self.use_interval:
|
||||
for s in self.interval_functions_same:
|
||||
dict_fun[s] = 'imath.' + s
|
||||
for k, v in self.interval_functions_different.items():
|
||||
dict_fun[k] = 'imath.' + v
|
||||
return dict_fun
|
||||
|
||||
##############################################################################
|
||||
# The translator functions, tree parsers, etc.
|
||||
##############################################################################
|
||||
|
||||
def str2tree(self, exprstr):
|
||||
"""Converts an expression string to a tree.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Functions are represented by ('func_name(', tree_of_arguments).
|
||||
Other expressions are (head_string, mid_tree, tail_str).
|
||||
Expressions that do not contain functions are directly returned.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x, y, z
|
||||
>>> from sympy import Integral, sin
|
||||
>>> from sympy.plotting.experimental_lambdify import Lambdifier
|
||||
>>> str2tree = Lambdifier([x], x).str2tree
|
||||
|
||||
>>> str2tree(str(Integral(x, (x, 1, y))))
|
||||
('', ('Integral(', 'x, (x, 1, y)'), ')')
|
||||
>>> str2tree(str(x+y))
|
||||
'x + y'
|
||||
>>> str2tree(str(x+y*sin(z)+1))
|
||||
('x + y*', ('sin(', 'z'), ') + 1')
|
||||
>>> str2tree('sin(y*(y + 1.1) + (sin(y)))')
|
||||
('', ('sin(', ('y*(y + 1.1) + (', ('sin(', 'y'), '))')), ')')
|
||||
"""
|
||||
#matches the first 'function_name('
|
||||
first_par = re.search(r'(\w+\()', exprstr)
|
||||
if first_par is None:
|
||||
return exprstr
|
||||
else:
|
||||
start = first_par.start()
|
||||
end = first_par.end()
|
||||
head = exprstr[:start]
|
||||
func = exprstr[start:end]
|
||||
tail = exprstr[end:]
|
||||
count = 0
|
||||
for i, c in enumerate(tail):
|
||||
if c == '(':
|
||||
count += 1
|
||||
elif c == ')':
|
||||
count -= 1
|
||||
if count == -1:
|
||||
break
|
||||
func_tail = self.str2tree(tail[:i])
|
||||
tail = self.str2tree(tail[i:])
|
||||
return (head, (func, func_tail), tail)
|
||||
|
||||
@classmethod
|
||||
def tree2str(cls, tree):
|
||||
"""Converts a tree to string without translations.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x, y, z
|
||||
>>> from sympy import sin
|
||||
>>> from sympy.plotting.experimental_lambdify import Lambdifier
|
||||
>>> str2tree = Lambdifier([x], x).str2tree
|
||||
>>> tree2str = Lambdifier([x], x).tree2str
|
||||
|
||||
>>> tree2str(str2tree(str(x+y*sin(z)+1)))
|
||||
'x + y*sin(z) + 1'
|
||||
"""
|
||||
if isinstance(tree, str):
|
||||
return tree
|
||||
else:
|
||||
return ''.join(map(cls.tree2str, tree))
|
||||
|
||||
def tree2str_translate(self, tree):
|
||||
"""Converts a tree to string with translations.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Function names are translated by translate_func.
|
||||
Other strings are translated by translate_str.
|
||||
"""
|
||||
if isinstance(tree, str):
|
||||
return self.translate_str(tree)
|
||||
elif isinstance(tree, tuple) and len(tree) == 2:
|
||||
return self.translate_func(tree[0][:-1], tree[1])
|
||||
else:
|
||||
return ''.join([self.tree2str_translate(t) for t in tree])
|
||||
|
||||
def translate_str(self, estr):
|
||||
"""Translate substrings of estr using in order the dictionaries in
|
||||
dict_tuple_str."""
|
||||
for pattern, repl in self.dict_str.items():
|
||||
estr = re.sub(pattern, repl, estr)
|
||||
return estr
|
||||
|
||||
def translate_func(self, func_name, argtree):
|
||||
"""Translate function names and the tree of arguments.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
If the function name is not in the dictionaries of dict_tuple_fun then the
|
||||
function is surrounded by a float((...).evalf()).
|
||||
|
||||
The use of float is necessary as np.<function>(sympy.Float(..)) raises an
|
||||
error."""
|
||||
if func_name in self.dict_fun:
|
||||
new_name = self.dict_fun[func_name]
|
||||
argstr = self.tree2str_translate(argtree)
|
||||
return new_name + '(' + argstr
|
||||
elif func_name in ['Eq', 'Ne']:
|
||||
op = {'Eq': '==', 'Ne': '!='}
|
||||
return "(lambda x, y: x {} y)({}".format(op[func_name], self.tree2str_translate(argtree))
|
||||
else:
|
||||
template = '(%s(%s)).evalf(' if self.use_evalf else '%s(%s'
|
||||
if self.float_wrap_evalf:
|
||||
template = 'float(%s)' % template
|
||||
elif self.complex_wrap_evalf:
|
||||
template = 'complex(%s)' % template
|
||||
|
||||
# Wrapping should only happen on the outermost expression, which
|
||||
# is the only thing we know will be a number.
|
||||
float_wrap_evalf = self.float_wrap_evalf
|
||||
complex_wrap_evalf = self.complex_wrap_evalf
|
||||
self.float_wrap_evalf = False
|
||||
self.complex_wrap_evalf = False
|
||||
ret = template % (func_name, self.tree2str_translate(argtree))
|
||||
self.float_wrap_evalf = float_wrap_evalf
|
||||
self.complex_wrap_evalf = complex_wrap_evalf
|
||||
return ret
|
||||
|
||||
##############################################################################
|
||||
# The namespace constructors
|
||||
##############################################################################
|
||||
|
||||
@classmethod
|
||||
def sympy_expression_namespace(cls, expr):
|
||||
"""Traverses the (func, args) tree of an expression and creates a SymPy
|
||||
namespace. All other modules are imported only as a module name. That way
|
||||
the namespace is not polluted and rests quite small. It probably causes much
|
||||
more variable lookups and so it takes more time, but there are no tests on
|
||||
that for the moment."""
|
||||
if expr is None:
|
||||
return {}
|
||||
else:
|
||||
funcname = str(expr.func)
|
||||
# XXX Workaround
|
||||
# Here we add an ugly workaround because str(func(x))
|
||||
# is not always the same as str(func). Eg
|
||||
# >>> str(Integral(x))
|
||||
# "Integral(x)"
|
||||
# >>> str(Integral)
|
||||
# "<class 'sympy.integrals.integrals.Integral'>"
|
||||
# >>> str(sqrt(x))
|
||||
# "sqrt(x)"
|
||||
# >>> str(sqrt)
|
||||
# "<function sqrt at 0x3d92de8>"
|
||||
# >>> str(sin(x))
|
||||
# "sin(x)"
|
||||
# >>> str(sin)
|
||||
# "sin"
|
||||
# Either one of those can be used but not all at the same time.
|
||||
# The code considers the sin example as the right one.
|
||||
regexlist = [
|
||||
r'<class \'sympy[\w.]*?.([\w]*)\'>$',
|
||||
# the example Integral
|
||||
r'<function ([\w]*) at 0x[\w]*>$', # the example sqrt
|
||||
]
|
||||
for r in regexlist:
|
||||
m = re.match(r, funcname)
|
||||
if m is not None:
|
||||
funcname = m.groups()[0]
|
||||
# End of the workaround
|
||||
# XXX debug: print funcname
|
||||
args_dict = {}
|
||||
for a in expr.args:
|
||||
if (isinstance(a, (Symbol, NumberSymbol)) or a in [I, zoo, oo]):
|
||||
continue
|
||||
else:
|
||||
args_dict.update(cls.sympy_expression_namespace(a))
|
||||
args_dict.update({funcname: expr.func})
|
||||
return args_dict
|
||||
|
||||
@staticmethod
|
||||
def sympy_atoms_namespace(expr):
|
||||
"""For no real reason this function is separated from
|
||||
sympy_expression_namespace. It can be moved to it."""
|
||||
atoms = expr.atoms(Symbol, NumberSymbol, I, zoo, oo)
|
||||
d = {}
|
||||
for a in atoms:
|
||||
# XXX debug: print 'atom:' + str(a)
|
||||
d[str(a)] = a
|
||||
return d
|
||||
@@ -0,0 +1,12 @@
|
||||
from .interval_arithmetic import interval
|
||||
from .lib_interval import (Abs, exp, log, log10, sin, cos, tan, sqrt,
|
||||
imin, imax, sinh, cosh, tanh, acosh, asinh, atanh,
|
||||
asin, acos, atan, ceil, floor, And, Or)
|
||||
|
||||
__all__ = [
|
||||
'interval',
|
||||
|
||||
'Abs', 'exp', 'log', 'log10', 'sin', 'cos', 'tan', 'sqrt', 'imin', 'imax',
|
||||
'sinh', 'cosh', 'tanh', 'acosh', 'asinh', 'atanh', 'asin', 'acos', 'atan',
|
||||
'ceil', 'floor', 'And', 'Or',
|
||||
]
|
||||
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
Interval Arithmetic for plotting.
|
||||
This module does not implement interval arithmetic accurately and
|
||||
hence cannot be used for purposes other than plotting. If you want
|
||||
to use interval arithmetic, use mpmath's interval arithmetic.
|
||||
|
||||
The module implements interval arithmetic using numpy and
|
||||
python floating points. The rounding up and down is not handled
|
||||
and hence this is not an accurate implementation of interval
|
||||
arithmetic.
|
||||
|
||||
The module uses numpy for speed which cannot be achieved with mpmath.
|
||||
"""
|
||||
|
||||
# Q: Why use numpy? Why not simply use mpmath's interval arithmetic?
|
||||
# A: mpmath's interval arithmetic simulates a floating point unit
|
||||
# and hence is slow, while numpy evaluations are orders of magnitude
|
||||
# faster.
|
||||
|
||||
# Q: Why create a separate class for intervals? Why not use SymPy's
|
||||
# Interval Sets?
|
||||
# A: The functionalities that will be required for plotting is quite
|
||||
# different from what Interval Sets implement.
|
||||
|
||||
# Q: Why is rounding up and down according to IEEE754 not handled?
|
||||
# A: It is not possible to do it in both numpy and python. An external
|
||||
# library has to used, which defeats the whole purpose i.e., speed. Also
|
||||
# rounding is handled for very few functions in those libraries.
|
||||
|
||||
# Q Will my plots be affected?
|
||||
# A It will not affect most of the plots. The interval arithmetic
|
||||
# module based suffers the same problems as that of floating point
|
||||
# arithmetic.
|
||||
|
||||
from sympy.core.numbers import int_valued
|
||||
from sympy.core.logic import fuzzy_and
|
||||
from sympy.simplify.simplify import nsimplify
|
||||
|
||||
from .interval_membership import intervalMembership
|
||||
|
||||
|
||||
class interval:
|
||||
""" Represents an interval containing floating points as start and
|
||||
end of the interval
|
||||
The is_valid variable tracks whether the interval obtained as the
|
||||
result of the function is in the domain and is continuous.
|
||||
- True: Represents the interval result of a function is continuous and
|
||||
in the domain of the function.
|
||||
- False: The interval argument of the function was not in the domain of
|
||||
the function, hence the is_valid of the result interval is False
|
||||
- None: The function was not continuous over the interval or
|
||||
the function's argument interval is partly in the domain of the
|
||||
function
|
||||
|
||||
A comparison between an interval and a real number, or a
|
||||
comparison between two intervals may return ``intervalMembership``
|
||||
of two 3-valued logic values.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, is_valid=True, **kwargs):
|
||||
self.is_valid = is_valid
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], interval):
|
||||
self.start, self.end = args[0].start, args[0].end
|
||||
else:
|
||||
self.start = float(args[0])
|
||||
self.end = float(args[0])
|
||||
elif len(args) == 2:
|
||||
if args[0] < args[1]:
|
||||
self.start = float(args[0])
|
||||
self.end = float(args[1])
|
||||
else:
|
||||
self.start = float(args[1])
|
||||
self.end = float(args[0])
|
||||
|
||||
else:
|
||||
raise ValueError("interval takes a maximum of two float values "
|
||||
"as arguments")
|
||||
|
||||
@property
|
||||
def mid(self):
|
||||
return (self.start + self.end) / 2.0
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.end - self.start
|
||||
|
||||
def __repr__(self):
|
||||
return "interval(%f, %f)" % (self.start, self.end)
|
||||
|
||||
def __str__(self):
|
||||
return "[%f, %f]" % (self.start, self.end)
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
if self.end < other:
|
||||
return intervalMembership(True, self.is_valid)
|
||||
elif self.start > other:
|
||||
return intervalMembership(False, self.is_valid)
|
||||
else:
|
||||
return intervalMembership(None, self.is_valid)
|
||||
|
||||
elif isinstance(other, interval):
|
||||
valid = fuzzy_and([self.is_valid, other.is_valid])
|
||||
if self.end < other. start:
|
||||
return intervalMembership(True, valid)
|
||||
if self.start > other.end:
|
||||
return intervalMembership(False, valid)
|
||||
return intervalMembership(None, valid)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
if self.start > other:
|
||||
return intervalMembership(True, self.is_valid)
|
||||
elif self.end < other:
|
||||
return intervalMembership(False, self.is_valid)
|
||||
else:
|
||||
return intervalMembership(None, self.is_valid)
|
||||
elif isinstance(other, interval):
|
||||
return other.__lt__(self)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
if self.start == other and self.end == other:
|
||||
return intervalMembership(True, self.is_valid)
|
||||
if other in self:
|
||||
return intervalMembership(None, self.is_valid)
|
||||
else:
|
||||
return intervalMembership(False, self.is_valid)
|
||||
|
||||
if isinstance(other, interval):
|
||||
valid = fuzzy_and([self.is_valid, other.is_valid])
|
||||
if self.start == other.start and self.end == other.end:
|
||||
return intervalMembership(True, valid)
|
||||
elif self.__lt__(other)[0] is not None:
|
||||
return intervalMembership(False, valid)
|
||||
else:
|
||||
return intervalMembership(None, valid)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
if self.start == other and self.end == other:
|
||||
return intervalMembership(False, self.is_valid)
|
||||
if other in self:
|
||||
return intervalMembership(None, self.is_valid)
|
||||
else:
|
||||
return intervalMembership(True, self.is_valid)
|
||||
|
||||
if isinstance(other, interval):
|
||||
valid = fuzzy_and([self.is_valid, other.is_valid])
|
||||
if self.start == other.start and self.end == other.end:
|
||||
return intervalMembership(False, valid)
|
||||
if not self.__lt__(other)[0] is None:
|
||||
return intervalMembership(True, valid)
|
||||
return intervalMembership(None, valid)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
if self.end <= other:
|
||||
return intervalMembership(True, self.is_valid)
|
||||
if self.start > other:
|
||||
return intervalMembership(False, self.is_valid)
|
||||
else:
|
||||
return intervalMembership(None, self.is_valid)
|
||||
|
||||
if isinstance(other, interval):
|
||||
valid = fuzzy_and([self.is_valid, other.is_valid])
|
||||
if self.end <= other.start:
|
||||
return intervalMembership(True, valid)
|
||||
if self.start > other.end:
|
||||
return intervalMembership(False, valid)
|
||||
return intervalMembership(None, valid)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
if self.start >= other:
|
||||
return intervalMembership(True, self.is_valid)
|
||||
elif self.end < other:
|
||||
return intervalMembership(False, self.is_valid)
|
||||
else:
|
||||
return intervalMembership(None, self.is_valid)
|
||||
elif isinstance(other, interval):
|
||||
return other.__le__(self)
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
if self.is_valid:
|
||||
return interval(self.start + other, self.end + other)
|
||||
else:
|
||||
start = self.start + other
|
||||
end = self.end + other
|
||||
return interval(start, end, is_valid=self.is_valid)
|
||||
|
||||
elif isinstance(other, interval):
|
||||
start = self.start + other.start
|
||||
end = self.end + other.end
|
||||
valid = fuzzy_and([self.is_valid, other.is_valid])
|
||||
return interval(start, end, is_valid=valid)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
def __sub__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
start = self.start - other
|
||||
end = self.end - other
|
||||
return interval(start, end, is_valid=self.is_valid)
|
||||
|
||||
elif isinstance(other, interval):
|
||||
start = self.start - other.end
|
||||
end = self.end - other.start
|
||||
valid = fuzzy_and([self.is_valid, other.is_valid])
|
||||
return interval(start, end, is_valid=valid)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __rsub__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
start = other - self.end
|
||||
end = other - self.start
|
||||
return interval(start, end, is_valid=self.is_valid)
|
||||
elif isinstance(other, interval):
|
||||
return other.__sub__(self)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __neg__(self):
|
||||
if self.is_valid:
|
||||
return interval(-self.end, -self.start)
|
||||
else:
|
||||
return interval(-self.end, -self.start, is_valid=self.is_valid)
|
||||
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, interval):
|
||||
if self.is_valid is False or other.is_valid is False:
|
||||
return interval(-float('inf'), float('inf'), is_valid=False)
|
||||
elif self.is_valid is None or other.is_valid is None:
|
||||
return interval(-float('inf'), float('inf'), is_valid=None)
|
||||
else:
|
||||
inters = []
|
||||
inters.append(self.start * other.start)
|
||||
inters.append(self.end * other.start)
|
||||
inters.append(self.start * other.end)
|
||||
inters.append(self.end * other.end)
|
||||
start = min(inters)
|
||||
end = max(inters)
|
||||
return interval(start, end)
|
||||
elif isinstance(other, (int, float)):
|
||||
return interval(self.start*other, self.end*other, is_valid=self.is_valid)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
__rmul__ = __mul__
|
||||
|
||||
def __contains__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
return self.start <= other and self.end >= other
|
||||
else:
|
||||
return self.start <= other.start and other.end <= self.end
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
other = interval(other)
|
||||
return other.__truediv__(self)
|
||||
elif isinstance(other, interval):
|
||||
return other.__truediv__(self)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __truediv__(self, other):
|
||||
# Both None and False are handled
|
||||
if not self.is_valid:
|
||||
# Don't divide as the value is not valid
|
||||
return interval(-float('inf'), float('inf'), is_valid=self.is_valid)
|
||||
if isinstance(other, (int, float)):
|
||||
if other == 0:
|
||||
# Divide by zero encountered. valid nowhere
|
||||
return interval(-float('inf'), float('inf'), is_valid=False)
|
||||
else:
|
||||
return interval(self.start / other, self.end / other)
|
||||
|
||||
elif isinstance(other, interval):
|
||||
if other.is_valid is False or self.is_valid is False:
|
||||
return interval(-float('inf'), float('inf'), is_valid=False)
|
||||
elif other.is_valid is None or self.is_valid is None:
|
||||
return interval(-float('inf'), float('inf'), is_valid=None)
|
||||
else:
|
||||
# denominator contains both signs, i.e. being divided by zero
|
||||
# return the whole real line with is_valid = None
|
||||
if 0 in other:
|
||||
return interval(-float('inf'), float('inf'), is_valid=None)
|
||||
|
||||
# denominator negative
|
||||
this = self
|
||||
if other.end < 0:
|
||||
this = -this
|
||||
other = -other
|
||||
|
||||
# denominator positive
|
||||
inters = []
|
||||
inters.append(this.start / other.start)
|
||||
inters.append(this.end / other.start)
|
||||
inters.append(this.start / other.end)
|
||||
inters.append(this.end / other.end)
|
||||
start = max(inters)
|
||||
end = min(inters)
|
||||
return interval(start, end)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __pow__(self, other):
|
||||
# Implements only power to an integer.
|
||||
from .lib_interval import exp, log
|
||||
if not self.is_valid:
|
||||
return self
|
||||
if isinstance(other, interval):
|
||||
return exp(other * log(self))
|
||||
elif isinstance(other, (float, int)):
|
||||
if other < 0:
|
||||
return 1 / self.__pow__(abs(other))
|
||||
else:
|
||||
if int_valued(other):
|
||||
return _pow_int(self, other)
|
||||
else:
|
||||
return _pow_float(self, other)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __rpow__(self, other):
|
||||
if isinstance(other, (float, int)):
|
||||
if not self.is_valid:
|
||||
#Don't do anything
|
||||
return self
|
||||
elif other < 0:
|
||||
if self.width > 0:
|
||||
return interval(-float('inf'), float('inf'), is_valid=False)
|
||||
else:
|
||||
power_rational = nsimplify(self.start)
|
||||
num, denom = power_rational.as_numer_denom()
|
||||
if denom % 2 == 0:
|
||||
return interval(-float('inf'), float('inf'),
|
||||
is_valid=False)
|
||||
else:
|
||||
start = -abs(other)**self.start
|
||||
end = start
|
||||
return interval(start, end)
|
||||
else:
|
||||
return interval(other**self.start, other**self.end)
|
||||
elif isinstance(other, interval):
|
||||
return other.__pow__(self)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.is_valid, self.start, self.end))
|
||||
|
||||
|
||||
def _pow_float(inter, power):
|
||||
"""Evaluates an interval raised to a floating point."""
|
||||
power_rational = nsimplify(power)
|
||||
num, denom = power_rational.as_numer_denom()
|
||||
if num % 2 == 0:
|
||||
start = abs(inter.start)**power
|
||||
end = abs(inter.end)**power
|
||||
if start < 0:
|
||||
ret = interval(0, max(start, end))
|
||||
else:
|
||||
ret = interval(start, end)
|
||||
return ret
|
||||
elif denom % 2 == 0:
|
||||
if inter.end < 0:
|
||||
return interval(-float('inf'), float('inf'), is_valid=False)
|
||||
elif inter.start < 0:
|
||||
return interval(0, inter.end**power, is_valid=None)
|
||||
else:
|
||||
return interval(inter.start**power, inter.end**power)
|
||||
else:
|
||||
if inter.start < 0:
|
||||
start = -abs(inter.start)**power
|
||||
else:
|
||||
start = inter.start**power
|
||||
|
||||
if inter.end < 0:
|
||||
end = -abs(inter.end)**power
|
||||
else:
|
||||
end = inter.end**power
|
||||
|
||||
return interval(start, end, is_valid=inter.is_valid)
|
||||
|
||||
|
||||
def _pow_int(inter, power):
|
||||
"""Evaluates an interval raised to an integer power"""
|
||||
power = int(power)
|
||||
if power & 1:
|
||||
return interval(inter.start**power, inter.end**power)
|
||||
else:
|
||||
if inter.start < 0 and inter.end > 0:
|
||||
start = 0
|
||||
end = max(inter.start**power, inter.end**power)
|
||||
return interval(start, end)
|
||||
else:
|
||||
return interval(inter.start**power, inter.end**power)
|
||||
@@ -0,0 +1,78 @@
|
||||
from sympy.core.logic import fuzzy_and, fuzzy_or, fuzzy_not, fuzzy_xor
|
||||
|
||||
|
||||
class intervalMembership:
|
||||
"""Represents a boolean expression returned by the comparison of
|
||||
the interval object.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
(a, b) : (bool, bool)
|
||||
The first value determines the comparison as follows:
|
||||
- True: If the comparison is True throughout the intervals.
|
||||
- False: If the comparison is False throughout the intervals.
|
||||
- None: If the comparison is True for some part of the intervals.
|
||||
|
||||
The second value is determined as follows:
|
||||
- True: If both the intervals in comparison are valid.
|
||||
- False: If at least one of the intervals is False, else
|
||||
- None
|
||||
"""
|
||||
def __init__(self, a, b):
|
||||
self._wrapped = (a, b)
|
||||
|
||||
def __getitem__(self, i):
|
||||
try:
|
||||
return self._wrapped[i]
|
||||
except IndexError:
|
||||
raise IndexError(
|
||||
"{} must be a valid indexing for the 2-tuple."
|
||||
.format(i))
|
||||
|
||||
def __len__(self):
|
||||
return 2
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._wrapped)
|
||||
|
||||
def __str__(self):
|
||||
return "intervalMembership({}, {})".format(*self)
|
||||
__repr__ = __str__
|
||||
|
||||
def __and__(self, other):
|
||||
if not isinstance(other, intervalMembership):
|
||||
raise ValueError(
|
||||
"The comparison is not supported for {}.".format(other))
|
||||
|
||||
a1, b1 = self
|
||||
a2, b2 = other
|
||||
return intervalMembership(fuzzy_and([a1, a2]), fuzzy_and([b1, b2]))
|
||||
|
||||
def __or__(self, other):
|
||||
if not isinstance(other, intervalMembership):
|
||||
raise ValueError(
|
||||
"The comparison is not supported for {}.".format(other))
|
||||
|
||||
a1, b1 = self
|
||||
a2, b2 = other
|
||||
return intervalMembership(fuzzy_or([a1, a2]), fuzzy_and([b1, b2]))
|
||||
|
||||
def __invert__(self):
|
||||
a, b = self
|
||||
return intervalMembership(fuzzy_not(a), b)
|
||||
|
||||
def __xor__(self, other):
|
||||
if not isinstance(other, intervalMembership):
|
||||
raise ValueError(
|
||||
"The comparison is not supported for {}.".format(other))
|
||||
|
||||
a1, b1 = self
|
||||
a2, b2 = other
|
||||
return intervalMembership(fuzzy_xor([a1, a2]), fuzzy_and([b1, b2]))
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._wrapped == other
|
||||
|
||||
def __ne__(self, other):
|
||||
return self._wrapped != other
|
||||
@@ -0,0 +1,452 @@
|
||||
""" The module contains implemented functions for interval arithmetic."""
|
||||
from functools import reduce
|
||||
|
||||
from sympy.plotting.intervalmath import interval
|
||||
from sympy.external import import_module
|
||||
|
||||
|
||||
def Abs(x):
|
||||
if isinstance(x, (int, float)):
|
||||
return interval(abs(x))
|
||||
elif isinstance(x, interval):
|
||||
if x.start < 0 and x.end > 0:
|
||||
return interval(0, max(abs(x.start), abs(x.end)), is_valid=x.is_valid)
|
||||
else:
|
||||
return interval(abs(x.start), abs(x.end))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
#Monotonic
|
||||
|
||||
|
||||
def exp(x):
|
||||
"""evaluates the exponential of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
return interval(np.exp(x), np.exp(x))
|
||||
elif isinstance(x, interval):
|
||||
return interval(np.exp(x.start), np.exp(x.end), is_valid=x.is_valid)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
#Monotonic
|
||||
def log(x):
|
||||
"""evaluates the natural logarithm of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
if x <= 0:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
else:
|
||||
return interval(np.log(x))
|
||||
elif isinstance(x, interval):
|
||||
if not x.is_valid:
|
||||
return interval(-np.inf, np.inf, is_valid=x.is_valid)
|
||||
elif x.end <= 0:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
elif x.start <= 0:
|
||||
return interval(-np.inf, np.inf, is_valid=None)
|
||||
|
||||
return interval(np.log(x.start), np.log(x.end))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
#Monotonic
|
||||
def log10(x):
|
||||
"""evaluates the logarithm to the base 10 of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
if x <= 0:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
else:
|
||||
return interval(np.log10(x))
|
||||
elif isinstance(x, interval):
|
||||
if not x.is_valid:
|
||||
return interval(-np.inf, np.inf, is_valid=x.is_valid)
|
||||
elif x.end <= 0:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
elif x.start <= 0:
|
||||
return interval(-np.inf, np.inf, is_valid=None)
|
||||
return interval(np.log10(x.start), np.log10(x.end))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
#Monotonic
|
||||
def atan(x):
|
||||
"""evaluates the tan inverse of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
return interval(np.arctan(x))
|
||||
elif isinstance(x, interval):
|
||||
start = np.arctan(x.start)
|
||||
end = np.arctan(x.end)
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
#periodic
|
||||
def sin(x):
|
||||
"""evaluates the sine of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
return interval(np.sin(x))
|
||||
elif isinstance(x, interval):
|
||||
if not x.is_valid:
|
||||
return interval(-1, 1, is_valid=x.is_valid)
|
||||
na, __ = divmod(x.start, np.pi / 2.0)
|
||||
nb, __ = divmod(x.end, np.pi / 2.0)
|
||||
start = min(np.sin(x.start), np.sin(x.end))
|
||||
end = max(np.sin(x.start), np.sin(x.end))
|
||||
if nb - na > 4:
|
||||
return interval(-1, 1, is_valid=x.is_valid)
|
||||
elif na == nb:
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
else:
|
||||
if (na - 1) // 4 != (nb - 1) // 4:
|
||||
#sin has max
|
||||
end = 1
|
||||
if (na - 3) // 4 != (nb - 3) // 4:
|
||||
#sin has min
|
||||
start = -1
|
||||
return interval(start, end)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
#periodic
|
||||
def cos(x):
|
||||
"""Evaluates the cos of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
return interval(np.sin(x))
|
||||
elif isinstance(x, interval):
|
||||
if not (np.isfinite(x.start) and np.isfinite(x.end)):
|
||||
return interval(-1, 1, is_valid=x.is_valid)
|
||||
na, __ = divmod(x.start, np.pi / 2.0)
|
||||
nb, __ = divmod(x.end, np.pi / 2.0)
|
||||
start = min(np.cos(x.start), np.cos(x.end))
|
||||
end = max(np.cos(x.start), np.cos(x.end))
|
||||
if nb - na > 4:
|
||||
#differ more than 2*pi
|
||||
return interval(-1, 1, is_valid=x.is_valid)
|
||||
elif na == nb:
|
||||
#in the same quadarant
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
else:
|
||||
if (na) // 4 != (nb) // 4:
|
||||
#cos has max
|
||||
end = 1
|
||||
if (na - 2) // 4 != (nb - 2) // 4:
|
||||
#cos has min
|
||||
start = -1
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def tan(x):
|
||||
"""Evaluates the tan of an interval"""
|
||||
return sin(x) / cos(x)
|
||||
|
||||
|
||||
#Monotonic
|
||||
def sqrt(x):
|
||||
"""Evaluates the square root of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
if x > 0:
|
||||
return interval(np.sqrt(x))
|
||||
else:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
elif isinstance(x, interval):
|
||||
#Outside the domain
|
||||
if x.end < 0:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
#Partially outside the domain
|
||||
elif x.start < 0:
|
||||
return interval(-np.inf, np.inf, is_valid=None)
|
||||
else:
|
||||
return interval(np.sqrt(x.start), np.sqrt(x.end),
|
||||
is_valid=x.is_valid)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def imin(*args):
|
||||
"""Evaluates the minimum of a list of intervals"""
|
||||
np = import_module('numpy')
|
||||
if not all(isinstance(arg, (int, float, interval)) for arg in args):
|
||||
return NotImplementedError
|
||||
else:
|
||||
new_args = [a for a in args if isinstance(a, (int, float))
|
||||
or a.is_valid]
|
||||
if len(new_args) == 0:
|
||||
if all(a.is_valid is False for a in args):
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
else:
|
||||
return interval(-np.inf, np.inf, is_valid=None)
|
||||
start_array = [a if isinstance(a, (int, float)) else a.start
|
||||
for a in new_args]
|
||||
|
||||
end_array = [a if isinstance(a, (int, float)) else a.end
|
||||
for a in new_args]
|
||||
return interval(min(start_array), min(end_array))
|
||||
|
||||
|
||||
def imax(*args):
|
||||
"""Evaluates the maximum of a list of intervals"""
|
||||
np = import_module('numpy')
|
||||
if not all(isinstance(arg, (int, float, interval)) for arg in args):
|
||||
return NotImplementedError
|
||||
else:
|
||||
new_args = [a for a in args if isinstance(a, (int, float))
|
||||
or a.is_valid]
|
||||
if len(new_args) == 0:
|
||||
if all(a.is_valid is False for a in args):
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
else:
|
||||
return interval(-np.inf, np.inf, is_valid=None)
|
||||
start_array = [a if isinstance(a, (int, float)) else a.start
|
||||
for a in new_args]
|
||||
|
||||
end_array = [a if isinstance(a, (int, float)) else a.end
|
||||
for a in new_args]
|
||||
|
||||
return interval(max(start_array), max(end_array))
|
||||
|
||||
|
||||
#Monotonic
|
||||
def sinh(x):
|
||||
"""Evaluates the hyperbolic sine of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
return interval(np.sinh(x), np.sinh(x))
|
||||
elif isinstance(x, interval):
|
||||
return interval(np.sinh(x.start), np.sinh(x.end), is_valid=x.is_valid)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def cosh(x):
|
||||
"""Evaluates the hyperbolic cos of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
return interval(np.cosh(x), np.cosh(x))
|
||||
elif isinstance(x, interval):
|
||||
#both signs
|
||||
if x.start < 0 and x.end > 0:
|
||||
end = max(np.cosh(x.start), np.cosh(x.end))
|
||||
return interval(1, end, is_valid=x.is_valid)
|
||||
else:
|
||||
#Monotonic
|
||||
start = np.cosh(x.start)
|
||||
end = np.cosh(x.end)
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
#Monotonic
|
||||
def tanh(x):
|
||||
"""Evaluates the hyperbolic tan of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
return interval(np.tanh(x), np.tanh(x))
|
||||
elif isinstance(x, interval):
|
||||
return interval(np.tanh(x.start), np.tanh(x.end), is_valid=x.is_valid)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def asin(x):
|
||||
"""Evaluates the inverse sine of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
#Outside the domain
|
||||
if abs(x) > 1:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
else:
|
||||
return interval(np.arcsin(x), np.arcsin(x))
|
||||
elif isinstance(x, interval):
|
||||
#Outside the domain
|
||||
if x.is_valid is False or x.start > 1 or x.end < -1:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
#Partially outside the domain
|
||||
elif x.start < -1 or x.end > 1:
|
||||
return interval(-np.inf, np.inf, is_valid=None)
|
||||
else:
|
||||
start = np.arcsin(x.start)
|
||||
end = np.arcsin(x.end)
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
|
||||
|
||||
def acos(x):
|
||||
"""Evaluates the inverse cos of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
if abs(x) > 1:
|
||||
#Outside the domain
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
else:
|
||||
return interval(np.arccos(x), np.arccos(x))
|
||||
elif isinstance(x, interval):
|
||||
#Outside the domain
|
||||
if x.is_valid is False or x.start > 1 or x.end < -1:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
#Partially outside the domain
|
||||
elif x.start < -1 or x.end > 1:
|
||||
return interval(-np.inf, np.inf, is_valid=None)
|
||||
else:
|
||||
start = np.arccos(x.start)
|
||||
end = np.arccos(x.end)
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
|
||||
|
||||
def ceil(x):
|
||||
"""Evaluates the ceiling of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
return interval(np.ceil(x))
|
||||
elif isinstance(x, interval):
|
||||
if x.is_valid is False:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
else:
|
||||
start = np.ceil(x.start)
|
||||
end = np.ceil(x.end)
|
||||
#Continuous over the interval
|
||||
if start == end:
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
else:
|
||||
#Not continuous over the interval
|
||||
return interval(start, end, is_valid=None)
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
def floor(x):
|
||||
"""Evaluates the floor of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
return interval(np.floor(x))
|
||||
elif isinstance(x, interval):
|
||||
if x.is_valid is False:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
else:
|
||||
start = np.floor(x.start)
|
||||
end = np.floor(x.end)
|
||||
#continuous over the argument
|
||||
if start == end:
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
else:
|
||||
#not continuous over the interval
|
||||
return interval(start, end, is_valid=None)
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
def acosh(x):
|
||||
"""Evaluates the inverse hyperbolic cosine of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
#Outside the domain
|
||||
if x < 1:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
else:
|
||||
return interval(np.arccosh(x))
|
||||
elif isinstance(x, interval):
|
||||
#Outside the domain
|
||||
if x.end < 1:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
#Partly outside the domain
|
||||
elif x.start < 1:
|
||||
return interval(-np.inf, np.inf, is_valid=None)
|
||||
else:
|
||||
start = np.arccosh(x.start)
|
||||
end = np.arccosh(x.end)
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
#Monotonic
|
||||
def asinh(x):
|
||||
"""Evaluates the inverse hyperbolic sine of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
return interval(np.arcsinh(x))
|
||||
elif isinstance(x, interval):
|
||||
start = np.arcsinh(x.start)
|
||||
end = np.arcsinh(x.end)
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
def atanh(x):
|
||||
"""Evaluates the inverse hyperbolic tangent of an interval"""
|
||||
np = import_module('numpy')
|
||||
if isinstance(x, (int, float)):
|
||||
#Outside the domain
|
||||
if abs(x) >= 1:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
else:
|
||||
return interval(np.arctanh(x))
|
||||
elif isinstance(x, interval):
|
||||
#outside the domain
|
||||
if x.is_valid is False or x.start >= 1 or x.end <= -1:
|
||||
return interval(-np.inf, np.inf, is_valid=False)
|
||||
#partly outside the domain
|
||||
elif x.start <= -1 or x.end >= 1:
|
||||
return interval(-np.inf, np.inf, is_valid=None)
|
||||
else:
|
||||
start = np.arctanh(x.start)
|
||||
end = np.arctanh(x.end)
|
||||
return interval(start, end, is_valid=x.is_valid)
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
#Three valued logic for interval plotting.
|
||||
|
||||
def And(*args):
|
||||
"""Defines the three valued ``And`` behaviour for a 2-tuple of
|
||||
three valued logic values"""
|
||||
def reduce_and(cmp_intervala, cmp_intervalb):
|
||||
if cmp_intervala[0] is False or cmp_intervalb[0] is False:
|
||||
first = False
|
||||
elif cmp_intervala[0] is None or cmp_intervalb[0] is None:
|
||||
first = None
|
||||
else:
|
||||
first = True
|
||||
if cmp_intervala[1] is False or cmp_intervalb[1] is False:
|
||||
second = False
|
||||
elif cmp_intervala[1] is None or cmp_intervalb[1] is None:
|
||||
second = None
|
||||
else:
|
||||
second = True
|
||||
return (first, second)
|
||||
return reduce(reduce_and, args)
|
||||
|
||||
|
||||
def Or(*args):
|
||||
"""Defines the three valued ``Or`` behaviour for a 2-tuple of
|
||||
three valued logic values"""
|
||||
def reduce_or(cmp_intervala, cmp_intervalb):
|
||||
if cmp_intervala[0] is True or cmp_intervalb[0] is True:
|
||||
first = True
|
||||
elif cmp_intervala[0] is None or cmp_intervalb[0] is None:
|
||||
first = None
|
||||
else:
|
||||
first = False
|
||||
|
||||
if cmp_intervala[1] is True or cmp_intervalb[1] is True:
|
||||
second = True
|
||||
elif cmp_intervala[1] is None or cmp_intervalb[1] is None:
|
||||
second = None
|
||||
else:
|
||||
second = False
|
||||
return (first, second)
|
||||
return reduce(reduce_or, args)
|
||||
@@ -0,0 +1,415 @@
|
||||
from sympy.external import import_module
|
||||
from sympy.plotting.intervalmath import (
|
||||
Abs, acos, acosh, And, asin, asinh, atan, atanh, ceil, cos, cosh,
|
||||
exp, floor, imax, imin, interval, log, log10, Or, sin, sinh, sqrt,
|
||||
tan, tanh,
|
||||
)
|
||||
|
||||
np = import_module('numpy')
|
||||
if not np:
|
||||
disabled = True
|
||||
|
||||
|
||||
#requires Numpy. Hence included in interval_functions
|
||||
|
||||
|
||||
def test_interval_pow():
|
||||
a = 2**interval(1, 2) == interval(2, 4)
|
||||
assert a == (True, True)
|
||||
a = interval(1, 2)**interval(1, 2) == interval(1, 4)
|
||||
assert a == (True, True)
|
||||
a = interval(-1, 1)**interval(0.5, 2)
|
||||
assert a.is_valid is None
|
||||
a = interval(-2, -1) ** interval(1, 2)
|
||||
assert a.is_valid is False
|
||||
a = interval(-2, -1) ** (1.0 / 2)
|
||||
assert a.is_valid is False
|
||||
a = interval(-1, 1)**(1.0 / 2)
|
||||
assert a.is_valid is None
|
||||
a = interval(-1, 1)**(1.0 / 3) == interval(-1, 1)
|
||||
assert a == (True, True)
|
||||
a = interval(-1, 1)**2 == interval(0, 1)
|
||||
assert a == (True, True)
|
||||
a = interval(-1, 1) ** (1.0 / 29) == interval(-1, 1)
|
||||
assert a == (True, True)
|
||||
a = -2**interval(1, 1) == interval(-2, -2)
|
||||
assert a == (True, True)
|
||||
|
||||
a = interval(1, 2, is_valid=False)**2
|
||||
assert a.is_valid is False
|
||||
|
||||
a = (-3)**interval(1, 2)
|
||||
assert a.is_valid is False
|
||||
a = (-4)**interval(0.5, 0.5)
|
||||
assert a.is_valid is False
|
||||
assert ((-3)**interval(1, 1) == interval(-3, -3)) == (True, True)
|
||||
|
||||
a = interval(8, 64)**(2.0 / 3)
|
||||
assert abs(a.start - 4) < 1e-10 # eps
|
||||
assert abs(a.end - 16) < 1e-10
|
||||
a = interval(-8, 64)**(2.0 / 3)
|
||||
assert abs(a.start - 4) < 1e-10 # eps
|
||||
assert abs(a.end - 16) < 1e-10
|
||||
|
||||
|
||||
def test_exp():
|
||||
a = exp(interval(-np.inf, 0))
|
||||
assert a.start == np.exp(-np.inf)
|
||||
assert a.end == np.exp(0)
|
||||
a = exp(interval(1, 2))
|
||||
assert a.start == np.exp(1)
|
||||
assert a.end == np.exp(2)
|
||||
a = exp(1)
|
||||
assert a.start == np.exp(1)
|
||||
assert a.end == np.exp(1)
|
||||
|
||||
|
||||
def test_log():
|
||||
a = log(interval(1, 2))
|
||||
assert a.start == 0
|
||||
assert a.end == np.log(2)
|
||||
a = log(interval(-1, 1))
|
||||
assert a.is_valid is None
|
||||
a = log(interval(-3, -1))
|
||||
assert a.is_valid is False
|
||||
a = log(-3)
|
||||
assert a.is_valid is False
|
||||
a = log(2)
|
||||
assert a.start == np.log(2)
|
||||
assert a.end == np.log(2)
|
||||
|
||||
|
||||
def test_log10():
|
||||
a = log10(interval(1, 2))
|
||||
assert a.start == 0
|
||||
assert a.end == np.log10(2)
|
||||
a = log10(interval(-1, 1))
|
||||
assert a.is_valid is None
|
||||
a = log10(interval(-3, -1))
|
||||
assert a.is_valid is False
|
||||
a = log10(-3)
|
||||
assert a.is_valid is False
|
||||
a = log10(2)
|
||||
assert a.start == np.log10(2)
|
||||
assert a.end == np.log10(2)
|
||||
|
||||
|
||||
def test_atan():
|
||||
a = atan(interval(0, 1))
|
||||
assert a.start == np.arctan(0)
|
||||
assert a.end == np.arctan(1)
|
||||
a = atan(1)
|
||||
assert a.start == np.arctan(1)
|
||||
assert a.end == np.arctan(1)
|
||||
|
||||
|
||||
def test_sin():
|
||||
a = sin(interval(0, np.pi / 4))
|
||||
assert a.start == np.sin(0)
|
||||
assert a.end == np.sin(np.pi / 4)
|
||||
|
||||
a = sin(interval(-np.pi / 4, np.pi / 4))
|
||||
assert a.start == np.sin(-np.pi / 4)
|
||||
assert a.end == np.sin(np.pi / 4)
|
||||
|
||||
a = sin(interval(np.pi / 4, 3 * np.pi / 4))
|
||||
assert a.start == np.sin(np.pi / 4)
|
||||
assert a.end == 1
|
||||
|
||||
a = sin(interval(7 * np.pi / 6, 7 * np.pi / 4))
|
||||
assert a.start == -1
|
||||
assert a.end == np.sin(7 * np.pi / 6)
|
||||
|
||||
a = sin(interval(0, 3 * np.pi))
|
||||
assert a.start == -1
|
||||
assert a.end == 1
|
||||
|
||||
a = sin(interval(np.pi / 3, 7 * np.pi / 4))
|
||||
assert a.start == -1
|
||||
assert a.end == 1
|
||||
|
||||
a = sin(np.pi / 4)
|
||||
assert a.start == np.sin(np.pi / 4)
|
||||
assert a.end == np.sin(np.pi / 4)
|
||||
|
||||
a = sin(interval(1, 2, is_valid=False))
|
||||
assert a.is_valid is False
|
||||
|
||||
|
||||
def test_cos():
|
||||
a = cos(interval(0, np.pi / 4))
|
||||
assert a.start == np.cos(np.pi / 4)
|
||||
assert a.end == 1
|
||||
|
||||
a = cos(interval(-np.pi / 4, np.pi / 4))
|
||||
assert a.start == np.cos(-np.pi / 4)
|
||||
assert a.end == 1
|
||||
|
||||
a = cos(interval(np.pi / 4, 3 * np.pi / 4))
|
||||
assert a.start == np.cos(3 * np.pi / 4)
|
||||
assert a.end == np.cos(np.pi / 4)
|
||||
|
||||
a = cos(interval(3 * np.pi / 4, 5 * np.pi / 4))
|
||||
assert a.start == -1
|
||||
assert a.end == np.cos(3 * np.pi / 4)
|
||||
|
||||
a = cos(interval(0, 3 * np.pi))
|
||||
assert a.start == -1
|
||||
assert a.end == 1
|
||||
|
||||
a = cos(interval(- np.pi / 3, 5 * np.pi / 4))
|
||||
assert a.start == -1
|
||||
assert a.end == 1
|
||||
|
||||
a = cos(interval(1, 2, is_valid=False))
|
||||
assert a.is_valid is False
|
||||
|
||||
|
||||
def test_tan():
|
||||
a = tan(interval(0, np.pi / 4))
|
||||
assert a.start == 0
|
||||
# must match lib_interval definition of tan:
|
||||
assert a.end == np.sin(np.pi / 4)/np.cos(np.pi / 4)
|
||||
|
||||
a = tan(interval(np.pi / 4, 3 * np.pi / 4))
|
||||
#discontinuity
|
||||
assert a.is_valid is None
|
||||
|
||||
|
||||
def test_sqrt():
|
||||
a = sqrt(interval(1, 4))
|
||||
assert a.start == 1
|
||||
assert a.end == 2
|
||||
|
||||
a = sqrt(interval(0.01, 1))
|
||||
assert a.start == np.sqrt(0.01)
|
||||
assert a.end == 1
|
||||
|
||||
a = sqrt(interval(-1, 1))
|
||||
assert a.is_valid is None
|
||||
|
||||
a = sqrt(interval(-3, -1))
|
||||
assert a.is_valid is False
|
||||
|
||||
a = sqrt(4)
|
||||
assert (a == interval(2, 2)) == (True, True)
|
||||
|
||||
a = sqrt(-3)
|
||||
assert a.is_valid is False
|
||||
|
||||
|
||||
def test_imin():
|
||||
a = imin(interval(1, 3), interval(2, 5), interval(-1, 3))
|
||||
assert a.start == -1
|
||||
assert a.end == 3
|
||||
|
||||
a = imin(-2, interval(1, 4))
|
||||
assert a.start == -2
|
||||
assert a.end == -2
|
||||
|
||||
a = imin(5, interval(3, 4), interval(-2, 2, is_valid=False))
|
||||
assert a.start == 3
|
||||
assert a.end == 4
|
||||
|
||||
|
||||
def test_imax():
|
||||
a = imax(interval(-2, 2), interval(2, 7), interval(-3, 9))
|
||||
assert a.start == 2
|
||||
assert a.end == 9
|
||||
|
||||
a = imax(8, interval(1, 4))
|
||||
assert a.start == 8
|
||||
assert a.end == 8
|
||||
|
||||
a = imax(interval(1, 2), interval(3, 4), interval(-2, 2, is_valid=False))
|
||||
assert a.start == 3
|
||||
assert a.end == 4
|
||||
|
||||
|
||||
def test_sinh():
|
||||
a = sinh(interval(-1, 1))
|
||||
assert a.start == np.sinh(-1)
|
||||
assert a.end == np.sinh(1)
|
||||
|
||||
a = sinh(1)
|
||||
assert a.start == np.sinh(1)
|
||||
assert a.end == np.sinh(1)
|
||||
|
||||
|
||||
def test_cosh():
|
||||
a = cosh(interval(1, 2))
|
||||
assert a.start == np.cosh(1)
|
||||
assert a.end == np.cosh(2)
|
||||
a = cosh(interval(-2, -1))
|
||||
assert a.start == np.cosh(-1)
|
||||
assert a.end == np.cosh(-2)
|
||||
|
||||
a = cosh(interval(-2, 1))
|
||||
assert a.start == 1
|
||||
assert a.end == np.cosh(-2)
|
||||
|
||||
a = cosh(1)
|
||||
assert a.start == np.cosh(1)
|
||||
assert a.end == np.cosh(1)
|
||||
|
||||
|
||||
def test_tanh():
|
||||
a = tanh(interval(-3, 3))
|
||||
assert a.start == np.tanh(-3)
|
||||
assert a.end == np.tanh(3)
|
||||
|
||||
a = tanh(3)
|
||||
assert a.start == np.tanh(3)
|
||||
assert a.end == np.tanh(3)
|
||||
|
||||
|
||||
def test_asin():
|
||||
a = asin(interval(-0.5, 0.5))
|
||||
assert a.start == np.arcsin(-0.5)
|
||||
assert a.end == np.arcsin(0.5)
|
||||
|
||||
a = asin(interval(-1.5, 1.5))
|
||||
assert a.is_valid is None
|
||||
a = asin(interval(-2, -1.5))
|
||||
assert a.is_valid is False
|
||||
|
||||
a = asin(interval(0, 2))
|
||||
assert a.is_valid is None
|
||||
|
||||
a = asin(interval(2, 5))
|
||||
assert a.is_valid is False
|
||||
|
||||
a = asin(0.5)
|
||||
assert a.start == np.arcsin(0.5)
|
||||
assert a.end == np.arcsin(0.5)
|
||||
|
||||
a = asin(1.5)
|
||||
assert a.is_valid is False
|
||||
|
||||
|
||||
def test_acos():
|
||||
a = acos(interval(-0.5, 0.5))
|
||||
assert a.start == np.arccos(0.5)
|
||||
assert a.end == np.arccos(-0.5)
|
||||
|
||||
a = acos(interval(-1.5, 1.5))
|
||||
assert a.is_valid is None
|
||||
a = acos(interval(-2, -1.5))
|
||||
assert a.is_valid is False
|
||||
|
||||
a = acos(interval(0, 2))
|
||||
assert a.is_valid is None
|
||||
|
||||
a = acos(interval(2, 5))
|
||||
assert a.is_valid is False
|
||||
|
||||
a = acos(0.5)
|
||||
assert a.start == np.arccos(0.5)
|
||||
assert a.end == np.arccos(0.5)
|
||||
|
||||
a = acos(1.5)
|
||||
assert a.is_valid is False
|
||||
|
||||
|
||||
def test_ceil():
|
||||
a = ceil(interval(0.2, 0.5))
|
||||
assert a.start == 1
|
||||
assert a.end == 1
|
||||
|
||||
a = ceil(interval(0.5, 1.5))
|
||||
assert a.start == 1
|
||||
assert a.end == 2
|
||||
assert a.is_valid is None
|
||||
|
||||
a = ceil(interval(-5, 5))
|
||||
assert a.is_valid is None
|
||||
|
||||
a = ceil(5.4)
|
||||
assert a.start == 6
|
||||
assert a.end == 6
|
||||
|
||||
|
||||
def test_floor():
|
||||
a = floor(interval(0.2, 0.5))
|
||||
assert a.start == 0
|
||||
assert a.end == 0
|
||||
|
||||
a = floor(interval(0.5, 1.5))
|
||||
assert a.start == 0
|
||||
assert a.end == 1
|
||||
assert a.is_valid is None
|
||||
|
||||
a = floor(interval(-5, 5))
|
||||
assert a.is_valid is None
|
||||
|
||||
a = floor(5.4)
|
||||
assert a.start == 5
|
||||
assert a.end == 5
|
||||
|
||||
|
||||
def test_asinh():
|
||||
a = asinh(interval(1, 2))
|
||||
assert a.start == np.arcsinh(1)
|
||||
assert a.end == np.arcsinh(2)
|
||||
|
||||
a = asinh(0.5)
|
||||
assert a.start == np.arcsinh(0.5)
|
||||
assert a.end == np.arcsinh(0.5)
|
||||
|
||||
|
||||
def test_acosh():
|
||||
a = acosh(interval(3, 5))
|
||||
assert a.start == np.arccosh(3)
|
||||
assert a.end == np.arccosh(5)
|
||||
|
||||
a = acosh(interval(0, 3))
|
||||
assert a.is_valid is None
|
||||
a = acosh(interval(-3, 0.5))
|
||||
assert a.is_valid is False
|
||||
|
||||
a = acosh(0.5)
|
||||
assert a.is_valid is False
|
||||
|
||||
a = acosh(2)
|
||||
assert a.start == np.arccosh(2)
|
||||
assert a.end == np.arccosh(2)
|
||||
|
||||
|
||||
def test_atanh():
|
||||
a = atanh(interval(-0.5, 0.5))
|
||||
assert a.start == np.arctanh(-0.5)
|
||||
assert a.end == np.arctanh(0.5)
|
||||
|
||||
a = atanh(interval(0, 3))
|
||||
assert a.is_valid is None
|
||||
|
||||
a = atanh(interval(-3, -2))
|
||||
assert a.is_valid is False
|
||||
|
||||
a = atanh(0.5)
|
||||
assert a.start == np.arctanh(0.5)
|
||||
assert a.end == np.arctanh(0.5)
|
||||
|
||||
a = atanh(1.5)
|
||||
assert a.is_valid is False
|
||||
|
||||
|
||||
def test_Abs():
|
||||
assert (Abs(interval(-0.5, 0.5)) == interval(0, 0.5)) == (True, True)
|
||||
assert (Abs(interval(-3, -2)) == interval(2, 3)) == (True, True)
|
||||
assert (Abs(-3) == interval(3, 3)) == (True, True)
|
||||
|
||||
|
||||
def test_And():
|
||||
args = [(True, True), (True, False), (True, None)]
|
||||
assert And(*args) == (True, False)
|
||||
|
||||
args = [(False, True), (None, None), (True, True)]
|
||||
assert And(*args) == (False, None)
|
||||
|
||||
|
||||
def test_Or():
|
||||
args = [(True, True), (True, False), (False, None)]
|
||||
assert Or(*args) == (True, True)
|
||||
args = [(None, None), (False, None), (False, False)]
|
||||
assert Or(*args) == (None, None)
|
||||
@@ -0,0 +1,150 @@
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.plotting.intervalmath import interval
|
||||
from sympy.plotting.intervalmath.interval_membership import intervalMembership
|
||||
from sympy.plotting.experimental_lambdify import experimental_lambdify
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
|
||||
def test_creation():
|
||||
assert intervalMembership(True, True)
|
||||
raises(TypeError, lambda: intervalMembership(True))
|
||||
raises(TypeError, lambda: intervalMembership(True, True, True))
|
||||
|
||||
|
||||
def test_getitem():
|
||||
a = intervalMembership(True, False)
|
||||
assert a[0] is True
|
||||
assert a[1] is False
|
||||
raises(IndexError, lambda: a[2])
|
||||
|
||||
|
||||
def test_str():
|
||||
a = intervalMembership(True, False)
|
||||
assert str(a) == 'intervalMembership(True, False)'
|
||||
assert repr(a) == 'intervalMembership(True, False)'
|
||||
|
||||
|
||||
def test_equivalence():
|
||||
a = intervalMembership(True, True)
|
||||
b = intervalMembership(True, False)
|
||||
assert (a == b) is False
|
||||
assert (a != b) is True
|
||||
|
||||
a = intervalMembership(True, False)
|
||||
b = intervalMembership(True, False)
|
||||
assert (a == b) is True
|
||||
assert (a != b) is False
|
||||
|
||||
|
||||
def test_not():
|
||||
x = Symbol('x')
|
||||
|
||||
r1 = x > -1
|
||||
r2 = x <= -1
|
||||
|
||||
i = interval
|
||||
|
||||
f1 = experimental_lambdify((x,), r1)
|
||||
f2 = experimental_lambdify((x,), r2)
|
||||
|
||||
tt = i(-0.1, 0.1, is_valid=True)
|
||||
tn = i(-0.1, 0.1, is_valid=None)
|
||||
tf = i(-0.1, 0.1, is_valid=False)
|
||||
|
||||
assert f1(tt) == ~f2(tt)
|
||||
assert f1(tn) == ~f2(tn)
|
||||
assert f1(tf) == ~f2(tf)
|
||||
|
||||
nt = i(0.9, 1.1, is_valid=True)
|
||||
nn = i(0.9, 1.1, is_valid=None)
|
||||
nf = i(0.9, 1.1, is_valid=False)
|
||||
|
||||
assert f1(nt) == ~f2(nt)
|
||||
assert f1(nn) == ~f2(nn)
|
||||
assert f1(nf) == ~f2(nf)
|
||||
|
||||
ft = i(1.9, 2.1, is_valid=True)
|
||||
fn = i(1.9, 2.1, is_valid=None)
|
||||
ff = i(1.9, 2.1, is_valid=False)
|
||||
|
||||
assert f1(ft) == ~f2(ft)
|
||||
assert f1(fn) == ~f2(fn)
|
||||
assert f1(ff) == ~f2(ff)
|
||||
|
||||
|
||||
def test_boolean():
|
||||
# There can be 9*9 test cases in full mapping of the cartesian product.
|
||||
# But we only consider 3*3 cases for simplicity.
|
||||
s = [
|
||||
intervalMembership(False, False),
|
||||
intervalMembership(None, None),
|
||||
intervalMembership(True, True)
|
||||
]
|
||||
|
||||
# Reduced tests for 'And'
|
||||
a1 = [
|
||||
intervalMembership(False, False),
|
||||
intervalMembership(False, False),
|
||||
intervalMembership(False, False),
|
||||
intervalMembership(False, False),
|
||||
intervalMembership(None, None),
|
||||
intervalMembership(None, None),
|
||||
intervalMembership(False, False),
|
||||
intervalMembership(None, None),
|
||||
intervalMembership(True, True)
|
||||
]
|
||||
a1_iter = iter(a1)
|
||||
for i in range(len(s)):
|
||||
for j in range(len(s)):
|
||||
assert s[i] & s[j] == next(a1_iter)
|
||||
|
||||
# Reduced tests for 'Or'
|
||||
a1 = [
|
||||
intervalMembership(False, False),
|
||||
intervalMembership(None, False),
|
||||
intervalMembership(True, False),
|
||||
intervalMembership(None, False),
|
||||
intervalMembership(None, None),
|
||||
intervalMembership(True, None),
|
||||
intervalMembership(True, False),
|
||||
intervalMembership(True, None),
|
||||
intervalMembership(True, True)
|
||||
]
|
||||
a1_iter = iter(a1)
|
||||
for i in range(len(s)):
|
||||
for j in range(len(s)):
|
||||
assert s[i] | s[j] == next(a1_iter)
|
||||
|
||||
# Reduced tests for 'Xor'
|
||||
a1 = [
|
||||
intervalMembership(False, False),
|
||||
intervalMembership(None, False),
|
||||
intervalMembership(True, False),
|
||||
intervalMembership(None, False),
|
||||
intervalMembership(None, None),
|
||||
intervalMembership(None, None),
|
||||
intervalMembership(True, False),
|
||||
intervalMembership(None, None),
|
||||
intervalMembership(False, True)
|
||||
]
|
||||
a1_iter = iter(a1)
|
||||
for i in range(len(s)):
|
||||
for j in range(len(s)):
|
||||
assert s[i] ^ s[j] == next(a1_iter)
|
||||
|
||||
# Reduced tests for 'Not'
|
||||
a1 = [
|
||||
intervalMembership(True, False),
|
||||
intervalMembership(None, None),
|
||||
intervalMembership(False, True)
|
||||
]
|
||||
a1_iter = iter(a1)
|
||||
for i in range(len(s)):
|
||||
assert ~s[i] == next(a1_iter)
|
||||
|
||||
|
||||
def test_boolean_errors():
|
||||
a = intervalMembership(True, True)
|
||||
raises(ValueError, lambda: a & 1)
|
||||
raises(ValueError, lambda: a | 1)
|
||||
raises(ValueError, lambda: a ^ 1)
|
||||
@@ -0,0 +1,213 @@
|
||||
from sympy.plotting.intervalmath import interval
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
|
||||
def test_interval():
|
||||
assert (interval(1, 1) == interval(1, 1, is_valid=True)) == (True, True)
|
||||
assert (interval(1, 1) == interval(1, 1, is_valid=False)) == (True, False)
|
||||
assert (interval(1, 1) == interval(1, 1, is_valid=None)) == (True, None)
|
||||
assert (interval(1, 1.5) == interval(1, 2)) == (None, True)
|
||||
assert (interval(0, 1) == interval(2, 3)) == (False, True)
|
||||
assert (interval(0, 1) == interval(1, 2)) == (None, True)
|
||||
assert (interval(1, 2) != interval(1, 2)) == (False, True)
|
||||
assert (interval(1, 3) != interval(2, 3)) == (None, True)
|
||||
assert (interval(1, 3) != interval(-5, -3)) == (True, True)
|
||||
assert (
|
||||
interval(1, 3, is_valid=False) != interval(-5, -3)) == (True, False)
|
||||
assert (interval(1, 3, is_valid=None) != interval(-5, 3)) == (None, None)
|
||||
assert (interval(4, 4) != 4) == (False, True)
|
||||
assert (interval(1, 1) == 1) == (True, True)
|
||||
assert (interval(1, 3, is_valid=False) == interval(1, 3)) == (True, False)
|
||||
assert (interval(1, 3, is_valid=None) == interval(1, 3)) == (True, None)
|
||||
inter = interval(-5, 5)
|
||||
assert (interval(inter) == interval(-5, 5)) == (True, True)
|
||||
assert inter.width == 10
|
||||
assert 0 in inter
|
||||
assert -5 in inter
|
||||
assert 5 in inter
|
||||
assert interval(0, 3) in inter
|
||||
assert interval(-6, 2) not in inter
|
||||
assert -5.05 not in inter
|
||||
assert 5.3 not in inter
|
||||
interb = interval(-float('inf'), float('inf'))
|
||||
assert 0 in inter
|
||||
assert inter in interb
|
||||
assert interval(0, float('inf')) in interb
|
||||
assert interval(-float('inf'), 5) in interb
|
||||
assert interval(-1e50, 1e50) in interb
|
||||
assert (
|
||||
-interval(-1, -2, is_valid=False) == interval(1, 2)) == (True, False)
|
||||
raises(ValueError, lambda: interval(1, 2, 3))
|
||||
|
||||
|
||||
def test_interval_add():
|
||||
assert (interval(1, 2) + interval(2, 3) == interval(3, 5)) == (True, True)
|
||||
assert (1 + interval(1, 2) == interval(2, 3)) == (True, True)
|
||||
assert (interval(1, 2) + 1 == interval(2, 3)) == (True, True)
|
||||
compare = (1 + interval(0, float('inf')) == interval(1, float('inf')))
|
||||
assert compare == (True, True)
|
||||
a = 1 + interval(2, 5, is_valid=False)
|
||||
assert a.is_valid is False
|
||||
a = 1 + interval(2, 5, is_valid=None)
|
||||
assert a.is_valid is None
|
||||
a = interval(2, 5, is_valid=False) + interval(3, 5, is_valid=None)
|
||||
assert a.is_valid is False
|
||||
a = interval(3, 5) + interval(-1, 1, is_valid=None)
|
||||
assert a.is_valid is None
|
||||
a = interval(2, 5, is_valid=False) + 1
|
||||
assert a.is_valid is False
|
||||
|
||||
|
||||
def test_interval_sub():
|
||||
assert (interval(1, 2) - interval(1, 5) == interval(-4, 1)) == (True, True)
|
||||
assert (interval(1, 2) - 1 == interval(0, 1)) == (True, True)
|
||||
assert (1 - interval(1, 2) == interval(-1, 0)) == (True, True)
|
||||
a = 1 - interval(1, 2, is_valid=False)
|
||||
assert a.is_valid is False
|
||||
a = interval(1, 4, is_valid=None) - 1
|
||||
assert a.is_valid is None
|
||||
a = interval(1, 3, is_valid=False) - interval(1, 3)
|
||||
assert a.is_valid is False
|
||||
a = interval(1, 3, is_valid=None) - interval(1, 3)
|
||||
assert a.is_valid is None
|
||||
|
||||
|
||||
def test_interval_inequality():
|
||||
assert (interval(1, 2) < interval(3, 4)) == (True, True)
|
||||
assert (interval(1, 2) < interval(2, 4)) == (None, True)
|
||||
assert (interval(1, 2) < interval(-2, 0)) == (False, True)
|
||||
assert (interval(1, 2) <= interval(2, 4)) == (True, True)
|
||||
assert (interval(1, 2) <= interval(1.5, 6)) == (None, True)
|
||||
assert (interval(2, 3) <= interval(1, 2)) == (None, True)
|
||||
assert (interval(2, 3) <= interval(1, 1.5)) == (False, True)
|
||||
assert (
|
||||
interval(1, 2, is_valid=False) <= interval(-2, 0)) == (False, False)
|
||||
assert (interval(1, 2, is_valid=None) <= interval(-2, 0)) == (False, None)
|
||||
assert (interval(1, 2) <= 1.5) == (None, True)
|
||||
assert (interval(1, 2) <= 3) == (True, True)
|
||||
assert (interval(1, 2) <= 0) == (False, True)
|
||||
assert (interval(5, 8) > interval(2, 3)) == (True, True)
|
||||
assert (interval(2, 5) > interval(1, 3)) == (None, True)
|
||||
assert (interval(2, 3) > interval(3.1, 5)) == (False, True)
|
||||
|
||||
assert (interval(-1, 1) == 0) == (None, True)
|
||||
assert (interval(-1, 1) == 2) == (False, True)
|
||||
assert (interval(-1, 1) != 0) == (None, True)
|
||||
assert (interval(-1, 1) != 2) == (True, True)
|
||||
|
||||
assert (interval(3, 5) > 2) == (True, True)
|
||||
assert (interval(3, 5) < 2) == (False, True)
|
||||
assert (interval(1, 5) < 2) == (None, True)
|
||||
assert (interval(1, 5) > 2) == (None, True)
|
||||
assert (interval(0, 1) > 2) == (False, True)
|
||||
assert (interval(1, 2) >= interval(0, 1)) == (True, True)
|
||||
assert (interval(1, 2) >= interval(0, 1.5)) == (None, True)
|
||||
assert (interval(1, 2) >= interval(3, 4)) == (False, True)
|
||||
assert (interval(1, 2) >= 0) == (True, True)
|
||||
assert (interval(1, 2) >= 1.2) == (None, True)
|
||||
assert (interval(1, 2) >= 3) == (False, True)
|
||||
assert (2 > interval(0, 1)) == (True, True)
|
||||
a = interval(-1, 1, is_valid=False) < interval(2, 5, is_valid=None)
|
||||
assert a == (True, False)
|
||||
a = interval(-1, 1, is_valid=None) < interval(2, 5, is_valid=False)
|
||||
assert a == (True, False)
|
||||
a = interval(-1, 1, is_valid=None) < interval(2, 5, is_valid=None)
|
||||
assert a == (True, None)
|
||||
a = interval(-1, 1, is_valid=False) > interval(-5, -2, is_valid=None)
|
||||
assert a == (True, False)
|
||||
a = interval(-1, 1, is_valid=None) > interval(-5, -2, is_valid=False)
|
||||
assert a == (True, False)
|
||||
a = interval(-1, 1, is_valid=None) > interval(-5, -2, is_valid=None)
|
||||
assert a == (True, None)
|
||||
|
||||
|
||||
def test_interval_mul():
|
||||
assert (
|
||||
interval(1, 5) * interval(2, 10) == interval(2, 50)) == (True, True)
|
||||
a = interval(-1, 1) * interval(2, 10) == interval(-10, 10)
|
||||
assert a == (True, True)
|
||||
|
||||
a = interval(-1, 1) * interval(-5, 3) == interval(-5, 5)
|
||||
assert a == (True, True)
|
||||
|
||||
assert (interval(1, 3) * 2 == interval(2, 6)) == (True, True)
|
||||
assert (3 * interval(-1, 2) == interval(-3, 6)) == (True, True)
|
||||
|
||||
a = 3 * interval(1, 2, is_valid=False)
|
||||
assert a.is_valid is False
|
||||
|
||||
a = 3 * interval(1, 2, is_valid=None)
|
||||
assert a.is_valid is None
|
||||
|
||||
a = interval(1, 5, is_valid=False) * interval(1, 2, is_valid=None)
|
||||
assert a.is_valid is False
|
||||
|
||||
|
||||
def test_interval_div():
|
||||
div = interval(1, 2, is_valid=False) / 3
|
||||
assert div == interval(-float('inf'), float('inf'), is_valid=False)
|
||||
|
||||
div = interval(1, 2, is_valid=None) / 3
|
||||
assert div == interval(-float('inf'), float('inf'), is_valid=None)
|
||||
|
||||
div = 3 / interval(1, 2, is_valid=None)
|
||||
assert div == interval(-float('inf'), float('inf'), is_valid=None)
|
||||
a = interval(1, 2) / 0
|
||||
assert a.is_valid is False
|
||||
a = interval(0.5, 1) / interval(-1, 0)
|
||||
assert a.is_valid is None
|
||||
a = interval(0, 1) / interval(0, 1)
|
||||
assert a.is_valid is None
|
||||
|
||||
a = interval(-1, 1) / interval(-1, 1)
|
||||
assert a.is_valid is None
|
||||
|
||||
a = interval(-1, 2) / interval(0.5, 1) == interval(-2.0, 4.0)
|
||||
assert a == (True, True)
|
||||
a = interval(0, 1) / interval(0.5, 1) == interval(0.0, 2.0)
|
||||
assert a == (True, True)
|
||||
a = interval(-1, 0) / interval(0.5, 1) == interval(-2.0, 0.0)
|
||||
assert a == (True, True)
|
||||
a = interval(-0.5, -0.25) / interval(0.5, 1) == interval(-1.0, -0.25)
|
||||
assert a == (True, True)
|
||||
a = interval(0.5, 1) / interval(0.5, 1) == interval(0.5, 2.0)
|
||||
assert a == (True, True)
|
||||
a = interval(0.5, 4) / interval(0.5, 1) == interval(0.5, 8.0)
|
||||
assert a == (True, True)
|
||||
a = interval(-1, -0.5) / interval(0.5, 1) == interval(-2.0, -0.5)
|
||||
assert a == (True, True)
|
||||
a = interval(-4, -0.5) / interval(0.5, 1) == interval(-8.0, -0.5)
|
||||
assert a == (True, True)
|
||||
a = interval(-1, 2) / interval(-2, -0.5) == interval(-4.0, 2.0)
|
||||
assert a == (True, True)
|
||||
a = interval(0, 1) / interval(-2, -0.5) == interval(-2.0, 0.0)
|
||||
assert a == (True, True)
|
||||
a = interval(-1, 0) / interval(-2, -0.5) == interval(0.0, 2.0)
|
||||
assert a == (True, True)
|
||||
a = interval(-0.5, -0.25) / interval(-2, -0.5) == interval(0.125, 1.0)
|
||||
assert a == (True, True)
|
||||
a = interval(0.5, 1) / interval(-2, -0.5) == interval(-2.0, -0.25)
|
||||
assert a == (True, True)
|
||||
a = interval(0.5, 4) / interval(-2, -0.5) == interval(-8.0, -0.25)
|
||||
assert a == (True, True)
|
||||
a = interval(-1, -0.5) / interval(-2, -0.5) == interval(0.25, 2.0)
|
||||
assert a == (True, True)
|
||||
a = interval(-4, -0.5) / interval(-2, -0.5) == interval(0.25, 8.0)
|
||||
assert a == (True, True)
|
||||
a = interval(-5, 5, is_valid=False) / 2
|
||||
assert a.is_valid is False
|
||||
|
||||
def test_hashable():
|
||||
'''
|
||||
test that interval objects are hashable.
|
||||
this is required in order to be able to put them into the cache, which
|
||||
appears to be necessary for plotting in py3k. For details, see:
|
||||
|
||||
https://github.com/sympy/sympy/pull/2101
|
||||
https://github.com/sympy/sympy/issues/6533
|
||||
'''
|
||||
hash(interval(1, 1))
|
||||
hash(interval(1, 1, is_valid=True))
|
||||
hash(interval(-4, -0.5))
|
||||
hash(interval(-2, -0.5))
|
||||
hash(interval(0.25, 8.0))
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,233 @@
|
||||
"""Implicit plotting module for SymPy.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The module implements a data series called ImplicitSeries which is used by
|
||||
``Plot`` class to plot implicit plots for different backends. The module,
|
||||
by default, implements plotting using interval arithmetic. It switches to a
|
||||
fall back algorithm if the expression cannot be plotted using interval arithmetic.
|
||||
It is also possible to specify to use the fall back algorithm for all plots.
|
||||
|
||||
Boolean combinations of expressions cannot be plotted by the fall back
|
||||
algorithm.
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
sympy.plotting.plot
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [1] Jeffrey Allen Tupper. Reliable Two-Dimensional Graphing Methods for
|
||||
Mathematical Formulae with Two Free Variables.
|
||||
|
||||
.. [2] Jeffrey Allen Tupper. Graphing Equations with Generalized Interval
|
||||
Arithmetic. Master's thesis. University of Toronto, 1996
|
||||
|
||||
"""
|
||||
|
||||
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.symbol import (Dummy, Symbol)
|
||||
from sympy.polys.polyutils import _sort_gens
|
||||
from sympy.plotting.series import ImplicitSeries, _set_discretization_points
|
||||
from sympy.plotting.plot import plot_factory
|
||||
from sympy.utilities.decorator import doctest_depends_on
|
||||
from sympy.utilities.iterables import flatten
|
||||
|
||||
|
||||
__doctest_requires__ = {'plot_implicit': ['matplotlib']}
|
||||
|
||||
|
||||
@doctest_depends_on(modules=('matplotlib',))
|
||||
def plot_implicit(expr, x_var=None, y_var=None, adaptive=True, depth=0,
|
||||
n=300, line_color="blue", show=True, **kwargs):
|
||||
"""A plot function to plot implicit equations / inequalities.
|
||||
|
||||
Arguments
|
||||
=========
|
||||
|
||||
- expr : The equation / inequality that is to be plotted.
|
||||
- x_var (optional) : symbol to plot on x-axis or tuple giving symbol
|
||||
and range as ``(symbol, xmin, xmax)``
|
||||
- y_var (optional) : symbol to plot on y-axis or tuple giving symbol
|
||||
and range as ``(symbol, ymin, ymax)``
|
||||
|
||||
If neither ``x_var`` nor ``y_var`` are given then the free symbols in the
|
||||
expression will be assigned in the order they are sorted.
|
||||
|
||||
The following keyword arguments can also be used:
|
||||
|
||||
- ``adaptive`` Boolean. The default value is set to True. It has to be
|
||||
set to False if you want to use a mesh grid.
|
||||
|
||||
- ``depth`` integer. The depth of recursion for adaptive mesh grid.
|
||||
Default value is 0. Takes value in the range (0, 4).
|
||||
|
||||
- ``n`` integer. The number of points if adaptive mesh grid is not
|
||||
used. Default value is 300. This keyword argument replaces ``points``,
|
||||
which should be considered deprecated.
|
||||
|
||||
- ``show`` Boolean. Default value is True. If set to False, the plot will
|
||||
not be shown. See ``Plot`` for further information.
|
||||
|
||||
- ``title`` string. The title for the plot.
|
||||
|
||||
- ``xlabel`` string. The label for the x-axis
|
||||
|
||||
- ``ylabel`` string. The label for the y-axis
|
||||
|
||||
Aesthetics options:
|
||||
|
||||
- ``line_color``: float or string. Specifies the color for the plot.
|
||||
See ``Plot`` to see how to set color for the plots.
|
||||
Default value is "Blue"
|
||||
|
||||
plot_implicit, by default, uses interval arithmetic to plot functions. If
|
||||
the expression cannot be plotted using interval arithmetic, it defaults to
|
||||
a generating a contour using a mesh grid of fixed number of points. By
|
||||
setting adaptive to False, you can force plot_implicit to use the mesh
|
||||
grid. The mesh grid method can be effective when adaptive plotting using
|
||||
interval arithmetic, fails to plot with small line width.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
Plot expressions:
|
||||
|
||||
.. plot::
|
||||
:context: reset
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> from sympy import plot_implicit, symbols, Eq, And
|
||||
>>> x, y = symbols('x y')
|
||||
|
||||
Without any ranges for the symbols in the expression:
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> p1 = plot_implicit(Eq(x**2 + y**2, 5))
|
||||
|
||||
With the range for the symbols:
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> p2 = plot_implicit(
|
||||
... Eq(x**2 + y**2, 3), (x, -3, 3), (y, -3, 3))
|
||||
|
||||
With depth of recursion as argument:
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> p3 = plot_implicit(
|
||||
... Eq(x**2 + y**2, 5), (x, -4, 4), (y, -4, 4), depth = 2)
|
||||
|
||||
Using mesh grid and not using adaptive meshing:
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> p4 = plot_implicit(
|
||||
... Eq(x**2 + y**2, 5), (x, -5, 5), (y, -2, 2),
|
||||
... adaptive=False)
|
||||
|
||||
Using mesh grid without using adaptive meshing with number of points
|
||||
specified:
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> p5 = plot_implicit(
|
||||
... Eq(x**2 + y**2, 5), (x, -5, 5), (y, -2, 2),
|
||||
... adaptive=False, n=400)
|
||||
|
||||
Plotting regions:
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> p6 = plot_implicit(y > x**2)
|
||||
|
||||
Plotting Using boolean conjunctions:
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> p7 = plot_implicit(And(y > x, y > -x))
|
||||
|
||||
When plotting an expression with a single variable (y - 1, for example),
|
||||
specify the x or the y variable explicitly:
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> p8 = plot_implicit(y - 1, y_var=y)
|
||||
>>> p9 = plot_implicit(x - 1, x_var=x)
|
||||
"""
|
||||
|
||||
xyvar = [i for i in (x_var, y_var) if i is not None]
|
||||
free_symbols = expr.free_symbols
|
||||
range_symbols = Tuple(*flatten(xyvar)).free_symbols
|
||||
undeclared = free_symbols - range_symbols
|
||||
if len(free_symbols & range_symbols) > 2:
|
||||
raise NotImplementedError("Implicit plotting is not implemented for "
|
||||
"more than 2 variables")
|
||||
|
||||
#Create default ranges if the range is not provided.
|
||||
default_range = Tuple(-5, 5)
|
||||
def _range_tuple(s):
|
||||
if isinstance(s, Symbol):
|
||||
return Tuple(s) + default_range
|
||||
if len(s) == 3:
|
||||
return Tuple(*s)
|
||||
raise ValueError('symbol or `(symbol, min, max)` expected but got %s' % s)
|
||||
|
||||
if len(xyvar) == 0:
|
||||
xyvar = list(_sort_gens(free_symbols))
|
||||
var_start_end_x = _range_tuple(xyvar[0])
|
||||
x = var_start_end_x[0]
|
||||
if len(xyvar) != 2:
|
||||
if x in undeclared or not undeclared:
|
||||
xyvar.append(Dummy('f(%s)' % x.name))
|
||||
else:
|
||||
xyvar.append(undeclared.pop())
|
||||
var_start_end_y = _range_tuple(xyvar[1])
|
||||
|
||||
kwargs = _set_discretization_points(kwargs, ImplicitSeries)
|
||||
series_argument = ImplicitSeries(
|
||||
expr, var_start_end_x, var_start_end_y,
|
||||
adaptive=adaptive, depth=depth,
|
||||
n=n, line_color=line_color)
|
||||
|
||||
#set the x and y limits
|
||||
kwargs['xlim'] = tuple(float(x) for x in var_start_end_x[1:])
|
||||
kwargs['ylim'] = tuple(float(y) for y in var_start_end_y[1:])
|
||||
# set the x and y labels
|
||||
kwargs.setdefault('xlabel', var_start_end_x[0])
|
||||
kwargs.setdefault('ylabel', var_start_end_y[0])
|
||||
p = plot_factory(series_argument, **kwargs)
|
||||
if show:
|
||||
p.show()
|
||||
return p
|
||||
@@ -0,0 +1,188 @@
|
||||
|
||||
from sympy.external import import_module
|
||||
import sympy.plotting.backends.base_backend as base_backend
|
||||
|
||||
|
||||
# N.B.
|
||||
# When changing the minimum module version for matplotlib, please change
|
||||
# the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py`
|
||||
|
||||
|
||||
__doctest_requires__ = {
|
||||
("PlotGrid",): ["matplotlib"],
|
||||
}
|
||||
|
||||
|
||||
class PlotGrid:
|
||||
"""This class helps to plot subplots from already created SymPy plots
|
||||
in a single figure.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> from sympy import symbols
|
||||
>>> from sympy.plotting import plot, plot3d, PlotGrid
|
||||
>>> x, y = symbols('x, y')
|
||||
>>> p1 = plot(x, x**2, x**3, (x, -5, 5))
|
||||
>>> p2 = plot((x**2, (x, -6, 6)), (x, (x, -5, 5)))
|
||||
>>> p3 = plot(x**3, (x, -5, 5))
|
||||
>>> p4 = plot3d(x*y, (x, -5, 5), (y, -5, 5))
|
||||
|
||||
Plotting vertically in a single line:
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> PlotGrid(2, 1, p1, p2)
|
||||
PlotGrid object containing:
|
||||
Plot[0]:Plot object containing:
|
||||
[0]: cartesian line: x for x over (-5.0, 5.0)
|
||||
[1]: cartesian line: x**2 for x over (-5.0, 5.0)
|
||||
[2]: cartesian line: x**3 for x over (-5.0, 5.0)
|
||||
Plot[1]:Plot object containing:
|
||||
[0]: cartesian line: x**2 for x over (-6.0, 6.0)
|
||||
[1]: cartesian line: x for x over (-5.0, 5.0)
|
||||
|
||||
Plotting horizontally in a single line:
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> PlotGrid(1, 3, p2, p3, p4)
|
||||
PlotGrid object containing:
|
||||
Plot[0]:Plot object containing:
|
||||
[0]: cartesian line: x**2 for x over (-6.0, 6.0)
|
||||
[1]: cartesian line: x for x over (-5.0, 5.0)
|
||||
Plot[1]:Plot object containing:
|
||||
[0]: cartesian line: x**3 for x over (-5.0, 5.0)
|
||||
Plot[2]:Plot object containing:
|
||||
[0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0)
|
||||
|
||||
Plotting in a grid form:
|
||||
|
||||
.. plot::
|
||||
:context: close-figs
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> PlotGrid(2, 2, p1, p2, p3, p4)
|
||||
PlotGrid object containing:
|
||||
Plot[0]:Plot object containing:
|
||||
[0]: cartesian line: x for x over (-5.0, 5.0)
|
||||
[1]: cartesian line: x**2 for x over (-5.0, 5.0)
|
||||
[2]: cartesian line: x**3 for x over (-5.0, 5.0)
|
||||
Plot[1]:Plot object containing:
|
||||
[0]: cartesian line: x**2 for x over (-6.0, 6.0)
|
||||
[1]: cartesian line: x for x over (-5.0, 5.0)
|
||||
Plot[2]:Plot object containing:
|
||||
[0]: cartesian line: x**3 for x over (-5.0, 5.0)
|
||||
Plot[3]:Plot object containing:
|
||||
[0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0)
|
||||
|
||||
"""
|
||||
def __init__(self, nrows, ncolumns, *args, show=True, size=None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
==========
|
||||
|
||||
nrows :
|
||||
The number of rows that should be in the grid of the
|
||||
required subplot.
|
||||
ncolumns :
|
||||
The number of columns that should be in the grid
|
||||
of the required subplot.
|
||||
|
||||
nrows and ncolumns together define the required grid.
|
||||
|
||||
Arguments
|
||||
=========
|
||||
|
||||
A list of predefined plot objects entered in a row-wise sequence
|
||||
i.e. plot objects which are to be in the top row of the required
|
||||
grid are written first, then the second row objects and so on
|
||||
|
||||
Keyword arguments
|
||||
=================
|
||||
|
||||
show : Boolean
|
||||
The default value is set to ``True``. Set show to ``False`` and
|
||||
the function will not display the subplot. The returned instance
|
||||
of the ``PlotGrid`` class can then be used to save or display the
|
||||
plot by calling the ``save()`` and ``show()`` methods
|
||||
respectively.
|
||||
size : (float, float), optional
|
||||
A tuple in the form (width, height) in inches to specify the size of
|
||||
the overall figure. The default value is set to ``None``, meaning
|
||||
the size will be set by the default backend.
|
||||
"""
|
||||
self.matplotlib = import_module('matplotlib',
|
||||
import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']},
|
||||
min_module_version='1.1.0', catch=(RuntimeError,))
|
||||
self.nrows = nrows
|
||||
self.ncolumns = ncolumns
|
||||
self._series = []
|
||||
self._fig = None
|
||||
self.args = args
|
||||
for arg in args:
|
||||
self._series.append(arg._series)
|
||||
self.size = size
|
||||
if show and self.matplotlib:
|
||||
self.show()
|
||||
|
||||
def _create_figure(self):
|
||||
gs = self.matplotlib.gridspec.GridSpec(self.nrows, self.ncolumns)
|
||||
mapping = {}
|
||||
c = 0
|
||||
for i in range(self.nrows):
|
||||
for j in range(self.ncolumns):
|
||||
if c < len(self.args):
|
||||
mapping[gs[i, j]] = self.args[c]
|
||||
c += 1
|
||||
|
||||
kw = {} if not self.size else {"figsize": self.size}
|
||||
self._fig = self.matplotlib.pyplot.figure(**kw)
|
||||
for spec, p in mapping.items():
|
||||
kw = ({"projection": "3d"} if (len(p._series) > 0 and
|
||||
p._series[0].is_3D) else {})
|
||||
cur_ax = self._fig.add_subplot(spec, **kw)
|
||||
p._plotgrid_fig = self._fig
|
||||
p._plotgrid_ax = cur_ax
|
||||
p.process_series()
|
||||
|
||||
@property
|
||||
def fig(self):
|
||||
if not self._fig:
|
||||
self._create_figure()
|
||||
return self._fig
|
||||
|
||||
@property
|
||||
def _backend(self):
|
||||
return self
|
||||
|
||||
def close(self):
|
||||
self.matplotlib.pyplot.close(self.fig)
|
||||
|
||||
def show(self):
|
||||
if base_backend._show:
|
||||
self.fig.tight_layout()
|
||||
self.matplotlib.pyplot.show()
|
||||
else:
|
||||
self.close()
|
||||
|
||||
def save(self, path):
|
||||
self.fig.savefig(path)
|
||||
|
||||
def __str__(self):
|
||||
plot_strs = [('Plot[%d]:' % i) + str(plot)
|
||||
for i, plot in enumerate(self.args)]
|
||||
|
||||
return 'PlotGrid object containing:\n' + '\n'.join(plot_strs)
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Plotting module that can plot 2D and 3D functions
|
||||
"""
|
||||
|
||||
from sympy.utilities.decorator import doctest_depends_on
|
||||
|
||||
@doctest_depends_on(modules=('pyglet',))
|
||||
def PygletPlot(*args, **kwargs):
|
||||
"""
|
||||
|
||||
Plot Examples
|
||||
=============
|
||||
|
||||
See examples/advanced/pyglet_plotting.py for many more examples.
|
||||
|
||||
>>> from sympy.plotting.pygletplot import PygletPlot as Plot
|
||||
>>> from sympy.abc import x, y, z
|
||||
|
||||
>>> Plot(x*y**3-y*x**3)
|
||||
[0]: -x**3*y + x*y**3, 'mode=cartesian'
|
||||
|
||||
>>> p = Plot()
|
||||
>>> p[1] = x*y
|
||||
>>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4)
|
||||
|
||||
>>> p = Plot()
|
||||
>>> p[1] = x**2+y**2
|
||||
>>> p[2] = -x**2-y**2
|
||||
|
||||
|
||||
Variable Intervals
|
||||
==================
|
||||
|
||||
The basic format is [var, min, max, steps], but the
|
||||
syntax is flexible and arguments left out are taken
|
||||
from the defaults for the current coordinate mode:
|
||||
|
||||
>>> Plot(x**2) # implies [x,-5,5,100]
|
||||
[0]: x**2, 'mode=cartesian'
|
||||
|
||||
>>> Plot(x**2, [], []) # [x,-1,1,40], [y,-1,1,40]
|
||||
[0]: x**2, 'mode=cartesian'
|
||||
>>> Plot(x**2-y**2, [100], [100]) # [x,-1,1,100], [y,-1,1,100]
|
||||
[0]: x**2 - y**2, 'mode=cartesian'
|
||||
>>> Plot(x**2, [x,-13,13,100])
|
||||
[0]: x**2, 'mode=cartesian'
|
||||
>>> Plot(x**2, [-13,13]) # [x,-13,13,100]
|
||||
[0]: x**2, 'mode=cartesian'
|
||||
>>> Plot(x**2, [x,-13,13]) # [x,-13,13,100]
|
||||
[0]: x**2, 'mode=cartesian'
|
||||
>>> Plot(1*x, [], [x], mode='cylindrical')
|
||||
... # [unbound_theta,0,2*Pi,40], [x,-1,1,20]
|
||||
[0]: x, 'mode=cartesian'
|
||||
|
||||
|
||||
Coordinate Modes
|
||||
================
|
||||
|
||||
Plot supports several curvilinear coordinate modes, and
|
||||
they independent for each plotted function. You can specify
|
||||
a coordinate mode explicitly with the 'mode' named argument,
|
||||
but it can be automatically determined for Cartesian or
|
||||
parametric plots, and therefore must only be specified for
|
||||
polar, cylindrical, and spherical modes.
|
||||
|
||||
Specifically, Plot(function arguments) and Plot[n] =
|
||||
(function arguments) will interpret your arguments as a
|
||||
Cartesian plot if you provide one function and a parametric
|
||||
plot if you provide two or three functions. Similarly, the
|
||||
arguments will be interpreted as a curve if one variable is
|
||||
used, and a surface if two are used.
|
||||
|
||||
Supported mode names by number of variables:
|
||||
|
||||
1: parametric, cartesian, polar
|
||||
2: parametric, cartesian, cylindrical = polar, spherical
|
||||
|
||||
>>> Plot(1, mode='spherical')
|
||||
|
||||
|
||||
Calculator-like Interface
|
||||
=========================
|
||||
|
||||
>>> p = Plot(visible=False)
|
||||
>>> f = x**2
|
||||
>>> p[1] = f
|
||||
>>> p[2] = f.diff(x)
|
||||
>>> p[3] = f.diff(x).diff(x)
|
||||
>>> p
|
||||
[1]: x**2, 'mode=cartesian'
|
||||
[2]: 2*x, 'mode=cartesian'
|
||||
[3]: 2, 'mode=cartesian'
|
||||
>>> p.show()
|
||||
>>> p.clear()
|
||||
>>> p
|
||||
<blank plot>
|
||||
>>> p[1] = x**2+y**2
|
||||
>>> p[1].style = 'solid'
|
||||
>>> p[2] = -x**2-y**2
|
||||
>>> p[2].style = 'wireframe'
|
||||
>>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4)
|
||||
>>> p[1].style = 'both'
|
||||
>>> p[2].style = 'both'
|
||||
>>> p.close()
|
||||
|
||||
|
||||
Plot Window Keyboard Controls
|
||||
=============================
|
||||
|
||||
Screen Rotation:
|
||||
X,Y axis Arrow Keys, A,S,D,W, Numpad 4,6,8,2
|
||||
Z axis Q,E, Numpad 7,9
|
||||
|
||||
Model Rotation:
|
||||
Z axis Z,C, Numpad 1,3
|
||||
|
||||
Zoom: R,F, PgUp,PgDn, Numpad +,-
|
||||
|
||||
Reset Camera: X, Numpad 5
|
||||
|
||||
Camera Presets:
|
||||
XY F1
|
||||
XZ F2
|
||||
YZ F3
|
||||
Perspective F4
|
||||
|
||||
Sensitivity Modifier: SHIFT
|
||||
|
||||
Axes Toggle:
|
||||
Visible F5
|
||||
Colors F6
|
||||
|
||||
Close Window: ESCAPE
|
||||
|
||||
=============================
|
||||
"""
|
||||
|
||||
from sympy.plotting.pygletplot.plot import PygletPlot
|
||||
return PygletPlot(*args, **kwargs)
|
||||
@@ -0,0 +1,336 @@
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from .util import interpolate, rinterpolate, create_bounds, update_bounds
|
||||
from sympy.utilities.iterables import sift
|
||||
|
||||
|
||||
class ColorGradient:
|
||||
colors = [0.4, 0.4, 0.4], [0.9, 0.9, 0.9]
|
||||
intervals = 0.0, 1.0
|
||||
|
||||
def __init__(self, *args):
|
||||
if len(args) == 2:
|
||||
self.colors = list(args)
|
||||
self.intervals = [0.0, 1.0]
|
||||
elif len(args) > 0:
|
||||
if len(args) % 2 != 0:
|
||||
raise ValueError("len(args) should be even")
|
||||
self.colors = [args[i] for i in range(1, len(args), 2)]
|
||||
self.intervals = [args[i] for i in range(0, len(args), 2)]
|
||||
assert len(self.colors) == len(self.intervals)
|
||||
|
||||
def copy(self):
|
||||
c = ColorGradient()
|
||||
c.colors = [e[::] for e in self.colors]
|
||||
c.intervals = self.intervals[::]
|
||||
return c
|
||||
|
||||
def _find_interval(self, v):
|
||||
m = len(self.intervals)
|
||||
i = 0
|
||||
while i < m - 1 and self.intervals[i] <= v:
|
||||
i += 1
|
||||
return i
|
||||
|
||||
def _interpolate_axis(self, axis, v):
|
||||
i = self._find_interval(v)
|
||||
v = rinterpolate(self.intervals[i - 1], self.intervals[i], v)
|
||||
return interpolate(self.colors[i - 1][axis], self.colors[i][axis], v)
|
||||
|
||||
def __call__(self, r, g, b):
|
||||
c = self._interpolate_axis
|
||||
return c(0, r), c(1, g), c(2, b)
|
||||
|
||||
default_color_schemes = {} # defined at the bottom of this file
|
||||
|
||||
|
||||
class ColorScheme:
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.f, self.gradient = None, ColorGradient()
|
||||
|
||||
if len(args) == 1 and not isinstance(args[0], Basic) and callable(args[0]):
|
||||
self.f = args[0]
|
||||
elif len(args) == 1 and isinstance(args[0], str):
|
||||
if args[0] in default_color_schemes:
|
||||
cs = default_color_schemes[args[0]]
|
||||
self.f, self.gradient = cs.f, cs.gradient.copy()
|
||||
else:
|
||||
self.f = lambdify('x,y,z,u,v', args[0])
|
||||
else:
|
||||
self.f, self.gradient = self._interpret_args(args)
|
||||
self._test_color_function()
|
||||
if not isinstance(self.gradient, ColorGradient):
|
||||
raise ValueError("Color gradient not properly initialized. "
|
||||
"(Not a ColorGradient instance.)")
|
||||
|
||||
def _interpret_args(self, args):
|
||||
f, gradient = None, self.gradient
|
||||
atoms, lists = self._sort_args(args)
|
||||
s = self._pop_symbol_list(lists)
|
||||
s = self._fill_in_vars(s)
|
||||
|
||||
# prepare the error message for lambdification failure
|
||||
f_str = ', '.join(str(fa) for fa in atoms)
|
||||
s_str = (str(sa) for sa in s)
|
||||
s_str = ', '.join(sa for sa in s_str if sa.find('unbound') < 0)
|
||||
f_error = ValueError("Could not interpret arguments "
|
||||
"%s as functions of %s." % (f_str, s_str))
|
||||
|
||||
# try to lambdify args
|
||||
if len(atoms) == 1:
|
||||
fv = atoms[0]
|
||||
try:
|
||||
f = lambdify(s, [fv, fv, fv])
|
||||
except TypeError:
|
||||
raise f_error
|
||||
|
||||
elif len(atoms) == 3:
|
||||
fr, fg, fb = atoms
|
||||
try:
|
||||
f = lambdify(s, [fr, fg, fb])
|
||||
except TypeError:
|
||||
raise f_error
|
||||
|
||||
else:
|
||||
raise ValueError("A ColorScheme must provide 1 or 3 "
|
||||
"functions in x, y, z, u, and/or v.")
|
||||
|
||||
# try to intrepret any given color information
|
||||
if len(lists) == 0:
|
||||
gargs = []
|
||||
|
||||
elif len(lists) == 1:
|
||||
gargs = lists[0]
|
||||
|
||||
elif len(lists) == 2:
|
||||
try:
|
||||
(r1, g1, b1), (r2, g2, b2) = lists
|
||||
except TypeError:
|
||||
raise ValueError("If two color arguments are given, "
|
||||
"they must be given in the format "
|
||||
"(r1, g1, b1), (r2, g2, b2).")
|
||||
gargs = lists
|
||||
|
||||
elif len(lists) == 3:
|
||||
try:
|
||||
(r1, r2), (g1, g2), (b1, b2) = lists
|
||||
except Exception:
|
||||
raise ValueError("If three color arguments are given, "
|
||||
"they must be given in the format "
|
||||
"(r1, r2), (g1, g2), (b1, b2). To create "
|
||||
"a multi-step gradient, use the syntax "
|
||||
"[0, colorStart, step1, color1, ..., 1, "
|
||||
"colorEnd].")
|
||||
gargs = [[r1, g1, b1], [r2, g2, b2]]
|
||||
|
||||
else:
|
||||
raise ValueError("Don't know what to do with collection "
|
||||
"arguments %s." % (', '.join(str(l) for l in lists)))
|
||||
|
||||
if gargs:
|
||||
try:
|
||||
gradient = ColorGradient(*gargs)
|
||||
except Exception as ex:
|
||||
raise ValueError(("Could not initialize a gradient "
|
||||
"with arguments %s. Inner "
|
||||
"exception: %s") % (gargs, str(ex)))
|
||||
|
||||
return f, gradient
|
||||
|
||||
def _pop_symbol_list(self, lists):
|
||||
symbol_lists = []
|
||||
for l in lists:
|
||||
mark = True
|
||||
for s in l:
|
||||
if s is not None and not isinstance(s, Symbol):
|
||||
mark = False
|
||||
break
|
||||
if mark:
|
||||
lists.remove(l)
|
||||
symbol_lists.append(l)
|
||||
if len(symbol_lists) == 1:
|
||||
return symbol_lists[0]
|
||||
elif len(symbol_lists) == 0:
|
||||
return []
|
||||
else:
|
||||
raise ValueError("Only one list of Symbols "
|
||||
"can be given for a color scheme.")
|
||||
|
||||
def _fill_in_vars(self, args):
|
||||
defaults = symbols('x,y,z,u,v')
|
||||
v_error = ValueError("Could not find what to plot.")
|
||||
if len(args) == 0:
|
||||
return defaults
|
||||
if not isinstance(args, (tuple, list)):
|
||||
raise v_error
|
||||
if len(args) == 0:
|
||||
return defaults
|
||||
for s in args:
|
||||
if s is not None and not isinstance(s, Symbol):
|
||||
raise v_error
|
||||
# when vars are given explicitly, any vars
|
||||
# not given are marked 'unbound' as to not
|
||||
# be accidentally used in an expression
|
||||
vars = [Symbol('unbound%i' % (i)) for i in range(1, 6)]
|
||||
# interpret as t
|
||||
if len(args) == 1:
|
||||
vars[3] = args[0]
|
||||
# interpret as u,v
|
||||
elif len(args) == 2:
|
||||
if args[0] is not None:
|
||||
vars[3] = args[0]
|
||||
if args[1] is not None:
|
||||
vars[4] = args[1]
|
||||
# interpret as x,y,z
|
||||
elif len(args) >= 3:
|
||||
# allow some of x,y,z to be
|
||||
# left unbound if not given
|
||||
if args[0] is not None:
|
||||
vars[0] = args[0]
|
||||
if args[1] is not None:
|
||||
vars[1] = args[1]
|
||||
if args[2] is not None:
|
||||
vars[2] = args[2]
|
||||
# interpret the rest as t
|
||||
if len(args) >= 4:
|
||||
vars[3] = args[3]
|
||||
# ...or u,v
|
||||
if len(args) >= 5:
|
||||
vars[4] = args[4]
|
||||
return vars
|
||||
|
||||
def _sort_args(self, args):
|
||||
lists, atoms = sift(args,
|
||||
lambda a: isinstance(a, (tuple, list)), binary=True)
|
||||
return atoms, lists
|
||||
|
||||
def _test_color_function(self):
|
||||
if not callable(self.f):
|
||||
raise ValueError("Color function is not callable.")
|
||||
try:
|
||||
result = self.f(0, 0, 0, 0, 0)
|
||||
if len(result) != 3:
|
||||
raise ValueError("length should be equal to 3")
|
||||
except TypeError:
|
||||
raise ValueError("Color function needs to accept x,y,z,u,v, "
|
||||
"as arguments even if it doesn't use all of them.")
|
||||
except AssertionError:
|
||||
raise ValueError("Color function needs to return 3-tuple r,g,b.")
|
||||
except Exception:
|
||||
pass # color function probably not valid at 0,0,0,0,0
|
||||
|
||||
def __call__(self, x, y, z, u, v):
|
||||
try:
|
||||
return self.f(x, y, z, u, v)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def apply_to_curve(self, verts, u_set, set_len=None, inc_pos=None):
|
||||
"""
|
||||
Apply this color scheme to a
|
||||
set of vertices over a single
|
||||
independent variable u.
|
||||
"""
|
||||
bounds = create_bounds()
|
||||
cverts = []
|
||||
if callable(set_len):
|
||||
set_len(len(u_set)*2)
|
||||
# calculate f() = r,g,b for each vert
|
||||
# and find the min and max for r,g,b
|
||||
for _u in range(len(u_set)):
|
||||
if verts[_u] is None:
|
||||
cverts.append(None)
|
||||
else:
|
||||
x, y, z = verts[_u]
|
||||
u, v = u_set[_u], None
|
||||
c = self(x, y, z, u, v)
|
||||
if c is not None:
|
||||
c = list(c)
|
||||
update_bounds(bounds, c)
|
||||
cverts.append(c)
|
||||
if callable(inc_pos):
|
||||
inc_pos()
|
||||
# scale and apply gradient
|
||||
for _u in range(len(u_set)):
|
||||
if cverts[_u] is not None:
|
||||
for _c in range(3):
|
||||
# scale from [f_min, f_max] to [0,1]
|
||||
cverts[_u][_c] = rinterpolate(bounds[_c][0], bounds[_c][1],
|
||||
cverts[_u][_c])
|
||||
# apply gradient
|
||||
cverts[_u] = self.gradient(*cverts[_u])
|
||||
if callable(inc_pos):
|
||||
inc_pos()
|
||||
return cverts
|
||||
|
||||
def apply_to_surface(self, verts, u_set, v_set, set_len=None, inc_pos=None):
|
||||
"""
|
||||
Apply this color scheme to a
|
||||
set of vertices over two
|
||||
independent variables u and v.
|
||||
"""
|
||||
bounds = create_bounds()
|
||||
cverts = []
|
||||
if callable(set_len):
|
||||
set_len(len(u_set)*len(v_set)*2)
|
||||
# calculate f() = r,g,b for each vert
|
||||
# and find the min and max for r,g,b
|
||||
for _u in range(len(u_set)):
|
||||
column = []
|
||||
for _v in range(len(v_set)):
|
||||
if verts[_u][_v] is None:
|
||||
column.append(None)
|
||||
else:
|
||||
x, y, z = verts[_u][_v]
|
||||
u, v = u_set[_u], v_set[_v]
|
||||
c = self(x, y, z, u, v)
|
||||
if c is not None:
|
||||
c = list(c)
|
||||
update_bounds(bounds, c)
|
||||
column.append(c)
|
||||
if callable(inc_pos):
|
||||
inc_pos()
|
||||
cverts.append(column)
|
||||
# scale and apply gradient
|
||||
for _u in range(len(u_set)):
|
||||
for _v in range(len(v_set)):
|
||||
if cverts[_u][_v] is not None:
|
||||
# scale from [f_min, f_max] to [0,1]
|
||||
for _c in range(3):
|
||||
cverts[_u][_v][_c] = rinterpolate(bounds[_c][0],
|
||||
bounds[_c][1], cverts[_u][_v][_c])
|
||||
# apply gradient
|
||||
cverts[_u][_v] = self.gradient(*cverts[_u][_v])
|
||||
if callable(inc_pos):
|
||||
inc_pos()
|
||||
return cverts
|
||||
|
||||
def str_base(self):
|
||||
return ", ".join(str(a) for a in self.args)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s" % (self.str_base())
|
||||
|
||||
|
||||
x, y, z, t, u, v = symbols('x,y,z,t,u,v')
|
||||
|
||||
default_color_schemes['rainbow'] = ColorScheme(z, y, x)
|
||||
default_color_schemes['zfade'] = ColorScheme(z, (0.4, 0.4, 0.97),
|
||||
(0.97, 0.4, 0.4), (None, None, z))
|
||||
default_color_schemes['zfade3'] = ColorScheme(z, (None, None, z),
|
||||
[0.00, (0.2, 0.2, 1.0),
|
||||
0.35, (0.2, 0.8, 0.4),
|
||||
0.50, (0.3, 0.9, 0.3),
|
||||
0.65, (0.4, 0.8, 0.2),
|
||||
1.00, (1.0, 0.2, 0.2)])
|
||||
|
||||
default_color_schemes['zfade4'] = ColorScheme(z, (None, None, z),
|
||||
[0.0, (0.3, 0.3, 1.0),
|
||||
0.30, (0.3, 1.0, 0.3),
|
||||
0.55, (0.95, 1.0, 0.2),
|
||||
0.65, (1.0, 0.95, 0.2),
|
||||
0.85, (1.0, 0.7, 0.2),
|
||||
1.0, (1.0, 0.3, 0.2)])
|
||||
@@ -0,0 +1,106 @@
|
||||
from pyglet.window import Window
|
||||
from pyglet.clock import Clock
|
||||
|
||||
from threading import Thread, Lock
|
||||
|
||||
gl_lock = Lock()
|
||||
|
||||
|
||||
class ManagedWindow(Window):
|
||||
"""
|
||||
A pyglet window with an event loop which executes automatically
|
||||
in a separate thread. Behavior is added by creating a subclass
|
||||
which overrides setup, update, and/or draw.
|
||||
"""
|
||||
fps_limit = 30
|
||||
default_win_args = {"width": 600,
|
||||
"height": 500,
|
||||
"vsync": False,
|
||||
"resizable": True}
|
||||
|
||||
def __init__(self, **win_args):
|
||||
"""
|
||||
It is best not to override this function in the child
|
||||
class, unless you need to take additional arguments.
|
||||
Do any OpenGL initialization calls in setup().
|
||||
"""
|
||||
|
||||
# check if this is run from the doctester
|
||||
if win_args.get('runfromdoctester', False):
|
||||
return
|
||||
|
||||
self.win_args = dict(self.default_win_args, **win_args)
|
||||
self.Thread = Thread(target=self.__event_loop__)
|
||||
self.Thread.start()
|
||||
|
||||
def __event_loop__(self, **win_args):
|
||||
"""
|
||||
The event loop thread function. Do not override or call
|
||||
directly (it is called by __init__).
|
||||
"""
|
||||
gl_lock.acquire()
|
||||
try:
|
||||
try:
|
||||
super().__init__(**self.win_args)
|
||||
self.switch_to()
|
||||
self.setup()
|
||||
except Exception as e:
|
||||
print("Window initialization failed: %s" % (str(e)))
|
||||
self.has_exit = True
|
||||
finally:
|
||||
gl_lock.release()
|
||||
|
||||
clock = Clock()
|
||||
clock.fps_limit = self.fps_limit
|
||||
while not self.has_exit:
|
||||
dt = clock.tick()
|
||||
gl_lock.acquire()
|
||||
try:
|
||||
try:
|
||||
self.switch_to()
|
||||
self.dispatch_events()
|
||||
self.clear()
|
||||
self.update(dt)
|
||||
self.draw()
|
||||
self.flip()
|
||||
except Exception as e:
|
||||
print("Uncaught exception in event loop: %s" % str(e))
|
||||
self.has_exit = True
|
||||
finally:
|
||||
gl_lock.release()
|
||||
super().close()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Closes the window.
|
||||
"""
|
||||
self.has_exit = True
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Called once before the event loop begins.
|
||||
Override this method in a child class. This
|
||||
is the best place to put things like OpenGL
|
||||
initialization calls.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update(self, dt):
|
||||
"""
|
||||
Called before draw during each iteration of
|
||||
the event loop. dt is the elapsed time in
|
||||
seconds since the last update. OpenGL rendering
|
||||
calls are best put in draw() rather than here.
|
||||
"""
|
||||
pass
|
||||
|
||||
def draw(self):
|
||||
"""
|
||||
Called after update during each iteration of
|
||||
the event loop. Put OpenGL rendering calls
|
||||
here.
|
||||
"""
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
ManagedWindow()
|
||||
@@ -0,0 +1,464 @@
|
||||
from threading import RLock
|
||||
|
||||
# it is sufficient to import "pyglet" here once
|
||||
try:
|
||||
import pyglet.gl as pgl
|
||||
except ImportError:
|
||||
raise ImportError("pyglet is required for plotting.\n "
|
||||
"visit https://pyglet.org/")
|
||||
|
||||
from sympy.core.numbers import Integer
|
||||
from sympy.external.gmpy import SYMPY_INTS
|
||||
from sympy.geometry.entity import GeometryEntity
|
||||
from sympy.plotting.pygletplot.plot_axes import PlotAxes
|
||||
from sympy.plotting.pygletplot.plot_mode import PlotMode
|
||||
from sympy.plotting.pygletplot.plot_object import PlotObject
|
||||
from sympy.plotting.pygletplot.plot_window import PlotWindow
|
||||
from sympy.plotting.pygletplot.util import parse_option_string
|
||||
from sympy.utilities.decorator import doctest_depends_on
|
||||
from sympy.utilities.iterables import is_sequence
|
||||
|
||||
from time import sleep
|
||||
from os import getcwd, listdir
|
||||
|
||||
import ctypes
|
||||
|
||||
@doctest_depends_on(modules=('pyglet',))
|
||||
class PygletPlot:
|
||||
"""
|
||||
Plot Examples
|
||||
=============
|
||||
|
||||
See examples/advanced/pyglet_plotting.py for many more examples.
|
||||
|
||||
>>> from sympy.plotting.pygletplot import PygletPlot as Plot
|
||||
>>> from sympy.abc import x, y, z
|
||||
|
||||
>>> Plot(x*y**3-y*x**3)
|
||||
[0]: -x**3*y + x*y**3, 'mode=cartesian'
|
||||
|
||||
>>> p = Plot()
|
||||
>>> p[1] = x*y
|
||||
>>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4)
|
||||
|
||||
>>> p = Plot()
|
||||
>>> p[1] = x**2+y**2
|
||||
>>> p[2] = -x**2-y**2
|
||||
|
||||
|
||||
Variable Intervals
|
||||
==================
|
||||
|
||||
The basic format is [var, min, max, steps], but the
|
||||
syntax is flexible and arguments left out are taken
|
||||
from the defaults for the current coordinate mode:
|
||||
|
||||
>>> Plot(x**2) # implies [x,-5,5,100]
|
||||
[0]: x**2, 'mode=cartesian'
|
||||
>>> Plot(x**2, [], []) # [x,-1,1,40], [y,-1,1,40]
|
||||
[0]: x**2, 'mode=cartesian'
|
||||
>>> Plot(x**2-y**2, [100], [100]) # [x,-1,1,100], [y,-1,1,100]
|
||||
[0]: x**2 - y**2, 'mode=cartesian'
|
||||
>>> Plot(x**2, [x,-13,13,100])
|
||||
[0]: x**2, 'mode=cartesian'
|
||||
>>> Plot(x**2, [-13,13]) # [x,-13,13,100]
|
||||
[0]: x**2, 'mode=cartesian'
|
||||
>>> Plot(x**2, [x,-13,13]) # [x,-13,13,10]
|
||||
[0]: x**2, 'mode=cartesian'
|
||||
>>> Plot(1*x, [], [x], mode='cylindrical')
|
||||
... # [unbound_theta,0,2*Pi,40], [x,-1,1,20]
|
||||
[0]: x, 'mode=cartesian'
|
||||
|
||||
|
||||
Coordinate Modes
|
||||
================
|
||||
|
||||
Plot supports several curvilinear coordinate modes, and
|
||||
they independent for each plotted function. You can specify
|
||||
a coordinate mode explicitly with the 'mode' named argument,
|
||||
but it can be automatically determined for Cartesian or
|
||||
parametric plots, and therefore must only be specified for
|
||||
polar, cylindrical, and spherical modes.
|
||||
|
||||
Specifically, Plot(function arguments) and Plot[n] =
|
||||
(function arguments) will interpret your arguments as a
|
||||
Cartesian plot if you provide one function and a parametric
|
||||
plot if you provide two or three functions. Similarly, the
|
||||
arguments will be interpreted as a curve if one variable is
|
||||
used, and a surface if two are used.
|
||||
|
||||
Supported mode names by number of variables:
|
||||
|
||||
1: parametric, cartesian, polar
|
||||
2: parametric, cartesian, cylindrical = polar, spherical
|
||||
|
||||
>>> Plot(1, mode='spherical')
|
||||
|
||||
|
||||
Calculator-like Interface
|
||||
=========================
|
||||
|
||||
>>> p = Plot(visible=False)
|
||||
>>> f = x**2
|
||||
>>> p[1] = f
|
||||
>>> p[2] = f.diff(x)
|
||||
>>> p[3] = f.diff(x).diff(x)
|
||||
>>> p
|
||||
[1]: x**2, 'mode=cartesian'
|
||||
[2]: 2*x, 'mode=cartesian'
|
||||
[3]: 2, 'mode=cartesian'
|
||||
>>> p.show()
|
||||
>>> p.clear()
|
||||
>>> p
|
||||
<blank plot>
|
||||
>>> p[1] = x**2+y**2
|
||||
>>> p[1].style = 'solid'
|
||||
>>> p[2] = -x**2-y**2
|
||||
>>> p[2].style = 'wireframe'
|
||||
>>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4)
|
||||
>>> p[1].style = 'both'
|
||||
>>> p[2].style = 'both'
|
||||
>>> p.close()
|
||||
|
||||
|
||||
Plot Window Keyboard Controls
|
||||
=============================
|
||||
|
||||
Screen Rotation:
|
||||
X,Y axis Arrow Keys, A,S,D,W, Numpad 4,6,8,2
|
||||
Z axis Q,E, Numpad 7,9
|
||||
|
||||
Model Rotation:
|
||||
Z axis Z,C, Numpad 1,3
|
||||
|
||||
Zoom: R,F, PgUp,PgDn, Numpad +,-
|
||||
|
||||
Reset Camera: X, Numpad 5
|
||||
|
||||
Camera Presets:
|
||||
XY F1
|
||||
XZ F2
|
||||
YZ F3
|
||||
Perspective F4
|
||||
|
||||
Sensitivity Modifier: SHIFT
|
||||
|
||||
Axes Toggle:
|
||||
Visible F5
|
||||
Colors F6
|
||||
|
||||
Close Window: ESCAPE
|
||||
|
||||
=============================
|
||||
|
||||
"""
|
||||
|
||||
@doctest_depends_on(modules=('pyglet',))
|
||||
def __init__(self, *fargs, **win_args):
|
||||
"""
|
||||
Positional Arguments
|
||||
====================
|
||||
|
||||
Any given positional arguments are used to
|
||||
initialize a plot function at index 1. In
|
||||
other words...
|
||||
|
||||
>>> from sympy.plotting.pygletplot import PygletPlot as Plot
|
||||
>>> from sympy.abc import x
|
||||
>>> p = Plot(x**2, visible=False)
|
||||
|
||||
...is equivalent to...
|
||||
|
||||
>>> p = Plot(visible=False)
|
||||
>>> p[1] = x**2
|
||||
|
||||
Note that in earlier versions of the plotting
|
||||
module, you were able to specify multiple
|
||||
functions in the initializer. This functionality
|
||||
has been dropped in favor of better automatic
|
||||
plot plot_mode detection.
|
||||
|
||||
|
||||
Named Arguments
|
||||
===============
|
||||
|
||||
axes
|
||||
An option string of the form
|
||||
"key1=value1; key2 = value2" which
|
||||
can use the following options:
|
||||
|
||||
style = ordinate
|
||||
none OR frame OR box OR ordinate
|
||||
|
||||
stride = 0.25
|
||||
val OR (val_x, val_y, val_z)
|
||||
|
||||
overlay = True (draw on top of plot)
|
||||
True OR False
|
||||
|
||||
colored = False (False uses Black,
|
||||
True uses colors
|
||||
R,G,B = X,Y,Z)
|
||||
True OR False
|
||||
|
||||
label_axes = False (display axis names
|
||||
at endpoints)
|
||||
True OR False
|
||||
|
||||
visible = True (show immediately
|
||||
True OR False
|
||||
|
||||
|
||||
The following named arguments are passed as
|
||||
arguments to window initialization:
|
||||
|
||||
antialiasing = True
|
||||
True OR False
|
||||
|
||||
ortho = False
|
||||
True OR False
|
||||
|
||||
invert_mouse_zoom = False
|
||||
True OR False
|
||||
|
||||
"""
|
||||
# Register the plot modes
|
||||
from . import plot_modes # noqa
|
||||
|
||||
self._win_args = win_args
|
||||
self._window = None
|
||||
|
||||
self._render_lock = RLock()
|
||||
|
||||
self._functions = {}
|
||||
self._pobjects = []
|
||||
self._screenshot = ScreenShot(self)
|
||||
|
||||
axe_options = parse_option_string(win_args.pop('axes', ''))
|
||||
self.axes = PlotAxes(**axe_options)
|
||||
self._pobjects.append(self.axes)
|
||||
|
||||
self[0] = fargs
|
||||
if win_args.get('visible', True):
|
||||
self.show()
|
||||
|
||||
## Window Interfaces
|
||||
|
||||
def show(self):
|
||||
"""
|
||||
Creates and displays a plot window, or activates it
|
||||
(gives it focus) if it has already been created.
|
||||
"""
|
||||
if self._window and not self._window.has_exit:
|
||||
self._window.activate()
|
||||
else:
|
||||
self._win_args['visible'] = True
|
||||
self.axes.reset_resources()
|
||||
|
||||
#if hasattr(self, '_doctest_depends_on'):
|
||||
# self._win_args['runfromdoctester'] = True
|
||||
|
||||
self._window = PlotWindow(self, **self._win_args)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Closes the plot window.
|
||||
"""
|
||||
if self._window:
|
||||
self._window.close()
|
||||
|
||||
def saveimage(self, outfile=None, format='', size=(600, 500)):
|
||||
"""
|
||||
Saves a screen capture of the plot window to an
|
||||
image file.
|
||||
|
||||
If outfile is given, it can either be a path
|
||||
or a file object. Otherwise a png image will
|
||||
be saved to the current working directory.
|
||||
If the format is omitted, it is determined from
|
||||
the filename extension.
|
||||
"""
|
||||
self._screenshot.save(outfile, format, size)
|
||||
|
||||
## Function List Interfaces
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Clears the function list of this plot.
|
||||
"""
|
||||
self._render_lock.acquire()
|
||||
self._functions = {}
|
||||
self.adjust_all_bounds()
|
||||
self._render_lock.release()
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""
|
||||
Returns the function at position i in the
|
||||
function list.
|
||||
"""
|
||||
return self._functions[i]
|
||||
|
||||
def __setitem__(self, i, args):
|
||||
"""
|
||||
Parses and adds a PlotMode to the function
|
||||
list.
|
||||
"""
|
||||
if not (isinstance(i, (SYMPY_INTS, Integer)) and i >= 0):
|
||||
raise ValueError("Function index must "
|
||||
"be an integer >= 0.")
|
||||
|
||||
if isinstance(args, PlotObject):
|
||||
f = args
|
||||
else:
|
||||
if (not is_sequence(args)) or isinstance(args, GeometryEntity):
|
||||
args = [args]
|
||||
if len(args) == 0:
|
||||
return # no arguments given
|
||||
kwargs = {"bounds_callback": self.adjust_all_bounds}
|
||||
f = PlotMode(*args, **kwargs)
|
||||
|
||||
if f:
|
||||
self._render_lock.acquire()
|
||||
self._functions[i] = f
|
||||
self._render_lock.release()
|
||||
else:
|
||||
raise ValueError("Failed to parse '%s'."
|
||||
% ', '.join(str(a) for a in args))
|
||||
|
||||
def __delitem__(self, i):
|
||||
"""
|
||||
Removes the function in the function list at
|
||||
position i.
|
||||
"""
|
||||
self._render_lock.acquire()
|
||||
del self._functions[i]
|
||||
self.adjust_all_bounds()
|
||||
self._render_lock.release()
|
||||
|
||||
def firstavailableindex(self):
|
||||
"""
|
||||
Returns the first unused index in the function list.
|
||||
"""
|
||||
i = 0
|
||||
self._render_lock.acquire()
|
||||
while i in self._functions:
|
||||
i += 1
|
||||
self._render_lock.release()
|
||||
return i
|
||||
|
||||
def append(self, *args):
|
||||
"""
|
||||
Parses and adds a PlotMode to the function
|
||||
list at the first available index.
|
||||
"""
|
||||
self.__setitem__(self.firstavailableindex(), args)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns the number of functions in the function list.
|
||||
"""
|
||||
return len(self._functions)
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Allows iteration of the function list.
|
||||
"""
|
||||
return self._functions.itervalues()
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Returns a string containing a new-line separated
|
||||
list of the functions in the function list.
|
||||
"""
|
||||
s = ""
|
||||
if len(self._functions) == 0:
|
||||
s += "<blank plot>"
|
||||
else:
|
||||
self._render_lock.acquire()
|
||||
s += "\n".join(["%s[%i]: %s" % ("", i, str(self._functions[i]))
|
||||
for i in self._functions])
|
||||
self._render_lock.release()
|
||||
return s
|
||||
|
||||
def adjust_all_bounds(self):
|
||||
self._render_lock.acquire()
|
||||
self.axes.reset_bounding_box()
|
||||
for f in self._functions:
|
||||
self.axes.adjust_bounds(self._functions[f].bounds)
|
||||
self._render_lock.release()
|
||||
|
||||
def wait_for_calculations(self):
|
||||
sleep(0)
|
||||
self._render_lock.acquire()
|
||||
for f in self._functions:
|
||||
a = self._functions[f]._get_calculating_verts
|
||||
b = self._functions[f]._get_calculating_cverts
|
||||
while a() or b():
|
||||
sleep(0)
|
||||
self._render_lock.release()
|
||||
|
||||
class ScreenShot:
|
||||
def __init__(self, plot):
|
||||
self._plot = plot
|
||||
self.screenshot_requested = False
|
||||
self.outfile = None
|
||||
self.format = ''
|
||||
self.invisibleMode = False
|
||||
self.flag = 0
|
||||
|
||||
def __bool__(self):
|
||||
return self.screenshot_requested
|
||||
|
||||
def _execute_saving(self):
|
||||
if self.flag < 3:
|
||||
self.flag += 1
|
||||
return
|
||||
|
||||
size_x, size_y = self._plot._window.get_size()
|
||||
size = size_x*size_y*4*ctypes.sizeof(ctypes.c_ubyte)
|
||||
image = ctypes.create_string_buffer(size)
|
||||
pgl.glReadPixels(0, 0, size_x, size_y, pgl.GL_RGBA, pgl.GL_UNSIGNED_BYTE, image)
|
||||
from PIL import Image
|
||||
im = Image.frombuffer('RGBA', (size_x, size_y),
|
||||
image.raw, 'raw', 'RGBA', 0, 1)
|
||||
im.transpose(Image.FLIP_TOP_BOTTOM).save(self.outfile, self.format)
|
||||
|
||||
self.flag = 0
|
||||
self.screenshot_requested = False
|
||||
if self.invisibleMode:
|
||||
self._plot._window.close()
|
||||
|
||||
def save(self, outfile=None, format='', size=(600, 500)):
|
||||
self.outfile = outfile
|
||||
self.format = format
|
||||
self.size = size
|
||||
self.screenshot_requested = True
|
||||
|
||||
if not self._plot._window or self._plot._window.has_exit:
|
||||
self._plot._win_args['visible'] = False
|
||||
|
||||
self._plot._win_args['width'] = size[0]
|
||||
self._plot._win_args['height'] = size[1]
|
||||
|
||||
self._plot.axes.reset_resources()
|
||||
self._plot._window = PlotWindow(self._plot, **self._plot._win_args)
|
||||
self.invisibleMode = True
|
||||
|
||||
if self.outfile is None:
|
||||
self.outfile = self._create_unique_path()
|
||||
print(self.outfile)
|
||||
|
||||
def _create_unique_path(self):
|
||||
cwd = getcwd()
|
||||
l = listdir(cwd)
|
||||
path = ''
|
||||
i = 0
|
||||
while True:
|
||||
if not 'plot_%s.png' % i in l:
|
||||
path = cwd + '/plot_%s.png' % i
|
||||
break
|
||||
i += 1
|
||||
return path
|
||||
@@ -0,0 +1,251 @@
|
||||
import pyglet.gl as pgl
|
||||
from pyglet import font
|
||||
|
||||
from sympy.core import S
|
||||
from sympy.plotting.pygletplot.plot_object import PlotObject
|
||||
from sympy.plotting.pygletplot.util import billboard_matrix, dot_product, \
|
||||
get_direction_vectors, strided_range, vec_mag, vec_sub
|
||||
from sympy.utilities.iterables import is_sequence
|
||||
|
||||
|
||||
class PlotAxes(PlotObject):
|
||||
|
||||
def __init__(self, *args,
|
||||
style='', none=None, frame=None, box=None, ordinate=None,
|
||||
stride=0.25,
|
||||
visible='', overlay='', colored='', label_axes='', label_ticks='',
|
||||
tick_length=0.1,
|
||||
font_face='Arial', font_size=28,
|
||||
**kwargs):
|
||||
# initialize style parameter
|
||||
style = style.lower()
|
||||
|
||||
# allow alias kwargs to override style kwarg
|
||||
if none is not None:
|
||||
style = 'none'
|
||||
if frame is not None:
|
||||
style = 'frame'
|
||||
if box is not None:
|
||||
style = 'box'
|
||||
if ordinate is not None:
|
||||
style = 'ordinate'
|
||||
|
||||
if style in ['', 'ordinate']:
|
||||
self._render_object = PlotAxesOrdinate(self)
|
||||
elif style in ['frame', 'box']:
|
||||
self._render_object = PlotAxesFrame(self)
|
||||
elif style in ['none']:
|
||||
self._render_object = None
|
||||
else:
|
||||
raise ValueError(("Unrecognized axes style %s.") % (style))
|
||||
|
||||
# initialize stride parameter
|
||||
try:
|
||||
stride = eval(stride)
|
||||
except TypeError:
|
||||
pass
|
||||
if is_sequence(stride):
|
||||
if len(stride) != 3:
|
||||
raise ValueError("length should be equal to 3")
|
||||
self._stride = stride
|
||||
else:
|
||||
self._stride = [stride, stride, stride]
|
||||
self._tick_length = float(tick_length)
|
||||
|
||||
# setup bounding box and ticks
|
||||
self._origin = [0, 0, 0]
|
||||
self.reset_bounding_box()
|
||||
|
||||
def flexible_boolean(input, default):
|
||||
if input in [True, False]:
|
||||
return input
|
||||
if input in ('f', 'F', 'false', 'False'):
|
||||
return False
|
||||
if input in ('t', 'T', 'true', 'True'):
|
||||
return True
|
||||
return default
|
||||
|
||||
# initialize remaining parameters
|
||||
self.visible = flexible_boolean(kwargs, True)
|
||||
self._overlay = flexible_boolean(overlay, True)
|
||||
self._colored = flexible_boolean(colored, False)
|
||||
self._label_axes = flexible_boolean(label_axes, False)
|
||||
self._label_ticks = flexible_boolean(label_ticks, True)
|
||||
|
||||
# setup label font
|
||||
self.font_face = font_face
|
||||
self.font_size = font_size
|
||||
|
||||
# this is also used to reinit the
|
||||
# font on window close/reopen
|
||||
self.reset_resources()
|
||||
|
||||
def reset_resources(self):
|
||||
self.label_font = None
|
||||
|
||||
def reset_bounding_box(self):
|
||||
self._bounding_box = [[None, None], [None, None], [None, None]]
|
||||
self._axis_ticks = [[], [], []]
|
||||
|
||||
def draw(self):
|
||||
if self._render_object:
|
||||
pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT | pgl.GL_DEPTH_BUFFER_BIT)
|
||||
if self._overlay:
|
||||
pgl.glDisable(pgl.GL_DEPTH_TEST)
|
||||
self._render_object.draw()
|
||||
pgl.glPopAttrib()
|
||||
|
||||
def adjust_bounds(self, child_bounds):
|
||||
b = self._bounding_box
|
||||
c = child_bounds
|
||||
for i in range(3):
|
||||
if abs(c[i][0]) is S.Infinity or abs(c[i][1]) is S.Infinity:
|
||||
continue
|
||||
b[i][0] = c[i][0] if b[i][0] is None else min([b[i][0], c[i][0]])
|
||||
b[i][1] = c[i][1] if b[i][1] is None else max([b[i][1], c[i][1]])
|
||||
self._bounding_box = b
|
||||
self._recalculate_axis_ticks(i)
|
||||
|
||||
def _recalculate_axis_ticks(self, axis):
|
||||
b = self._bounding_box
|
||||
if b[axis][0] is None or b[axis][1] is None:
|
||||
self._axis_ticks[axis] = []
|
||||
else:
|
||||
self._axis_ticks[axis] = strided_range(b[axis][0], b[axis][1],
|
||||
self._stride[axis])
|
||||
|
||||
def toggle_visible(self):
|
||||
self.visible = not self.visible
|
||||
|
||||
def toggle_colors(self):
|
||||
self._colored = not self._colored
|
||||
|
||||
|
||||
class PlotAxesBase(PlotObject):
|
||||
|
||||
def __init__(self, parent_axes):
|
||||
self._p = parent_axes
|
||||
|
||||
def draw(self):
|
||||
color = [([0.2, 0.1, 0.3], [0.2, 0.1, 0.3], [0.2, 0.1, 0.3]),
|
||||
([0.9, 0.3, 0.5], [0.5, 1.0, 0.5], [0.3, 0.3, 0.9])][self._p._colored]
|
||||
self.draw_background(color)
|
||||
self.draw_axis(2, color[2])
|
||||
self.draw_axis(1, color[1])
|
||||
self.draw_axis(0, color[0])
|
||||
|
||||
def draw_background(self, color):
|
||||
pass # optional
|
||||
|
||||
def draw_axis(self, axis, color):
|
||||
raise NotImplementedError()
|
||||
|
||||
def draw_text(self, text, position, color, scale=1.0):
|
||||
if len(color) == 3:
|
||||
color = (color[0], color[1], color[2], 1.0)
|
||||
|
||||
if self._p.label_font is None:
|
||||
self._p.label_font = font.load(self._p.font_face,
|
||||
self._p.font_size,
|
||||
bold=True, italic=False)
|
||||
|
||||
label = font.Text(self._p.label_font, text,
|
||||
color=color,
|
||||
valign=font.Text.BASELINE,
|
||||
halign=font.Text.CENTER)
|
||||
|
||||
pgl.glPushMatrix()
|
||||
pgl.glTranslatef(*position)
|
||||
billboard_matrix()
|
||||
scale_factor = 0.005 * scale
|
||||
pgl.glScalef(scale_factor, scale_factor, scale_factor)
|
||||
pgl.glColor4f(0, 0, 0, 0)
|
||||
label.draw()
|
||||
pgl.glPopMatrix()
|
||||
|
||||
def draw_line(self, v, color):
|
||||
o = self._p._origin
|
||||
pgl.glBegin(pgl.GL_LINES)
|
||||
pgl.glColor3f(*color)
|
||||
pgl.glVertex3f(v[0][0] + o[0], v[0][1] + o[1], v[0][2] + o[2])
|
||||
pgl.glVertex3f(v[1][0] + o[0], v[1][1] + o[1], v[1][2] + o[2])
|
||||
pgl.glEnd()
|
||||
|
||||
|
||||
class PlotAxesOrdinate(PlotAxesBase):
|
||||
|
||||
def __init__(self, parent_axes):
|
||||
super().__init__(parent_axes)
|
||||
|
||||
def draw_axis(self, axis, color):
|
||||
ticks = self._p._axis_ticks[axis]
|
||||
radius = self._p._tick_length / 2.0
|
||||
if len(ticks) < 2:
|
||||
return
|
||||
|
||||
# calculate the vector for this axis
|
||||
axis_lines = [[0, 0, 0], [0, 0, 0]]
|
||||
axis_lines[0][axis], axis_lines[1][axis] = ticks[0], ticks[-1]
|
||||
axis_vector = vec_sub(axis_lines[1], axis_lines[0])
|
||||
|
||||
# calculate angle to the z direction vector
|
||||
pos_z = get_direction_vectors()[2]
|
||||
d = abs(dot_product(axis_vector, pos_z))
|
||||
d = d / vec_mag(axis_vector)
|
||||
|
||||
# don't draw labels if we're looking down the axis
|
||||
labels_visible = abs(d - 1.0) > 0.02
|
||||
|
||||
# draw the ticks and labels
|
||||
for tick in ticks:
|
||||
self.draw_tick_line(axis, color, radius, tick, labels_visible)
|
||||
|
||||
# draw the axis line and labels
|
||||
self.draw_axis_line(axis, color, ticks[0], ticks[-1], labels_visible)
|
||||
|
||||
def draw_axis_line(self, axis, color, a_min, a_max, labels_visible):
|
||||
axis_line = [[0, 0, 0], [0, 0, 0]]
|
||||
axis_line[0][axis], axis_line[1][axis] = a_min, a_max
|
||||
self.draw_line(axis_line, color)
|
||||
if labels_visible:
|
||||
self.draw_axis_line_labels(axis, color, axis_line)
|
||||
|
||||
def draw_axis_line_labels(self, axis, color, axis_line):
|
||||
if not self._p._label_axes:
|
||||
return
|
||||
axis_labels = [axis_line[0][::], axis_line[1][::]]
|
||||
axis_labels[0][axis] -= 0.3
|
||||
axis_labels[1][axis] += 0.3
|
||||
a_str = ['X', 'Y', 'Z'][axis]
|
||||
self.draw_text("-" + a_str, axis_labels[0], color)
|
||||
self.draw_text("+" + a_str, axis_labels[1], color)
|
||||
|
||||
def draw_tick_line(self, axis, color, radius, tick, labels_visible):
|
||||
tick_axis = {0: 1, 1: 0, 2: 1}[axis]
|
||||
tick_line = [[0, 0, 0], [0, 0, 0]]
|
||||
tick_line[0][axis] = tick_line[1][axis] = tick
|
||||
tick_line[0][tick_axis], tick_line[1][tick_axis] = -radius, radius
|
||||
self.draw_line(tick_line, color)
|
||||
if labels_visible:
|
||||
self.draw_tick_line_label(axis, color, radius, tick)
|
||||
|
||||
def draw_tick_line_label(self, axis, color, radius, tick):
|
||||
if not self._p._label_axes:
|
||||
return
|
||||
tick_label_vector = [0, 0, 0]
|
||||
tick_label_vector[axis] = tick
|
||||
tick_label_vector[{0: 1, 1: 0, 2: 1}[axis]] = [-1, 1, 1][
|
||||
axis] * radius * 3.5
|
||||
self.draw_text(str(tick), tick_label_vector, color, scale=0.5)
|
||||
|
||||
|
||||
class PlotAxesFrame(PlotAxesBase):
|
||||
|
||||
def __init__(self, parent_axes):
|
||||
super().__init__(parent_axes)
|
||||
|
||||
def draw_background(self, color):
|
||||
pass
|
||||
|
||||
def draw_axis(self, axis, color):
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,124 @@
|
||||
import pyglet.gl as pgl
|
||||
from sympy.plotting.pygletplot.plot_rotation import get_spherical_rotatation
|
||||
from sympy.plotting.pygletplot.util import get_model_matrix, model_to_screen, \
|
||||
screen_to_model, vec_subs
|
||||
|
||||
|
||||
class PlotCamera:
|
||||
|
||||
min_dist = 0.05
|
||||
max_dist = 500.0
|
||||
|
||||
min_ortho_dist = 100.0
|
||||
max_ortho_dist = 10000.0
|
||||
|
||||
_default_dist = 6.0
|
||||
_default_ortho_dist = 600.0
|
||||
|
||||
rot_presets = {
|
||||
'xy': (0, 0, 0),
|
||||
'xz': (-90, 0, 0),
|
||||
'yz': (0, 90, 0),
|
||||
'perspective': (-45, 0, -45)
|
||||
}
|
||||
|
||||
def __init__(self, window, ortho=False):
|
||||
self.window = window
|
||||
self.axes = self.window.plot.axes
|
||||
self.ortho = ortho
|
||||
self.reset()
|
||||
|
||||
def init_rot_matrix(self):
|
||||
pgl.glPushMatrix()
|
||||
pgl.glLoadIdentity()
|
||||
self._rot = get_model_matrix()
|
||||
pgl.glPopMatrix()
|
||||
|
||||
def set_rot_preset(self, preset_name):
|
||||
self.init_rot_matrix()
|
||||
if preset_name not in self.rot_presets:
|
||||
raise ValueError(
|
||||
"%s is not a valid rotation preset." % preset_name)
|
||||
r = self.rot_presets[preset_name]
|
||||
self.euler_rotate(r[0], 1, 0, 0)
|
||||
self.euler_rotate(r[1], 0, 1, 0)
|
||||
self.euler_rotate(r[2], 0, 0, 1)
|
||||
|
||||
def reset(self):
|
||||
self._dist = 0.0
|
||||
self._x, self._y = 0.0, 0.0
|
||||
self._rot = None
|
||||
if self.ortho:
|
||||
self._dist = self._default_ortho_dist
|
||||
else:
|
||||
self._dist = self._default_dist
|
||||
self.init_rot_matrix()
|
||||
|
||||
def mult_rot_matrix(self, rot):
|
||||
pgl.glPushMatrix()
|
||||
pgl.glLoadMatrixf(rot)
|
||||
pgl.glMultMatrixf(self._rot)
|
||||
self._rot = get_model_matrix()
|
||||
pgl.glPopMatrix()
|
||||
|
||||
def setup_projection(self):
|
||||
pgl.glMatrixMode(pgl.GL_PROJECTION)
|
||||
pgl.glLoadIdentity()
|
||||
if self.ortho:
|
||||
# yep, this is pseudo ortho (don't tell anyone)
|
||||
pgl.gluPerspective(
|
||||
0.3, float(self.window.width)/float(self.window.height),
|
||||
self.min_ortho_dist - 0.01, self.max_ortho_dist + 0.01)
|
||||
else:
|
||||
pgl.gluPerspective(
|
||||
30.0, float(self.window.width)/float(self.window.height),
|
||||
self.min_dist - 0.01, self.max_dist + 0.01)
|
||||
pgl.glMatrixMode(pgl.GL_MODELVIEW)
|
||||
|
||||
def _get_scale(self):
|
||||
return 1.0, 1.0, 1.0
|
||||
|
||||
def apply_transformation(self):
|
||||
pgl.glLoadIdentity()
|
||||
pgl.glTranslatef(self._x, self._y, -self._dist)
|
||||
if self._rot is not None:
|
||||
pgl.glMultMatrixf(self._rot)
|
||||
pgl.glScalef(*self._get_scale())
|
||||
|
||||
def spherical_rotate(self, p1, p2, sensitivity=1.0):
|
||||
mat = get_spherical_rotatation(p1, p2, self.window.width,
|
||||
self.window.height, sensitivity)
|
||||
if mat is not None:
|
||||
self.mult_rot_matrix(mat)
|
||||
|
||||
def euler_rotate(self, angle, x, y, z):
|
||||
pgl.glPushMatrix()
|
||||
pgl.glLoadMatrixf(self._rot)
|
||||
pgl.glRotatef(angle, x, y, z)
|
||||
self._rot = get_model_matrix()
|
||||
pgl.glPopMatrix()
|
||||
|
||||
def zoom_relative(self, clicks, sensitivity):
|
||||
|
||||
if self.ortho:
|
||||
dist_d = clicks * sensitivity * 50.0
|
||||
min_dist = self.min_ortho_dist
|
||||
max_dist = self.max_ortho_dist
|
||||
else:
|
||||
dist_d = clicks * sensitivity
|
||||
min_dist = self.min_dist
|
||||
max_dist = self.max_dist
|
||||
|
||||
new_dist = (self._dist - dist_d)
|
||||
if (clicks < 0 and new_dist < max_dist) or new_dist > min_dist:
|
||||
self._dist = new_dist
|
||||
|
||||
def mouse_translate(self, x, y, dx, dy):
|
||||
pgl.glPushMatrix()
|
||||
pgl.glLoadIdentity()
|
||||
pgl.glTranslatef(0, 0, -self._dist)
|
||||
z = model_to_screen(0, 0, 0)[2]
|
||||
d = vec_subs(screen_to_model(x, y, z), screen_to_model(x - dx, y - dy, z))
|
||||
pgl.glPopMatrix()
|
||||
self._x += d[0]
|
||||
self._y += d[1]
|
||||
@@ -0,0 +1,218 @@
|
||||
from pyglet.window import key
|
||||
from pyglet.window.mouse import LEFT, RIGHT, MIDDLE
|
||||
from sympy.plotting.pygletplot.util import get_direction_vectors, get_basis_vectors
|
||||
|
||||
|
||||
class PlotController:
|
||||
|
||||
normal_mouse_sensitivity = 4.0
|
||||
modified_mouse_sensitivity = 1.0
|
||||
|
||||
normal_key_sensitivity = 160.0
|
||||
modified_key_sensitivity = 40.0
|
||||
|
||||
keymap = {
|
||||
key.LEFT: 'left',
|
||||
key.A: 'left',
|
||||
key.NUM_4: 'left',
|
||||
|
||||
key.RIGHT: 'right',
|
||||
key.D: 'right',
|
||||
key.NUM_6: 'right',
|
||||
|
||||
key.UP: 'up',
|
||||
key.W: 'up',
|
||||
key.NUM_8: 'up',
|
||||
|
||||
key.DOWN: 'down',
|
||||
key.S: 'down',
|
||||
key.NUM_2: 'down',
|
||||
|
||||
key.Z: 'rotate_z_neg',
|
||||
key.NUM_1: 'rotate_z_neg',
|
||||
|
||||
key.C: 'rotate_z_pos',
|
||||
key.NUM_3: 'rotate_z_pos',
|
||||
|
||||
key.Q: 'spin_left',
|
||||
key.NUM_7: 'spin_left',
|
||||
key.E: 'spin_right',
|
||||
key.NUM_9: 'spin_right',
|
||||
|
||||
key.X: 'reset_camera',
|
||||
key.NUM_5: 'reset_camera',
|
||||
|
||||
key.NUM_ADD: 'zoom_in',
|
||||
key.PAGEUP: 'zoom_in',
|
||||
key.R: 'zoom_in',
|
||||
|
||||
key.NUM_SUBTRACT: 'zoom_out',
|
||||
key.PAGEDOWN: 'zoom_out',
|
||||
key.F: 'zoom_out',
|
||||
|
||||
key.RSHIFT: 'modify_sensitivity',
|
||||
key.LSHIFT: 'modify_sensitivity',
|
||||
|
||||
key.F1: 'rot_preset_xy',
|
||||
key.F2: 'rot_preset_xz',
|
||||
key.F3: 'rot_preset_yz',
|
||||
key.F4: 'rot_preset_perspective',
|
||||
|
||||
key.F5: 'toggle_axes',
|
||||
key.F6: 'toggle_axe_colors',
|
||||
|
||||
key.F8: 'save_image'
|
||||
}
|
||||
|
||||
def __init__(self, window, *, invert_mouse_zoom=False, **kwargs):
|
||||
self.invert_mouse_zoom = invert_mouse_zoom
|
||||
self.window = window
|
||||
self.camera = window.camera
|
||||
self.action = {
|
||||
# Rotation around the view Y (up) vector
|
||||
'left': False,
|
||||
'right': False,
|
||||
# Rotation around the view X vector
|
||||
'up': False,
|
||||
'down': False,
|
||||
# Rotation around the view Z vector
|
||||
'spin_left': False,
|
||||
'spin_right': False,
|
||||
# Rotation around the model Z vector
|
||||
'rotate_z_neg': False,
|
||||
'rotate_z_pos': False,
|
||||
# Reset to the default rotation
|
||||
'reset_camera': False,
|
||||
# Performs camera z-translation
|
||||
'zoom_in': False,
|
||||
'zoom_out': False,
|
||||
# Use alternative sensitivity (speed)
|
||||
'modify_sensitivity': False,
|
||||
# Rotation presets
|
||||
'rot_preset_xy': False,
|
||||
'rot_preset_xz': False,
|
||||
'rot_preset_yz': False,
|
||||
'rot_preset_perspective': False,
|
||||
# axes
|
||||
'toggle_axes': False,
|
||||
'toggle_axe_colors': False,
|
||||
# screenshot
|
||||
'save_image': False
|
||||
}
|
||||
|
||||
def update(self, dt):
|
||||
z = 0
|
||||
if self.action['zoom_out']:
|
||||
z -= 1
|
||||
if self.action['zoom_in']:
|
||||
z += 1
|
||||
if z != 0:
|
||||
self.camera.zoom_relative(z/10.0, self.get_key_sensitivity()/10.0)
|
||||
|
||||
dx, dy, dz = 0, 0, 0
|
||||
if self.action['left']:
|
||||
dx -= 1
|
||||
if self.action['right']:
|
||||
dx += 1
|
||||
if self.action['up']:
|
||||
dy -= 1
|
||||
if self.action['down']:
|
||||
dy += 1
|
||||
if self.action['spin_left']:
|
||||
dz += 1
|
||||
if self.action['spin_right']:
|
||||
dz -= 1
|
||||
|
||||
if not self.is_2D():
|
||||
if dx != 0:
|
||||
self.camera.euler_rotate(dx*dt*self.get_key_sensitivity(),
|
||||
*(get_direction_vectors()[1]))
|
||||
if dy != 0:
|
||||
self.camera.euler_rotate(dy*dt*self.get_key_sensitivity(),
|
||||
*(get_direction_vectors()[0]))
|
||||
if dz != 0:
|
||||
self.camera.euler_rotate(dz*dt*self.get_key_sensitivity(),
|
||||
*(get_direction_vectors()[2]))
|
||||
else:
|
||||
self.camera.mouse_translate(0, 0, dx*dt*self.get_key_sensitivity(),
|
||||
-dy*dt*self.get_key_sensitivity())
|
||||
|
||||
rz = 0
|
||||
if self.action['rotate_z_neg'] and not self.is_2D():
|
||||
rz -= 1
|
||||
if self.action['rotate_z_pos'] and not self.is_2D():
|
||||
rz += 1
|
||||
|
||||
if rz != 0:
|
||||
self.camera.euler_rotate(rz*dt*self.get_key_sensitivity(),
|
||||
*(get_basis_vectors()[2]))
|
||||
|
||||
if self.action['reset_camera']:
|
||||
self.camera.reset()
|
||||
|
||||
if self.action['rot_preset_xy']:
|
||||
self.camera.set_rot_preset('xy')
|
||||
if self.action['rot_preset_xz']:
|
||||
self.camera.set_rot_preset('xz')
|
||||
if self.action['rot_preset_yz']:
|
||||
self.camera.set_rot_preset('yz')
|
||||
if self.action['rot_preset_perspective']:
|
||||
self.camera.set_rot_preset('perspective')
|
||||
|
||||
if self.action['toggle_axes']:
|
||||
self.action['toggle_axes'] = False
|
||||
self.camera.axes.toggle_visible()
|
||||
|
||||
if self.action['toggle_axe_colors']:
|
||||
self.action['toggle_axe_colors'] = False
|
||||
self.camera.axes.toggle_colors()
|
||||
|
||||
if self.action['save_image']:
|
||||
self.action['save_image'] = False
|
||||
self.window.plot.saveimage()
|
||||
|
||||
return True
|
||||
|
||||
def get_mouse_sensitivity(self):
|
||||
if self.action['modify_sensitivity']:
|
||||
return self.modified_mouse_sensitivity
|
||||
else:
|
||||
return self.normal_mouse_sensitivity
|
||||
|
||||
def get_key_sensitivity(self):
|
||||
if self.action['modify_sensitivity']:
|
||||
return self.modified_key_sensitivity
|
||||
else:
|
||||
return self.normal_key_sensitivity
|
||||
|
||||
def on_key_press(self, symbol, modifiers):
|
||||
if symbol in self.keymap:
|
||||
self.action[self.keymap[symbol]] = True
|
||||
|
||||
def on_key_release(self, symbol, modifiers):
|
||||
if symbol in self.keymap:
|
||||
self.action[self.keymap[symbol]] = False
|
||||
|
||||
def on_mouse_drag(self, x, y, dx, dy, buttons, modifiers):
|
||||
if buttons & LEFT:
|
||||
if self.is_2D():
|
||||
self.camera.mouse_translate(x, y, dx, dy)
|
||||
else:
|
||||
self.camera.spherical_rotate((x - dx, y - dy), (x, y),
|
||||
self.get_mouse_sensitivity())
|
||||
if buttons & MIDDLE:
|
||||
self.camera.zoom_relative([1, -1][self.invert_mouse_zoom]*dy,
|
||||
self.get_mouse_sensitivity()/20.0)
|
||||
if buttons & RIGHT:
|
||||
self.camera.mouse_translate(x, y, dx, dy)
|
||||
|
||||
def on_mouse_scroll(self, x, y, dx, dy):
|
||||
self.camera.zoom_relative([1, -1][self.invert_mouse_zoom]*dy,
|
||||
self.get_mouse_sensitivity())
|
||||
|
||||
def is_2D(self):
|
||||
functions = self.window.plot._functions
|
||||
for i in functions:
|
||||
if len(functions[i].i_vars) > 1 or len(functions[i].d_vars) > 2:
|
||||
return False
|
||||
return True
|
||||
@@ -0,0 +1,82 @@
|
||||
import pyglet.gl as pgl
|
||||
from sympy.core import S
|
||||
from sympy.plotting.pygletplot.plot_mode_base import PlotModeBase
|
||||
|
||||
|
||||
class PlotCurve(PlotModeBase):
|
||||
|
||||
style_override = 'wireframe'
|
||||
|
||||
def _on_calculate_verts(self):
|
||||
self.t_interval = self.intervals[0]
|
||||
self.t_set = list(self.t_interval.frange())
|
||||
self.bounds = [[S.Infinity, S.NegativeInfinity, 0],
|
||||
[S.Infinity, S.NegativeInfinity, 0],
|
||||
[S.Infinity, S.NegativeInfinity, 0]]
|
||||
evaluate = self._get_evaluator()
|
||||
|
||||
self._calculating_verts_pos = 0.0
|
||||
self._calculating_verts_len = float(self.t_interval.v_len)
|
||||
|
||||
self.verts = []
|
||||
b = self.bounds
|
||||
for t in self.t_set:
|
||||
try:
|
||||
_e = evaluate(t) # calculate vertex
|
||||
except (NameError, ZeroDivisionError):
|
||||
_e = None
|
||||
if _e is not None: # update bounding box
|
||||
for axis in range(3):
|
||||
b[axis][0] = min([b[axis][0], _e[axis]])
|
||||
b[axis][1] = max([b[axis][1], _e[axis]])
|
||||
self.verts.append(_e)
|
||||
self._calculating_verts_pos += 1.0
|
||||
|
||||
for axis in range(3):
|
||||
b[axis][2] = b[axis][1] - b[axis][0]
|
||||
if b[axis][2] == 0.0:
|
||||
b[axis][2] = 1.0
|
||||
|
||||
self.push_wireframe(self.draw_verts(False))
|
||||
|
||||
def _on_calculate_cverts(self):
|
||||
if not self.verts or not self.color:
|
||||
return
|
||||
|
||||
def set_work_len(n):
|
||||
self._calculating_cverts_len = float(n)
|
||||
|
||||
def inc_work_pos():
|
||||
self._calculating_cverts_pos += 1.0
|
||||
set_work_len(1)
|
||||
self._calculating_cverts_pos = 0
|
||||
self.cverts = self.color.apply_to_curve(self.verts,
|
||||
self.t_set,
|
||||
set_len=set_work_len,
|
||||
inc_pos=inc_work_pos)
|
||||
self.push_wireframe(self.draw_verts(True))
|
||||
|
||||
def calculate_one_cvert(self, t):
|
||||
vert = self.verts[t]
|
||||
return self.color(vert[0], vert[1], vert[2],
|
||||
self.t_set[t], None)
|
||||
|
||||
def draw_verts(self, use_cverts):
|
||||
def f():
|
||||
pgl.glBegin(pgl.GL_LINE_STRIP)
|
||||
for t in range(len(self.t_set)):
|
||||
p = self.verts[t]
|
||||
if p is None:
|
||||
pgl.glEnd()
|
||||
pgl.glBegin(pgl.GL_LINE_STRIP)
|
||||
continue
|
||||
if use_cverts:
|
||||
c = self.cverts[t]
|
||||
if c is None:
|
||||
c = (0, 0, 0)
|
||||
pgl.glColor3f(*c)
|
||||
else:
|
||||
pgl.glColor3f(*self.default_wireframe_color)
|
||||
pgl.glVertex3f(*p)
|
||||
pgl.glEnd()
|
||||
return f
|
||||
@@ -0,0 +1,181 @@
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.core.numbers import Integer
|
||||
|
||||
|
||||
class PlotInterval:
|
||||
"""
|
||||
"""
|
||||
_v, _v_min, _v_max, _v_steps = None, None, None, None
|
||||
|
||||
def require_all_args(f):
|
||||
def check(self, *args, **kwargs):
|
||||
for g in [self._v, self._v_min, self._v_max, self._v_steps]:
|
||||
if g is None:
|
||||
raise ValueError("PlotInterval is incomplete.")
|
||||
return f(self, *args, **kwargs)
|
||||
return check
|
||||
|
||||
def __init__(self, *args):
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], PlotInterval):
|
||||
self.fill_from(args[0])
|
||||
return
|
||||
elif isinstance(args[0], str):
|
||||
try:
|
||||
args = eval(args[0])
|
||||
except TypeError:
|
||||
s_eval_error = "Could not interpret string %s."
|
||||
raise ValueError(s_eval_error % (args[0]))
|
||||
elif isinstance(args[0], (tuple, list)):
|
||||
args = args[0]
|
||||
else:
|
||||
raise ValueError("Not an interval.")
|
||||
if not isinstance(args, (tuple, list)) or len(args) > 4:
|
||||
f_error = "PlotInterval must be a tuple or list of length 4 or less."
|
||||
raise ValueError(f_error)
|
||||
|
||||
args = list(args)
|
||||
if len(args) > 0 and (args[0] is None or isinstance(args[0], Symbol)):
|
||||
self.v = args.pop(0)
|
||||
if len(args) in [2, 3]:
|
||||
self.v_min = args.pop(0)
|
||||
self.v_max = args.pop(0)
|
||||
if len(args) == 1:
|
||||
self.v_steps = args.pop(0)
|
||||
elif len(args) == 1:
|
||||
self.v_steps = args.pop(0)
|
||||
|
||||
def get_v(self):
|
||||
return self._v
|
||||
|
||||
def set_v(self, v):
|
||||
if v is None:
|
||||
self._v = None
|
||||
return
|
||||
if not isinstance(v, Symbol):
|
||||
raise ValueError("v must be a SymPy Symbol.")
|
||||
self._v = v
|
||||
|
||||
def get_v_min(self):
|
||||
return self._v_min
|
||||
|
||||
def set_v_min(self, v_min):
|
||||
if v_min is None:
|
||||
self._v_min = None
|
||||
return
|
||||
try:
|
||||
self._v_min = sympify(v_min)
|
||||
float(self._v_min.evalf())
|
||||
except TypeError:
|
||||
raise ValueError("v_min could not be interpreted as a number.")
|
||||
|
||||
def get_v_max(self):
|
||||
return self._v_max
|
||||
|
||||
def set_v_max(self, v_max):
|
||||
if v_max is None:
|
||||
self._v_max = None
|
||||
return
|
||||
try:
|
||||
self._v_max = sympify(v_max)
|
||||
float(self._v_max.evalf())
|
||||
except TypeError:
|
||||
raise ValueError("v_max could not be interpreted as a number.")
|
||||
|
||||
def get_v_steps(self):
|
||||
return self._v_steps
|
||||
|
||||
def set_v_steps(self, v_steps):
|
||||
if v_steps is None:
|
||||
self._v_steps = None
|
||||
return
|
||||
if isinstance(v_steps, int):
|
||||
v_steps = Integer(v_steps)
|
||||
elif not isinstance(v_steps, Integer):
|
||||
raise ValueError("v_steps must be an int or SymPy Integer.")
|
||||
if v_steps <= S.Zero:
|
||||
raise ValueError("v_steps must be positive.")
|
||||
self._v_steps = v_steps
|
||||
|
||||
@require_all_args
|
||||
def get_v_len(self):
|
||||
return self.v_steps + 1
|
||||
|
||||
v = property(get_v, set_v)
|
||||
v_min = property(get_v_min, set_v_min)
|
||||
v_max = property(get_v_max, set_v_max)
|
||||
v_steps = property(get_v_steps, set_v_steps)
|
||||
v_len = property(get_v_len)
|
||||
|
||||
def fill_from(self, b):
|
||||
if b.v is not None:
|
||||
self.v = b.v
|
||||
if b.v_min is not None:
|
||||
self.v_min = b.v_min
|
||||
if b.v_max is not None:
|
||||
self.v_max = b.v_max
|
||||
if b.v_steps is not None:
|
||||
self.v_steps = b.v_steps
|
||||
|
||||
@staticmethod
|
||||
def try_parse(*args):
|
||||
"""
|
||||
Returns a PlotInterval if args can be interpreted
|
||||
as such, otherwise None.
|
||||
"""
|
||||
if len(args) == 1 and isinstance(args[0], PlotInterval):
|
||||
return args[0]
|
||||
try:
|
||||
return PlotInterval(*args)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _str_base(self):
|
||||
return ",".join([str(self.v), str(self.v_min),
|
||||
str(self.v_max), str(self.v_steps)])
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
A string representing the interval in class constructor form.
|
||||
"""
|
||||
return "PlotInterval(%s)" % (self._str_base())
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
A string representing the interval in list form.
|
||||
"""
|
||||
return "[%s]" % (self._str_base())
|
||||
|
||||
@require_all_args
|
||||
def assert_complete(self):
|
||||
pass
|
||||
|
||||
@require_all_args
|
||||
def vrange(self):
|
||||
"""
|
||||
Yields v_steps+1 SymPy numbers ranging from
|
||||
v_min to v_max.
|
||||
"""
|
||||
d = (self.v_max - self.v_min) / self.v_steps
|
||||
for i in range(self.v_steps + 1):
|
||||
a = self.v_min + (d * Integer(i))
|
||||
yield a
|
||||
|
||||
@require_all_args
|
||||
def vrange2(self):
|
||||
"""
|
||||
Yields v_steps pairs of SymPy numbers ranging from
|
||||
(v_min, v_min + step) to (v_max - step, v_max).
|
||||
"""
|
||||
d = (self.v_max - self.v_min) / self.v_steps
|
||||
a = self.v_min + (d * S.Zero)
|
||||
for i in range(self.v_steps):
|
||||
b = self.v_min + (d * Integer(i + 1))
|
||||
yield a, b
|
||||
a = b
|
||||
|
||||
def frange(self):
|
||||
for i in self.vrange():
|
||||
yield float(i.evalf())
|
||||
@@ -0,0 +1,400 @@
|
||||
from .plot_interval import PlotInterval
|
||||
from .plot_object import PlotObject
|
||||
from .util import parse_option_string
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.geometry.entity import GeometryEntity
|
||||
from sympy.utilities.iterables import is_sequence
|
||||
|
||||
|
||||
class PlotMode(PlotObject):
|
||||
"""
|
||||
Grandparent class for plotting
|
||||
modes. Serves as interface for
|
||||
registration, lookup, and init
|
||||
of modes.
|
||||
|
||||
To create a new plot mode,
|
||||
inherit from PlotModeBase
|
||||
or one of its children, such
|
||||
as PlotSurface or PlotCurve.
|
||||
"""
|
||||
|
||||
## Class-level attributes
|
||||
## used to register and lookup
|
||||
## plot modes. See PlotModeBase
|
||||
## for descriptions and usage.
|
||||
|
||||
i_vars, d_vars = '', ''
|
||||
intervals = []
|
||||
aliases = []
|
||||
is_default = False
|
||||
|
||||
## Draw is the only method here which
|
||||
## is meant to be overridden in child
|
||||
## classes, and PlotModeBase provides
|
||||
## a base implementation.
|
||||
def draw(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
## Everything else in this file has to
|
||||
## do with registration and retrieval
|
||||
## of plot modes. This is where I've
|
||||
## hidden much of the ugliness of automatic
|
||||
## plot mode divination...
|
||||
|
||||
## Plot mode registry data structures
|
||||
_mode_alias_list = []
|
||||
_mode_map = {
|
||||
1: {1: {}, 2: {}},
|
||||
2: {1: {}, 2: {}},
|
||||
3: {1: {}, 2: {}},
|
||||
} # [d][i][alias_str]: class
|
||||
_mode_default_map = {
|
||||
1: {},
|
||||
2: {},
|
||||
3: {},
|
||||
} # [d][i]: class
|
||||
_i_var_max, _d_var_max = 2, 3
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
This is the function which interprets
|
||||
arguments given to Plot.__init__ and
|
||||
Plot.__setattr__. Returns an initialized
|
||||
instance of the appropriate child class.
|
||||
"""
|
||||
|
||||
newargs, newkwargs = PlotMode._extract_options(args, kwargs)
|
||||
mode_arg = newkwargs.get('mode', '')
|
||||
|
||||
# Interpret the arguments
|
||||
d_vars, intervals = PlotMode._interpret_args(newargs)
|
||||
i_vars = PlotMode._find_i_vars(d_vars, intervals)
|
||||
i, d = max([len(i_vars), len(intervals)]), len(d_vars)
|
||||
|
||||
# Find the appropriate mode
|
||||
subcls = PlotMode._get_mode(mode_arg, i, d)
|
||||
|
||||
# Create the object
|
||||
o = object.__new__(subcls)
|
||||
|
||||
# Do some setup for the mode instance
|
||||
o.d_vars = d_vars
|
||||
o._fill_i_vars(i_vars)
|
||||
o._fill_intervals(intervals)
|
||||
o.options = newkwargs
|
||||
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def _get_mode(mode_arg, i_var_count, d_var_count):
|
||||
"""
|
||||
Tries to return an appropriate mode class.
|
||||
Intended to be called only by __new__.
|
||||
|
||||
mode_arg
|
||||
Can be a string or a class. If it is a
|
||||
PlotMode subclass, it is simply returned.
|
||||
If it is a string, it can an alias for
|
||||
a mode or an empty string. In the latter
|
||||
case, we try to find a default mode for
|
||||
the i_var_count and d_var_count.
|
||||
|
||||
i_var_count
|
||||
The number of independent variables
|
||||
needed to evaluate the d_vars.
|
||||
|
||||
d_var_count
|
||||
The number of dependent variables;
|
||||
usually the number of functions to
|
||||
be evaluated in plotting.
|
||||
|
||||
For example, a Cartesian function y = f(x) has
|
||||
one i_var (x) and one d_var (y). A parametric
|
||||
form x,y,z = f(u,v), f(u,v), f(u,v) has two
|
||||
two i_vars (u,v) and three d_vars (x,y,z).
|
||||
"""
|
||||
# if the mode_arg is simply a PlotMode class,
|
||||
# check that the mode supports the numbers
|
||||
# of independent and dependent vars, then
|
||||
# return it
|
||||
try:
|
||||
m = None
|
||||
if issubclass(mode_arg, PlotMode):
|
||||
m = mode_arg
|
||||
except TypeError:
|
||||
pass
|
||||
if m:
|
||||
if not m._was_initialized:
|
||||
raise ValueError(("To use unregistered plot mode %s "
|
||||
"you must first call %s._init_mode().")
|
||||
% (m.__name__, m.__name__))
|
||||
if d_var_count != m.d_var_count:
|
||||
raise ValueError(("%s can only plot functions "
|
||||
"with %i dependent variables.")
|
||||
% (m.__name__,
|
||||
m.d_var_count))
|
||||
if i_var_count > m.i_var_count:
|
||||
raise ValueError(("%s cannot plot functions "
|
||||
"with more than %i independent "
|
||||
"variables.")
|
||||
% (m.__name__,
|
||||
m.i_var_count))
|
||||
return m
|
||||
# If it is a string, there are two possibilities.
|
||||
if isinstance(mode_arg, str):
|
||||
i, d = i_var_count, d_var_count
|
||||
if i > PlotMode._i_var_max:
|
||||
raise ValueError(var_count_error(True, True))
|
||||
if d > PlotMode._d_var_max:
|
||||
raise ValueError(var_count_error(False, True))
|
||||
# If the string is '', try to find a suitable
|
||||
# default mode
|
||||
if not mode_arg:
|
||||
return PlotMode._get_default_mode(i, d)
|
||||
# Otherwise, interpret the string as a mode
|
||||
# alias (e.g. 'cartesian', 'parametric', etc)
|
||||
else:
|
||||
return PlotMode._get_aliased_mode(mode_arg, i, d)
|
||||
else:
|
||||
raise ValueError("PlotMode argument must be "
|
||||
"a class or a string")
|
||||
|
||||
@staticmethod
|
||||
def _get_default_mode(i, d, i_vars=-1):
|
||||
if i_vars == -1:
|
||||
i_vars = i
|
||||
try:
|
||||
return PlotMode._mode_default_map[d][i]
|
||||
except KeyError:
|
||||
# Keep looking for modes in higher i var counts
|
||||
# which support the given d var count until we
|
||||
# reach the max i_var count.
|
||||
if i < PlotMode._i_var_max:
|
||||
return PlotMode._get_default_mode(i + 1, d, i_vars)
|
||||
else:
|
||||
raise ValueError(("Couldn't find a default mode "
|
||||
"for %i independent and %i "
|
||||
"dependent variables.") % (i_vars, d))
|
||||
|
||||
@staticmethod
|
||||
def _get_aliased_mode(alias, i, d, i_vars=-1):
|
||||
if i_vars == -1:
|
||||
i_vars = i
|
||||
if alias not in PlotMode._mode_alias_list:
|
||||
raise ValueError(("Couldn't find a mode called"
|
||||
" %s. Known modes: %s.")
|
||||
% (alias, ", ".join(PlotMode._mode_alias_list)))
|
||||
try:
|
||||
return PlotMode._mode_map[d][i][alias]
|
||||
except TypeError:
|
||||
# Keep looking for modes in higher i var counts
|
||||
# which support the given d var count and alias
|
||||
# until we reach the max i_var count.
|
||||
if i < PlotMode._i_var_max:
|
||||
return PlotMode._get_aliased_mode(alias, i + 1, d, i_vars)
|
||||
else:
|
||||
raise ValueError(("Couldn't find a %s mode "
|
||||
"for %i independent and %i "
|
||||
"dependent variables.")
|
||||
% (alias, i_vars, d))
|
||||
|
||||
@classmethod
|
||||
def _register(cls):
|
||||
"""
|
||||
Called once for each user-usable plot mode.
|
||||
For Cartesian2D, it is invoked after the
|
||||
class definition: Cartesian2D._register()
|
||||
"""
|
||||
name = cls.__name__
|
||||
cls._init_mode()
|
||||
|
||||
try:
|
||||
i, d = cls.i_var_count, cls.d_var_count
|
||||
# Add the mode to _mode_map under all
|
||||
# given aliases
|
||||
for a in cls.aliases:
|
||||
if a not in PlotMode._mode_alias_list:
|
||||
# Also track valid aliases, so
|
||||
# we can quickly know when given
|
||||
# an invalid one in _get_mode.
|
||||
PlotMode._mode_alias_list.append(a)
|
||||
PlotMode._mode_map[d][i][a] = cls
|
||||
if cls.is_default:
|
||||
# If this mode was marked as the
|
||||
# default for this d,i combination,
|
||||
# also set that.
|
||||
PlotMode._mode_default_map[d][i] = cls
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(("Failed to register "
|
||||
"plot mode %s. Reason: %s")
|
||||
% (name, (str(e))))
|
||||
|
||||
@classmethod
|
||||
def _init_mode(cls):
|
||||
"""
|
||||
Initializes the plot mode based on
|
||||
the 'mode-specific parameters' above.
|
||||
Only intended to be called by
|
||||
PlotMode._register(). To use a mode without
|
||||
registering it, you can directly call
|
||||
ModeSubclass._init_mode().
|
||||
"""
|
||||
def symbols_list(symbol_str):
|
||||
return [Symbol(s) for s in symbol_str]
|
||||
|
||||
# Convert the vars strs into
|
||||
# lists of symbols.
|
||||
cls.i_vars = symbols_list(cls.i_vars)
|
||||
cls.d_vars = symbols_list(cls.d_vars)
|
||||
|
||||
# Var count is used often, calculate
|
||||
# it once here
|
||||
cls.i_var_count = len(cls.i_vars)
|
||||
cls.d_var_count = len(cls.d_vars)
|
||||
|
||||
if cls.i_var_count > PlotMode._i_var_max:
|
||||
raise ValueError(var_count_error(True, False))
|
||||
if cls.d_var_count > PlotMode._d_var_max:
|
||||
raise ValueError(var_count_error(False, False))
|
||||
|
||||
# Try to use first alias as primary_alias
|
||||
if len(cls.aliases) > 0:
|
||||
cls.primary_alias = cls.aliases[0]
|
||||
else:
|
||||
cls.primary_alias = cls.__name__
|
||||
|
||||
di = cls.intervals
|
||||
if len(di) != cls.i_var_count:
|
||||
raise ValueError("Plot mode must provide a "
|
||||
"default interval for each i_var.")
|
||||
for i in range(cls.i_var_count):
|
||||
# default intervals must be given [min,max,steps]
|
||||
# (no var, but they must be in the same order as i_vars)
|
||||
if len(di[i]) != 3:
|
||||
raise ValueError("length should be equal to 3")
|
||||
|
||||
# Initialize an incomplete interval,
|
||||
# to later be filled with a var when
|
||||
# the mode is instantiated.
|
||||
di[i] = PlotInterval(None, *di[i])
|
||||
|
||||
# To prevent people from using modes
|
||||
# without these required fields set up.
|
||||
cls._was_initialized = True
|
||||
|
||||
_was_initialized = False
|
||||
|
||||
## Initializer Helper Methods
|
||||
|
||||
@staticmethod
|
||||
def _find_i_vars(functions, intervals):
|
||||
i_vars = []
|
||||
|
||||
# First, collect i_vars in the
|
||||
# order they are given in any
|
||||
# intervals.
|
||||
for i in intervals:
|
||||
if i.v is None:
|
||||
continue
|
||||
elif i.v in i_vars:
|
||||
raise ValueError(("Multiple intervals given "
|
||||
"for %s.") % (str(i.v)))
|
||||
i_vars.append(i.v)
|
||||
|
||||
# Then, find any remaining
|
||||
# i_vars in given functions
|
||||
# (aka d_vars)
|
||||
for f in functions:
|
||||
for a in f.free_symbols:
|
||||
if a not in i_vars:
|
||||
i_vars.append(a)
|
||||
|
||||
return i_vars
|
||||
|
||||
def _fill_i_vars(self, i_vars):
|
||||
# copy default i_vars
|
||||
self.i_vars = [Symbol(str(i)) for i in self.i_vars]
|
||||
# replace with given i_vars
|
||||
for i in range(len(i_vars)):
|
||||
self.i_vars[i] = i_vars[i]
|
||||
|
||||
def _fill_intervals(self, intervals):
|
||||
# copy default intervals
|
||||
self.intervals = [PlotInterval(i) for i in self.intervals]
|
||||
# track i_vars used so far
|
||||
v_used = []
|
||||
# fill copy of default
|
||||
# intervals with given info
|
||||
for i in range(len(intervals)):
|
||||
self.intervals[i].fill_from(intervals[i])
|
||||
if self.intervals[i].v is not None:
|
||||
v_used.append(self.intervals[i].v)
|
||||
# Find any orphan intervals and
|
||||
# assign them i_vars
|
||||
for i in range(len(self.intervals)):
|
||||
if self.intervals[i].v is None:
|
||||
u = [v for v in self.i_vars if v not in v_used]
|
||||
if len(u) == 0:
|
||||
raise ValueError("length should not be equal to 0")
|
||||
self.intervals[i].v = u[0]
|
||||
v_used.append(u[0])
|
||||
|
||||
@staticmethod
|
||||
def _interpret_args(args):
|
||||
interval_wrong_order = "PlotInterval %s was given before any function(s)."
|
||||
interpret_error = "Could not interpret %s as a function or interval."
|
||||
|
||||
functions, intervals = [], []
|
||||
if isinstance(args[0], GeometryEntity):
|
||||
for coords in list(args[0].arbitrary_point()):
|
||||
functions.append(coords)
|
||||
intervals.append(PlotInterval.try_parse(args[0].plot_interval()))
|
||||
else:
|
||||
for a in args:
|
||||
i = PlotInterval.try_parse(a)
|
||||
if i is not None:
|
||||
if len(functions) == 0:
|
||||
raise ValueError(interval_wrong_order % (str(i)))
|
||||
else:
|
||||
intervals.append(i)
|
||||
else:
|
||||
if is_sequence(a, include=str):
|
||||
raise ValueError(interpret_error % (str(a)))
|
||||
try:
|
||||
f = sympify(a)
|
||||
functions.append(f)
|
||||
except TypeError:
|
||||
raise ValueError(interpret_error % str(a))
|
||||
|
||||
return functions, intervals
|
||||
|
||||
@staticmethod
|
||||
def _extract_options(args, kwargs):
|
||||
newkwargs, newargs = {}, []
|
||||
for a in args:
|
||||
if isinstance(a, str):
|
||||
newkwargs = dict(newkwargs, **parse_option_string(a))
|
||||
else:
|
||||
newargs.append(a)
|
||||
newkwargs = dict(newkwargs, **kwargs)
|
||||
return newargs, newkwargs
|
||||
|
||||
|
||||
def var_count_error(is_independent, is_plotting):
|
||||
"""
|
||||
Used to format an error message which differs
|
||||
slightly in 4 places.
|
||||
"""
|
||||
if is_plotting:
|
||||
v = "Plotting"
|
||||
else:
|
||||
v = "Registering plot modes"
|
||||
if is_independent:
|
||||
n, s = PlotMode._i_var_max, "independent"
|
||||
else:
|
||||
n, s = PlotMode._d_var_max, "dependent"
|
||||
return ("%s with more than %i %s variables "
|
||||
"is not supported.") % (v, n, s)
|
||||
@@ -0,0 +1,378 @@
|
||||
import pyglet.gl as pgl
|
||||
from sympy.core import S
|
||||
from sympy.plotting.pygletplot.color_scheme import ColorScheme
|
||||
from sympy.plotting.pygletplot.plot_mode import PlotMode
|
||||
from sympy.utilities.iterables import is_sequence
|
||||
from time import sleep
|
||||
from threading import Thread, Event, RLock
|
||||
import warnings
|
||||
|
||||
|
||||
class PlotModeBase(PlotMode):
|
||||
"""
|
||||
Intended parent class for plotting
|
||||
modes. Provides base functionality
|
||||
in conjunction with its parent,
|
||||
PlotMode.
|
||||
"""
|
||||
|
||||
##
|
||||
## Class-Level Attributes
|
||||
##
|
||||
|
||||
"""
|
||||
The following attributes are meant
|
||||
to be set at the class level, and serve
|
||||
as parameters to the plot mode registry
|
||||
(in PlotMode). See plot_modes.py for
|
||||
concrete examples.
|
||||
"""
|
||||
|
||||
"""
|
||||
i_vars
|
||||
'x' for Cartesian2D
|
||||
'xy' for Cartesian3D
|
||||
etc.
|
||||
|
||||
d_vars
|
||||
'y' for Cartesian2D
|
||||
'r' for Polar
|
||||
etc.
|
||||
"""
|
||||
i_vars, d_vars = '', ''
|
||||
|
||||
"""
|
||||
intervals
|
||||
Default intervals for each i_var, and in the
|
||||
same order. Specified [min, max, steps].
|
||||
No variable can be given (it is bound later).
|
||||
"""
|
||||
intervals = []
|
||||
|
||||
"""
|
||||
aliases
|
||||
A list of strings which can be used to
|
||||
access this mode.
|
||||
'cartesian' for Cartesian2D and Cartesian3D
|
||||
'polar' for Polar
|
||||
'cylindrical', 'polar' for Cylindrical
|
||||
|
||||
Note that _init_mode chooses the first alias
|
||||
in the list as the mode's primary_alias, which
|
||||
will be displayed to the end user in certain
|
||||
contexts.
|
||||
"""
|
||||
aliases = []
|
||||
|
||||
"""
|
||||
is_default
|
||||
Whether to set this mode as the default
|
||||
for arguments passed to PlotMode() containing
|
||||
the same number of d_vars as this mode and
|
||||
at most the same number of i_vars.
|
||||
"""
|
||||
is_default = False
|
||||
|
||||
"""
|
||||
All of the above attributes are defined in PlotMode.
|
||||
The following ones are specific to PlotModeBase.
|
||||
"""
|
||||
|
||||
"""
|
||||
A list of the render styles. Do not modify.
|
||||
"""
|
||||
styles = {'wireframe': 1, 'solid': 2, 'both': 3}
|
||||
|
||||
"""
|
||||
style_override
|
||||
Always use this style if not blank.
|
||||
"""
|
||||
style_override = ''
|
||||
|
||||
"""
|
||||
default_wireframe_color
|
||||
default_solid_color
|
||||
Can be used when color is None or being calculated.
|
||||
Used by PlotCurve and PlotSurface, but not anywhere
|
||||
in PlotModeBase.
|
||||
"""
|
||||
|
||||
default_wireframe_color = (0.85, 0.85, 0.85)
|
||||
default_solid_color = (0.6, 0.6, 0.9)
|
||||
default_rot_preset = 'xy'
|
||||
|
||||
##
|
||||
## Instance-Level Attributes
|
||||
##
|
||||
|
||||
## 'Abstract' member functions
|
||||
def _get_evaluator(self):
|
||||
if self.use_lambda_eval:
|
||||
try:
|
||||
e = self._get_lambda_evaluator()
|
||||
return e
|
||||
except Exception:
|
||||
warnings.warn("\nWarning: creating lambda evaluator failed. "
|
||||
"Falling back on SymPy subs evaluator.")
|
||||
return self._get_sympy_evaluator()
|
||||
|
||||
def _get_sympy_evaluator(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_lambda_evaluator(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _on_calculate_verts(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _on_calculate_cverts(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
## Base member functions
|
||||
def __init__(self, *args, bounds_callback=None, **kwargs):
|
||||
self.verts = []
|
||||
self.cverts = []
|
||||
self.bounds = [[S.Infinity, S.NegativeInfinity, 0],
|
||||
[S.Infinity, S.NegativeInfinity, 0],
|
||||
[S.Infinity, S.NegativeInfinity, 0]]
|
||||
self.cbounds = [[S.Infinity, S.NegativeInfinity, 0],
|
||||
[S.Infinity, S.NegativeInfinity, 0],
|
||||
[S.Infinity, S.NegativeInfinity, 0]]
|
||||
|
||||
self._draw_lock = RLock()
|
||||
|
||||
self._calculating_verts = Event()
|
||||
self._calculating_cverts = Event()
|
||||
self._calculating_verts_pos = 0.0
|
||||
self._calculating_verts_len = 0.0
|
||||
self._calculating_cverts_pos = 0.0
|
||||
self._calculating_cverts_len = 0.0
|
||||
|
||||
self._max_render_stack_size = 3
|
||||
self._draw_wireframe = [-1]
|
||||
self._draw_solid = [-1]
|
||||
|
||||
self._style = None
|
||||
self._color = None
|
||||
|
||||
self.predraw = []
|
||||
self.postdraw = []
|
||||
|
||||
self.use_lambda_eval = self.options.pop('use_sympy_eval', None) is None
|
||||
self.style = self.options.pop('style', '')
|
||||
self.color = self.options.pop('color', 'rainbow')
|
||||
self.bounds_callback = bounds_callback
|
||||
|
||||
self._on_calculate()
|
||||
|
||||
def synchronized(f):
|
||||
def w(self, *args, **kwargs):
|
||||
self._draw_lock.acquire()
|
||||
try:
|
||||
r = f(self, *args, **kwargs)
|
||||
return r
|
||||
finally:
|
||||
self._draw_lock.release()
|
||||
return w
|
||||
|
||||
@synchronized
|
||||
def push_wireframe(self, function):
|
||||
"""
|
||||
Push a function which performs gl commands
|
||||
used to build a display list. (The list is
|
||||
built outside of the function)
|
||||
"""
|
||||
assert callable(function)
|
||||
self._draw_wireframe.append(function)
|
||||
if len(self._draw_wireframe) > self._max_render_stack_size:
|
||||
del self._draw_wireframe[1] # leave marker element
|
||||
|
||||
@synchronized
|
||||
def push_solid(self, function):
|
||||
"""
|
||||
Push a function which performs gl commands
|
||||
used to build a display list. (The list is
|
||||
built outside of the function)
|
||||
"""
|
||||
assert callable(function)
|
||||
self._draw_solid.append(function)
|
||||
if len(self._draw_solid) > self._max_render_stack_size:
|
||||
del self._draw_solid[1] # leave marker element
|
||||
|
||||
def _create_display_list(self, function):
|
||||
dl = pgl.glGenLists(1)
|
||||
pgl.glNewList(dl, pgl.GL_COMPILE)
|
||||
function()
|
||||
pgl.glEndList()
|
||||
return dl
|
||||
|
||||
def _render_stack_top(self, render_stack):
|
||||
top = render_stack[-1]
|
||||
if top == -1:
|
||||
return -1 # nothing to display
|
||||
elif callable(top):
|
||||
dl = self._create_display_list(top)
|
||||
render_stack[-1] = (dl, top)
|
||||
return dl # display newly added list
|
||||
elif len(top) == 2:
|
||||
if pgl.GL_TRUE == pgl.glIsList(top[0]):
|
||||
return top[0] # display stored list
|
||||
dl = self._create_display_list(top[1])
|
||||
render_stack[-1] = (dl, top[1])
|
||||
return dl # display regenerated list
|
||||
|
||||
def _draw_solid_display_list(self, dl):
|
||||
pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT)
|
||||
pgl.glPolygonMode(pgl.GL_FRONT_AND_BACK, pgl.GL_FILL)
|
||||
pgl.glCallList(dl)
|
||||
pgl.glPopAttrib()
|
||||
|
||||
def _draw_wireframe_display_list(self, dl):
|
||||
pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT)
|
||||
pgl.glPolygonMode(pgl.GL_FRONT_AND_BACK, pgl.GL_LINE)
|
||||
pgl.glEnable(pgl.GL_POLYGON_OFFSET_LINE)
|
||||
pgl.glPolygonOffset(-0.005, -50.0)
|
||||
pgl.glCallList(dl)
|
||||
pgl.glPopAttrib()
|
||||
|
||||
@synchronized
|
||||
def draw(self):
|
||||
for f in self.predraw:
|
||||
if callable(f):
|
||||
f()
|
||||
if self.style_override:
|
||||
style = self.styles[self.style_override]
|
||||
else:
|
||||
style = self.styles[self._style]
|
||||
# Draw solid component if style includes solid
|
||||
if style & 2:
|
||||
dl = self._render_stack_top(self._draw_solid)
|
||||
if dl > 0 and pgl.GL_TRUE == pgl.glIsList(dl):
|
||||
self._draw_solid_display_list(dl)
|
||||
# Draw wireframe component if style includes wireframe
|
||||
if style & 1:
|
||||
dl = self._render_stack_top(self._draw_wireframe)
|
||||
if dl > 0 and pgl.GL_TRUE == pgl.glIsList(dl):
|
||||
self._draw_wireframe_display_list(dl)
|
||||
for f in self.postdraw:
|
||||
if callable(f):
|
||||
f()
|
||||
|
||||
def _on_change_color(self, color):
|
||||
Thread(target=self._calculate_cverts).start()
|
||||
|
||||
def _on_calculate(self):
|
||||
Thread(target=self._calculate_all).start()
|
||||
|
||||
def _calculate_all(self):
|
||||
self._calculate_verts()
|
||||
self._calculate_cverts()
|
||||
|
||||
def _calculate_verts(self):
|
||||
if self._calculating_verts.is_set():
|
||||
return
|
||||
self._calculating_verts.set()
|
||||
try:
|
||||
self._on_calculate_verts()
|
||||
finally:
|
||||
self._calculating_verts.clear()
|
||||
if callable(self.bounds_callback):
|
||||
self.bounds_callback()
|
||||
|
||||
def _calculate_cverts(self):
|
||||
if self._calculating_verts.is_set():
|
||||
return
|
||||
while self._calculating_cverts.is_set():
|
||||
sleep(0) # wait for previous calculation
|
||||
self._calculating_cverts.set()
|
||||
try:
|
||||
self._on_calculate_cverts()
|
||||
finally:
|
||||
self._calculating_cverts.clear()
|
||||
|
||||
def _get_calculating_verts(self):
|
||||
return self._calculating_verts.is_set()
|
||||
|
||||
def _get_calculating_verts_pos(self):
|
||||
return self._calculating_verts_pos
|
||||
|
||||
def _get_calculating_verts_len(self):
|
||||
return self._calculating_verts_len
|
||||
|
||||
def _get_calculating_cverts(self):
|
||||
return self._calculating_cverts.is_set()
|
||||
|
||||
def _get_calculating_cverts_pos(self):
|
||||
return self._calculating_cverts_pos
|
||||
|
||||
def _get_calculating_cverts_len(self):
|
||||
return self._calculating_cverts_len
|
||||
|
||||
## Property handlers
|
||||
def _get_style(self):
|
||||
return self._style
|
||||
|
||||
@synchronized
|
||||
def _set_style(self, v):
|
||||
if v is None:
|
||||
return
|
||||
if v == '':
|
||||
step_max = 0
|
||||
for i in self.intervals:
|
||||
if i.v_steps is None:
|
||||
continue
|
||||
step_max = max([step_max, int(i.v_steps)])
|
||||
v = ['both', 'solid'][step_max > 40]
|
||||
if v not in self.styles:
|
||||
raise ValueError("v should be there in self.styles")
|
||||
if v == self._style:
|
||||
return
|
||||
self._style = v
|
||||
|
||||
def _get_color(self):
|
||||
return self._color
|
||||
|
||||
@synchronized
|
||||
def _set_color(self, v):
|
||||
try:
|
||||
if v is not None:
|
||||
if is_sequence(v):
|
||||
v = ColorScheme(*v)
|
||||
else:
|
||||
v = ColorScheme(v)
|
||||
if repr(v) == repr(self._color):
|
||||
return
|
||||
self._on_change_color(v)
|
||||
self._color = v
|
||||
except Exception as e:
|
||||
raise RuntimeError("Color change failed. "
|
||||
"Reason: %s" % (str(e)))
|
||||
|
||||
style = property(_get_style, _set_style)
|
||||
color = property(_get_color, _set_color)
|
||||
|
||||
calculating_verts = property(_get_calculating_verts)
|
||||
calculating_verts_pos = property(_get_calculating_verts_pos)
|
||||
calculating_verts_len = property(_get_calculating_verts_len)
|
||||
|
||||
calculating_cverts = property(_get_calculating_cverts)
|
||||
calculating_cverts_pos = property(_get_calculating_cverts_pos)
|
||||
calculating_cverts_len = property(_get_calculating_cverts_len)
|
||||
|
||||
## String representations
|
||||
|
||||
def __str__(self):
|
||||
f = ", ".join(str(d) for d in self.d_vars)
|
||||
o = "'mode=%s'" % (self.primary_alias)
|
||||
return ", ".join([f, o])
|
||||
|
||||
def __repr__(self):
|
||||
f = ", ".join(str(d) for d in self.d_vars)
|
||||
i = ", ".join(str(i) for i in self.intervals)
|
||||
d = [('mode', self.primary_alias),
|
||||
('color', str(self.color)),
|
||||
('style', str(self.style))]
|
||||
|
||||
o = "'%s'" % ("; ".join("%s=%s" % (k, v)
|
||||
for k, v in d if v != 'None'))
|
||||
return ", ".join([f, i, o])
|
||||
@@ -0,0 +1,209 @@
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from sympy.core.numbers import pi
|
||||
from sympy.functions import sin, cos
|
||||
from sympy.plotting.pygletplot.plot_curve import PlotCurve
|
||||
from sympy.plotting.pygletplot.plot_surface import PlotSurface
|
||||
|
||||
from math import sin as p_sin
|
||||
from math import cos as p_cos
|
||||
|
||||
|
||||
def float_vec3(f):
|
||||
def inner(*args):
|
||||
v = f(*args)
|
||||
return float(v[0]), float(v[1]), float(v[2])
|
||||
return inner
|
||||
|
||||
|
||||
class Cartesian2D(PlotCurve):
|
||||
i_vars, d_vars = 'x', 'y'
|
||||
intervals = [[-5, 5, 100]]
|
||||
aliases = ['cartesian']
|
||||
is_default = True
|
||||
|
||||
def _get_sympy_evaluator(self):
|
||||
fy = self.d_vars[0]
|
||||
x = self.t_interval.v
|
||||
|
||||
@float_vec3
|
||||
def e(_x):
|
||||
return (_x, fy.subs(x, _x), 0.0)
|
||||
return e
|
||||
|
||||
def _get_lambda_evaluator(self):
|
||||
fy = self.d_vars[0]
|
||||
x = self.t_interval.v
|
||||
return lambdify([x], [x, fy, 0.0])
|
||||
|
||||
|
||||
class Cartesian3D(PlotSurface):
|
||||
i_vars, d_vars = 'xy', 'z'
|
||||
intervals = [[-1, 1, 40], [-1, 1, 40]]
|
||||
aliases = ['cartesian', 'monge']
|
||||
is_default = True
|
||||
|
||||
def _get_sympy_evaluator(self):
|
||||
fz = self.d_vars[0]
|
||||
x = self.u_interval.v
|
||||
y = self.v_interval.v
|
||||
|
||||
@float_vec3
|
||||
def e(_x, _y):
|
||||
return (_x, _y, fz.subs(x, _x).subs(y, _y))
|
||||
return e
|
||||
|
||||
def _get_lambda_evaluator(self):
|
||||
fz = self.d_vars[0]
|
||||
x = self.u_interval.v
|
||||
y = self.v_interval.v
|
||||
return lambdify([x, y], [x, y, fz])
|
||||
|
||||
|
||||
class ParametricCurve2D(PlotCurve):
|
||||
i_vars, d_vars = 't', 'xy'
|
||||
intervals = [[0, 2*pi, 100]]
|
||||
aliases = ['parametric']
|
||||
is_default = True
|
||||
|
||||
def _get_sympy_evaluator(self):
|
||||
fx, fy = self.d_vars
|
||||
t = self.t_interval.v
|
||||
|
||||
@float_vec3
|
||||
def e(_t):
|
||||
return (fx.subs(t, _t), fy.subs(t, _t), 0.0)
|
||||
return e
|
||||
|
||||
def _get_lambda_evaluator(self):
|
||||
fx, fy = self.d_vars
|
||||
t = self.t_interval.v
|
||||
return lambdify([t], [fx, fy, 0.0])
|
||||
|
||||
|
||||
class ParametricCurve3D(PlotCurve):
|
||||
i_vars, d_vars = 't', 'xyz'
|
||||
intervals = [[0, 2*pi, 100]]
|
||||
aliases = ['parametric']
|
||||
is_default = True
|
||||
|
||||
def _get_sympy_evaluator(self):
|
||||
fx, fy, fz = self.d_vars
|
||||
t = self.t_interval.v
|
||||
|
||||
@float_vec3
|
||||
def e(_t):
|
||||
return (fx.subs(t, _t), fy.subs(t, _t), fz.subs(t, _t))
|
||||
return e
|
||||
|
||||
def _get_lambda_evaluator(self):
|
||||
fx, fy, fz = self.d_vars
|
||||
t = self.t_interval.v
|
||||
return lambdify([t], [fx, fy, fz])
|
||||
|
||||
|
||||
class ParametricSurface(PlotSurface):
|
||||
i_vars, d_vars = 'uv', 'xyz'
|
||||
intervals = [[-1, 1, 40], [-1, 1, 40]]
|
||||
aliases = ['parametric']
|
||||
is_default = True
|
||||
|
||||
def _get_sympy_evaluator(self):
|
||||
fx, fy, fz = self.d_vars
|
||||
u = self.u_interval.v
|
||||
v = self.v_interval.v
|
||||
|
||||
@float_vec3
|
||||
def e(_u, _v):
|
||||
return (fx.subs(u, _u).subs(v, _v),
|
||||
fy.subs(u, _u).subs(v, _v),
|
||||
fz.subs(u, _u).subs(v, _v))
|
||||
return e
|
||||
|
||||
def _get_lambda_evaluator(self):
|
||||
fx, fy, fz = self.d_vars
|
||||
u = self.u_interval.v
|
||||
v = self.v_interval.v
|
||||
return lambdify([u, v], [fx, fy, fz])
|
||||
|
||||
|
||||
class Polar(PlotCurve):
|
||||
i_vars, d_vars = 't', 'r'
|
||||
intervals = [[0, 2*pi, 100]]
|
||||
aliases = ['polar']
|
||||
is_default = False
|
||||
|
||||
def _get_sympy_evaluator(self):
|
||||
fr = self.d_vars[0]
|
||||
t = self.t_interval.v
|
||||
|
||||
def e(_t):
|
||||
_r = float(fr.subs(t, _t))
|
||||
return (_r*p_cos(_t), _r*p_sin(_t), 0.0)
|
||||
return e
|
||||
|
||||
def _get_lambda_evaluator(self):
|
||||
fr = self.d_vars[0]
|
||||
t = self.t_interval.v
|
||||
fx, fy = fr*cos(t), fr*sin(t)
|
||||
return lambdify([t], [fx, fy, 0.0])
|
||||
|
||||
|
||||
class Cylindrical(PlotSurface):
|
||||
i_vars, d_vars = 'th', 'r'
|
||||
intervals = [[0, 2*pi, 40], [-1, 1, 20]]
|
||||
aliases = ['cylindrical', 'polar']
|
||||
is_default = False
|
||||
|
||||
def _get_sympy_evaluator(self):
|
||||
fr = self.d_vars[0]
|
||||
t = self.u_interval.v
|
||||
h = self.v_interval.v
|
||||
|
||||
def e(_t, _h):
|
||||
_r = float(fr.subs(t, _t).subs(h, _h))
|
||||
return (_r*p_cos(_t), _r*p_sin(_t), _h)
|
||||
return e
|
||||
|
||||
def _get_lambda_evaluator(self):
|
||||
fr = self.d_vars[0]
|
||||
t = self.u_interval.v
|
||||
h = self.v_interval.v
|
||||
fx, fy = fr*cos(t), fr*sin(t)
|
||||
return lambdify([t, h], [fx, fy, h])
|
||||
|
||||
|
||||
class Spherical(PlotSurface):
|
||||
i_vars, d_vars = 'tp', 'r'
|
||||
intervals = [[0, 2*pi, 40], [0, pi, 20]]
|
||||
aliases = ['spherical']
|
||||
is_default = False
|
||||
|
||||
def _get_sympy_evaluator(self):
|
||||
fr = self.d_vars[0]
|
||||
t = self.u_interval.v
|
||||
p = self.v_interval.v
|
||||
|
||||
def e(_t, _p):
|
||||
_r = float(fr.subs(t, _t).subs(p, _p))
|
||||
return (_r*p_cos(_t)*p_sin(_p),
|
||||
_r*p_sin(_t)*p_sin(_p),
|
||||
_r*p_cos(_p))
|
||||
return e
|
||||
|
||||
def _get_lambda_evaluator(self):
|
||||
fr = self.d_vars[0]
|
||||
t = self.u_interval.v
|
||||
p = self.v_interval.v
|
||||
fx = fr * cos(t) * sin(p)
|
||||
fy = fr * sin(t) * sin(p)
|
||||
fz = fr * cos(p)
|
||||
return lambdify([t, p], [fx, fy, fz])
|
||||
|
||||
Cartesian2D._register()
|
||||
Cartesian3D._register()
|
||||
ParametricCurve2D._register()
|
||||
ParametricCurve3D._register()
|
||||
ParametricSurface._register()
|
||||
Polar._register()
|
||||
Cylindrical._register()
|
||||
Spherical._register()
|
||||
@@ -0,0 +1,17 @@
|
||||
class PlotObject:
|
||||
"""
|
||||
Base class for objects which can be displayed in
|
||||
a Plot.
|
||||
"""
|
||||
visible = True
|
||||
|
||||
def _draw(self):
|
||||
if self.visible:
|
||||
self.draw()
|
||||
|
||||
def draw(self):
|
||||
"""
|
||||
OpenGL rendering code for the plot object.
|
||||
Override in base class.
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,68 @@
|
||||
try:
|
||||
from ctypes import c_float
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import pyglet.gl as pgl
|
||||
from math import sqrt as _sqrt, acos as _acos, pi
|
||||
|
||||
|
||||
def cross(a, b):
|
||||
return (a[1] * b[2] - a[2] * b[1],
|
||||
a[2] * b[0] - a[0] * b[2],
|
||||
a[0] * b[1] - a[1] * b[0])
|
||||
|
||||
|
||||
def dot(a, b):
|
||||
return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
|
||||
|
||||
|
||||
def mag(a):
|
||||
return _sqrt(a[0]**2 + a[1]**2 + a[2]**2)
|
||||
|
||||
|
||||
def norm(a):
|
||||
m = mag(a)
|
||||
return (a[0] / m, a[1] / m, a[2] / m)
|
||||
|
||||
|
||||
def get_sphere_mapping(x, y, width, height):
|
||||
x = min([max([x, 0]), width])
|
||||
y = min([max([y, 0]), height])
|
||||
|
||||
sr = _sqrt((width/2)**2 + (height/2)**2)
|
||||
sx = ((x - width / 2) / sr)
|
||||
sy = ((y - height / 2) / sr)
|
||||
|
||||
sz = 1.0 - sx**2 - sy**2
|
||||
|
||||
if sz > 0.0:
|
||||
sz = _sqrt(sz)
|
||||
return (sx, sy, sz)
|
||||
else:
|
||||
sz = 0
|
||||
return norm((sx, sy, sz))
|
||||
|
||||
rad2deg = 180.0 / pi
|
||||
|
||||
|
||||
def get_spherical_rotatation(p1, p2, width, height, theta_multiplier):
|
||||
v1 = get_sphere_mapping(p1[0], p1[1], width, height)
|
||||
v2 = get_sphere_mapping(p2[0], p2[1], width, height)
|
||||
|
||||
d = min(max([dot(v1, v2), -1]), 1)
|
||||
|
||||
if abs(d - 1.0) < 0.000001:
|
||||
return None
|
||||
|
||||
raxis = norm( cross(v1, v2) )
|
||||
rtheta = theta_multiplier * rad2deg * _acos(d)
|
||||
|
||||
pgl.glPushMatrix()
|
||||
pgl.glLoadIdentity()
|
||||
pgl.glRotatef(rtheta, *raxis)
|
||||
mat = (c_float*16)()
|
||||
pgl.glGetFloatv(pgl.GL_MODELVIEW_MATRIX, mat)
|
||||
pgl.glPopMatrix()
|
||||
|
||||
return mat
|
||||
@@ -0,0 +1,102 @@
|
||||
import pyglet.gl as pgl
|
||||
|
||||
from sympy.core import S
|
||||
from sympy.plotting.pygletplot.plot_mode_base import PlotModeBase
|
||||
|
||||
|
||||
class PlotSurface(PlotModeBase):
|
||||
|
||||
default_rot_preset = 'perspective'
|
||||
|
||||
def _on_calculate_verts(self):
|
||||
self.u_interval = self.intervals[0]
|
||||
self.u_set = list(self.u_interval.frange())
|
||||
self.v_interval = self.intervals[1]
|
||||
self.v_set = list(self.v_interval.frange())
|
||||
self.bounds = [[S.Infinity, S.NegativeInfinity, 0],
|
||||
[S.Infinity, S.NegativeInfinity, 0],
|
||||
[S.Infinity, S.NegativeInfinity, 0]]
|
||||
evaluate = self._get_evaluator()
|
||||
|
||||
self._calculating_verts_pos = 0.0
|
||||
self._calculating_verts_len = float(
|
||||
self.u_interval.v_len*self.v_interval.v_len)
|
||||
|
||||
verts = []
|
||||
b = self.bounds
|
||||
for u in self.u_set:
|
||||
column = []
|
||||
for v in self.v_set:
|
||||
try:
|
||||
_e = evaluate(u, v) # calculate vertex
|
||||
except ZeroDivisionError:
|
||||
_e = None
|
||||
if _e is not None: # update bounding box
|
||||
for axis in range(3):
|
||||
b[axis][0] = min([b[axis][0], _e[axis]])
|
||||
b[axis][1] = max([b[axis][1], _e[axis]])
|
||||
column.append(_e)
|
||||
self._calculating_verts_pos += 1.0
|
||||
|
||||
verts.append(column)
|
||||
for axis in range(3):
|
||||
b[axis][2] = b[axis][1] - b[axis][0]
|
||||
if b[axis][2] == 0.0:
|
||||
b[axis][2] = 1.0
|
||||
|
||||
self.verts = verts
|
||||
self.push_wireframe(self.draw_verts(False, False))
|
||||
self.push_solid(self.draw_verts(False, True))
|
||||
|
||||
def _on_calculate_cverts(self):
|
||||
if not self.verts or not self.color:
|
||||
return
|
||||
|
||||
def set_work_len(n):
|
||||
self._calculating_cverts_len = float(n)
|
||||
|
||||
def inc_work_pos():
|
||||
self._calculating_cverts_pos += 1.0
|
||||
set_work_len(1)
|
||||
self._calculating_cverts_pos = 0
|
||||
self.cverts = self.color.apply_to_surface(self.verts,
|
||||
self.u_set,
|
||||
self.v_set,
|
||||
set_len=set_work_len,
|
||||
inc_pos=inc_work_pos)
|
||||
self.push_solid(self.draw_verts(True, True))
|
||||
|
||||
def calculate_one_cvert(self, u, v):
|
||||
vert = self.verts[u][v]
|
||||
return self.color(vert[0], vert[1], vert[2],
|
||||
self.u_set[u], self.v_set[v])
|
||||
|
||||
def draw_verts(self, use_cverts, use_solid_color):
|
||||
def f():
|
||||
for u in range(1, len(self.u_set)):
|
||||
pgl.glBegin(pgl.GL_QUAD_STRIP)
|
||||
for v in range(len(self.v_set)):
|
||||
pa = self.verts[u - 1][v]
|
||||
pb = self.verts[u][v]
|
||||
if pa is None or pb is None:
|
||||
pgl.glEnd()
|
||||
pgl.glBegin(pgl.GL_QUAD_STRIP)
|
||||
continue
|
||||
if use_cverts:
|
||||
ca = self.cverts[u - 1][v]
|
||||
cb = self.cverts[u][v]
|
||||
if ca is None:
|
||||
ca = (0, 0, 0)
|
||||
if cb is None:
|
||||
cb = (0, 0, 0)
|
||||
else:
|
||||
if use_solid_color:
|
||||
ca = cb = self.default_solid_color
|
||||
else:
|
||||
ca = cb = self.default_wireframe_color
|
||||
pgl.glColor3f(*ca)
|
||||
pgl.glVertex3f(*pa)
|
||||
pgl.glColor3f(*cb)
|
||||
pgl.glVertex3f(*pb)
|
||||
pgl.glEnd()
|
||||
return f
|
||||
@@ -0,0 +1,144 @@
|
||||
from time import perf_counter
|
||||
|
||||
|
||||
import pyglet.gl as pgl
|
||||
|
||||
from sympy.plotting.pygletplot.managed_window import ManagedWindow
|
||||
from sympy.plotting.pygletplot.plot_camera import PlotCamera
|
||||
from sympy.plotting.pygletplot.plot_controller import PlotController
|
||||
|
||||
|
||||
class PlotWindow(ManagedWindow):
|
||||
|
||||
def __init__(self, plot, antialiasing=True, ortho=False,
|
||||
invert_mouse_zoom=False, linewidth=1.5, caption="SymPy Plot",
|
||||
**kwargs):
|
||||
"""
|
||||
Named Arguments
|
||||
===============
|
||||
|
||||
antialiasing = True
|
||||
True OR False
|
||||
ortho = False
|
||||
True OR False
|
||||
invert_mouse_zoom = False
|
||||
True OR False
|
||||
"""
|
||||
self.plot = plot
|
||||
|
||||
self.camera = None
|
||||
self._calculating = False
|
||||
|
||||
self.antialiasing = antialiasing
|
||||
self.ortho = ortho
|
||||
self.invert_mouse_zoom = invert_mouse_zoom
|
||||
self.linewidth = linewidth
|
||||
self.title = caption
|
||||
self.last_caption_update = 0
|
||||
self.caption_update_interval = 0.2
|
||||
self.drawing_first_object = True
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.camera = PlotCamera(self, ortho=self.ortho)
|
||||
self.controller = PlotController(self,
|
||||
invert_mouse_zoom=self.invert_mouse_zoom)
|
||||
self.push_handlers(self.controller)
|
||||
|
||||
pgl.glClearColor(1.0, 1.0, 1.0, 0.0)
|
||||
pgl.glClearDepth(1.0)
|
||||
|
||||
pgl.glDepthFunc(pgl.GL_LESS)
|
||||
pgl.glEnable(pgl.GL_DEPTH_TEST)
|
||||
|
||||
pgl.glEnable(pgl.GL_LINE_SMOOTH)
|
||||
pgl.glShadeModel(pgl.GL_SMOOTH)
|
||||
pgl.glLineWidth(self.linewidth)
|
||||
|
||||
pgl.glEnable(pgl.GL_BLEND)
|
||||
pgl.glBlendFunc(pgl.GL_SRC_ALPHA, pgl.GL_ONE_MINUS_SRC_ALPHA)
|
||||
|
||||
if self.antialiasing:
|
||||
pgl.glHint(pgl.GL_LINE_SMOOTH_HINT, pgl.GL_NICEST)
|
||||
pgl.glHint(pgl.GL_POLYGON_SMOOTH_HINT, pgl.GL_NICEST)
|
||||
|
||||
self.camera.setup_projection()
|
||||
|
||||
def on_resize(self, w, h):
|
||||
super().on_resize(w, h)
|
||||
if self.camera is not None:
|
||||
self.camera.setup_projection()
|
||||
|
||||
def update(self, dt):
|
||||
self.controller.update(dt)
|
||||
|
||||
def draw(self):
|
||||
self.plot._render_lock.acquire()
|
||||
self.camera.apply_transformation()
|
||||
|
||||
calc_verts_pos, calc_verts_len = 0, 0
|
||||
calc_cverts_pos, calc_cverts_len = 0, 0
|
||||
|
||||
should_update_caption = (perf_counter() - self.last_caption_update >
|
||||
self.caption_update_interval)
|
||||
|
||||
if len(self.plot._functions.values()) == 0:
|
||||
self.drawing_first_object = True
|
||||
|
||||
iterfunctions = iter(self.plot._functions.values())
|
||||
|
||||
for r in iterfunctions:
|
||||
if self.drawing_first_object:
|
||||
self.camera.set_rot_preset(r.default_rot_preset)
|
||||
self.drawing_first_object = False
|
||||
|
||||
pgl.glPushMatrix()
|
||||
r._draw()
|
||||
pgl.glPopMatrix()
|
||||
|
||||
# might as well do this while we are
|
||||
# iterating and have the lock rather
|
||||
# than locking and iterating twice
|
||||
# per frame:
|
||||
|
||||
if should_update_caption:
|
||||
try:
|
||||
if r.calculating_verts:
|
||||
calc_verts_pos += r.calculating_verts_pos
|
||||
calc_verts_len += r.calculating_verts_len
|
||||
if r.calculating_cverts:
|
||||
calc_cverts_pos += r.calculating_cverts_pos
|
||||
calc_cverts_len += r.calculating_cverts_len
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
for r in self.plot._pobjects:
|
||||
pgl.glPushMatrix()
|
||||
r._draw()
|
||||
pgl.glPopMatrix()
|
||||
|
||||
if should_update_caption:
|
||||
self.update_caption(calc_verts_pos, calc_verts_len,
|
||||
calc_cverts_pos, calc_cverts_len)
|
||||
self.last_caption_update = perf_counter()
|
||||
|
||||
if self.plot._screenshot:
|
||||
self.plot._screenshot._execute_saving()
|
||||
|
||||
self.plot._render_lock.release()
|
||||
|
||||
def update_caption(self, calc_verts_pos, calc_verts_len,
|
||||
calc_cverts_pos, calc_cverts_len):
|
||||
caption = self.title
|
||||
if calc_verts_len or calc_cverts_len:
|
||||
caption += " (calculating"
|
||||
if calc_verts_len > 0:
|
||||
p = (calc_verts_pos / calc_verts_len) * 100
|
||||
caption += " vertices %i%%" % (p)
|
||||
if calc_cverts_len > 0:
|
||||
p = (calc_cverts_pos / calc_cverts_len) * 100
|
||||
caption += " colors %i%%" % (p)
|
||||
caption += ")"
|
||||
if self.caption != caption:
|
||||
self.set_caption(caption)
|
||||
@@ -0,0 +1,88 @@
|
||||
from sympy.external.importtools import import_module
|
||||
|
||||
disabled = False
|
||||
|
||||
# if pyglet.gl fails to import, e.g. opengl is missing, we disable the tests
|
||||
pyglet_gl = import_module("pyglet.gl", catch=(OSError,))
|
||||
pyglet_window = import_module("pyglet.window", catch=(OSError,))
|
||||
if not pyglet_gl or not pyglet_window:
|
||||
disabled = True
|
||||
|
||||
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import log
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||||
x, y, z = symbols('x, y, z')
|
||||
|
||||
|
||||
def test_plot_2d():
|
||||
from sympy.plotting.pygletplot import PygletPlot
|
||||
p = PygletPlot(x, [x, -5, 5, 4], visible=False)
|
||||
p.wait_for_calculations()
|
||||
|
||||
|
||||
def test_plot_2d_discontinuous():
|
||||
from sympy.plotting.pygletplot import PygletPlot
|
||||
p = PygletPlot(1/x, [x, -1, 1, 2], visible=False)
|
||||
p.wait_for_calculations()
|
||||
|
||||
|
||||
def test_plot_3d():
|
||||
from sympy.plotting.pygletplot import PygletPlot
|
||||
p = PygletPlot(x*y, [x, -5, 5, 5], [y, -5, 5, 5], visible=False)
|
||||
p.wait_for_calculations()
|
||||
|
||||
|
||||
def test_plot_3d_discontinuous():
|
||||
from sympy.plotting.pygletplot import PygletPlot
|
||||
p = PygletPlot(1/x, [x, -3, 3, 6], [y, -1, 1, 1], visible=False)
|
||||
p.wait_for_calculations()
|
||||
|
||||
|
||||
def test_plot_2d_polar():
|
||||
from sympy.plotting.pygletplot import PygletPlot
|
||||
p = PygletPlot(1/x, [x, -1, 1, 4], 'mode=polar', visible=False)
|
||||
p.wait_for_calculations()
|
||||
|
||||
|
||||
def test_plot_3d_cylinder():
|
||||
from sympy.plotting.pygletplot import PygletPlot
|
||||
p = PygletPlot(
|
||||
1/y, [x, 0, 6.282, 4], [y, -1, 1, 4], 'mode=polar;style=solid',
|
||||
visible=False)
|
||||
p.wait_for_calculations()
|
||||
|
||||
|
||||
def test_plot_3d_spherical():
|
||||
from sympy.plotting.pygletplot import PygletPlot
|
||||
p = PygletPlot(
|
||||
1, [x, 0, 6.282, 4], [y, 0, 3.141,
|
||||
4], 'mode=spherical;style=wireframe',
|
||||
visible=False)
|
||||
p.wait_for_calculations()
|
||||
|
||||
|
||||
def test_plot_2d_parametric():
|
||||
from sympy.plotting.pygletplot import PygletPlot
|
||||
p = PygletPlot(sin(x), cos(x), [x, 0, 6.282, 4], visible=False)
|
||||
p.wait_for_calculations()
|
||||
|
||||
|
||||
def test_plot_3d_parametric():
|
||||
from sympy.plotting.pygletplot import PygletPlot
|
||||
p = PygletPlot(sin(x), cos(x), x/5.0, [x, 0, 6.282, 4], visible=False)
|
||||
p.wait_for_calculations()
|
||||
|
||||
|
||||
def _test_plot_log():
|
||||
from sympy.plotting.pygletplot import PygletPlot
|
||||
p = PygletPlot(log(x), [x, 0, 6.282, 4], 'mode=polar', visible=False)
|
||||
p.wait_for_calculations()
|
||||
|
||||
|
||||
def test_plot_integral():
|
||||
# Make sure it doesn't treat x as an independent variable
|
||||
from sympy.plotting.pygletplot import PygletPlot
|
||||
from sympy.integrals.integrals import Integral
|
||||
p = PygletPlot(Integral(z*x, (x, 1, z), (z, 1, y)), visible=False)
|
||||
p.wait_for_calculations()
|
||||
@@ -0,0 +1,188 @@
|
||||
try:
|
||||
from ctypes import c_float, c_int, c_double
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import pyglet.gl as pgl
|
||||
from sympy.core import S
|
||||
|
||||
|
||||
def get_model_matrix(array_type=c_float, glGetMethod=pgl.glGetFloatv):
|
||||
"""
|
||||
Returns the current modelview matrix.
|
||||
"""
|
||||
m = (array_type*16)()
|
||||
glGetMethod(pgl.GL_MODELVIEW_MATRIX, m)
|
||||
return m
|
||||
|
||||
|
||||
def get_projection_matrix(array_type=c_float, glGetMethod=pgl.glGetFloatv):
|
||||
"""
|
||||
Returns the current modelview matrix.
|
||||
"""
|
||||
m = (array_type*16)()
|
||||
glGetMethod(pgl.GL_PROJECTION_MATRIX, m)
|
||||
return m
|
||||
|
||||
|
||||
def get_viewport():
|
||||
"""
|
||||
Returns the current viewport.
|
||||
"""
|
||||
m = (c_int*4)()
|
||||
pgl.glGetIntegerv(pgl.GL_VIEWPORT, m)
|
||||
return m
|
||||
|
||||
|
||||
def get_direction_vectors():
|
||||
m = get_model_matrix()
|
||||
return ((m[0], m[4], m[8]),
|
||||
(m[1], m[5], m[9]),
|
||||
(m[2], m[6], m[10]))
|
||||
|
||||
|
||||
def get_view_direction_vectors():
|
||||
m = get_model_matrix()
|
||||
return ((m[0], m[1], m[2]),
|
||||
(m[4], m[5], m[6]),
|
||||
(m[8], m[9], m[10]))
|
||||
|
||||
|
||||
def get_basis_vectors():
|
||||
return ((1, 0, 0), (0, 1, 0), (0, 0, 1))
|
||||
|
||||
|
||||
def screen_to_model(x, y, z):
|
||||
m = get_model_matrix(c_double, pgl.glGetDoublev)
|
||||
p = get_projection_matrix(c_double, pgl.glGetDoublev)
|
||||
w = get_viewport()
|
||||
mx, my, mz = c_double(), c_double(), c_double()
|
||||
pgl.gluUnProject(x, y, z, m, p, w, mx, my, mz)
|
||||
return float(mx.value), float(my.value), float(mz.value)
|
||||
|
||||
|
||||
def model_to_screen(x, y, z):
|
||||
m = get_model_matrix(c_double, pgl.glGetDoublev)
|
||||
p = get_projection_matrix(c_double, pgl.glGetDoublev)
|
||||
w = get_viewport()
|
||||
mx, my, mz = c_double(), c_double(), c_double()
|
||||
pgl.gluProject(x, y, z, m, p, w, mx, my, mz)
|
||||
return float(mx.value), float(my.value), float(mz.value)
|
||||
|
||||
|
||||
def vec_subs(a, b):
|
||||
return tuple(a[i] - b[i] for i in range(len(a)))
|
||||
|
||||
|
||||
def billboard_matrix():
|
||||
"""
|
||||
Removes rotational components of
|
||||
current matrix so that primitives
|
||||
are always drawn facing the viewer.
|
||||
|
||||
|1|0|0|x|
|
||||
|0|1|0|x|
|
||||
|0|0|1|x| (x means left unchanged)
|
||||
|x|x|x|x|
|
||||
"""
|
||||
m = get_model_matrix()
|
||||
# XXX: for i in range(11): m[i] = i ?
|
||||
m[0] = 1
|
||||
m[1] = 0
|
||||
m[2] = 0
|
||||
m[4] = 0
|
||||
m[5] = 1
|
||||
m[6] = 0
|
||||
m[8] = 0
|
||||
m[9] = 0
|
||||
m[10] = 1
|
||||
pgl.glLoadMatrixf(m)
|
||||
|
||||
|
||||
def create_bounds():
|
||||
return [[S.Infinity, S.NegativeInfinity, 0],
|
||||
[S.Infinity, S.NegativeInfinity, 0],
|
||||
[S.Infinity, S.NegativeInfinity, 0]]
|
||||
|
||||
|
||||
def update_bounds(b, v):
|
||||
if v is None:
|
||||
return
|
||||
for axis in range(3):
|
||||
b[axis][0] = min([b[axis][0], v[axis]])
|
||||
b[axis][1] = max([b[axis][1], v[axis]])
|
||||
|
||||
|
||||
def interpolate(a_min, a_max, a_ratio):
|
||||
return a_min + a_ratio * (a_max - a_min)
|
||||
|
||||
|
||||
def rinterpolate(a_min, a_max, a_value):
|
||||
a_range = a_max - a_min
|
||||
if a_max == a_min:
|
||||
a_range = 1.0
|
||||
return (a_value - a_min) / float(a_range)
|
||||
|
||||
|
||||
def interpolate_color(color1, color2, ratio):
|
||||
return tuple(interpolate(color1[i], color2[i], ratio) for i in range(3))
|
||||
|
||||
|
||||
def scale_value(v, v_min, v_len):
|
||||
return (v - v_min) / v_len
|
||||
|
||||
|
||||
def scale_value_list(flist):
|
||||
v_min, v_max = min(flist), max(flist)
|
||||
v_len = v_max - v_min
|
||||
return [scale_value(f, v_min, v_len) for f in flist]
|
||||
|
||||
|
||||
def strided_range(r_min, r_max, stride, max_steps=50):
|
||||
o_min, o_max = r_min, r_max
|
||||
if abs(r_min - r_max) < 0.001:
|
||||
return []
|
||||
try:
|
||||
range(int(r_min - r_max))
|
||||
except (TypeError, OverflowError):
|
||||
return []
|
||||
if r_min > r_max:
|
||||
raise ValueError("r_min cannot be greater than r_max")
|
||||
r_min_s = (r_min % stride)
|
||||
r_max_s = stride - (r_max % stride)
|
||||
if abs(r_max_s - stride) < 0.001:
|
||||
r_max_s = 0.0
|
||||
r_min -= r_min_s
|
||||
r_max += r_max_s
|
||||
r_steps = int((r_max - r_min)/stride)
|
||||
if max_steps and r_steps > max_steps:
|
||||
return strided_range(o_min, o_max, stride*2)
|
||||
return [r_min] + [r_min + e*stride for e in range(1, r_steps + 1)] + [r_max]
|
||||
|
||||
|
||||
def parse_option_string(s):
|
||||
if not isinstance(s, str):
|
||||
return None
|
||||
options = {}
|
||||
for token in s.split(';'):
|
||||
pieces = token.split('=')
|
||||
if len(pieces) == 1:
|
||||
option, value = pieces[0], ""
|
||||
elif len(pieces) == 2:
|
||||
option, value = pieces
|
||||
else:
|
||||
raise ValueError("Plot option string '%s' is malformed." % (s))
|
||||
options[option.strip()] = value.strip()
|
||||
return options
|
||||
|
||||
|
||||
def dot_product(v1, v2):
|
||||
return sum(v1[i]*v2[i] for i in range(3))
|
||||
|
||||
|
||||
def vec_sub(v1, v2):
|
||||
return tuple(v1[i] - v2[i] for i in range(3))
|
||||
|
||||
|
||||
def vec_mag(v):
|
||||
return sum(v[i]**2 for i in range(3))**(0.5)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,77 @@
|
||||
from sympy.core.symbol import symbols, Symbol
|
||||
from sympy.functions import Max
|
||||
from sympy.plotting.experimental_lambdify import experimental_lambdify
|
||||
from sympy.plotting.intervalmath.interval_arithmetic import \
|
||||
interval, intervalMembership
|
||||
|
||||
|
||||
# Tests for exception handling in experimental_lambdify
|
||||
def test_experimental_lambify():
|
||||
x = Symbol('x')
|
||||
f = experimental_lambdify([x], Max(x, 5))
|
||||
# XXX should f be tested? If f(2) is attempted, an
|
||||
# error is raised because a complex produced during wrapping of the arg
|
||||
# is being compared with an int.
|
||||
assert Max(2, 5) == 5
|
||||
assert Max(5, 7) == 7
|
||||
|
||||
x = Symbol('x-3')
|
||||
f = experimental_lambdify([x], x + 1)
|
||||
assert f(1) == 2
|
||||
|
||||
|
||||
def test_composite_boolean_region():
|
||||
x, y = symbols('x y')
|
||||
|
||||
r1 = (x - 1)**2 + y**2 < 2
|
||||
r2 = (x + 1)**2 + y**2 < 2
|
||||
|
||||
f = experimental_lambdify((x, y), r1 & r2)
|
||||
a = (interval(-0.1, 0.1), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(True, True)
|
||||
a = (interval(-1.1, -0.9), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
a = (interval(0.9, 1.1), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
a = (interval(-0.1, 0.1), interval(1.9, 2.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
|
||||
f = experimental_lambdify((x, y), r1 | r2)
|
||||
a = (interval(-0.1, 0.1), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(True, True)
|
||||
a = (interval(-1.1, -0.9), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(True, True)
|
||||
a = (interval(0.9, 1.1), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(True, True)
|
||||
a = (interval(-0.1, 0.1), interval(1.9, 2.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
|
||||
f = experimental_lambdify((x, y), r1 & ~r2)
|
||||
a = (interval(-0.1, 0.1), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
a = (interval(-1.1, -0.9), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
a = (interval(0.9, 1.1), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(True, True)
|
||||
a = (interval(-0.1, 0.1), interval(1.9, 2.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
|
||||
f = experimental_lambdify((x, y), ~r1 & r2)
|
||||
a = (interval(-0.1, 0.1), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
a = (interval(-1.1, -0.9), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(True, True)
|
||||
a = (interval(0.9, 1.1), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
a = (interval(-0.1, 0.1), interval(1.9, 2.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
|
||||
f = experimental_lambdify((x, y), ~r1 & ~r2)
|
||||
a = (interval(-0.1, 0.1), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
a = (interval(-1.1, -0.9), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
a = (interval(0.9, 1.1), interval(-0.1, 0.1))
|
||||
assert f(*a) == intervalMembership(False, True)
|
||||
a = (interval(-0.1, 0.1), interval(1.9, 2.1))
|
||||
assert f(*a) == intervalMembership(True, True)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,146 @@
|
||||
from sympy.core.numbers import (I, pi)
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.elementary.complexes import re
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin, tan)
|
||||
from sympy.logic.boolalg import (And, Or)
|
||||
from sympy.plotting.plot_implicit import plot_implicit
|
||||
from sympy.plotting.plot import unset_show
|
||||
from tempfile import NamedTemporaryFile, mkdtemp
|
||||
from sympy.testing.pytest import skip, warns, XFAIL
|
||||
from sympy.external import import_module
|
||||
from sympy.testing.tmpfiles import TmpFileManager
|
||||
|
||||
import os
|
||||
|
||||
#Set plots not to show
|
||||
unset_show()
|
||||
|
||||
def tmp_file(dir=None, name=''):
|
||||
return NamedTemporaryFile(
|
||||
suffix='.png', dir=dir, delete=False).name
|
||||
|
||||
def plot_and_save(expr, *args, name='', dir=None, **kwargs):
|
||||
p = plot_implicit(expr, *args, **kwargs)
|
||||
p.save(tmp_file(dir=dir, name=name))
|
||||
# Close the plot to avoid a warning from matplotlib
|
||||
p._backend.close()
|
||||
|
||||
def plot_implicit_tests(name):
|
||||
temp_dir = mkdtemp()
|
||||
TmpFileManager.tmp_folder(temp_dir)
|
||||
x = Symbol('x')
|
||||
y = Symbol('y')
|
||||
#implicit plot tests
|
||||
plot_and_save(Eq(y, cos(x)), (x, -5, 5), (y, -2, 2), name=name, dir=temp_dir)
|
||||
plot_and_save(Eq(y**2, x**3 - x), (x, -5, 5),
|
||||
(y, -4, 4), name=name, dir=temp_dir)
|
||||
plot_and_save(y > 1 / x, (x, -5, 5),
|
||||
(y, -2, 2), name=name, dir=temp_dir)
|
||||
plot_and_save(y < 1 / tan(x), (x, -5, 5),
|
||||
(y, -2, 2), name=name, dir=temp_dir)
|
||||
plot_and_save(y >= 2 * sin(x) * cos(x), (x, -5, 5),
|
||||
(y, -2, 2), name=name, dir=temp_dir)
|
||||
plot_and_save(y <= x**2, (x, -3, 3),
|
||||
(y, -1, 5), name=name, dir=temp_dir)
|
||||
|
||||
#Test all input args for plot_implicit
|
||||
plot_and_save(Eq(y**2, x**3 - x), dir=temp_dir)
|
||||
plot_and_save(Eq(y**2, x**3 - x), adaptive=False, dir=temp_dir)
|
||||
plot_and_save(Eq(y**2, x**3 - x), adaptive=False, n=500, dir=temp_dir)
|
||||
plot_and_save(y > x, (x, -5, 5), dir=temp_dir)
|
||||
plot_and_save(And(y > exp(x), y > x + 2), dir=temp_dir)
|
||||
plot_and_save(Or(y > x, y > -x), dir=temp_dir)
|
||||
plot_and_save(x**2 - 1, (x, -5, 5), dir=temp_dir)
|
||||
plot_and_save(x**2 - 1, dir=temp_dir)
|
||||
plot_and_save(y > x, depth=-5, dir=temp_dir)
|
||||
plot_and_save(y > x, depth=5, dir=temp_dir)
|
||||
plot_and_save(y > cos(x), adaptive=False, dir=temp_dir)
|
||||
plot_and_save(y < cos(x), adaptive=False, dir=temp_dir)
|
||||
plot_and_save(And(y > cos(x), Or(y > x, Eq(y, x))), dir=temp_dir)
|
||||
plot_and_save(y - cos(pi / x), dir=temp_dir)
|
||||
|
||||
plot_and_save(x**2 - 1, title='An implicit plot', dir=temp_dir)
|
||||
|
||||
@XFAIL
|
||||
def test_no_adaptive_meshing():
|
||||
matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,))
|
||||
if matplotlib:
|
||||
try:
|
||||
temp_dir = mkdtemp()
|
||||
TmpFileManager.tmp_folder(temp_dir)
|
||||
x = Symbol('x')
|
||||
y = Symbol('y')
|
||||
# Test plots which cannot be rendered using the adaptive algorithm
|
||||
|
||||
# This works, but it triggers a deprecation warning from sympify(). The
|
||||
# code needs to be updated to detect if interval math is supported without
|
||||
# relying on random AttributeErrors.
|
||||
with warns(UserWarning, match="Adaptive meshing could not be applied"):
|
||||
plot_and_save(Eq(y, re(cos(x) + I*sin(x))), name='test', dir=temp_dir)
|
||||
finally:
|
||||
TmpFileManager.cleanup()
|
||||
else:
|
||||
skip("Matplotlib not the default backend")
|
||||
def test_line_color():
|
||||
x, y = symbols('x, y')
|
||||
p = plot_implicit(x**2 + y**2 - 1, line_color="green", show=False)
|
||||
assert p._series[0].line_color == "green"
|
||||
p = plot_implicit(x**2 + y**2 - 1, line_color='r', show=False)
|
||||
assert p._series[0].line_color == "r"
|
||||
|
||||
def test_matplotlib():
|
||||
matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,))
|
||||
if matplotlib:
|
||||
try:
|
||||
plot_implicit_tests('test')
|
||||
test_line_color()
|
||||
finally:
|
||||
TmpFileManager.cleanup()
|
||||
else:
|
||||
skip("Matplotlib not the default backend")
|
||||
|
||||
|
||||
def test_region_and():
|
||||
matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,))
|
||||
if not matplotlib:
|
||||
skip("Matplotlib not the default backend")
|
||||
|
||||
from matplotlib.testing.compare import compare_images
|
||||
test_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
try:
|
||||
temp_dir = mkdtemp()
|
||||
TmpFileManager.tmp_folder(temp_dir)
|
||||
|
||||
x, y = symbols('x y')
|
||||
|
||||
r1 = (x - 1)**2 + y**2 < 2
|
||||
r2 = (x + 1)**2 + y**2 < 2
|
||||
|
||||
test_filename = tmp_file(dir=temp_dir, name="test_region_and")
|
||||
cmp_filename = os.path.join(test_directory, "test_region_and.png")
|
||||
p = plot_implicit(r1 & r2, x, y)
|
||||
p.save(test_filename)
|
||||
compare_images(cmp_filename, test_filename, 0.005)
|
||||
|
||||
test_filename = tmp_file(dir=temp_dir, name="test_region_or")
|
||||
cmp_filename = os.path.join(test_directory, "test_region_or.png")
|
||||
p = plot_implicit(r1 | r2, x, y)
|
||||
p.save(test_filename)
|
||||
compare_images(cmp_filename, test_filename, 0.005)
|
||||
|
||||
test_filename = tmp_file(dir=temp_dir, name="test_region_not")
|
||||
cmp_filename = os.path.join(test_directory, "test_region_not.png")
|
||||
p = plot_implicit(~r1, x, y)
|
||||
p.save(test_filename)
|
||||
compare_images(cmp_filename, test_filename, 0.005)
|
||||
|
||||
test_filename = tmp_file(dir=temp_dir, name="test_region_xor")
|
||||
cmp_filename = os.path.join(test_directory, "test_region_xor.png")
|
||||
p = plot_implicit(r1 ^ r2, x, y)
|
||||
p.save(test_filename)
|
||||
compare_images(cmp_filename, test_filename, 0.005)
|
||||
finally:
|
||||
TmpFileManager.cleanup()
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 6.7 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 7.8 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 8.6 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 9.8 KiB |
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,203 @@
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.functions.elementary.exponential import log
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.plotting.textplot import textplot_str
|
||||
|
||||
from sympy.utilities.exceptions import ignore_warnings
|
||||
|
||||
|
||||
def test_axes_alignment():
|
||||
x = Symbol('x')
|
||||
lines = [
|
||||
' 1 | ..',
|
||||
' | ... ',
|
||||
' | .. ',
|
||||
' | ... ',
|
||||
' | ... ',
|
||||
' | .. ',
|
||||
' | ... ',
|
||||
' | ... ',
|
||||
' | .. ',
|
||||
' | ... ',
|
||||
' 0 |--------------------------...--------------------------',
|
||||
' | ... ',
|
||||
' | .. ',
|
||||
' | ... ',
|
||||
' | ... ',
|
||||
' | .. ',
|
||||
' | ... ',
|
||||
' | ... ',
|
||||
' | .. ',
|
||||
' | ... ',
|
||||
' -1 |_______________________________________________________',
|
||||
' -1 0 1'
|
||||
]
|
||||
assert lines == list(textplot_str(x, -1, 1))
|
||||
|
||||
lines = [
|
||||
' 1 | ..',
|
||||
' | .... ',
|
||||
' | ... ',
|
||||
' | ... ',
|
||||
' | .... ',
|
||||
' | ... ',
|
||||
' | ... ',
|
||||
' | .... ',
|
||||
' 0 |--------------------------...--------------------------',
|
||||
' | .... ',
|
||||
' | ... ',
|
||||
' | ... ',
|
||||
' | .... ',
|
||||
' | ... ',
|
||||
' | ... ',
|
||||
' | .... ',
|
||||
' -1 |_______________________________________________________',
|
||||
' -1 0 1'
|
||||
]
|
||||
assert lines == list(textplot_str(x, -1, 1, H=17))
|
||||
|
||||
|
||||
def test_singularity():
|
||||
x = Symbol('x')
|
||||
lines = [
|
||||
' 54 | . ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' 27.5 |--.----------------------------------------------------',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | . ',
|
||||
' | \\ ',
|
||||
' | \\ ',
|
||||
' | .. ',
|
||||
' | ... ',
|
||||
' | ............. ',
|
||||
' 1 |_______________________________________________________',
|
||||
' 0 0.5 1'
|
||||
]
|
||||
assert lines == list(textplot_str(1/x, 0, 1))
|
||||
|
||||
lines = [
|
||||
' 0 | ......',
|
||||
' | ........ ',
|
||||
' | ........ ',
|
||||
' | ...... ',
|
||||
' | ..... ',
|
||||
' | .... ',
|
||||
' | ... ',
|
||||
' | .. ',
|
||||
' | ... ',
|
||||
' | / ',
|
||||
' -2 |-------..----------------------------------------------',
|
||||
' | / ',
|
||||
' | / ',
|
||||
' | / ',
|
||||
' | . ',
|
||||
' | ',
|
||||
' | . ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' -4 |_______________________________________________________',
|
||||
' 0 0.5 1'
|
||||
]
|
||||
# RuntimeWarning: divide by zero encountered in log
|
||||
with ignore_warnings(RuntimeWarning):
|
||||
assert lines == list(textplot_str(log(x), 0, 1))
|
||||
|
||||
|
||||
def test_sinc():
|
||||
x = Symbol('x')
|
||||
lines = [
|
||||
' 1 | . . ',
|
||||
' | . . ',
|
||||
' | ',
|
||||
' | . . ',
|
||||
' | ',
|
||||
' | . . ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | . . ',
|
||||
' | ',
|
||||
' 0.4 |-------------------------------------------------------',
|
||||
' | . . ',
|
||||
' | ',
|
||||
' | . . ',
|
||||
' | ',
|
||||
' | ..... ..... ',
|
||||
' | .. \\ . . / .. ',
|
||||
' | / \\ / \\ ',
|
||||
' |/ \\ . . / \\',
|
||||
' | \\ / \\ / ',
|
||||
' -0.2 |_______________________________________________________',
|
||||
' -10 0 10'
|
||||
]
|
||||
# RuntimeWarning: invalid value encountered in double_scalars
|
||||
with ignore_warnings(RuntimeWarning):
|
||||
assert lines == list(textplot_str(sin(x)/x, -10, 10))
|
||||
|
||||
|
||||
def test_imaginary():
|
||||
x = Symbol('x')
|
||||
lines = [
|
||||
' 1 | ..',
|
||||
' | .. ',
|
||||
' | ... ',
|
||||
' | .. ',
|
||||
' | .. ',
|
||||
' | .. ',
|
||||
' | .. ',
|
||||
' | .. ',
|
||||
' | .. ',
|
||||
' | / ',
|
||||
' 0.5 |----------------------------------/--------------------',
|
||||
' | .. ',
|
||||
' | / ',
|
||||
' | . ',
|
||||
' | ',
|
||||
' | . ',
|
||||
' | . ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' 0 |_______________________________________________________',
|
||||
' -1 0 1'
|
||||
]
|
||||
# RuntimeWarning: invalid value encountered in sqrt
|
||||
with ignore_warnings(RuntimeWarning):
|
||||
assert list(textplot_str(sqrt(x), -1, 1)) == lines
|
||||
|
||||
lines = [
|
||||
' 1 | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' 0 |-------------------------------------------------------',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' | ',
|
||||
' -1 |_______________________________________________________',
|
||||
' -1 0 1'
|
||||
]
|
||||
assert list(textplot_str(S.ImaginaryUnit, -1, 1)) == lines
|
||||
@@ -0,0 +1,110 @@
|
||||
from pytest import raises
|
||||
from sympy import (
|
||||
symbols, Expr, Tuple, Integer, cos, solveset, FiniteSet, ImageSet)
|
||||
from sympy.plotting.utils import (
|
||||
_create_ranges, _plot_sympify, extract_solution)
|
||||
from sympy.physics.mechanics import ReferenceFrame, Vector as MechVector
|
||||
from sympy.vector import CoordSys3D, Vector
|
||||
|
||||
|
||||
def test_plot_sympify():
|
||||
x, y = symbols("x, y")
|
||||
|
||||
# argument is already sympified
|
||||
args = x + y
|
||||
r = _plot_sympify(args)
|
||||
assert r == args
|
||||
|
||||
# one argument needs to be sympified
|
||||
args = (x + y, 1)
|
||||
r = _plot_sympify(args)
|
||||
assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2
|
||||
assert isinstance(r[0], Expr)
|
||||
assert isinstance(r[1], Integer)
|
||||
|
||||
# string and dict should not be sympified
|
||||
args = (x + y, (x, 0, 1), "str", 1, {1: 1, 2: 2.0})
|
||||
r = _plot_sympify(args)
|
||||
assert isinstance(r, (list, tuple, Tuple)) and len(r) == 5
|
||||
assert isinstance(r[0], Expr)
|
||||
assert isinstance(r[1], Tuple)
|
||||
assert isinstance(r[2], str)
|
||||
assert isinstance(r[3], Integer)
|
||||
assert isinstance(r[4], dict) and isinstance(r[4][1], int) and isinstance(r[4][2], float)
|
||||
|
||||
# nested arguments containing strings
|
||||
args = ((x + y, (y, 0, 1), "a"), (x + 1, (x, 0, 1), "$f_{1}$"))
|
||||
r = _plot_sympify(args)
|
||||
assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2
|
||||
assert isinstance(r[0], Tuple)
|
||||
assert isinstance(r[0][1], Tuple)
|
||||
assert isinstance(r[0][1][1], Integer)
|
||||
assert isinstance(r[0][2], str)
|
||||
assert isinstance(r[1], Tuple)
|
||||
assert isinstance(r[1][1], Tuple)
|
||||
assert isinstance(r[1][1][1], Integer)
|
||||
assert isinstance(r[1][2], str)
|
||||
|
||||
# vectors from sympy.physics.vectors module are not sympified
|
||||
# vectors from sympy.vectors are sympified
|
||||
# in both cases, no error should be raised
|
||||
R = ReferenceFrame("R")
|
||||
v1 = 2 * R.x + R.y
|
||||
C = CoordSys3D("C")
|
||||
v2 = 2 * C.i + C.j
|
||||
args = (v1, v2)
|
||||
r = _plot_sympify(args)
|
||||
assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2
|
||||
assert isinstance(v1, MechVector)
|
||||
assert isinstance(v2, Vector)
|
||||
|
||||
|
||||
def test_create_ranges():
|
||||
x, y = symbols("x, y")
|
||||
|
||||
# user don't provide any range -> return a default range
|
||||
r = _create_ranges({x}, [], 1)
|
||||
assert isinstance(r, (list, tuple, Tuple)) and len(r) == 1
|
||||
assert isinstance(r[0], (Tuple, tuple))
|
||||
assert r[0] == (x, -10, 10)
|
||||
|
||||
r = _create_ranges({x, y}, [], 2)
|
||||
assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2
|
||||
assert isinstance(r[0], (Tuple, tuple))
|
||||
assert isinstance(r[1], (Tuple, tuple))
|
||||
assert r[0] == (x, -10, 10) or (y, -10, 10)
|
||||
assert r[1] == (y, -10, 10) or (x, -10, 10)
|
||||
assert r[0] != r[1]
|
||||
|
||||
# not enough ranges provided by the user -> create default ranges
|
||||
r = _create_ranges(
|
||||
{x, y},
|
||||
[
|
||||
(x, 0, 1),
|
||||
],
|
||||
2,
|
||||
)
|
||||
assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2
|
||||
assert isinstance(r[0], (Tuple, tuple))
|
||||
assert isinstance(r[1], (Tuple, tuple))
|
||||
assert r[0] == (x, 0, 1) or (y, -10, 10)
|
||||
assert r[1] == (y, -10, 10) or (x, 0, 1)
|
||||
assert r[0] != r[1]
|
||||
|
||||
# too many free symbols
|
||||
raises(ValueError, lambda: _create_ranges({x, y}, [], 1))
|
||||
raises(ValueError, lambda: _create_ranges({x, y}, [(x, 0, 5), (y, 0, 1)], 1))
|
||||
|
||||
|
||||
def test_extract_solution():
|
||||
x = symbols("x")
|
||||
|
||||
sol = solveset(cos(10 * x))
|
||||
assert sol.has(ImageSet)
|
||||
res = extract_solution(sol)
|
||||
assert len(res) == 20
|
||||
assert isinstance(res, FiniteSet)
|
||||
|
||||
res = extract_solution(sol, 20)
|
||||
assert len(res) == 40
|
||||
assert isinstance(res, FiniteSet)
|
||||
@@ -0,0 +1,168 @@
|
||||
from sympy.core.numbers import Float
|
||||
from sympy.core.symbol import Dummy
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def is_valid(x):
|
||||
"""Check if a floating point number is valid"""
|
||||
if x is None:
|
||||
return False
|
||||
if isinstance(x, complex):
|
||||
return False
|
||||
return not math.isinf(x) and not math.isnan(x)
|
||||
|
||||
|
||||
def rescale(y, W, H, mi, ma):
|
||||
"""Rescale the given array `y` to fit into the integer values
|
||||
between `0` and `H-1` for the values between ``mi`` and ``ma``.
|
||||
"""
|
||||
y_new = []
|
||||
|
||||
norm = ma - mi
|
||||
offset = (ma + mi) / 2
|
||||
|
||||
for x in range(W):
|
||||
if is_valid(y[x]):
|
||||
normalized = (y[x] - offset) / norm
|
||||
if not is_valid(normalized):
|
||||
y_new.append(None)
|
||||
else:
|
||||
rescaled = Float((normalized*H + H/2) * (H-1)/H).round()
|
||||
rescaled = int(rescaled)
|
||||
y_new.append(rescaled)
|
||||
else:
|
||||
y_new.append(None)
|
||||
return y_new
|
||||
|
||||
|
||||
def linspace(start, stop, num):
|
||||
return [start + (stop - start) * x / (num-1) for x in range(num)]
|
||||
|
||||
|
||||
def textplot_str(expr, a, b, W=55, H=21):
|
||||
"""Generator for the lines of the plot"""
|
||||
free = expr.free_symbols
|
||||
if len(free) > 1:
|
||||
raise ValueError(
|
||||
"The expression must have a single variable. (Got {})"
|
||||
.format(free))
|
||||
x = free.pop() if free else Dummy()
|
||||
f = lambdify([x], expr)
|
||||
if isinstance(a, complex):
|
||||
if a.imag == 0:
|
||||
a = a.real
|
||||
if isinstance(b, complex):
|
||||
if b.imag == 0:
|
||||
b = b.real
|
||||
a = float(a)
|
||||
b = float(b)
|
||||
|
||||
# Calculate function values
|
||||
x = linspace(a, b, W)
|
||||
y = []
|
||||
for val in x:
|
||||
try:
|
||||
y.append(f(val))
|
||||
# Not sure what exceptions to catch here or why...
|
||||
except (ValueError, TypeError, ZeroDivisionError):
|
||||
y.append(None)
|
||||
|
||||
# Normalize height to screen space
|
||||
y_valid = list(filter(is_valid, y))
|
||||
if y_valid:
|
||||
ma = max(y_valid)
|
||||
mi = min(y_valid)
|
||||
if ma == mi:
|
||||
if ma:
|
||||
mi, ma = sorted([0, 2*ma])
|
||||
else:
|
||||
mi, ma = -1, 1
|
||||
else:
|
||||
mi, ma = -1, 1
|
||||
y_range = ma - mi
|
||||
precision = math.floor(math.log10(y_range)) - 1
|
||||
precision *= -1
|
||||
mi = round(mi, precision)
|
||||
ma = round(ma, precision)
|
||||
y = rescale(y, W, H, mi, ma)
|
||||
|
||||
y_bins = linspace(mi, ma, H)
|
||||
|
||||
# Draw plot
|
||||
margin = 7
|
||||
for h in range(H - 1, -1, -1):
|
||||
s = [' '] * W
|
||||
for i in range(W):
|
||||
if y[i] == h:
|
||||
if (i == 0 or y[i - 1] == h - 1) and (i == W - 1 or y[i + 1] == h + 1):
|
||||
s[i] = '/'
|
||||
elif (i == 0 or y[i - 1] == h + 1) and (i == W - 1 or y[i + 1] == h - 1):
|
||||
s[i] = '\\'
|
||||
else:
|
||||
s[i] = '.'
|
||||
|
||||
if h == 0:
|
||||
for i in range(W):
|
||||
s[i] = '_'
|
||||
|
||||
# Print y values
|
||||
if h in (0, H//2, H - 1):
|
||||
prefix = ("%g" % y_bins[h]).rjust(margin)[:margin]
|
||||
else:
|
||||
prefix = " "*margin
|
||||
s = "".join(s)
|
||||
if h == H//2:
|
||||
s = s.replace(" ", "-")
|
||||
yield prefix + " |" + s
|
||||
|
||||
# Print x values
|
||||
bottom = " " * (margin + 2)
|
||||
bottom += ("%g" % x[0]).ljust(W//2)
|
||||
if W % 2 == 1:
|
||||
bottom += ("%g" % x[W//2]).ljust(W//2)
|
||||
else:
|
||||
bottom += ("%g" % x[W//2]).ljust(W//2-1)
|
||||
bottom += "%g" % x[-1]
|
||||
yield bottom
|
||||
|
||||
|
||||
def textplot(expr, a, b, W=55, H=21):
|
||||
r"""
|
||||
Print a crude ASCII art plot of the SymPy expression 'expr' (which
|
||||
should contain a single symbol, e.g. x or something else) over the
|
||||
interval [a, b].
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Symbol, sin
|
||||
>>> from sympy.plotting import textplot
|
||||
>>> t = Symbol('t')
|
||||
>>> textplot(sin(t)*t, 0, 15)
|
||||
14 | ...
|
||||
| .
|
||||
| .
|
||||
| .
|
||||
| .
|
||||
| ...
|
||||
| / . .
|
||||
| /
|
||||
| / .
|
||||
| . . .
|
||||
1.5 |----.......--------------------------------------------
|
||||
|.... \ . .
|
||||
| \ / .
|
||||
| .. / .
|
||||
| \ / .
|
||||
| ....
|
||||
| .
|
||||
| . .
|
||||
|
|
||||
| . .
|
||||
-11 |_______________________________________________________
|
||||
0 7.5 15
|
||||
"""
|
||||
for line in textplot_str(expr, a, b, W, H):
|
||||
print(line)
|
||||
@@ -0,0 +1,323 @@
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.expr import Expr
|
||||
from sympy.core.function import AppliedUndef
|
||||
from sympy.core.relational import Relational
|
||||
from sympy.core.symbol import Dummy
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.logic.boolalg import BooleanFunction
|
||||
from sympy.sets.fancysets import ImageSet
|
||||
from sympy.sets.sets import FiniteSet
|
||||
from sympy.tensor.indexed import Indexed
|
||||
|
||||
|
||||
def _get_free_symbols(exprs):
|
||||
"""Returns the free symbols of a symbolic expression.
|
||||
|
||||
If the expression contains any of these elements, assume that they are
|
||||
the "free symbols" of the expression:
|
||||
|
||||
* indexed objects
|
||||
* applied undefined function (useful for sympy.physics.mechanics module)
|
||||
"""
|
||||
if not isinstance(exprs, (list, tuple, set)):
|
||||
exprs = [exprs]
|
||||
if all(callable(e) for e in exprs):
|
||||
return set()
|
||||
|
||||
free = set().union(*[e.atoms(Indexed) for e in exprs])
|
||||
free = free.union(*[e.atoms(AppliedUndef) for e in exprs])
|
||||
return free or set().union(*[e.free_symbols for e in exprs])
|
||||
|
||||
|
||||
def extract_solution(set_sol, n=10):
|
||||
"""Extract numerical solutions from a set solution (computed by solveset,
|
||||
linsolve, nonlinsolve). Often, it is not trivial do get something useful
|
||||
out of them.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
n : int, optional
|
||||
In order to replace ImageSet with FiniteSet, an iterator is created
|
||||
for each ImageSet contained in `set_sol`, starting from 0 up to `n`.
|
||||
Default value: 10.
|
||||
"""
|
||||
images = set_sol.find(ImageSet)
|
||||
for im in images:
|
||||
it = iter(im)
|
||||
s = FiniteSet(*[next(it) for n in range(0, n)])
|
||||
set_sol = set_sol.subs(im, s)
|
||||
return set_sol
|
||||
|
||||
|
||||
def _plot_sympify(args):
|
||||
"""This function recursively loop over the arguments passed to the plot
|
||||
functions: the sympify function will be applied to all arguments except
|
||||
those of type string/dict.
|
||||
|
||||
Generally, users can provide the following arguments to a plot function:
|
||||
|
||||
expr, range1 [tuple, opt], ..., label [str, opt], rendering_kw [dict, opt]
|
||||
|
||||
`expr, range1, ...` can be sympified, whereas `label, rendering_kw` can't.
|
||||
In particular, whenever a special character like $, {, }, ... is used in
|
||||
the `label`, sympify will raise an error.
|
||||
"""
|
||||
if isinstance(args, Expr):
|
||||
return args
|
||||
|
||||
args = list(args)
|
||||
for i, a in enumerate(args):
|
||||
if isinstance(a, (list, tuple)):
|
||||
args[i] = Tuple(*_plot_sympify(a), sympify=False)
|
||||
elif not (isinstance(a, (str, dict)) or callable(a)
|
||||
# NOTE: check if it is a vector from sympy.physics.vector module
|
||||
# without importing the module (because it slows down SymPy's
|
||||
# import process and triggers SymPy's optional-dependencies
|
||||
# tests to fail).
|
||||
or ((a.__class__.__name__ == "Vector") and not isinstance(a, Basic))
|
||||
):
|
||||
args[i] = sympify(a)
|
||||
return args
|
||||
|
||||
|
||||
def _create_ranges(exprs, ranges, npar, label="", params=None):
|
||||
"""This function does two things:
|
||||
|
||||
1. Check if the number of free symbols is in agreement with the type of
|
||||
plot chosen. For example, plot() requires 1 free symbol;
|
||||
plot3d() requires 2 free symbols.
|
||||
2. Sometime users create plots without providing ranges for the variables.
|
||||
Here we create the necessary ranges.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
exprs : iterable
|
||||
The expressions from which to extract the free symbols
|
||||
ranges : iterable
|
||||
The limiting ranges provided by the user
|
||||
npar : int
|
||||
The number of free symbols required by the plot functions.
|
||||
For example,
|
||||
npar=1 for plot, npar=2 for plot3d, ...
|
||||
params : dict
|
||||
A dictionary mapping symbols to parameters for interactive plot.
|
||||
"""
|
||||
get_default_range = lambda symbol: Tuple(symbol, -10, 10)
|
||||
|
||||
free_symbols = _get_free_symbols(exprs)
|
||||
if params is not None:
|
||||
free_symbols = free_symbols.difference(params.keys())
|
||||
|
||||
if len(free_symbols) > npar:
|
||||
raise ValueError(
|
||||
"Too many free symbols.\n"
|
||||
+ "Expected {} free symbols.\n".format(npar)
|
||||
+ "Received {}: {}".format(len(free_symbols), free_symbols)
|
||||
)
|
||||
|
||||
if len(ranges) > npar:
|
||||
raise ValueError(
|
||||
"Too many ranges. Received %s, expected %s" % (len(ranges), npar))
|
||||
|
||||
# free symbols in the ranges provided by the user
|
||||
rfs = set().union([r[0] for r in ranges])
|
||||
if len(rfs) != len(ranges):
|
||||
raise ValueError("Multiple ranges with the same symbol")
|
||||
|
||||
if len(ranges) < npar:
|
||||
symbols = free_symbols.difference(rfs)
|
||||
if symbols != set():
|
||||
# add a range for each missing free symbols
|
||||
for s in symbols:
|
||||
ranges.append(get_default_range(s))
|
||||
# if there is still room, fill them with dummys
|
||||
for i in range(npar - len(ranges)):
|
||||
ranges.append(get_default_range(Dummy()))
|
||||
|
||||
if len(free_symbols) == npar:
|
||||
# there could be times when this condition is not met, for example
|
||||
# plotting the function f(x, y) = x (which is a plane); in this case,
|
||||
# free_symbols = {x} whereas rfs = {x, y} (or x and Dummy)
|
||||
rfs = set().union([r[0] for r in ranges])
|
||||
if len(free_symbols.difference(rfs)) > 0:
|
||||
raise ValueError(
|
||||
"Incompatible free symbols of the expressions with "
|
||||
"the ranges.\n"
|
||||
+ "Free symbols in the expressions: {}\n".format(free_symbols)
|
||||
+ "Free symbols in the ranges: {}".format(rfs)
|
||||
)
|
||||
return ranges
|
||||
|
||||
|
||||
def _is_range(r):
|
||||
"""A range is defined as (symbol, start, end). start and end should
|
||||
be numbers.
|
||||
"""
|
||||
# TODO: prange check goes here
|
||||
return (
|
||||
isinstance(r, Tuple)
|
||||
and (len(r) == 3)
|
||||
and (not isinstance(r.args[1], str)) and r.args[1].is_number
|
||||
and (not isinstance(r.args[2], str)) and r.args[2].is_number
|
||||
)
|
||||
|
||||
|
||||
def _unpack_args(*args):
|
||||
"""Given a list/tuple of arguments previously processed by _plot_sympify()
|
||||
and/or _check_arguments(), separates and returns its components:
|
||||
expressions, ranges, label and rendering keywords.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import cos, sin, symbols
|
||||
>>> from sympy.plotting.utils import _plot_sympify, _unpack_args
|
||||
>>> x, y = symbols('x, y')
|
||||
>>> args = (sin(x), (x, -10, 10), "f1")
|
||||
>>> args = _plot_sympify(args)
|
||||
>>> _unpack_args(*args)
|
||||
([sin(x)], [(x, -10, 10)], 'f1', None)
|
||||
|
||||
>>> args = (sin(x**2 + y**2), (x, -2, 2), (y, -3, 3), "f2")
|
||||
>>> args = _plot_sympify(args)
|
||||
>>> _unpack_args(*args)
|
||||
([sin(x**2 + y**2)], [(x, -2, 2), (y, -3, 3)], 'f2', None)
|
||||
|
||||
>>> args = (sin(x + y), cos(x - y), x + y, (x, -2, 2), (y, -3, 3), "f3")
|
||||
>>> args = _plot_sympify(args)
|
||||
>>> _unpack_args(*args)
|
||||
([sin(x + y), cos(x - y), x + y], [(x, -2, 2), (y, -3, 3)], 'f3', None)
|
||||
"""
|
||||
ranges = [t for t in args if _is_range(t)]
|
||||
labels = [t for t in args if isinstance(t, str)]
|
||||
label = None if not labels else labels[0]
|
||||
rendering_kw = [t for t in args if isinstance(t, dict)]
|
||||
rendering_kw = None if not rendering_kw else rendering_kw[0]
|
||||
# NOTE: why None? because args might have been preprocessed by
|
||||
# _check_arguments, so None might represent the rendering_kw
|
||||
results = [not (_is_range(a) or isinstance(a, (str, dict)) or (a is None)) for a in args]
|
||||
exprs = [a for a, b in zip(args, results) if b]
|
||||
return exprs, ranges, label, rendering_kw
|
||||
|
||||
|
||||
def _check_arguments(args, nexpr, npar, **kwargs):
|
||||
"""Checks the arguments and converts into tuples of the
|
||||
form (exprs, ranges, label, rendering_kw).
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
args
|
||||
The arguments provided to the plot functions
|
||||
nexpr
|
||||
The number of sub-expression forming an expression to be plotted.
|
||||
For example:
|
||||
nexpr=1 for plot.
|
||||
nexpr=2 for plot_parametric: a curve is represented by a tuple of two
|
||||
elements.
|
||||
nexpr=1 for plot3d.
|
||||
nexpr=3 for plot3d_parametric_line: a curve is represented by a tuple
|
||||
of three elements.
|
||||
npar
|
||||
The number of free symbols required by the plot functions. For example,
|
||||
npar=1 for plot, npar=2 for plot3d, ...
|
||||
**kwargs :
|
||||
keyword arguments passed to the plotting function. It will be used to
|
||||
verify if ``params`` has ben provided.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
.. plot::
|
||||
:context: reset
|
||||
:format: doctest
|
||||
:include-source: True
|
||||
|
||||
>>> from sympy import cos, sin, symbols
|
||||
>>> from sympy.plotting.plot import _check_arguments
|
||||
>>> x = symbols('x')
|
||||
>>> _check_arguments([cos(x), sin(x)], 2, 1)
|
||||
[(cos(x), sin(x), (x, -10, 10), None, None)]
|
||||
|
||||
>>> _check_arguments([cos(x), sin(x), "test"], 2, 1)
|
||||
[(cos(x), sin(x), (x, -10, 10), 'test', None)]
|
||||
|
||||
>>> _check_arguments([cos(x), sin(x), "test", {"a": 0, "b": 1}], 2, 1)
|
||||
[(cos(x), sin(x), (x, -10, 10), 'test', {'a': 0, 'b': 1})]
|
||||
|
||||
>>> _check_arguments([x, x**2], 1, 1)
|
||||
[(x, (x, -10, 10), None, None), (x**2, (x, -10, 10), None, None)]
|
||||
"""
|
||||
if not args:
|
||||
return []
|
||||
output = []
|
||||
params = kwargs.get("params", None)
|
||||
|
||||
if all(isinstance(a, (Expr, Relational, BooleanFunction)) for a in args[:nexpr]):
|
||||
# In this case, with a single plot command, we are plotting either:
|
||||
# 1. one expression
|
||||
# 2. multiple expressions over the same range
|
||||
|
||||
exprs, ranges, label, rendering_kw = _unpack_args(*args)
|
||||
free_symbols = set().union(*[e.free_symbols for e in exprs])
|
||||
ranges = _create_ranges(exprs, ranges, npar, label, params)
|
||||
|
||||
if nexpr > 1:
|
||||
# in case of plot_parametric or plot3d_parametric_line, there will
|
||||
# be 2 or 3 expressions defining a curve. Group them together.
|
||||
if len(exprs) == nexpr:
|
||||
exprs = (tuple(exprs),)
|
||||
for expr in exprs:
|
||||
# need this if-else to deal with both plot/plot3d and
|
||||
# plot_parametric/plot3d_parametric_line
|
||||
is_expr = isinstance(expr, (Expr, Relational, BooleanFunction))
|
||||
e = (expr,) if is_expr else expr
|
||||
output.append((*e, *ranges, label, rendering_kw))
|
||||
|
||||
else:
|
||||
# In this case, we are plotting multiple expressions, each one with its
|
||||
# range. Each "expression" to be plotted has the following form:
|
||||
# (expr, range, label) where label is optional
|
||||
|
||||
_, ranges, labels, rendering_kw = _unpack_args(*args)
|
||||
labels = [labels] if labels else []
|
||||
|
||||
# number of expressions
|
||||
n = (len(ranges) + len(labels) +
|
||||
(len(rendering_kw) if rendering_kw is not None else 0))
|
||||
new_args = args[:-n] if n > 0 else args
|
||||
|
||||
# at this point, new_args might just be [expr]. But I need it to be
|
||||
# [[expr]] in order to be able to loop over
|
||||
# [expr, range [opt], label [opt]]
|
||||
if not isinstance(new_args[0], (list, tuple, Tuple)):
|
||||
new_args = [new_args]
|
||||
|
||||
# Each arg has the form (expr1, expr2, ..., range1 [optional], ...,
|
||||
# label [optional], rendering_kw [optional])
|
||||
for arg in new_args:
|
||||
# look for "local" range and label. If there is not, use "global".
|
||||
l = [a for a in arg if isinstance(a, str)]
|
||||
if not l:
|
||||
l = labels
|
||||
r = [a for a in arg if _is_range(a)]
|
||||
if not r:
|
||||
r = ranges.copy()
|
||||
rend_kw = [a for a in arg if isinstance(a, dict)]
|
||||
rend_kw = rendering_kw if len(rend_kw) == 0 else rend_kw[0]
|
||||
|
||||
# NOTE: arg = arg[:nexpr] may raise an exception if lambda
|
||||
# functions are used. Execute the following instead:
|
||||
arg = [arg[i] for i in range(nexpr)]
|
||||
free_symbols = set()
|
||||
if all(not callable(a) for a in arg):
|
||||
free_symbols = free_symbols.union(*[a.free_symbols for a in arg])
|
||||
if len(r) != npar:
|
||||
r = _create_ranges(arg, r, npar, "", params)
|
||||
|
||||
label = None if not l else l[0]
|
||||
output.append((*arg, *r, label, rend_kw))
|
||||
return output
|
||||
Reference in New Issue
Block a user