Compare commits

...

8 Commits

15 changed files with 344 additions and 37 deletions

View File

@@ -0,0 +1,9 @@
def minimum(x: int, y: int):
if x < y:
return x
else:
return y
a = 15
b = 72
c = minimum(a, b)

View File

@@ -76,6 +76,12 @@ class ReturnStmt:
value: Optional[Expr] value: Optional[Expr]
class IfStmt:
test: Expr
body: list[Stmt]
orelse: list[Stmt]
###< ###<

View File

@@ -419,7 +419,14 @@ class PythonAstPrinter(
self._mark_last() self._mark_last()
self._print_argument(arg) self._print_argument(arg)
self._write_optional_child("returns", stmt.returns, last=True) self._write_optional_child("returns", stmt.returns)
self._write_line("body", last=True)
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
def _print_argument(self, arg: p.Function.Argument) -> None: def _print_argument(self, arg: p.Function.Argument) -> None:
self._write_line("FunctionArgument") self._write_line("FunctionArgument")
@@ -454,6 +461,26 @@ class PythonAstPrinter(
with self._child_level(): with self._child_level():
self._write_optional_child("value", stmt.value, last=True) self._write_optional_child("value", stmt.value, last=True)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self._write_line("IfStmt")
with self._child_level():
self._write_line("test")
stmt.test.accept(self)
self._write_line("body")
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
self._write_line("orelse", last=True)
with self._child_level():
for i, else_stmt in enumerate(stmt.orelse):
self._idx = i
if i == len(stmt.orelse) - 1:
self._mark_last()
else_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr") self._write_line("BinaryExpr")
with self._child_level(): with self._child_level():

View File

@@ -103,6 +103,9 @@ class Stmt(ABC):
@abstractmethod @abstractmethod
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ... def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
@abstractmethod
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class ExpressionStmt(Stmt): class ExpressionStmt(Stmt):
@@ -164,6 +167,16 @@ class ReturnStmt(Stmt):
return visitor.visit_return_stmt(self) return visitor.visit_return_stmt(self)
@dataclass(frozen=True)
class IfStmt(Stmt):
test: Expr
body: list[Stmt]
orelse: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_if_stmt(self)
############### ###############
# Expressions # # Expressions #
############### ###############

View File

@@ -8,7 +8,7 @@ import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.operators import OPERATOR_METHODS from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.types import Function, Type, UnitType, UnknownType from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token from midas.lexer.token import Token
@@ -85,21 +85,29 @@ class Checker(
""" """
return expr.accept(self) return expr.accept(self)
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> None: def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool:
"""Evaluate a sequence of statements """Evaluate a sequence of statements
Args: Args:
block (list[p.Stmt]): the statements to evaluate block (list[p.Stmt]): the statements to evaluate
env (Environment): the environment in which to evaluate env (Environment): the environment in which to evaluate
Returns:
bool: whether a return statement is present in the block
""" """
previous_env: Environment = self.env previous_env: Environment = self.env
self.env = env self.env = env
for stmt in block: returned: bool = False
for i, stmt in enumerate(block):
try: try:
stmt.accept(self) stmt.accept(self)
except ReturnException: except ReturnException:
returned = True
if i < len(block) - 1:
self.warning(block[i + 1].location, "Unreachable statement")
break break
self.env = previous_env self.env = previous_env
return returned
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]: def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
"""Type check a sequence of statements and returns diagnostics """Type check a sequence of statements and returns diagnostics
@@ -216,11 +224,14 @@ class Checker(
for arg in pos_args + args + kw_args: for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type) env.define(arg.name, arg.type)
self.evaluate_block(stmt.body, env) returned: bool = self.evaluate_block(stmt.body, env)
inferred_return: Type = UnknownType() inferred_return: Type = UnknownType()
if len(env.return_types) == 1: if not returned:
inferred_return = list(env.return_types)[0] env.return_types.append(UnitType())
elif len(env.return_types) > 1: return_types: set[Type] = set(env.return_types)
if len(return_types) == 1:
inferred_return = list(return_types)[0]
elif len(return_types) > 1:
self.error( self.error(
stmt.location, stmt.location,
f"Mixed return types: {env.return_types}", f"Mixed return types: {env.return_types}",
@@ -276,6 +287,27 @@ class Checker(
self.env.return_types.append(type) self.env.return_types.append(type)
raise ReturnException() raise ReturnException()
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not evaluated in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
test_type: Type = stmt.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.ctx.get_type("bool"):
self.error(
stmt.test.location, f"If test must be a boolean, got {test_type}"
)
env: Environment = Environment(self.env)
body_returned: bool = self.evaluate_block(stmt.body, env)
else_returned: bool = self.evaluate_block(stmt.orelse, env)
self.env.return_types.extend(env.return_types)
if body_returned and else_returned:
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type: def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__) method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None: if method is None:
@@ -294,7 +326,23 @@ class Checker(
return UnknownType() return UnknownType()
return result return result
def visit_compare_expr(self, expr: p.CompareExpr) -> Type: ... def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType()
left: Type = self.evaluate(expr.left)
right: Type = self.evaluate(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
self.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...

View File

@@ -148,6 +148,8 @@ class PythonHighlighter(
self.wrap(stmt, "function") self.wrap(stmt, "function")
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs: for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
self._highlight_function_argument(arg) self._highlight_function_argument(arg)
for body_stmt in stmt.body:
body_stmt.accept(self)
def _highlight_function_argument(self, arg: p.Function.Argument) -> None: def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
self.wrap(arg, "argument") self.wrap(arg, "argument")
@@ -157,9 +159,23 @@ class PythonHighlighter(
def visit_type_assign(self, stmt: p.TypeAssign) -> None: def visit_type_assign(self, stmt: p.TypeAssign) -> None:
stmt.type.accept(self) stmt.type.accept(self)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: ... def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
for target in stmt.targets:
target.accept(self)
stmt.value.accept(self)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: ... def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
self.wrap(stmt, "return")
if stmt.value is not None:
stmt.value.accept(self)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self.wrap(stmt, "if")
stmt.test.accept(self)
for body_stmt in stmt.body:
body_stmt.accept(self)
for else_stmt in stmt.orelse:
else_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ... def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
@@ -167,7 +183,13 @@ class PythonHighlighter(
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ... def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
def visit_call_expr(self, expr: p.CallExpr) -> None: ... def visit_call_expr(self, expr: p.CallExpr) -> None:
self.wrap(expr, "call")
expr.callee.accept(self)
for arg in expr.arguments:
arg.accept(self)
for arg in expr.keywords.values():
arg.accept(self)
def visit_get_expr(self, expr: p.GetExpr) -> None: ... def visit_get_expr(self, expr: p.GetExpr) -> None: ...

View File

@@ -1,4 +1,6 @@
span { span {
--opacity: 0.4;
&.error { &.error {
--col: 255, 0, 0; --col: 255, 0, 0;
} }

View File

@@ -3,7 +3,7 @@ import json
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, TextIO from typing import Optional, TextIO, get_args
import click import click
@@ -13,7 +13,13 @@ from midas.ast.location import Location
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
from midas.checker.checker import Checker from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic from midas.checker.diagnostic import Diagnostic
from midas.cli.highlighter import DiagnosticsHighlighter, Highlighter, MidasHighlighter, PythonHighlighter from midas.checker.types import Type
from midas.cli.highlighter import (
DiagnosticsHighlighter,
Highlighter,
MidasHighlighter,
PythonHighlighter,
)
from midas.lexer.midas import MidasLexer from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token, TokenType from midas.lexer.token import Token, TokenType
from midas.parser.midas import MidasParser from midas.parser.midas import MidasParser
@@ -46,7 +52,9 @@ def compile(highlight: Optional[TextIO], file: TextIO):
print( print(
json.dumps( json.dumps(
UniversalJSONDumper.dump( UniversalJSONDumper.dump(
checker.global_env, [("Environment", "_children")] checker.global_env,
[("Environment", "_children")],
lambda obj: isinstance(obj, get_args(Type)),
), ),
indent=4, indent=4,
) )

View File

@@ -16,6 +16,7 @@ from midas.ast.python import (
FrameType, FrameType,
Function, Function,
GetExpr, GetExpr,
IfStmt,
LiteralExpr, LiteralExpr,
LogicalExpr, LogicalExpr,
MidasType, MidasType,
@@ -82,6 +83,9 @@ class PythonParser:
value=self.parse_expr(value) if value is not None else None, value=self.parse_expr(value) if value is not None else None,
) )
case ast.If():
return self.parse_if(node)
case _: case _:
print(f"Unsupported statement: {ast.unparse(node)}") print(f"Unsupported statement: {ast.unparse(node)}")
return None return None
@@ -147,6 +151,30 @@ class PythonParser:
), ),
) )
def parse_if(self, node: ast.If) -> IfStmt:
body: list[Stmt] = []
for stmt in node.body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
orelse: list[Stmt] = []
for stmt in node.orelse:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
orelse.append(stmts)
elif stmts is not None:
orelse.extend(stmts)
return IfStmt(
location=Location.from_ast(node),
test=self.parse_expr(node.test),
body=body,
orelse=orelse,
)
def parse_function(self, node: ast.FunctionDef) -> Function: def parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node) loc: Location = Location.from_ast(node)
match node: match node:

View File

@@ -2,12 +2,20 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from midas.checker.types import BaseType, Type from midas.checker.types import BaseType, Type, UnitType
if TYPE_CHECKING: if TYPE_CHECKING:
from midas.resolver.midas import MidasResolver from midas.resolver.midas import MidasResolver
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
ctx.define_operation(
left=t1,
operator=operator,
right=t2,
result=t3,
)
def basic_op(ctx: MidasResolver, type: Type, op: str): def basic_op(ctx: MidasResolver, type: Type, op: str):
ctx.define_operation( ctx.define_operation(
left=type, left=type,
@@ -19,21 +27,45 @@ def basic_op(ctx: MidasResolver, type: Type, op: str):
def define_builtins(ctx: MidasResolver): def define_builtins(ctx: MidasResolver):
"""Define builtin types and operations""" """Define builtin types and operations"""
unit = ctx.define_type("None", UnitType())
bool = ctx.define_type("bool", BaseType(name="bool")) bool = ctx.define_type("bool", BaseType(name="bool"))
int = ctx.define_type("int", BaseType(name="int")) int = ctx.define_type("int", BaseType(name="int"))
float = ctx.define_type("float", BaseType(name="float")) float = ctx.define_type("float", BaseType(name="float"))
str = ctx.define_type("str", BaseType(name="str")) str = ctx.define_type("str", BaseType(name="str"))
basic_op(ctx, int, "__add__") basic_op(ctx, int, "__add__") # int + int = int
basic_op(ctx, int, "__sub__") basic_op(ctx, int, "__sub__") # int - int = int
basic_op(ctx, int, "__mul__") basic_op(ctx, int, "__mul__") # int * int = int
basic_op(ctx, int, "__pow__") basic_op(ctx, int, "__pow__") # int ** int = int
basic_op(ctx, int, "__mod__") basic_op(ctx, int, "__mod__") # int % int = int
basic_op(ctx, int, "__and__") basic_op(ctx, int, "__and__") # int & int = int
basic_op(ctx, int, "__or__") basic_op(ctx, int, "__or__") # int | int = int
basic_op(ctx, int, "__xor__") basic_op(ctx, int, "__xor__") # int ^ int = int
basic_op(ctx, float, "__add__") op(ctx, int, "__lt__", int, bool) # int < int = bool
basic_op(ctx, float, "__sub__") op(ctx, int, "__gt__", int, bool) # int > int = bool
basic_op(ctx, float, "__mul__") op(ctx, int, "__le__", int, bool) # int <= int = bool
basic_op(ctx, float, "__truediv__") op(ctx, int, "__ge__", int, bool) # int >= int = bool
basic_op(ctx, str, "__add__") op(ctx, int, "__eq__", int, bool) # int == int = bool
basic_op(ctx, float, "__add__") # float + float = float
basic_op(ctx, float, "__sub__") # float - float = float
basic_op(ctx, float, "__mul__") # float * float = float
basic_op(ctx, float, "__truediv__") # float / float = float
op(ctx, float, "__lt__", float, bool) # float < float = bool
op(ctx, float, "__gt__", float, bool) # float > float = bool
op(ctx, float, "__le__", float, bool) # float <= float = bool
op(ctx, float, "__ge__", float, bool) # float >= float = bool
op(ctx, float, "__eq__", float, bool) # float == float = bool
basic_op(ctx, str, "__add__") # str + str = str
op(ctx, str, "__eq__", str, bool) # str == str = bool
op(ctx, int, "__lt__", float, bool) # int < float = bool
op(ctx, int, "__gt__", float, bool) # int > float = bool
op(ctx, int, "__le__", float, bool) # int <= float = bool
op(ctx, int, "__ge__", float, bool) # int >= float = bool
op(ctx, int, "__eq__", float, bool) # int == float = bool
op(ctx, float, "__lt__", int, bool) # float < int = bool
op(ctx, float, "__gt__", int, bool) # float > int = bool
op(ctx, float, "__le__", int, bool) # float <= int = bool
op(ctx, float, "__ge__", int, bool) # float >= int = bool
op(ctx, float, "__eq__", int, bool) # float == int = bool

View File

@@ -121,6 +121,24 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
if stmt.value is not None: if stmt.value is not None:
self.resolve(stmt.value) self.resolve(stmt.value)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not resolved in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
self.resolve(stmt.test)
# Body
self.begin_scope()
self.resolve(*stmt.body)
self.end_scope()
# Else
self.begin_scope()
self.resolve(*stmt.orelse)
self.end_scope()
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self.resolve(expr.left) self.resolve(expr.left)
self.resolve(expr.right) self.resolve(expr.right)

View File

@@ -1,18 +1,27 @@
from typing import Any, Optional from typing import Any, Callable, Optional
AllowRepeat = Callable[[object], bool]
class UniversalJSONDumper: class UniversalJSONDumper:
@classmethod @classmethod
def dump( def dump(
cls, obj: Any, include_keys: Optional[list[str | tuple[str, str]]] = None cls,
obj: Any,
include_keys: Optional[list[str | tuple[str, str]]] = None,
allow_repeat: Optional[AllowRepeat] = None,
) -> Any: ) -> Any:
if include_keys is None: if include_keys is None:
include_keys = [] include_keys = []
return cls._dump(obj, include_keys, []) return cls._dump(obj, include_keys, allow_repeat, [])
@classmethod @classmethod
def _dump( def _dump(
cls, obj: Any, include_keys: list[str | tuple[str, str]], visited: list[Any] cls,
obj: Any,
include_keys: list[str | tuple[str, str]],
allow_repeat: Optional[AllowRepeat],
visited: list[Any],
) -> Any: ) -> Any:
if obj in visited: if obj in visited:
return None return None
@@ -20,17 +29,22 @@ class UniversalJSONDumper:
case str() | int() | float() | None: case str() | int() | float() | None:
return obj return obj
case list() | set() | tuple(): case list() | set() | tuple():
return [cls._dump(child, include_keys, visited) for child in obj] return [
cls._dump(child, include_keys, allow_repeat, visited)
for child in obj
]
case dict(): case dict():
return { return {
str(k): cls._dump(v, include_keys, visited) for k, v in obj.items() str(k): cls._dump(v, include_keys, allow_repeat, visited)
for k, v in obj.items()
} }
case object(): case object():
visited.append(obj) if allow_repeat is None or not allow_repeat(obj):
visited.append(obj)
return { return {
"_type": obj.__class__.__name__, "_type": obj.__class__.__name__,
} | { } | {
k: cls._dump(v, include_keys, visited) k: cls._dump(v, include_keys, allow_repeat, visited)
for k, v in obj.__dict__.items() for k, v in obj.__dict__.items()
if not k.startswith("_") if not k.startswith("_")
or k in include_keys or k in include_keys

View File

@@ -0,0 +1,25 @@
def valid(a: int, b: int) -> int:
return a + b
def with_if(a: int, b: int) -> int:
if a < b:
return b - a
else:
return a - b
def unreachable1():
return
a = 0
def unreachable2(a: int) -> int:
if a > 10:
return a - 10
else:
return a
b = 0
def mixed(a: int, b: int):
if a < b:
return b - a
else:
return "oops"

View File

@@ -0,0 +1,46 @@
{
"diagnostics": [
{
"type": "Warning",
"location": {
"start": [
12,
4
],
"end": [
12,
9
]
},
"message": "Unreachable statement"
},
{
"type": "Warning",
"location": {
"start": [
19,
4
],
"end": [
19,
9
]
},
"message": "Unreachable statement"
},
{
"type": "Error",
"location": {
"start": [
21,
0
],
"end": [
25,
21
]
},
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
}
]
}

View File

@@ -15,6 +15,7 @@ from midas.ast.python import (
FrameType, FrameType,
Function, Function,
GetExpr, GetExpr,
IfStmt,
LiteralExpr, LiteralExpr,
LogicalExpr, LogicalExpr,
MidasType, MidasType,
@@ -164,6 +165,14 @@ class PythonAstJsonSerializer(
"value": self._serialize_optional(stmt.value), "value": self._serialize_optional(stmt.value),
} }
def visit_if_stmt(self, stmt: IfStmt) -> dict:
return {
"_type": "IfStmt",
"test": stmt.test.accept(self),
"body": self._serialize_list(stmt.body),
"orelse": self._serialize_list(stmt.orelse),
}
def visit_binary_expr(self, expr: BinaryExpr) -> dict: def visit_binary_expr(self, expr: BinaryExpr) -> dict:
return { return {
"_type": "BinaryExpr", "_type": "BinaryExpr",