diff --git a/midas/checker/types.py b/midas/checker/types.py index 71765ad..ea94b45 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -10,15 +10,8 @@ class BaseType: @dataclass(frozen=True, kw_only=True) class SimpleType: - base: BaseType - - -@dataclass(frozen=True, kw_only=True) -class Operation: - left: Type - operator: str - right: Type - result: Type + name: str + base: BaseType | SimpleType @dataclass(frozen=True, kw_only=True) diff --git a/midas/resolver/midas.py b/midas/resolver/midas.py new file mode 100644 index 0000000..66c3152 --- /dev/null +++ b/midas/resolver/midas.py @@ -0,0 +1,114 @@ +from typing import Optional + +import midas.ast.midas as m +from midas.checker.types import BaseType, SimpleType, Type + + +class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]): + def __init__(self) -> None: + self._types: dict[str, Type] = {} + self._operations: dict[tuple[Type, str, Type], Type] = {} + + self._define_builtin() + + def get_type(self, name: str) -> Type: + type: Optional[Type] = self._types.get(name) + if type is None: + raise NameError(f"Undefined type {name}") + return type + + def get_operation_result(self, left: Type, operator: str, right: Type) -> Type: + operation: tuple[Type, str, Type] = (left, operator, right) + result: Optional[Type] = self._operations.get(operation) + if result is None: + raise ValueError( + f"Undefined operation {operator} between {left} and {right}" + ) + return result + + def _define_builtin(self): + self.define_type("int", BaseType(name="int")) + self.define_type("float", BaseType(name="float")) + self.define_type("bool", BaseType(name="bool")) + self.define_operation( + left=self.get_type("int"), + operator="__add__", + right=self.get_type("int"), + result=self.get_type("int"), + ) + + def define_type(self, name: str, type: Type) -> Type: + if name in self._types: + raise ValueError(f"Type {name} already defined") + self._types[name] = type + return type + + def define_operation(self, left: Type, operator: str, right: Type, result: Type): + operation: tuple[Type, str, Type] = (left, operator, right) + if operation in self._operations: + raise ValueError( + f"Operation {operator} already defined between {left} and {right}" + ) + self._operations[operation] = result + + def resolve(self, stmts: list[m.Stmt]): + for stmt in stmts: + stmt.accept(self) + + def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None: + # TODO generics, optional, constraint + base: Type = self.get_type(stmt.base.name.lexeme) + match base: + case BaseType() | SimpleType(): + type = SimpleType( + name=stmt.name.lexeme, + base=base, + ) + self.define_type(type.name, type) + case _: + raise TypeError(f"Invalid base {base} for simple type") + + def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None: ... + + def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... + + def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: + base: Type = stmt.type.accept(self) + for op in stmt.operations: + right: Type = op.operand.accept(self) + result: Type = op.result.accept(self) + self.define_operation( + left=base, + operator=op.name.lexeme, + right=right, + result=result, + ) + + def visit_op_stmt(self, stmt: m.OpStmt) -> None: ... + + def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... + + def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> Type: + return self.get_type(expr.name.lexeme) + + def visit_logical_expr(self, expr: m.LogicalExpr) -> Type: ... + + def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: ... + + def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: ... + + def visit_get_expr(self, expr: m.GetExpr) -> Type: ... + + def visit_variable_expr(self, expr: m.VariableExpr) -> Type: ... + + def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type: + return expr.expr.accept(self) + + def visit_literal_expr(self, expr: m.LiteralExpr) -> Type: ... + + def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type: ... + + def visit_template_expr(self, expr: m.TemplateExpr) -> Type: ... + + def visit_type_expr(self, expr: m.TypeExpr) -> Type: + return self.get_type(expr.name.lexeme)