Compare commits

...

2 Commits

Author SHA1 Message Date
bea3f399ad feat(checker): handle ternary expression 2026-06-01 15:02:12 +02:00
55060bfecd feat(parser): add ternary statement 2026-06-01 15:00:21 +02:00
8 changed files with 75 additions and 3 deletions

View File

@@ -12,3 +12,5 @@ def factorial(n: int) -> int:
if n <= 1: if n <= 1:
return 1 return 1
return n * factorial(n - 1) return n * factorial(n - 1)
category = "Category 1" if a < 10 else "Category 2"

View File

@@ -139,4 +139,10 @@ class CastExpr:
expr: Expr expr: Expr
class TernaryExpr:
test: Expr
if_true: Expr
if_false: Expr
###< ###<

View File

@@ -465,6 +465,7 @@ class PythonAstPrinter(
self._write_line("IfStmt") self._write_line("IfStmt")
with self._child_level(): with self._child_level():
self._write_line("test") self._write_line("test")
with self._child_level(single=True):
stmt.test.accept(self) stmt.test.accept(self)
self._write_line("body") self._write_line("body")
with self._child_level(): with self._child_level():
@@ -592,3 +593,18 @@ class PythonAstPrinter(
self._write_line("expr", last=True) self._write_line("expr", last=True)
with self._child_level(single=True): with self._child_level(single=True):
expr.expr.accept(self) expr.expr.accept(self)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self._write_line("TernaryExpr")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
expr.test.accept(self)
self._write_line("if_true")
with self._child_level(single=True):
expr.if_true.accept(self)
self._write_line("if_false", last=True)
with self._child_level(single=True):
expr.if_false.accept(self)

View File

@@ -220,6 +220,9 @@ class Expr(ABC):
@abstractmethod @abstractmethod
def visit_cast_expr(self, expr: CastExpr) -> T: ... def visit_cast_expr(self, expr: CastExpr) -> T: ...
@abstractmethod
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class BinaryExpr(Expr): class BinaryExpr(Expr):
@@ -312,3 +315,13 @@ class CastExpr(Expr):
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_cast_expr(self) return visitor.visit_cast_expr(self)
@dataclass(frozen=True)
class TernaryExpr(Expr):
test: Expr
if_true: Expr
if_false: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_ternary_expr(self)

View File

@@ -9,7 +9,7 @@ from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.types import Function, Type, UnitType, UnknownType from midas.checker.types import BaseType, Function, SimpleType, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token from midas.lexer.token import Token
from midas.parser.midas import MidasParser from midas.parser.midas import MidasParser
@@ -405,6 +405,22 @@ class Checker(
def visit_cast_expr(self, expr: p.CastExpr) -> Type: def visit_cast_expr(self, expr: p.CastExpr) -> Type:
return expr.type.accept(self) return expr.type.accept(self)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
test_type: Type = expr.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.ctx.get_type("bool"):
self.error(
expr.test.location, f"If test must be a boolean, got {test_type}"
)
true_type: Type = expr.if_true.accept(self)
false_type: Type = expr.if_false.accept(self)
if true_type != false_type:
self.error(expr.location, f"Type mismatch in ternary if branches: true={true_type} != false={false_type}")
return UnknownType()
return true_type
def visit_base_type(self, node: p.BaseType) -> Type: def visit_base_type(self, node: p.BaseType) -> Type:
return self.ctx.get_type(node.base) return self.ctx.get_type(node.base)

View File

@@ -203,6 +203,8 @@ class PythonHighlighter(
def visit_cast_expr(self, expr: p.CastExpr) -> None: ... def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]): class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css" EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"

View File

@@ -22,6 +22,7 @@ from midas.ast.python import (
MidasType, MidasType,
ReturnStmt, ReturnStmt,
Stmt, Stmt,
TernaryExpr,
TypeAssign, TypeAssign,
UnaryExpr, UnaryExpr,
VariableExpr, VariableExpr,
@@ -389,6 +390,9 @@ class PythonParser:
case ast.Call(): case ast.Call():
return self.parse_call(node) return self.parse_call(node)
case ast.IfExp():
return self.parse_ternary(node)
case ast.Constant(value=value): case ast.Constant(value=value):
return LiteralExpr(location=location, value=value) return LiteralExpr(location=location, value=value)
@@ -478,3 +482,11 @@ class PythonParser:
if arg.arg is not None # Should always be True, type checker happy if arg.arg is not None # Should always be True, type checker happy
}, },
) )
def parse_ternary(self, node: ast.IfExp) -> TernaryExpr:
return TernaryExpr(
location=Location.from_ast(node),
test=self.parse_expr(node.test),
if_true=self.parse_expr(node.body),
if_false=self.parse_expr(node.orelse),
)

View File

@@ -180,3 +180,8 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def visit_cast_expr(self, expr: p.CastExpr) -> None: def visit_cast_expr(self, expr: p.CastExpr) -> None:
self.resolve(expr.expr) self.resolve(expr.expr)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self.resolve(expr.test)
self.resolve(expr.if_true)
self.resolve(expr.if_false)