feat(gen): generate predicate functions
This commit is contained in:
@@ -352,3 +352,6 @@ class TypesRegistry:
|
||||
case _:
|
||||
self.logger.debug(f"Can't get member on {type}")
|
||||
return None
|
||||
|
||||
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
||||
return self._predicates.get(name)
|
||||
|
||||
@@ -38,5 +38,5 @@ def compile(
|
||||
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
|
||||
sys.exit(1)
|
||||
|
||||
generator = Generator(workdir=source_path.parent)
|
||||
generator = Generator(workdir=source_path.parent, types=checker.types)
|
||||
generator.generate(typed_ast, source_path)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import Function, Predicate, Type
|
||||
from midas.lexer.token import TokenType
|
||||
|
||||
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
|
||||
@@ -31,6 +34,97 @@ COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
|
||||
|
||||
|
||||
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
def __init__(self, types: TypesRegistry):
|
||||
self.types: TypesRegistry = types
|
||||
self._id: int = 0
|
||||
self._definitions: list[ast.stmt] = []
|
||||
self._aliases: dict[str, str] = {}
|
||||
|
||||
def get_definitions(self) -> list[ast.stmt]:
|
||||
return self._definitions
|
||||
|
||||
def generate(self, expr: m.Expr) -> ast.expr:
|
||||
match expr:
|
||||
case m.VariableExpr():
|
||||
return expr.accept(self)
|
||||
case _:
|
||||
func = Function(
|
||||
pos_args=[],
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="_",
|
||||
type=self.types.get_type("Any"),
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
kw_args=[],
|
||||
returns=self.types.get_type("bool"),
|
||||
)
|
||||
alias: str = self.make_alias(None)
|
||||
definition: ast.stmt = self.make_definition(
|
||||
alias, Predicate(type=func, body=expr)
|
||||
)
|
||||
self._definitions.append(definition)
|
||||
return ast.Name(id=alias)
|
||||
|
||||
def make_alias(self, name: Optional[str]) -> str:
|
||||
suffix: str = f"_{name}" if name is not None else ""
|
||||
alias: str = f"__midas_p{self._id}{suffix}__"
|
||||
self._id += 1
|
||||
return alias
|
||||
|
||||
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
|
||||
body: list[ast.stmt] = [ast.Return(value=predicate.body.accept(self))]
|
||||
return self.make_func(name, body, predicate.type)
|
||||
|
||||
def make_args(self, func: Function) -> ast.arguments:
|
||||
return ast.arguments(
|
||||
posonlyargs=[ast.arg(arg=arg.name) for arg in func.pos_args],
|
||||
args=[ast.arg(arg=arg.name) for arg in func.args],
|
||||
kwonlyargs=[ast.arg(arg=arg.name) for arg in func.kw_args],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
)
|
||||
|
||||
def make_func(
|
||||
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
|
||||
) -> ast.stmt:
|
||||
match type:
|
||||
case Function(returns=Function()):
|
||||
inner_name: str = f"inner{level}"
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(type),
|
||||
body=[
|
||||
self.make_func(inner_name, inner_body, type.returns, level + 1),
|
||||
ast.Return(value=ast.Name(id=inner_name)),
|
||||
],
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
case Function():
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(type),
|
||||
body=inner_body,
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Expected function, got {type}")
|
||||
|
||||
def get_predicate(self, name: str) -> Optional[ast.expr]:
|
||||
if name not in self._aliases:
|
||||
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||
if predicate is None:
|
||||
return None
|
||||
alias: str = self.make_alias(name)
|
||||
self._aliases[name] = alias
|
||||
self._definitions.append(self.make_definition(alias, predicate))
|
||||
|
||||
return ast.Name(id=self._aliases[name])
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
|
||||
return ast.BoolOp(
|
||||
op=LOGICAL_OPERATORS[expr.operator.type](),
|
||||
@@ -79,8 +173,10 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
|
||||
# TODO: lookup predicate
|
||||
return ast.Name(id=expr.name.lexeme)
|
||||
name: str = expr.name.lexeme
|
||||
if (p := self.get_predicate(name)) is not None:
|
||||
return p
|
||||
return ast.Name(id=name)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
|
||||
return expr.accept(self)
|
||||
|
||||
@@ -8,6 +8,7 @@ import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
@@ -35,7 +36,7 @@ class Scope:
|
||||
|
||||
|
||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def __init__(self, workdir: Path) -> None:
|
||||
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
||||
self.workdir: Path = workdir.resolve()
|
||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||
if self.build_dir.exists():
|
||||
@@ -48,15 +49,18 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
judgements=[],
|
||||
)
|
||||
self._alias_count: int = 0
|
||||
self._predicate_count: int = 0
|
||||
self._scopes: list[Scope] = []
|
||||
|
||||
self._constraint_generator: ConstraintGenerator = ConstraintGenerator()
|
||||
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||
|
||||
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
||||
self.rel_src_path = src_path.relative_to(self.workdir)
|
||||
self._typed_ast = typed_ast
|
||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||
module = ast.Module(body=body, type_ignores=[])
|
||||
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||
module = ast.Module(body=predicates + body, type_ignores=[])
|
||||
module = ast.fix_missing_locations(module)
|
||||
return module
|
||||
|
||||
@@ -253,7 +257,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
return generated
|
||||
|
||||
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
||||
name: str = f"__midas_alias_{self._alias_count}__"
|
||||
name: str = f"__midas_a{self._alias_count}__"
|
||||
alias = ast.Name(id=name)
|
||||
self._alias_count += 1
|
||||
self._scopes[-1].aliases.append(name)
|
||||
@@ -361,9 +365,13 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
def _make_constraint_assert(
|
||||
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
||||
):
|
||||
test: ast.expr = constraint.accept(self._constraint_generator)
|
||||
test_func: ast.expr = self._get_constraint(constraint)
|
||||
self._add_assert(
|
||||
test,
|
||||
ast.Call(
|
||||
func=test_func,
|
||||
args=[expr],
|
||||
keywords=[],
|
||||
),
|
||||
self._make_constraint_assert_message(src_location, expr, constraint),
|
||||
)
|
||||
|
||||
@@ -377,3 +385,12 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
return ast.Constant(
|
||||
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
|
||||
)
|
||||
|
||||
def _get_constraint(self, expr: m.Expr) -> ast.expr:
|
||||
for expr2, constraint in self._constraints:
|
||||
if expr2 == expr:
|
||||
return constraint
|
||||
|
||||
constraint: ast.expr = self._constraint_generator.generate(expr)
|
||||
self._constraints.append((expr, constraint))
|
||||
return constraint
|
||||
|
||||
@@ -340,7 +340,7 @@ class MidasParser(Parser):
|
||||
|
||||
def call(self) -> Expr:
|
||||
expr: Expr = self.reference()
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
while self.match(TokenType.LEFT_PAREN):
|
||||
expr = self.finish_call(expr)
|
||||
return expr
|
||||
|
||||
|
||||
Reference in New Issue
Block a user