# -*- coding:utf-8 -*-
#
# Copyright (C) 2019-2020, Maximilian Köhl <mkoehl@cs.uni-saarland.de>
from __future__ import annotations
import typing as t
import abc
import dataclasses
import enum
import fractions
import math
import numbers
from . import context, errors, operators, properties, types
if t.TYPE_CHECKING:
from . import distributions
[docs]class Expression(properties.Property, abc.ABC):
@property
@abc.abstractmethod
def children(self) -> t.Sequence[Expression]:
raise NotImplementedError()
[docs] @abc.abstractmethod
def is_constant_in(self, scope: context.Scope) -> bool:
"""Returns `True` only if the expression has a constant value in the given scope.
Arguments:
scope: The scope to use.
"""
raise NotImplementedError()
@abc.abstractmethod
def infer_type(self, scope: context.Scope) -> types.Type:
raise NotImplementedError()
@property
def traverse(self) -> t.Iterable[Expression]:
yield self
for child in self.children:
yield from child.traverse
@property
def subexpressions(self) -> t.AbstractSet[Expression]:
return frozenset(self.traverse)
@property
def is_sampling_free(self) -> bool:
return all(not isinstance(e, Sample) for e in self.traverse)
@property
def identifiers(self) -> t.AbstractSet[Identifier]:
return frozenset(
child for child in self.traverse if isinstance(child, Identifier)
)
def lor(self, other: Expression) -> Expression:
if not isinstance(other, Expression):
return NotImplemented
return Boolean(operators.BooleanOperator.OR, self, other)
def lnot(self) -> Expression:
return lnot(self)
def add(self, other: MaybeExpression) -> Expression:
return Arithmetic(operators.ArithmeticOperator.ADD, self, convert(other))
def radd(self, other: MaybeExpression) -> Expression:
return Arithmetic(operators.ArithmeticOperator.ADD, convert(other), self)
def sub(self, other: MaybeExpression) -> Expression:
return Arithmetic(operators.ArithmeticOperator.SUB, self, convert(other))
def rsub(self, other: MaybeExpression) -> Expression:
return Arithmetic(operators.ArithmeticOperator.SUB, convert(other), self)
def mul(self, other: MaybeExpression) -> Expression:
return Arithmetic(operators.ArithmeticOperator.MUL, self, convert(other))
def rmul(self, other: MaybeExpression) -> Expression:
return Arithmetic(operators.ArithmeticOperator.MUL, convert(other), self)
def mod(self, other: MaybeExpression) -> Expression:
return Arithmetic(operators.ArithmeticOperator.MOD, self, convert(other))
def rmod(self, other: MaybeExpression) -> Expression:
return Arithmetic(operators.ArithmeticOperator.MOD, convert(other), self)
def floordiv(self, other: MaybeExpression) -> Expression:
return Arithmetic(operators.ArithmeticOperator.FLOOR_DIV, convert(other), self)
def eq(self, other: Expression) -> Expression:
return Equality(operators.EqualityOperator.EQ, self, other)
def neq(self, other: Expression) -> Expression:
return Equality(operators.EqualityOperator.NEQ, self, other)
def lt(self, other: Expression) -> Expression:
return Comparison(operators.ComparisonOperator.LT, self, other)
def le(self, other: Expression) -> Expression:
return Comparison(operators.ComparisonOperator.LE, self, other)
def ge(self, other: Expression) -> Expression:
return Comparison(operators.ComparisonOperator.GE, self, other)
def gt(self, other: Expression) -> Expression:
return Comparison(operators.ComparisonOperator.GT, self, other)
def land(self, other: Expression) -> Expression:
return land(self, other)
class _Leaf(Expression):
@property
def children(self) -> t.Sequence[Expression]:
return ()
class Constant(_Leaf, abc.ABC):
def is_constant_in(self, scope: context.Scope) -> bool:
return True
@dataclasses.dataclass(frozen=True)
class BooleanConstant(Constant):
boolean: bool
def infer_type(self, scope: context.Scope) -> types.Type:
return types.BOOL
class NumericConstant(Constant, abc.ABC):
@property
@abc.abstractmethod
def as_float(self) -> float:
raise NotImplementedError()
TRUE = BooleanConstant(True)
FALSE = BooleanConstant(False)
@dataclasses.dataclass(frozen=True)
class IntegerConstant(NumericConstant):
integer: int
def infer_type(self, scope: context.Scope) -> types.Type:
return types.INT
@property
def as_float(self) -> float:
# TODO: emit a warning to keep track of imprecisions
return float(self.integer)
_NAMED_REAL_MAP: t.Dict[str, NamedReal] = {}
class NamedReal(enum.Enum):
PI = "π", math.pi
E = "e", math.e
symbol: str
float_value: float
def __init__(self, symbol: str, float_value: float) -> None:
self.symbol = symbol
self.float_value = float_value
_NAMED_REAL_MAP[symbol] = self
Real = t.Union[NamedReal, fractions.Fraction]
@dataclasses.dataclass(frozen=True)
class RealConstant(NumericConstant):
real: Real
def infer_type(self, scope: context.Scope) -> types.Type:
return types.REAL
@property
def as_float(self) -> float:
# TODO: emit a warning to keep track of imprecisions
if isinstance(self.real, NamedReal):
return self.real.float_value
return float(self.real)
@dataclasses.dataclass(frozen=True)
class Identifier(_Leaf):
name: str
def is_constant_in(self, scope: context.Scope) -> bool:
return scope.lookup(self.name).is_constant_in(scope)
def infer_type(self, scope: context.Scope) -> types.Type:
return scope.lookup(self.name).typ
# XXX: this class should be abstract, however, then it would not type-check
# https://github.com/python/mypy/issues/5374
@dataclasses.dataclass(frozen=True)
class BinaryExpression(Expression):
operator: operators.BinaryOperator
left: Expression
right: Expression
@property
def children(self) -> t.Sequence[Expression]:
return self.left, self.right
def is_constant_in(self, scope: context.Scope) -> bool:
return self.left.is_constant_in(scope) and self.right.is_constant_in(scope)
# XXX: this method shall be implemented by all subclasses
def infer_type(self, scope: context.Scope) -> types.Type:
raise NotImplementedError()
class Boolean(BinaryExpression):
operator: operators.BooleanOperator
def infer_type(self, scope: context.Scope) -> types.Type:
left_type = scope.get_type(self.left)
if left_type != types.BOOL:
raise errors.InvalidTypeError(f"expected types.BOOL but got {left_type}")
right_type = scope.get_type(self.right)
if right_type != types.BOOL:
raise errors.InvalidTypeError(f"expected types.BOOL but got {right_type}")
return types.BOOL
_REAL_OPERATORS = {
operators.ArithmeticOperator.REAL_DIV,
operators.ArithmeticOperator.LOG,
operators.ArithmeticOperator.POW,
}
class Arithmetic(BinaryExpression):
operator: operators.ArithmeticOperator
def infer_type(self, scope: context.Scope) -> types.Type:
left_type = scope.get_type(self.left)
right_type = scope.get_type(self.right)
if not left_type.is_numeric or not right_type.is_numeric:
raise errors.InvalidTypeError(
"operands of arithmetic expressions must have a numeric type"
)
is_int = (
types.INT.is_assignable_from(left_type)
and types.INT.is_assignable_from(right_type)
and self.operator not in _REAL_OPERATORS
)
if is_int:
return types.INT
else:
return types.REAL
class Equality(BinaryExpression):
operator: operators.EqualityOperator
def get_common_type(self, scope: context.Scope) -> types.Type:
left_type = scope.get_type(self.left)
right_type = scope.get_type(self.right)
if left_type.is_assignable_from(right_type):
return left_type
elif right_type.is_assignable_from(left_type):
return right_type
raise AssertionError(
"type-inference should ensure that some of the above is true"
)
def infer_type(self, scope: context.Scope) -> types.Type:
left_type = scope.get_type(self.left)
right_type = scope.get_type(self.right)
# XXX: JANI specifies that “left and right must be assignable to some common type”
if not left_type.is_assignable_from(
right_type
) and not right_type.is_assignable_from(left_type):
raise errors.InvalidTypeError(
"invalid combination of type for equality comparison"
)
return types.BOOL
class Comparison(BinaryExpression):
operator: operators.ComparisonOperator
def infer_type(self, scope: context.Scope) -> types.Type:
left_type = scope.get_type(self.left)
if not left_type.is_numeric:
raise errors.InvalidTypeError(f"expected numeric type but got {left_type}")
right_type = scope.get_type(self.right)
if not right_type.is_numeric:
raise errors.InvalidTypeError(f"expected numeric type but got {right_type}")
return types.BOOL
@dataclasses.dataclass(frozen=True)
class Conditional(Expression):
condition: Expression
consequence: Expression
alternative: Expression
@property
def children(self) -> t.Sequence[Expression]:
return self.condition, self.consequence, self.alternative
def is_constant_in(self, scope: context.Scope) -> bool:
return (
self.condition.is_constant_in(scope)
and self.consequence.is_constant_in(scope)
and self.alternative.is_constant_in(scope)
)
def infer_type(self, scope: context.Scope) -> types.Type:
condition_type = scope.get_type(self.condition)
if condition_type != types.BOOL:
raise errors.InvalidTypeError(
f"expected `types.BOOL` but got `{condition_type}`"
)
consequence_type = scope.get_type(self.consequence)
alternative_type = scope.get_type(self.alternative)
if consequence_type.is_assignable_from(alternative_type):
return consequence_type
elif alternative_type.is_assignable_from(consequence_type):
return alternative_type
else:
raise errors.InvalidTypeError(
"invalid combination of consequence and alternative types"
)
# XXX: this class should be abstract, however, then it would not type-check
# https://github.com/python/mypy/issues/5374
@dataclasses.dataclass(frozen=True)
class UnaryExpression(Expression):
operator: operators.UnaryOperator
operand: Expression
@property
def children(self) -> t.Sequence[Expression]:
return (self.operand,)
def is_constant_in(self, scope: context.Scope) -> bool:
return self.operand.is_constant_in(scope)
# XXX: this method shall be implemented by all subclasses
def infer_type(self, scope: context.Scope) -> types.Type:
raise NotImplementedError()
class Round(UnaryExpression):
operator: operators.RoundOperator
def infer_type(self, scope: context.Scope) -> types.Type:
operand_type = scope.get_type(self.operand)
if not operand_type.is_numeric:
raise errors.InvalidTypeError(
f"expected a numeric type but got {operand_type}"
)
return types.INT
class Not(UnaryExpression):
operator: operators.NotOperator
def infer_type(self, scope: context.Scope) -> types.Type:
operand_type = scope.get_type(self.operand)
if operand_type != types.BOOL:
raise errors.InvalidTypeError(
f"expected `types.BOOL` but got {operand_type}"
)
return types.BOOL
@dataclasses.dataclass(frozen=True)
class Sample(Expression):
distribution: distributions.NamedDistribution
arguments: t.Sequence[Expression]
def __post_init__(self) -> None:
if len(self.arguments) != len(self.distribution.parameter_types):
raise ValueError("parameter and arguments arity mismatch")
@property
def children(self) -> t.Sequence[Expression]:
return self.arguments
def is_constant_in(self, scope: context.Scope) -> bool:
return False
def infer_type(self, scope: context.Scope) -> types.Type:
# we already know that the arity of the parameters and arguments match
for argument, parameter_type in zip(
self.arguments, self.distribution.parameter_types
):
argument_type = scope.get_type(argument)
if not parameter_type.is_assignable_from(argument_type):
raise errors.InvalidTypeError(
f"parameter type `{parameter_type}` is not assignable "
f"from argument type `{argument_type}`"
)
return self.distribution.result_type
@dataclasses.dataclass(frozen=True)
class Selection(Expression):
name: str
condition: Expression
def infer_type(self, scope: context.Scope) -> types.Type:
condition_type = scope.get_type(self.condition)
if condition_type != types.BOOL:
raise errors.InvalidTypeError("condition must have type `types.BOOL`")
declaration = scope.lookup(self.name)
assert isinstance(declaration, context.VariableDeclaration)
return declaration.typ
def is_constant_in(self, scope: context.Scope) -> bool:
return False
@property
def children(self) -> t.Sequence[Expression]:
return (self.condition,)
@dataclasses.dataclass(frozen=True)
class Derivative(Expression):
identifier: str
def infer_type(self, scope: context.Scope) -> types.Type:
return types.REAL
def is_constant_in(self, scope: context.Scope) -> bool:
return False
@property
def children(self) -> t.Sequence[Expression]:
return ()
def ite(
condition: Expression, consequence: Expression, alternative: Expression
) -> Expression:
return Conditional(condition, consequence, alternative)
PythonRealString = t.Literal["π", "e"]
PythonReal = t.Union[numbers.Real, float, PythonRealString, NamedReal]
PythonNumeric = t.Union[int, PythonReal]
PythonValue = t.Union[bool, PythonNumeric]
class ConversionError(ValueError):
pass
def const(value: PythonValue) -> Constant:
if isinstance(value, bool):
return BooleanConstant(value)
if isinstance(value, int):
return IntegerConstant(value)
elif isinstance(value, (fractions.Fraction, float)):
return RealConstant(fractions.Fraction(value))
elif isinstance(value, str):
return RealConstant(_NAMED_REAL_MAP[value])
raise ConversionError(f"unable to convert Python value {value!r} to Momba value")
MaybeExpression = t.Union[PythonValue, Expression]
def convert(value: MaybeExpression) -> Expression:
if isinstance(value, Expression):
return value
return const(value)
BinaryConstructor = t.Callable[[Expression, Expression], Expression]
def lor(*expressions: Expression) -> Expression:
if len(expressions) == 2:
return Boolean(operators.BooleanOperator.OR, expressions[0], expressions[1])
result = convert(False)
for disjunct in expressions:
result = Boolean(operators.BooleanOperator.OR, result, disjunct)
return result
def land(*expressions: Expression) -> Expression:
if len(expressions) == 2:
return Boolean(operators.BooleanOperator.AND, expressions[0], expressions[1])
result = convert(True)
for conjunct in expressions:
result = Boolean(operators.BooleanOperator.AND, result, conjunct)
return result
def xor(left: Expression, right: Expression) -> Expression:
return Boolean(operators.BooleanOperator.XOR, left, right)
def implies(left: Expression, right: Expression) -> Expression:
return Boolean(operators.BooleanOperator.IMPLY, left, right)
def equiv(left: Expression, right: Expression) -> Expression:
return Boolean(operators.BooleanOperator.EQUIV, left, right)
def eq(left: Expression, right: Expression) -> BinaryExpression:
return Equality(operators.EqualityOperator.EQ, left, right)
def neq(left: Expression, right: Expression) -> BinaryExpression:
return Equality(operators.EqualityOperator.NEQ, left, right)
def lt(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Comparison(operators.ComparisonOperator.LT, convert(left), convert(right))
def le(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Comparison(operators.ComparisonOperator.LE, convert(left), convert(right))
def ge(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Comparison(operators.ComparisonOperator.GE, convert(left), convert(right))
def gt(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Comparison(operators.ComparisonOperator.GT, convert(left), convert(right))
def add(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Arithmetic(operators.ArithmeticOperator.ADD, convert(left), convert(right))
def sub(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Arithmetic(operators.ArithmeticOperator.SUB, convert(left), convert(right))
def mul(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Arithmetic(operators.ArithmeticOperator.MUL, convert(left), convert(right))
def mod(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Arithmetic(operators.ArithmeticOperator.MOD, convert(left), convert(right))
def minimum(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Arithmetic(operators.ArithmeticOperator.MIN, convert(left), convert(right))
def maximum(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Arithmetic(operators.ArithmeticOperator.MAX, convert(left), convert(right))
def real_div(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Arithmetic(
operators.ArithmeticOperator.REAL_DIV, convert(left), convert(right)
)
def floor_div(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Arithmetic(
operators.ArithmeticOperator.FLOOR_DIV, convert(left), convert(right)
)
def power(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Arithmetic(operators.ArithmeticOperator.POW, convert(left), convert(right))
def log(left: MaybeExpression, right: MaybeExpression) -> BinaryExpression:
return Arithmetic(operators.ArithmeticOperator.LOG, convert(left), convert(right))
UnaryConstructor = t.Callable[[Expression], Expression]
def lnot(operand: Expression) -> Expression:
return Not(operators.NotOperator.NOT, operand)
def floor(operand: MaybeExpression) -> Expression:
return Round(operators.RoundOperator.FLOOR, convert(operand))
def ceil(operand: Expression) -> Expression:
return Round(operators.RoundOperator.CEIL, convert(operand))
def normalize_xor(expr: Expression) -> Expression:
assert isinstance(expr, Boolean) and expr.operator is operators.BooleanOperator.XOR
return lor(land(lnot(expr.left), expr.right), land(expr.right, lnot(expr.left)))
def normalize_equiv(expr: Expression) -> Expression:
assert (
isinstance(expr, Boolean) and expr.operator is operators.BooleanOperator.EQUIV
)
return land(implies(expr.left, expr.right), implies(expr.right, expr.left))
def normalize_floor_div(expr: Expression) -> Expression:
assert (
isinstance(expr, Arithmetic)
and expr.operator is operators.ArithmeticOperator.FLOOR_DIV
)
return floor(real_div(expr.left, expr.right))
logic_not = lnot
logic_or = lor
logic_and = land
logic_xor = xor
logic_implies = implies
logic_equiv = equiv
def logic_rimplies(left: Expression, right: Expression) -> Expression:
return Boolean(operators.BooleanOperator.IMPLY, right, left)
def identifier(name: str) -> Identifier:
return Identifier(name)