""" Sympy primitives for representing atoms of ga expressions """

from typing import Union

from sympy import Symbol, AtomicExpr, S, Basic, sympify, MatrixExpr
from sympy import Determinant as _Determinant
from sympy.core import numbers
from sympy.core.function import AppliedUndef, UndefinedFunction
from sympy.printing.pretty.stringpict import prettyForm, stringPict
from sympy.printing.pretty.pretty_symbology import U

__all__ = [

def _all_same(items):
    return all(x == items[0] for x in items)

[docs]class BasisVectorSymbol(Symbol): """ A symbol representing a basis vector """ is_commutative = False def _latex(self, print_obj): try: return print_obj._print_Symbol(self, style="bold") except TypeError: # too old a sympy version for `style=` return r"\mathbf{{{}}}".format(print_obj._print_Symbol(self))
class _GradedSymbol(AtomicExpr): """ Base class for all graded symbols Constructing this from a single symbol returns that symbol itself. Constructing from no symbols returns the scalar `S.One`. This may change in future. """ # the scalar isn't commutative, but __new__ ensures we do not ever create # this type of objects for scalars is_commutative = False def __new__(cls, *args: BasisVectorSymbol) -> Union[ numbers.One, BasisVectorSymbol, "_GradedSymbol" ]: if len(args) == 0: return S.One elif len(args) == 1: return args[0] else: return super().__new__(cls, *args) class _JoinedPrinterMixin(Basic): """ Helper class to print `Basic.args` joined by symbol. Subclasses must populate `_op_sym` and `_op_sym_latex` """ def _sympystr(self, printer): return self._op_sympystr.join( printer._print(v) for v in self.args ) def _pretty(self, printer): ret = [] for i, v in enumerate(self.args): if i != 0: ret.append(self._op_pretty) ret.append(printer._print(v)) return prettyForm(**ret)) def _latex(self, printer): return self._op_latex.join( printer._print(v) for v in self.args )
[docs]class BasisBaseSymbol(_GradedSymbol, _JoinedPrinterMixin): r""" A basis base in a non-orthogonal algebra, such as :math:`e_1 e_2` """ _op_sympystr = '*' _op_pretty = prettyForm('*') _op_latex = ''
[docs]class BasisBladeSymbol(_GradedSymbol, _JoinedPrinterMixin): r""" A basis blade such as :math:`e_1 \wedge e_2` """ _op_sympystr = '^' _op_pretty = prettyForm('^') _op_latex = r'\wedge '
[docs]class BasisBladeNoWedgeSymbol(BasisBladeSymbol): r""" A basis blade with shortened rendering such as :math:`e_{12}` """ def _split_name(self): sub_str = [] root_str = [] for basis_vec in self.args: split_lst ='_') if len(split_lst) != 2: raise ValueError('Incompatible basis vector {} for wedgeless printing'.format(basis_vec)) else: sub_str.append(split_lst[1]) root_str.append(split_lst[0]) if _all_same(root_str): return root_str[0], ''.join(sub_str) else: raise ValueError('No unique root symbol to use for wedgeless printing') def __common_printer(self, printer): # print as if we were a basis vector root, sub = self._split_name() return printer._print(BasisVectorSymbol("{}_{}".format(root, sub))) _sympystr = _pretty = _latex = __common_printer
[docs]class DotProductSymbol(AtomicExpr): """ A symbol used to represent a dot product, like :class:`sympy.DotProduct` """ is_real = True def _sympystr(self, printer): a, b = self.args return "({}.{})".format(printer._print(a), printer._print(b)) def _latex(self, printer): a, b = self.args return r"\left ({}\cdot {}\right ) ".format(printer._print(a), printer._print(b)) def _pretty(self, printer): a, b = self.args pform = prettyForm(* printer._print(a), printer._print(U('DOT OPERATOR')), printer._print(b), )) return prettyForm(*pform.parens())
class MatrixFunction(UndefinedFunction): """ Like a MatrixSymbol, but for functions. """ def __new__(mcl, name, m, n, **kwargs): cls = super().__new__(mcl, name, (AppliedUndef, MatrixExpr), {}, **kwargs) cls.shape = sympify(n, strict=True), sympify(n, strict=True) return cls # workaround until sympy/sympy#19354 is merged if _Determinant.is_commutative is not True: class Determinant(_Determinant): is_commutative = True else: Determinant = _Determinant