diff --git a/midas/checker/evaluator.py b/midas/checker/evaluator.py new file mode 100644 index 0000000..c21f1c2 --- /dev/null +++ b/midas/checker/evaluator.py @@ -0,0 +1,172 @@ +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import midas.ast.midas as m +from midas.checker.preamble import Preamble +from midas.checker.registry import TypesRegistry +from midas.checker.reporter import FileReporter +from midas.checker.types import Function, Predicate +from midas.lexer.token import TokenType + + +@dataclass(frozen=True, kw_only=True) +class PartialPredicate(Predicate): + scope: dict[str, Any] + + +class Evaluator(m.Expr.Visitor[Any]): + def __init__(self, types: TypesRegistry, reporter: Optional[FileReporter] = None): + self.types: TypesRegistry = types + self.reporter: Optional[FileReporter] = reporter + self.preamble: Preamble = Preamble(self.types) + self.scopes: list[dict[str, Any]] = [{}] + + def evaluate(self, expr: m.Expr) -> Any: + value: Any = expr.accept(self) + if self.reporter is not None: + self.reporter.debug(expr.location, f"Value: {value}") + return value + + def get_value(self, name: str) -> Any: + scope: dict[str, Any] = self.scopes[-1] + return scope[name] + + def set_value(self, name: str, value: Any, force_declare: bool = False): + if not force_declare: + for scope in reversed(self.scopes): + if name in scope: + scope[name] = value + return + self.scopes[-1][name] = value + + def visit_logical_expr(self, expr: m.LogicalExpr) -> Any: + def left(): + return self.evaluate(expr.left) + + def right(): + return self.evaluate(expr.right) + + match expr.operator.type: + case TokenType.AND: + return left() and right() + case _: + raise NotImplementedError + + def visit_binary_expr(self, expr: m.BinaryExpr) -> Any: + left: Any = self.evaluate(expr.left) + right: Any = self.evaluate(expr.right) + match expr.operator.type: + case TokenType.MINUS: + return left - right + case TokenType.STAR: + return left * right + case TokenType.SLASH: + return left / right + case TokenType.GREATER: + return left > right + case TokenType.GREATER_EQUAL: + return left >= right + case TokenType.LESS: + return left < right + case TokenType.LESS_EQUAL: + return left <= right + case TokenType.EQUAL_EQUAL: + return left == right + case TokenType.BANG_EQUAL: + return left != right + case _: + raise NotImplementedError + + def visit_unary_expr(self, expr: m.UnaryExpr) -> Any: + right: Any = self.evaluate(expr.right) + match expr.operator.type: + case TokenType.MINUS: + return -right + case _: + raise NotImplementedError + + def visit_call_expr(self, expr: m.CallExpr) -> Any: + callee: Any = self.evaluate(expr.callee) + args: list[Any] = [self.evaluate(arg) for arg in expr.arguments] + kwargs: dict[str, Any] = { + name: self.evaluate(arg) for name, arg in expr.keywords.items() + } + + match callee: + case Predicate(): + return self._evaluate_predicate(callee, args, kwargs) + case _ if callable(callee): + return callee(*args, **kwargs) + case _: + return NotImplementedError + + def visit_get_expr(self, expr: m.GetExpr) -> Any: + obj: Any = self.evaluate(expr.expr) + return getattr(obj, expr.name.lexeme) + + def visit_variable_expr(self, expr: m.VariableExpr) -> Any: + name: str = expr.name.lexeme + for scope in reversed(self.scopes): + if name in scope: + return scope[name] + + predicate: Optional[Predicate] = self.types.lookup_predicate(name) + if predicate is not None: + if predicate.alias: + return self.evaluate(predicate.body) + return predicate + + glob: Optional[Callable] = self.preamble.get_py_func(name) + if glob is not None: + return glob + raise NameError(f"Unknown variable '{name}'") + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> Any: + return self.evaluate(expr.expr) + + def visit_literal_expr(self, expr: m.LiteralExpr) -> Any: + return expr.value + + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Any: + return self.get_value("_") + + def _evaluate_predicate( + self, predicate: Predicate, args: list[Any], kwargs: dict[str, Any] + ) -> Any: + res: Any = None + if isinstance(predicate, PartialPredicate): + self.scopes.append(predicate.scope) + else: + self.scopes.append({}) + match predicate.type: + case Function(returns=Function() as inner): + self._map_args(predicate.type, args, kwargs) + res = PartialPredicate( + type=inner, + body=predicate.body, + alias=False, + scope=self.scopes[-1], + ) + + case Function(): + self._map_args(predicate.type, args, kwargs) + res = self.evaluate(predicate.body) + + case _: + raise NotImplementedError + self.scopes.pop() + return res + + def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]): + positional: list[Function.Argument] = function.pos_args + function.args + keywords: dict[str, Function.Argument] = { + arg.name: arg for arg in function.args + function.kw_args + } + + for i, arg in enumerate(args): + param: Function.Argument = positional[i] + self.set_value(param.name, arg) + + for name, arg in kwargs.items(): + param: Function.Argument = keywords[name] + self.set_value(param.name, arg)