Compare commits
8 Commits
86ad348b99
...
bd31713ab4
| Author | SHA1 | Date | |
|---|---|---|---|
|
bd31713ab4
|
|||
|
f4dc57cb96
|
|||
|
261fd47494
|
|||
|
1b66a8553d
|
|||
|
65164abadb
|
|||
|
9d45163d9c
|
|||
|
ab0fa1de1a
|
|||
|
5d4df7978b
|
9
examples/01_simple_type_checking/03_control_flow.py
Normal file
9
examples/01_simple_type_checking/03_control_flow.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
def minimum(x: int, y: int):
|
||||||
|
if x < y:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return y
|
||||||
|
|
||||||
|
a = 15
|
||||||
|
b = 72
|
||||||
|
c = minimum(a, b)
|
||||||
@@ -76,6 +76,12 @@ class ReturnStmt:
|
|||||||
value: Optional[Expr]
|
value: Optional[Expr]
|
||||||
|
|
||||||
|
|
||||||
|
class IfStmt:
|
||||||
|
test: Expr
|
||||||
|
body: list[Stmt]
|
||||||
|
orelse: list[Stmt]
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -419,7 +419,14 @@ class PythonAstPrinter(
|
|||||||
self._mark_last()
|
self._mark_last()
|
||||||
self._print_argument(arg)
|
self._print_argument(arg)
|
||||||
|
|
||||||
self._write_optional_child("returns", stmt.returns, last=True)
|
self._write_optional_child("returns", stmt.returns)
|
||||||
|
self._write_line("body", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, body_stmt in enumerate(stmt.body):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.body) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
body_stmt.accept(self)
|
||||||
|
|
||||||
def _print_argument(self, arg: p.Function.Argument) -> None:
|
def _print_argument(self, arg: p.Function.Argument) -> None:
|
||||||
self._write_line("FunctionArgument")
|
self._write_line("FunctionArgument")
|
||||||
@@ -454,6 +461,26 @@ class PythonAstPrinter(
|
|||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_optional_child("value", stmt.value, last=True)
|
self._write_optional_child("value", stmt.value, last=True)
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
self._write_line("IfStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("test")
|
||||||
|
stmt.test.accept(self)
|
||||||
|
self._write_line("body")
|
||||||
|
with self._child_level():
|
||||||
|
for i, body_stmt in enumerate(stmt.body):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.body) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
body_stmt.accept(self)
|
||||||
|
self._write_line("orelse", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, else_stmt in enumerate(stmt.orelse):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.orelse) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
else_stmt.accept(self)
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||||
self._write_line("BinaryExpr")
|
self._write_line("BinaryExpr")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
|
|||||||
@@ -103,6 +103,9 @@ class Stmt(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
|
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ExpressionStmt(Stmt):
|
class ExpressionStmt(Stmt):
|
||||||
@@ -164,6 +167,16 @@ class ReturnStmt(Stmt):
|
|||||||
return visitor.visit_return_stmt(self)
|
return visitor.visit_return_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class IfStmt(Stmt):
|
||||||
|
test: Expr
|
||||||
|
body: list[Stmt]
|
||||||
|
orelse: list[Stmt]
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_if_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
###############
|
###############
|
||||||
# Expressions #
|
# Expressions #
|
||||||
###############
|
###############
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import midas.ast.python as p
|
|||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.operators import OPERATOR_METHODS
|
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||||
from midas.checker.types import Function, Type, UnitType, UnknownType
|
from midas.checker.types import Function, Type, UnitType, UnknownType
|
||||||
from midas.lexer.midas import MidasLexer
|
from midas.lexer.midas import MidasLexer
|
||||||
from midas.lexer.token import Token
|
from midas.lexer.token import Token
|
||||||
@@ -85,21 +85,29 @@ class Checker(
|
|||||||
"""
|
"""
|
||||||
return expr.accept(self)
|
return expr.accept(self)
|
||||||
|
|
||||||
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> None:
|
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||||
"""Evaluate a sequence of statements
|
"""Evaluate a sequence of statements
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block (list[p.Stmt]): the statements to evaluate
|
block (list[p.Stmt]): the statements to evaluate
|
||||||
env (Environment): the environment in which to evaluate
|
env (Environment): the environment in which to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether a return statement is present in the block
|
||||||
"""
|
"""
|
||||||
previous_env: Environment = self.env
|
previous_env: Environment = self.env
|
||||||
self.env = env
|
self.env = env
|
||||||
for stmt in block:
|
returned: bool = False
|
||||||
|
for i, stmt in enumerate(block):
|
||||||
try:
|
try:
|
||||||
stmt.accept(self)
|
stmt.accept(self)
|
||||||
except ReturnException:
|
except ReturnException:
|
||||||
|
returned = True
|
||||||
|
if i < len(block) - 1:
|
||||||
|
self.warning(block[i + 1].location, "Unreachable statement")
|
||||||
break
|
break
|
||||||
self.env = previous_env
|
self.env = previous_env
|
||||||
|
return returned
|
||||||
|
|
||||||
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
||||||
"""Type check a sequence of statements and returns diagnostics
|
"""Type check a sequence of statements and returns diagnostics
|
||||||
@@ -216,11 +224,14 @@ class Checker(
|
|||||||
for arg in pos_args + args + kw_args:
|
for arg in pos_args + args + kw_args:
|
||||||
env.define(arg.name, arg.type)
|
env.define(arg.name, arg.type)
|
||||||
|
|
||||||
self.evaluate_block(stmt.body, env)
|
returned: bool = self.evaluate_block(stmt.body, env)
|
||||||
inferred_return: Type = UnknownType()
|
inferred_return: Type = UnknownType()
|
||||||
if len(env.return_types) == 1:
|
if not returned:
|
||||||
inferred_return = list(env.return_types)[0]
|
env.return_types.append(UnitType())
|
||||||
elif len(env.return_types) > 1:
|
return_types: set[Type] = set(env.return_types)
|
||||||
|
if len(return_types) == 1:
|
||||||
|
inferred_return = list(return_types)[0]
|
||||||
|
elif len(return_types) > 1:
|
||||||
self.error(
|
self.error(
|
||||||
stmt.location,
|
stmt.location,
|
||||||
f"Mixed return types: {env.return_types}",
|
f"Mixed return types: {env.return_types}",
|
||||||
@@ -276,6 +287,27 @@ class Checker(
|
|||||||
self.env.return_types.append(type)
|
self.env.return_types.append(type)
|
||||||
raise ReturnException()
|
raise ReturnException()
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||||
|
# For example:
|
||||||
|
# if (m := 1 + 1) < 2:
|
||||||
|
# ...
|
||||||
|
# print(m) # <- m is still defined
|
||||||
|
test_type: Type = stmt.test.accept(self)
|
||||||
|
|
||||||
|
# TODO Allow subtypes or any type
|
||||||
|
if test_type != self.ctx.get_type("bool"):
|
||||||
|
self.error(
|
||||||
|
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
env: Environment = Environment(self.env)
|
||||||
|
body_returned: bool = self.evaluate_block(stmt.body, env)
|
||||||
|
else_returned: bool = self.evaluate_block(stmt.orelse, env)
|
||||||
|
self.env.return_types.extend(env.return_types)
|
||||||
|
if body_returned and else_returned:
|
||||||
|
raise ReturnException()
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
@@ -294,7 +326,23 @@ class Checker(
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type: ...
|
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||||
|
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||||
|
return UnknownType()
|
||||||
|
left: Type = self.evaluate(expr.left)
|
||||||
|
right: Type = self.evaluate(expr.right)
|
||||||
|
|
||||||
|
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||||
|
if result is None:
|
||||||
|
self.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return result
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
||||||
|
|
||||||
|
|||||||
@@ -148,6 +148,8 @@ class PythonHighlighter(
|
|||||||
self.wrap(stmt, "function")
|
self.wrap(stmt, "function")
|
||||||
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
|
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
|
||||||
self._highlight_function_argument(arg)
|
self._highlight_function_argument(arg)
|
||||||
|
for body_stmt in stmt.body:
|
||||||
|
body_stmt.accept(self)
|
||||||
|
|
||||||
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
|
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
|
||||||
self.wrap(arg, "argument")
|
self.wrap(arg, "argument")
|
||||||
@@ -157,9 +159,23 @@ class PythonHighlighter(
|
|||||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
stmt.type.accept(self)
|
stmt.type.accept(self)
|
||||||
|
|
||||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: ...
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||||
|
for target in stmt.targets:
|
||||||
|
target.accept(self)
|
||||||
|
stmt.value.accept(self)
|
||||||
|
|
||||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: ...
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
|
self.wrap(stmt, "return")
|
||||||
|
if stmt.value is not None:
|
||||||
|
stmt.value.accept(self)
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
self.wrap(stmt, "if")
|
||||||
|
stmt.test.accept(self)
|
||||||
|
for body_stmt in stmt.body:
|
||||||
|
body_stmt.accept(self)
|
||||||
|
for else_stmt in stmt.orelse:
|
||||||
|
else_stmt.accept(self)
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
|
||||||
|
|
||||||
@@ -167,7 +183,13 @@ class PythonHighlighter(
|
|||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
|
||||||
|
|
||||||
def visit_call_expr(self, expr: p.CallExpr) -> None: ...
|
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||||
|
self.wrap(expr, "call")
|
||||||
|
expr.callee.accept(self)
|
||||||
|
for arg in expr.arguments:
|
||||||
|
arg.accept(self)
|
||||||
|
for arg in expr.keywords.values():
|
||||||
|
arg.accept(self)
|
||||||
|
|
||||||
def visit_get_expr(self, expr: p.GetExpr) -> None: ...
|
def visit_get_expr(self, expr: p.GetExpr) -> None: ...
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
span {
|
span {
|
||||||
|
--opacity: 0.4;
|
||||||
|
|
||||||
&.error {
|
&.error {
|
||||||
--col: 255, 0, 0;
|
--col: 255, 0, 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, TextIO
|
from typing import Optional, TextIO, get_args
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@@ -13,7 +13,13 @@ from midas.ast.location import Location
|
|||||||
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
|
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
|
||||||
from midas.checker.checker import Checker
|
from midas.checker.checker import Checker
|
||||||
from midas.checker.diagnostic import Diagnostic
|
from midas.checker.diagnostic import Diagnostic
|
||||||
from midas.cli.highlighter import DiagnosticsHighlighter, Highlighter, MidasHighlighter, PythonHighlighter
|
from midas.checker.types import Type
|
||||||
|
from midas.cli.highlighter import (
|
||||||
|
DiagnosticsHighlighter,
|
||||||
|
Highlighter,
|
||||||
|
MidasHighlighter,
|
||||||
|
PythonHighlighter,
|
||||||
|
)
|
||||||
from midas.lexer.midas import MidasLexer
|
from midas.lexer.midas import MidasLexer
|
||||||
from midas.lexer.token import Token, TokenType
|
from midas.lexer.token import Token, TokenType
|
||||||
from midas.parser.midas import MidasParser
|
from midas.parser.midas import MidasParser
|
||||||
@@ -46,7 +52,9 @@ def compile(highlight: Optional[TextIO], file: TextIO):
|
|||||||
print(
|
print(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
UniversalJSONDumper.dump(
|
UniversalJSONDumper.dump(
|
||||||
checker.global_env, [("Environment", "_children")]
|
checker.global_env,
|
||||||
|
[("Environment", "_children")],
|
||||||
|
lambda obj: isinstance(obj, get_args(Type)),
|
||||||
),
|
),
|
||||||
indent=4,
|
indent=4,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from midas.ast.python import (
|
|||||||
FrameType,
|
FrameType,
|
||||||
Function,
|
Function,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
|
IfStmt,
|
||||||
LiteralExpr,
|
LiteralExpr,
|
||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MidasType,
|
MidasType,
|
||||||
@@ -82,6 +83,9 @@ class PythonParser:
|
|||||||
value=self.parse_expr(value) if value is not None else None,
|
value=self.parse_expr(value) if value is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case ast.If():
|
||||||
|
return self.parse_if(node)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
print(f"Unsupported statement: {ast.unparse(node)}")
|
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||||
return None
|
return None
|
||||||
@@ -147,6 +151,30 @@ class PythonParser:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def parse_if(self, node: ast.If) -> IfStmt:
|
||||||
|
body: list[Stmt] = []
|
||||||
|
for stmt in node.body:
|
||||||
|
stmts = self.parse_stmt(stmt)
|
||||||
|
if isinstance(stmts, Stmt):
|
||||||
|
body.append(stmts)
|
||||||
|
elif stmts is not None:
|
||||||
|
body.extend(stmts)
|
||||||
|
|
||||||
|
orelse: list[Stmt] = []
|
||||||
|
for stmt in node.orelse:
|
||||||
|
stmts = self.parse_stmt(stmt)
|
||||||
|
if isinstance(stmts, Stmt):
|
||||||
|
orelse.append(stmts)
|
||||||
|
elif stmts is not None:
|
||||||
|
orelse.extend(stmts)
|
||||||
|
|
||||||
|
return IfStmt(
|
||||||
|
location=Location.from_ast(node),
|
||||||
|
test=self.parse_expr(node.test),
|
||||||
|
body=body,
|
||||||
|
orelse=orelse,
|
||||||
|
)
|
||||||
|
|
||||||
def parse_function(self, node: ast.FunctionDef) -> Function:
|
def parse_function(self, node: ast.FunctionDef) -> Function:
|
||||||
loc: Location = Location.from_ast(node)
|
loc: Location = Location.from_ast(node)
|
||||||
match node:
|
match node:
|
||||||
|
|||||||
@@ -2,12 +2,20 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from midas.checker.types import BaseType, Type
|
from midas.checker.types import BaseType, Type, UnitType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from midas.resolver.midas import MidasResolver
|
from midas.resolver.midas import MidasResolver
|
||||||
|
|
||||||
|
|
||||||
|
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
|
||||||
|
ctx.define_operation(
|
||||||
|
left=t1,
|
||||||
|
operator=operator,
|
||||||
|
right=t2,
|
||||||
|
result=t3,
|
||||||
|
)
|
||||||
|
|
||||||
def basic_op(ctx: MidasResolver, type: Type, op: str):
|
def basic_op(ctx: MidasResolver, type: Type, op: str):
|
||||||
ctx.define_operation(
|
ctx.define_operation(
|
||||||
left=type,
|
left=type,
|
||||||
@@ -19,21 +27,45 @@ def basic_op(ctx: MidasResolver, type: Type, op: str):
|
|||||||
|
|
||||||
def define_builtins(ctx: MidasResolver):
|
def define_builtins(ctx: MidasResolver):
|
||||||
"""Define builtin types and operations"""
|
"""Define builtin types and operations"""
|
||||||
|
unit = ctx.define_type("None", UnitType())
|
||||||
bool = ctx.define_type("bool", BaseType(name="bool"))
|
bool = ctx.define_type("bool", BaseType(name="bool"))
|
||||||
int = ctx.define_type("int", BaseType(name="int"))
|
int = ctx.define_type("int", BaseType(name="int"))
|
||||||
float = ctx.define_type("float", BaseType(name="float"))
|
float = ctx.define_type("float", BaseType(name="float"))
|
||||||
str = ctx.define_type("str", BaseType(name="str"))
|
str = ctx.define_type("str", BaseType(name="str"))
|
||||||
|
|
||||||
basic_op(ctx, int, "__add__")
|
basic_op(ctx, int, "__add__") # int + int = int
|
||||||
basic_op(ctx, int, "__sub__")
|
basic_op(ctx, int, "__sub__") # int - int = int
|
||||||
basic_op(ctx, int, "__mul__")
|
basic_op(ctx, int, "__mul__") # int * int = int
|
||||||
basic_op(ctx, int, "__pow__")
|
basic_op(ctx, int, "__pow__") # int ** int = int
|
||||||
basic_op(ctx, int, "__mod__")
|
basic_op(ctx, int, "__mod__") # int % int = int
|
||||||
basic_op(ctx, int, "__and__")
|
basic_op(ctx, int, "__and__") # int & int = int
|
||||||
basic_op(ctx, int, "__or__")
|
basic_op(ctx, int, "__or__") # int | int = int
|
||||||
basic_op(ctx, int, "__xor__")
|
basic_op(ctx, int, "__xor__") # int ^ int = int
|
||||||
basic_op(ctx, float, "__add__")
|
op(ctx, int, "__lt__", int, bool) # int < int = bool
|
||||||
basic_op(ctx, float, "__sub__")
|
op(ctx, int, "__gt__", int, bool) # int > int = bool
|
||||||
basic_op(ctx, float, "__mul__")
|
op(ctx, int, "__le__", int, bool) # int <= int = bool
|
||||||
basic_op(ctx, float, "__truediv__")
|
op(ctx, int, "__ge__", int, bool) # int >= int = bool
|
||||||
basic_op(ctx, str, "__add__")
|
op(ctx, int, "__eq__", int, bool) # int == int = bool
|
||||||
|
basic_op(ctx, float, "__add__") # float + float = float
|
||||||
|
basic_op(ctx, float, "__sub__") # float - float = float
|
||||||
|
basic_op(ctx, float, "__mul__") # float * float = float
|
||||||
|
basic_op(ctx, float, "__truediv__") # float / float = float
|
||||||
|
op(ctx, float, "__lt__", float, bool) # float < float = bool
|
||||||
|
op(ctx, float, "__gt__", float, bool) # float > float = bool
|
||||||
|
op(ctx, float, "__le__", float, bool) # float <= float = bool
|
||||||
|
op(ctx, float, "__ge__", float, bool) # float >= float = bool
|
||||||
|
op(ctx, float, "__eq__", float, bool) # float == float = bool
|
||||||
|
basic_op(ctx, str, "__add__") # str + str = str
|
||||||
|
op(ctx, str, "__eq__", str, bool) # str == str = bool
|
||||||
|
|
||||||
|
op(ctx, int, "__lt__", float, bool) # int < float = bool
|
||||||
|
op(ctx, int, "__gt__", float, bool) # int > float = bool
|
||||||
|
op(ctx, int, "__le__", float, bool) # int <= float = bool
|
||||||
|
op(ctx, int, "__ge__", float, bool) # int >= float = bool
|
||||||
|
op(ctx, int, "__eq__", float, bool) # int == float = bool
|
||||||
|
|
||||||
|
op(ctx, float, "__lt__", int, bool) # float < int = bool
|
||||||
|
op(ctx, float, "__gt__", int, bool) # float > int = bool
|
||||||
|
op(ctx, float, "__le__", int, bool) # float <= int = bool
|
||||||
|
op(ctx, float, "__ge__", int, bool) # float >= int = bool
|
||||||
|
op(ctx, float, "__eq__", int, bool) # float == int = bool
|
||||||
@@ -121,6 +121,24 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
if stmt.value is not None:
|
if stmt.value is not None:
|
||||||
self.resolve(stmt.value)
|
self.resolve(stmt.value)
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
# Not resolved in sub-environment because assignments in the test leak out of the if
|
||||||
|
# For example:
|
||||||
|
# if (m := 1 + 1) < 2:
|
||||||
|
# ...
|
||||||
|
# print(m) # <- m is still defined
|
||||||
|
self.resolve(stmt.test)
|
||||||
|
|
||||||
|
# Body
|
||||||
|
self.begin_scope()
|
||||||
|
self.resolve(*stmt.body)
|
||||||
|
self.end_scope()
|
||||||
|
|
||||||
|
# Else
|
||||||
|
self.begin_scope()
|
||||||
|
self.resolve(*stmt.orelse)
|
||||||
|
self.end_scope()
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||||
self.resolve(expr.left)
|
self.resolve(expr.left)
|
||||||
self.resolve(expr.right)
|
self.resolve(expr.right)
|
||||||
|
|||||||
@@ -1,18 +1,27 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
AllowRepeat = Callable[[object], bool]
|
||||||
|
|
||||||
|
|
||||||
class UniversalJSONDumper:
|
class UniversalJSONDumper:
|
||||||
@classmethod
|
@classmethod
|
||||||
def dump(
|
def dump(
|
||||||
cls, obj: Any, include_keys: Optional[list[str | tuple[str, str]]] = None
|
cls,
|
||||||
|
obj: Any,
|
||||||
|
include_keys: Optional[list[str | tuple[str, str]]] = None,
|
||||||
|
allow_repeat: Optional[AllowRepeat] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if include_keys is None:
|
if include_keys is None:
|
||||||
include_keys = []
|
include_keys = []
|
||||||
return cls._dump(obj, include_keys, [])
|
return cls._dump(obj, include_keys, allow_repeat, [])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _dump(
|
def _dump(
|
||||||
cls, obj: Any, include_keys: list[str | tuple[str, str]], visited: list[Any]
|
cls,
|
||||||
|
obj: Any,
|
||||||
|
include_keys: list[str | tuple[str, str]],
|
||||||
|
allow_repeat: Optional[AllowRepeat],
|
||||||
|
visited: list[Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if obj in visited:
|
if obj in visited:
|
||||||
return None
|
return None
|
||||||
@@ -20,17 +29,22 @@ class UniversalJSONDumper:
|
|||||||
case str() | int() | float() | None:
|
case str() | int() | float() | None:
|
||||||
return obj
|
return obj
|
||||||
case list() | set() | tuple():
|
case list() | set() | tuple():
|
||||||
return [cls._dump(child, include_keys, visited) for child in obj]
|
return [
|
||||||
|
cls._dump(child, include_keys, allow_repeat, visited)
|
||||||
|
for child in obj
|
||||||
|
]
|
||||||
case dict():
|
case dict():
|
||||||
return {
|
return {
|
||||||
str(k): cls._dump(v, include_keys, visited) for k, v in obj.items()
|
str(k): cls._dump(v, include_keys, allow_repeat, visited)
|
||||||
|
for k, v in obj.items()
|
||||||
}
|
}
|
||||||
case object():
|
case object():
|
||||||
|
if allow_repeat is None or not allow_repeat(obj):
|
||||||
visited.append(obj)
|
visited.append(obj)
|
||||||
return {
|
return {
|
||||||
"_type": obj.__class__.__name__,
|
"_type": obj.__class__.__name__,
|
||||||
} | {
|
} | {
|
||||||
k: cls._dump(v, include_keys, visited)
|
k: cls._dump(v, include_keys, allow_repeat, visited)
|
||||||
for k, v in obj.__dict__.items()
|
for k, v in obj.__dict__.items()
|
||||||
if not k.startswith("_")
|
if not k.startswith("_")
|
||||||
or k in include_keys
|
or k in include_keys
|
||||||
|
|||||||
25
tests/cases/checker/05_control_flow.py
Normal file
25
tests/cases/checker/05_control_flow.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
def valid(a: int, b: int) -> int:
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
def with_if(a: int, b: int) -> int:
|
||||||
|
if a < b:
|
||||||
|
return b - a
|
||||||
|
else:
|
||||||
|
return a - b
|
||||||
|
|
||||||
|
def unreachable1():
|
||||||
|
return
|
||||||
|
a = 0
|
||||||
|
|
||||||
|
def unreachable2(a: int) -> int:
|
||||||
|
if a > 10:
|
||||||
|
return a - 10
|
||||||
|
else:
|
||||||
|
return a
|
||||||
|
b = 0
|
||||||
|
|
||||||
|
def mixed(a: int, b: int):
|
||||||
|
if a < b:
|
||||||
|
return b - a
|
||||||
|
else:
|
||||||
|
return "oops"
|
||||||
46
tests/cases/checker/05_control_flow.py.ref.json
Normal file
46
tests/cases/checker/05_control_flow.py.ref.json
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
{
|
||||||
|
"diagnostics": [
|
||||||
|
{
|
||||||
|
"type": "Warning",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
12,
|
||||||
|
4
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
12,
|
||||||
|
9
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Unreachable statement"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Warning",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
19,
|
||||||
|
4
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
19,
|
||||||
|
9
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Unreachable statement"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
21,
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
25,
|
||||||
|
21
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ from midas.ast.python import (
|
|||||||
FrameType,
|
FrameType,
|
||||||
Function,
|
Function,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
|
IfStmt,
|
||||||
LiteralExpr,
|
LiteralExpr,
|
||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MidasType,
|
MidasType,
|
||||||
@@ -164,6 +165,14 @@ class PythonAstJsonSerializer(
|
|||||||
"value": self._serialize_optional(stmt.value),
|
"value": self._serialize_optional(stmt.value),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: IfStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "IfStmt",
|
||||||
|
"test": stmt.test.accept(self),
|
||||||
|
"body": self._serialize_list(stmt.body),
|
||||||
|
"orelse": self._serialize_list(stmt.orelse),
|
||||||
|
}
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "BinaryExpr",
|
"_type": "BinaryExpr",
|
||||||
|
|||||||
Reference in New Issue
Block a user