Files
midas/midas/checker/midas.py

438 lines
15 KiB
Python

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import midas.ast.midas as m
from midas.ast.location import Location
from midas.checker.builtins import define_builtins
from midas.checker.dispatcher import CallDispatcher, CallResult
from midas.checker.environment import Environment
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
Function,
GenericType,
Predicate,
Type,
TypeVar,
UnknownType,
)
from midas.checker.variance import VarianceInferrer
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
@dataclass(frozen=True, kw_only=True)
class TypedParamSpec:
pos: list[Function.Argument]
mixed: list[Function.Argument]
kw: list[Function.Argument]
class ReturnException(Exception):
pass
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: m.Expr
type: Type
argument: Function.Argument
@dataclass(frozen=True, kw_only=True)
class OverloadCandidate:
function: Function
mapped: list[MappedArgument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
self.logger: logging.Logger = logging.getLogger("MidasTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr](
self.types, self.reporter
)
self._local_variables: dict[str, TypeVar] = {}
self._predicate_params: dict[str, Type] = {}
self._current_name: Optional[str] = None
define_builtins(self.types)
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
self.process(builtins_path.read_text(), str(builtins_path))
self._bool: Type = self.get_type("bool")
self._preamble: Environment = Preamble(self.types)
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
self.dispatcher.set_reporter(reporter)
def process(self, source: str, path: Optional[str]):
reporter: FileReporter = self.reporter.for_file(path)
self.set_reporter(reporter)
lexer: MidasLexer = MidasLexer(source)
tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for error in parser.errors:
self.reporter.error(error.token.get_location(), error.message)
self.resolve(stmts)
def type_of(self, expr: m.Expr) -> Type:
type: Type = expr.accept(self)
return type
def get_type(self, name: str) -> Type:
"""Get a type from its name
Args:
name (str): the name of the type
Raises:
NameError: if the type is not defined
Returns:
Type: the type
"""
if name in self._local_variables:
return self._local_variables[name]
return self.types.get_type(name)
def get_variable(self, name: str) -> Type:
if name in self._predicate_params:
return self._predicate_params[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is not None:
return predicate.type
global_: Optional[Type] = self._preamble.get(name)
if global_ is not None:
return global_
raise NameError(f"Unknown variable '{name}'")
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements
Args:
stmts (list[m.Stmt]): the statements
"""
for stmt in stmts:
stmt.accept(self)
for name, type in self.types._types.items():
if isinstance(type, GenericType):
inferrer = VarianceInferrer(self.types)
self.types._types[name] = inferrer.infer(type)
def assert_bool(self, expr: m.Expr):
type: Type = self.type_of(expr)
if not self.types.is_subtype(type, self._bool):
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
params: list[TypeVar] = self._resolve_type_params(stmt.params)
type: Type = stmt.type.accept(self)
if len(params) != 0:
type = GenericType(name=name, params=params, body=type)
else:
type = DerivedType(name=name, type=type)
self.types.define_type(name, type)
self._local_variables.clear()
self._current_name = None
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
type: Type = stmt.type.accept(self)
self.types.define_type(name, type)
self._current_name = None
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._resolve_type_params(stmt.params)
base_name: str = stmt.name.lexeme
try:
_ = self.get_type(base_name)
except NameError:
self.reporter.error(stmt.name.get_location(), f"Unknown type '{base_name}'")
for member in stmt.members:
member_type: Type = member.type.accept(self)
self.types.define_member(
base_name,
member.name.lexeme,
member_type,
member.kind,
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
for spec in stmt.params:
for param in spec.mixed:
assert param.name is not None
self._predicate_params[param.name.lexeme] = param.type.accept(self)
type: Type = self.type_of(stmt.body)
params: list[TypedParamSpec] = [
self._visit_param_spec(spec) for spec in stmt.params
]
if not self._is_valid_predicate(type):
self.reporter.error(
stmt.body.location,
f"Predicate function body must evaluate to a boolean, got {type}",
)
if len(params) != 0:
type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
)
self._predicate_params = {}
self.types.define_predicate(
stmt.name.lexeme,
Predicate(
type=type,
body=stmt.body,
alias=len(params) == 0,
),
)
def _is_valid_predicate(self, body: Type) -> bool:
match body:
case Function(returns=returns):
return self._is_valid_predicate(returns)
case _ if self.types.is_subtype(body, self._bool):
return True
case _:
return False
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
self.assert_bool(expr.left)
self.assert_bool(expr.right)
return self._bool
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator.lexeme}"
)
return UnknownType()
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def _visit_binary_expr(
self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str
) -> Type:
left: Type = self.type_of(left_expr)
right: Type = self.type_of(right_expr)
operation: Optional[Type] = self.types.lookup_member(left, method)
if operation is None:
self.reporter.error(
location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
result: CallResult = self.dispatcher.get_result(
location=location,
callee=operation,
positional=[(right_expr, right)],
keywords={},
)
return result.result
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator.lexeme}"
)
return UnknownType()
operand: Type = self.type_of(expr.right)
operation: Optional[Type] = self.types.lookup_member(operand, method)
if operation is None:
self.reporter.error(
expr.location,
f"Undefined operation {method} for {operand}",
)
return UnknownType()
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=operation,
positional=[],
keywords={},
)
return result.result
def visit_call_expr(self, expr: m.CallExpr) -> Type:
callee: Type = expr.callee.accept(self)
positional: list[tuple[m.Expr, Type]] = [
(arg, self.type_of(arg)) for arg in expr.arguments
]
keywords: dict[str, tuple[m.Expr, Type]] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=callee,
positional=positional,
keywords=keywords,
)
return result.result
def visit_get_expr(self, expr: m.GetExpr) -> Type:
object: Type = expr.expr.accept(self)
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
if member is None:
self.reporter.error(
expr.location, f"Unknown member '{expr.name.lexeme}' of {object}"
)
return UnknownType()
return member
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
return self.get_variable(expr.name.lexeme)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
match expr.value:
case bool(): # Must be before int
return self.types.get_type("bool")
case int():
return self.types.get_type("int")
case float():
return self.types.get_type("float")
case str():
return self.types.get_type("str")
case _:
self.reporter.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
return self.get_variable("_")
def visit_named_type(self, type: m.NamedType) -> Type:
name: str = type.name.lexeme
try:
return self.get_type(name)
except NameError:
msg: str = f"Undefined type {name}"
if self._current_name == name:
msg += ". Recursive types are not supported, use an extend block"
self.reporter.error(type.name.get_location(), msg)
return UnknownType()
def visit_generic_type(self, type: m.GenericType) -> Type:
type_: Type = type.type.accept(self)
args: list[Type] = [arg.accept(self) for arg in type.args]
try:
return self.types.apply_generic(type_, args)
except Exception as e:
self.reporter.error(type.location, f"Cannot apply generic type: {e}")
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
return ConstraintType(
type=type.type.accept(self),
constraint=type.constraint,
)
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
return ComplexType(
members={
member.name.lexeme: member.type.accept(self) for member in type.members
}
)
def visit_extension_type(self, type: m.ExtensionType) -> Type:
return ExtensionType(
base=type.base.accept(self),
extension=self.visit_complex_type(type.extension),
)
def visit_function_type(self, type: m.FunctionType) -> Type:
params: TypedParamSpec = self._visit_param_spec(type.params)
return Function(
pos_args=params.pos,
args=params.mixed,
kw_args=params.kw,
returns=type.returns.accept(self),
)
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
n_pos: int = len(spec.pos)
n_mixed: int = len(spec.mixed)
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
return Function.Argument(
pos=i,
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
return TypedParamSpec(
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
)
def visit_frame_type(self, type: m.FrameType) -> Type:
def process_column(i: int, col: m.FrameType.Column) -> DataFrameType.Column:
return DataFrameType.Column(
index=i,
name=col.name.lexeme,
type=ColumnType(type=col.type.accept(self)),
)
return DataFrameType(
columns=[process_column(i, col) for i, col in enumerate(type.columns)]
)
def _resolve_type_params(self, params: list[m.TypeParam]):
vars: list[TypeVar] = []
for param in params:
name: str = param.name.lexeme
bound: Optional[Type] = None
if param.bound is not None:
bound = param.bound.accept(self)
var = TypeVar(name=name, bound=bound)
self._local_variables[name] = var
vars.append(var)
return vars