feat(parser): add location to midas AST nodes

This commit is contained in:
2026-05-25 12:14:14 +02:00
parent 9b59058881
commit e94db2181f
6 changed files with 161 additions and 53 deletions

37
midas/ast/location.py Normal file
View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Protocol
class HasLocation(Protocol):
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@dataclass(frozen=True, kw_only=True)
class Location:
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@staticmethod
def from_ast(obj: HasLocation) -> Location:
return Location(
lineno=obj.lineno,
col_offset=obj.col_offset,
end_lineno=obj.end_lineno,
end_col_offset=obj.end_col_offset,
)
@staticmethod
def span(start: Location, end: Location) -> Location:
return Location(
lineno=start.lineno,
col_offset=start.col_offset,
end_lineno=end.lineno,
end_col_offset=end.end_col_offset,
)

View File

@@ -9,6 +9,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
from midas.lexer.token import Token
T = TypeVar("T")
@@ -18,8 +19,10 @@ T = TypeVar("T")
##############
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class Stmt(ABC):
location: Optional[Location] = None
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
@@ -109,8 +112,10 @@ class PredicateStmt(Stmt):
###############
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class Expr(ABC):
location: Optional[Location] = None
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...

View File

@@ -3,35 +3,13 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import ast
from dataclasses import dataclass
from typing import Generic, Optional, Protocol, TypeVar
from typing import Generic, Optional, TypeVar
from midas.ast.location import Location
T = TypeVar("T")
class HasLocation(Protocol):
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@dataclass(frozen=True, kw_only=True)
class Location:
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@staticmethod
def from_ast(obj: HasLocation) -> Location:
return Location(
lineno=obj.lineno,
col_offset=obj.col_offset,
end_lineno=obj.end_lineno,
end_col_offset=obj.end_col_offset,
)
@dataclass(frozen=True, kw_only=True)
class Expr(ABC):
location: Optional[Location] = None

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any
from midas.ast.location import Location
from midas.lexer.position import Position
@@ -63,3 +66,23 @@ class Token:
lexeme: str
value: Any
position: Position
def get_location(self) -> Location:
lineno: int = self.position.line
col_offset: int = self.position.column - 1
end_lineno = lineno
end_col_offset = col_offset
for c in self.lexeme:
end_col_offset += 1
if c == "\n":
end_lineno += 1
end_col_offset = 0
return Location(
lineno=lineno,
col_offset=col_offset,
end_lineno=end_lineno,
end_col_offset=end_col_offset,
)
def location_to(self, to: Token) -> Location:
return Location.span(self.get_location(), to.get_location())

View File

