Compare commits

...

8 Commits

15 changed files with 344 additions and 37 deletions

View File

@@ -0,0 +1,9 @@
def minimum(x: int, y: int):
if x < y:
return x
else:
return y
a = 15
b = 72
c = minimum(a, b)

View File

@@ -76,6 +76,12 @@ class ReturnStmt:
value: Optional[Expr]
class IfStmt:
test: Expr
body: list[Stmt]
orelse: list[Stmt]
###<

View File

@@ -419,7 +419,14 @@ class PythonAstPrinter(
self._mark_last()
self._print_argument(arg)
self._write_optional_child("returns", stmt.returns, last=True)
self._write_optional_child("returns", stmt.returns)
self._write_line("body", last=True)
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
def _print_argument(self, arg: p.Function.Argument) -> None:
self._write_line("FunctionArgument")
@@ -454,6 +461,26 @@ class PythonAstPrinter(
with self._child_level():
self._write_optional_child("value", stmt.value, last=True)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self._write_line("IfStmt")
with self._child_level():
self._write_line("test")
stmt.test.accept(self)
self._write_line("body")
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
self._write_line("orelse", last=True)
with self._child_level():
for i, else_stmt in enumerate(stmt.orelse):
self._idx = i
if i == len(stmt.orelse) - 1:
self._mark_last()
else_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr")
with self._child_level():

View File

@@ -103,6 +103,9 @@ class Stmt(ABC):
@abstractmethod
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
@abstractmethod
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
@dataclass(frozen=True)
class ExpressionStmt(Stmt):
@@ -164,6 +167,16 @@ class ReturnStmt(Stmt):
return visitor.visit_return_stmt(self)
@dataclass(frozen=True)
class IfStmt(Stmt):
test: Expr
body: list[Stmt]
orelse: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_if_stmt(self)
###############
# Expressions #
###############

View File

@@ -8,7 +8,7 @@ import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment
from midas.checker.operators import OPERATOR_METHODS
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
@@ -85,21 +85,29 @@ class Checker(
"""
return expr.accept(self)
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> None:
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool:
"""Evaluate a sequence of statements
Args:
block (list[p.Stmt]): the statements to evaluate
env (Environment): the environment in which to evaluate
Returns:
bool: whether a return statement is present in the block
"""
previous_env: Environment = self.env
self.env = env
for stmt in block:
returned: bool = False
for i, stmt in enumerate(block):
try:
stmt.accept(self)
except ReturnException:
returned = True
if i < len(block) - 1:
self.warning(block[i + 1].location, "Unreachable statement")
break
self.env = previous_env
return returned
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
"""Type check a sequence of statements and returns diagnostics
@@ -216,11 +224,14 @@ class Checker(
for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type)
self.evaluate_block(stmt.body, env)
returned: bool = self.evaluate_block(stmt.body, env)
inferred_return: Type = UnknownType()
if len(env.return_types) == 1:
inferred_return = list(env.return_types)[0]
elif len(env.return_types) > 1:
if not returned:
env.return_types.append(UnitType())
return_types: set[Type] = set(env.return_types)
if len(return_types) == 1:
inferred_return = list(return_types)[0]
elif len(return_types) > 1:
self.error(
stmt.location,
f"Mixed return types: {env.return_types}",
@@ -276,6 +287,27 @@ class Checker(
self.env.return_types.append(type)
raise ReturnException()
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not evaluated in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
test_type: Type = stmt.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.ctx.get_type("bool"):
self.error(
stmt.test.location, f"If test must be a boolean, got {test_type}"
)
env: Environment = Environment(self.env)
body_returned: bool = self.evaluate_block(stmt.body, env)
else_returned: bool = self.evaluate_block(stmt.orelse, env)
self.env.return_types.extend(env.return_types)
if body_returned and else_returned:
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:
@@ -294,7 +326,23 @@ class Checker(
return UnknownType()
return result
def visit_compare_expr(self, expr: p.CompareExpr) -> Type: ...
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType()
left: Type = self.evaluate(expr.left)
right: Type = self.evaluate(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
self.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...

View File

@@ -148,6 +148,8 @@ class PythonHighlighter(
self.wrap(stmt, "function")
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
self._highlight_function_argument(arg)
for body_stmt in stmt.body:
body_stmt.accept(self)
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
self.wrap(arg, "argument")
@@ -157,9 +159,23 @@ class PythonHighlighter(
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
stmt.type.accept(self)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: ...
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
for target in stmt.targets:
target.accept(self)
stmt.value.accept(self)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: ...
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
self.wrap(stmt, "return")
if stmt.value is not None:
stmt.value.accept(self)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self.wrap(stmt, "if")
stmt.test.accept(self)
for body_stmt in stmt.body:
body_stmt.accept(self)
for else_stmt in stmt.orelse:
else_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
@@ -167,7 +183,13 @@ class PythonHighlighter(
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
def visit_call_expr(self, expr: p.CallExpr) -> None: ...
def visit_call_expr(self, expr: p.CallExpr) -> None:
self.wrap(expr, "call")
expr.callee.accept(self)
for arg in expr.arguments:
arg.accept(self)
for arg in expr.keywords.values():
arg.accept(self)
def visit_get_expr(self, expr: p.GetExpr) -> None: ...

View File

@@ -1,4 +1,6 @@
span {
--opacity: 0.4;
&.error {
--col: 255, 0, 0;
}

View File

@@ -3,7 +3,7 @@ import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, TextIO
from typing import Optional, TextIO, get_args
import click
@@ -13,7 +13,13 @@ from midas.ast.location import Location
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic
from midas.cli.highlighter import DiagnosticsHighlighter, Highlighter, MidasHighlighter, PythonHighlighter
from midas.checker.types import Type
from midas.cli.highlighter import (
DiagnosticsHighlighter,
Highlighter,
MidasHighlighter,
PythonHighlighter,
)
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token, TokenType
from midas.parser.midas import MidasParser
@@ -46,7 +52,9 @@ def compile(highlight: Optional[TextIO], file: TextIO):
print(
json.dumps(
UniversalJSONDumper.dump(
checker.global_env, [("Environment", "_children")]
checker.global_env,
[("Environment", "_children")],
lambda obj: isinstance(obj, get_args(Type)),
),
indent=4,
)

View File

@@ -16,6 +16,7 @@ from midas.ast.python import (
FrameType,
Function,
GetExpr,
IfStmt,
LiteralExpr,
LogicalExpr,
MidasType,
@@ -82,6 +83,9 @@ class PythonParser:
value=self.parse_expr(value) if value is not None else None,
)
case ast.If():
return self.parse_if(node)
case _:
print(f"Unsupported statement: {ast.unparse(node)}")
return None
@@ -147,6 +151,30 @@ class PythonParser:
),
)
def parse_if(self, node: ast.If) -> IfStmt:
body: list[Stmt] = []
for stmt in node.body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
orelse: list[Stmt] = []
for stmt in node.orelse:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
orelse.append(stmts)
elif stmts is not None:
orelse.extend(stmts)
return IfStmt(
location=Location.from_ast(node),
test=self.parse_expr(node.test),
body=body,
orelse=orelse,
)
def parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node)
match node:

View File

@@ -2,12 +2,20 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from midas.checker.types import BaseType, Type
from midas.checker.types import BaseType, Type, UnitType
if TYPE_CHECKING:
from midas.resolver.midas import MidasResolver
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
ctx.define_operation(
left=t1,
operator=operator,
right=t2,
result=t3,
)
def basic_op(ctx: MidasResolver, type: Type, op: str):
ctx.define_operation(
left=type,
@@ -19,21 +27,45 @@ def basic_op(ctx: MidasResolver, type: Type, op: str):
def define_builtins(ctx: MidasResolver):
"""Define builtin types and operations"""
unit = ctx.define_type("None", UnitType())
bool = ctx.define_type("bool", BaseType(name="bool"))
int = ctx.define_type("int", BaseType(name="int"))
float = ctx.define_type("float", BaseType(name="float"))
str = ctx.define_type("str", BaseType(name="str"))
basic_op(ctx, int, "__add__")
basic_op(ctx, int, "__sub__")
basic_op(ctx, int, "__mul__")
basic_op(ctx, int, "__pow__")
basic_op(ctx, int, "__mod__")
basic_op(ctx, int, "__and__")
basic_op(ctx, int, "__or__")
basic_op(ctx, int, "__xor__")
basic_op(ctx, float, "__add__")
basic_op(ctx, float, "__sub__")
basic_op(ctx, float, "__mul__")
basic_op(ctx, float, "__truediv__")
basic_op(ctx, str, "__add__")
basic_op(ctx, int, "__add__") # int + int = int
basic_op(ctx, int, "__sub__") # int - int = int
basic_op(ctx, int, "__mul__") # int * int = int
basic_op(ctx, int, "__pow__") # int ** int = int
basic_op(ctx, int, "__mod__") # int % int = int
basic_op(ctx, int, "__and__") # int & int = int
basic_op(ctx, int, "__or__") # int | int = int
basic_op(ctx, int, "__xor__") # int ^ int = int
op(ctx, int, "__lt__", int, bool) # int < int = bool
op(ctx, int, "__gt__", int, bool) # int > int = bool
op(ctx, int, "__le__", int, bool) # int <= int = bool
op(ctx, int, "__ge__", int, bool) # int >= int = bool
op(ctx, int, "__eq__", int, bool) # int == int = bool
basic_op(ctx, float, "__add__") # float + float = float
basic_op(ctx, float, "__sub__") # float - float = float
basic_op(ctx, float, "__mul__") # float * float = float
basic_op(ctx, float, "__truediv__") # float / float = float
op(ctx, float, "__lt__", float, bool) # float < float = bool
op(ctx, float, "__gt__", float, bool) # float > float = bool
op(ctx, float, "__le__", float, bool) # float <= float = bool
op(ctx, float, "__ge__", float, bool) # float >= float = bool
op(ctx, float, "__eq__", float, bool) # float == float = bool
basic_op(ctx, str, "__add__") # str + str = str
op(ctx, str, "__eq__", str, bool) # str == str = bool
op(ctx, int, "__lt__", float, bool) # int < float = bool
op(ctx, int, "__gt__", float, bool) # int > float = bool
op(ctx, int, "__le__", float, bool) # int <= float = bool
op(ctx, int, "__ge__", float, bool) # int >= float = bool
op(ctx, int, "__eq__", float, bool) # int == float = bool
op(ctx, float, "__lt__", int, bool) # float < int = bool
op(ctx, float, "__gt__", int, bool) # float > int = bool
op(ctx, float, "__le__", int, bool) # float <= int = bool
op(ctx, float, "__ge__", int, bool) # float >= int = bool
op(ctx, float, "__eq__", int, bool) # float == int = bool

View File

@@ -121,6 +121,24 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
if stmt.value is not None:
self.resolve(stmt.value)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not resolved in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
self.resolve(stmt.test)
# Body
self.begin_scope()
self.resolve(*stmt.body)
self.end_scope()
# Else
self.begin_scope()
self.resolve(*stmt.orelse)
self.end_scope()
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)

View File

@@ -1,18 +1,27 @@
from typing import Any, Optional
from typing import Any, Callable, Optional
AllowRepeat = Callable[[object], bool]
class UniversalJSONDumper:
@classmethod
def dump(
cls, obj: Any, include_keys: Optional[list[str | tuple[str, str]]] = None
cls,
obj: Any,
include_keys: Optional[list[str | tuple[str, str]]] = None,
allow_repeat: Optional[AllowRepeat] = None,
) -> Any:
if include_keys is None:
include_keys = []
return cls._dump(obj, include_keys, [])
return cls._dump(obj, include_keys, allow_repeat, [])
@classmethod
def _dump(
cls, obj: Any, include_keys: list[str | tuple[str, str]], visited: list[Any]
cls,
obj: Any,
include_keys: list[str | tuple[str, str]],
allow_repeat: Optional[AllowRepeat],
visited: list[Any],
) -> Any:
if obj in visited:
return None
@@ -20,17 +29,22 @@ class UniversalJSONDumper:
case str() | int() | float() | None:
return obj
case list() | set() | tuple():
return [cls._dump(child, include_keys, visited) for child in obj]
return [
cls._dump(child, include_keys, allow_repeat, visited)
for child in obj
]
case dict():
return {
str(k): cls._dump(v, include_keys, visited) for k, v in obj.items()
str(k): cls._dump(v, include_keys, allow_repeat, visited)
for k, v in obj.items()
}
case object():
if allow_repeat is None or not allow_repeat(obj):
visited.append(obj)
return {
"_type": obj.__class__.__name__,
} | {
k: cls._dump(v, include_keys, visited)
k: cls._dump(v, include_keys, allow_repeat, visited)
for k, v in obj.__dict__.items()
if not k.startswith("_")
or k in include_keys

View File

@@ -0,0 +1,25 @@
def valid(a: int, b: int) -> int:
return a + b
def with_if(a: int, b: int) -> int:
if a < b:
return b - a
else:
return a - b
def unreachable1():
return
a = 0
def unreachable2(a: int) -> int:
if a > 10:
return a - 10
else:
return a
b = 0
def mixed(a: int, b: int):
if a < b:
return b - a
else:
return "oops"

View File

@@ -0,0 +1,46 @@
{
"diagnostics": [
{
"type": "Warning",
"location": {
"start": [
12,
4
],
"end": [
12,
9
]
},
"message": "Unreachable statement"
},
{
"type": "Warning",
"location": {
"start": [
19,
4
],
"end": [
19,
9
]
},
"message": "Unreachable statement"
},
{
"type": "Error",
"location": {
"start": [
21,
0
],
"end": [
25,
21
]
},
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
}
]
}

View File

@@ -15,6 +15,7 @@ from midas.ast.python import (
FrameType,
Function,
GetExpr,
IfStmt,
LiteralExpr,
LogicalExpr,
MidasType,
@@ -164,6 +165,14 @@ class PythonAstJsonSerializer(
"value": self._serialize_optional(stmt.value),
}
def visit_if_stmt(self, stmt: IfStmt) -> dict:
return {
"_type": "IfStmt",
"test": stmt.test.accept(self),
"body": self._serialize_list(stmt.body),
"orelse": self._serialize_list(stmt.orelse),
}
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
return {
"_type": "BinaryExpr",