@@ -1,5 +1,6 @@
from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import (
BinaryExpr,
ComplexTypeStmt,
@@ -104,6 +105,7 @@ class MidasParser(Parser):
Returns:
TypeStmt: the parsed type declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
template: Optional[TemplateExpr] = None
if self.check(TokenType.LEFT_BRACKET):
@@ -116,11 +118,20 @@ class MidasParser(Parser):
if self.match(TokenType.WHERE):
constraint = self.constraint()
return SimpleTypeStmt(
name=name, template=template, base=base, constraint=constraint
location=keyword.location_to(self.previous()),
name=name,
template=template,
base=base,
constraint=constraint,
)
else:
properties: list[PropertyStmt] = self.type_properties()
return ComplexTypeStmt(name=name, template=template, properties=properties)
return ComplexTypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
properties=properties,
)
def template_expr(self) -> TemplateExpr:
"""Parse a generic template expression
@@ -130,10 +141,14 @@ class MidasParser(Parser):
Returns:
TemplateExpr: the parsed template expression
"""
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
left: Token = self.consume(
TokenType.LEFT_BRACKET, "Missing '[' before template expression"
)
type: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
return TemplateExpr(type=type)
right: Token = self.consume(
TokenType.RIGHT_BRACKET, "Missing ']' after template expression"
)
return TemplateExpr(location=left.location_to(right), type=type)
def type_expr(self) -> TypeExpr:
"""Parse a type expression
@@ -149,7 +164,12 @@ class MidasParser(Parser):
if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr()
optional: bool = self.match(TokenType.QMARK)
return TypeExpr(name=name, template=template, optional=optional)
return TypeExpr(
location=name.location_to(self.previous()),
name=name,
template=template,
optional=optional,
)
def simple_type_expr(self) -> SimpleTypeExpr:
"""Parse a simple type expression
@@ -161,7 +181,9 @@ class MidasParser(Parser):
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
optional: bool = self.match(TokenType.QMARK)
return SimpleTypeExpr(name=name, optional=optional)
return SimpleTypeExpr(
location=name.location_to(self.previous()), name=name, optional=optional
)
def constraint(self) -> Expr:
"""Parse a constraint
@@ -183,7 +205,12 @@ class MidasParser(Parser):
while self.match(TokenType.AND):
operator: Token = self.previous()
right: Expr = self.equality()
expr = LogicalExpr(left=expr, operator=operator, right=right)
location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = LogicalExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def equality(self) -> Expr:
@@ -196,7 +223,12 @@ class MidasParser(Parser):
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
operator: Token = self.previous()
right: Expr = self.comparison()
expr = BinaryExpr(left=expr, operator=operator, right=right)
location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def comparison(self) -> Expr:
@@ -214,7 +246,12 @@ class MidasParser(Parser):
):
operator: Token = self.previous()
right: Expr = self.unary()
expr = BinaryExpr(left=expr, operator=operator, right=right)
location: Optional[Location] = None
if expr.location and right.location:
location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def unary(self) -> Expr:
@@ -226,7 +263,10 @@ class MidasParser(Parser):
if self.match(TokenType.MINUS):
operator: Token = self.previous()
right: Expr = self.unary()
return UnaryExpr(operator=operator, right=right)
location: Optional[Location] = None
if right.location:
location = Location.span(operator.get_location(), right.location)
return UnaryExpr(location=location, operator=operator, right=right)
return self.reference()
def reference(self) -> Expr:
@@ -240,7 +280,10 @@ class MidasParser(Parser):
name: Token = self.consume(
TokenType.IDENTIFIER, "Expected property name after '.'"
)
expr = GetExpr(expr=expr, name=name)
location: Optional[Location] = None
if expr.location:
location = Location.span(expr.location, name.get_location())
expr = GetExpr(location=location, expr=expr, name=name)
return expr
def primary(self) -> Expr:
@@ -251,26 +294,27 @@ class MidasParser(Parser):
Returns:
Expr: the parsed expression
"""
token: Token = self.peek()
if self.match(TokenType.FALSE):
return LiteralExpr(False)
return LiteralExpr(location=token.get_location(), value=False)
if self.match(TokenType.TRUE):
return LiteralExpr(True)
return LiteralExpr(location=token.get_location(), value=True)
if self.match(TokenType.NONE):
return LiteralExpr(None)
return LiteralExpr(location=token.get_location(), value=None)
if self.match(TokenType.NUMBER):
return LiteralExpr(self.previous().value)
return LiteralExpr(location=token.get_location(), value=token.value)
if self.match(TokenType.IDENTIFIER):
return VariableExpr(self.previous())
return VariableExpr(location=token.get_location(), name=token)
if self.match(TokenType.UNDERSCORE):
return WildcardExpr(self.previous())
return WildcardExpr(location=token.get_location(), token=token)
if self.match(TokenType.LEFT_PAREN):
expr: Expr = self.constraint()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return GroupingExpr(expr)
right: Token = self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return GroupingExpr(location=token.location_to(right), expr=expr)
raise self.error(self.peek(), "Expected expression")
@@ -304,7 +348,12 @@ class MidasParser(Parser):
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
return PropertyStmt(name=name, type=type, constraint=constraint)
return PropertyStmt(
location=name.location_to(self.previous()),
name=name,
type=type,
constraint=constraint,
)
def extend_declaration(self) -> ExtendStmt:
"""Parse an extension definition
@@ -314,13 +363,17 @@ class MidasParser(Parser):
Returns:
ExtendStmt: the parsed extension statement
"""
keyword: Token = self.previous()
type: TypeExpr = self.type_expr()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
operations: list[OpStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
operations.append(self.op_declaration())
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
return ExtendStmt(type=type, operations=operations)
location: Optional[Location] = None
if type.location:
location = keyword.location_to(self.previous())
return ExtendStmt(location=location, type=type, operations=operations)
def op_declaration(self) -> OpStmt:
"""Parse an operation definition
@@ -330,7 +383,7 @@ class MidasParser(Parser):
Returns:
OpStmt: the parsed operation statement
"""
self.consume(TokenType.OP, "Expected 'op' keyword")
keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword")
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
@@ -340,7 +393,12 @@ class MidasParser(Parser):
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: TypeExpr = self.type_expr()
return OpStmt(name=name, operand=operand, result=result)
return OpStmt(
location=keyword.location_to(self.previous()),
name=name,
operand=operand,
result=result,
)
def predicate_declaration(self) -> PredicateStmt:
"""Parse a predicate declaration
@@ -350,6 +408,7 @@ class MidasParser(Parser):
Returns:
PredicateStmt: the parsed predicate declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
@@ -358,4 +417,10 @@ class MidasParser(Parser):
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint()
return PredicateStmt(name=name, subject=subject, type=type, condition=condition)
return PredicateStmt(
location=keyword.location_to(self.previous()),
name=name,
subject=subject,
type=type,
condition=condition,
)

View File

@@ -1,6 +1,7 @@
import ast
from typing import Any, Optional
from midas.ast.location import Location
from midas.ast.python import (
BaseType,
ConstraintType,
@@ -8,7 +9,6 @@ from midas.ast.python import (
FrameType,
Function,
FunctionArgument,
Location,
MidasType,
)