feat(parser): store locations in parsed nodes
This commit is contained in:
@@ -3,13 +3,39 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, Optional, TypeVar
|
||||
from typing import Generic, Optional, Protocol, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HasLocation(Protocol):
|
||||
lineno: int
|
||||
col_offset: int
|
||||
end_lineno: Optional[int]
|
||||
end_col_offset: Optional[int]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Location:
|
||||
lineno: int
|
||||
col_offset: int
|
||||
end_lineno: Optional[int]
|
||||
end_col_offset: Optional[int]
|
||||
|
||||
@staticmethod
|
||||
def from_ast(obj: HasLocation) -> Location:
|
||||
return Location(
|
||||
lineno=obj.lineno,
|
||||
col_offset=obj.col_offset,
|
||||
end_lineno=obj.end_lineno,
|
||||
end_col_offset=obj.end_col_offset,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Expr(ABC):
|
||||
location: Optional[Location] = None
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from midas.ast.python import (
|
||||
FrameType,
|
||||
Function,
|
||||
FunctionArgument,
|
||||
Location,
|
||||
MidasType,
|
||||
)
|
||||
|
||||
@@ -50,6 +51,7 @@ class PythonParser(ast.NodeVisitor):
|
||||
self.generic_visit(node)
|
||||
|
||||
def _parse_function(self, node: ast.FunctionDef) -> Function:
|
||||
loc: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.FunctionDef(
|
||||
name=name,
|
||||
@@ -65,6 +67,7 @@ class PythonParser(ast.NodeVisitor):
|
||||
return [self._parse_function_argument(arg) for arg in args_list]
|
||||
|
||||
return Function(
|
||||
location=loc,
|
||||
name=name,
|
||||
posonlyargs=parse_args(posonlyargs),
|
||||
args=parse_args(args),
|
||||
@@ -73,24 +76,38 @@ class PythonParser(ast.NodeVisitor):
|
||||
)
|
||||
|
||||
def _parse_function_argument(self, arg: ast.arg) -> FunctionArgument:
|
||||
loc: Location = Location.from_ast(arg)
|
||||
name: str = arg.arg
|
||||
type: Optional[MidasType] = None
|
||||
if arg.annotation is not None:
|
||||
type = self._parse_type(arg.annotation)
|
||||
return FunctionArgument(name=name, type=type)
|
||||
return FunctionArgument(
|
||||
location=loc,
|
||||
name=name,
|
||||
type=type,
|
||||
)
|
||||
|
||||
def _parse_type(
|
||||
self, type_expr: ast.expr, root: bool = False
|
||||
) -> Optional[MidasType]:
|
||||
loc: Location = Location.from_ast(type_expr)
|
||||
match type_expr:
|
||||
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
||||
return self._parse_frame_type(schema)
|
||||
|
||||
case ast.Subscript(value=ast.Name(id=name), slice=param):
|
||||
return BaseType(base=name, param=self._parse_type(param))
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
param=self._parse_type(param),
|
||||
)
|
||||
|
||||
case ast.Name(id=name):
|
||||
return BaseType(base=name, param=None)
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
param=None,
|
||||
)
|
||||
|
||||
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
||||
left = self._parse_type(left_expr)
|
||||
@@ -107,12 +124,17 @@ class PythonParser(ast.NodeVisitor):
|
||||
)
|
||||
ast.copy_location(constraint, type_expr)
|
||||
return ConstraintType(
|
||||
location=loc,
|
||||
type=left_type,
|
||||
constraint=constraint,
|
||||
)
|
||||
|
||||
case _:
|
||||
return ConstraintType(type=left, constraint=right_expr)
|
||||
return ConstraintType(
|
||||
location=loc,
|
||||
type=left,
|
||||
constraint=right_expr,
|
||||
)
|
||||
|
||||
case _:
|
||||
if root:
|
||||
@@ -120,6 +142,7 @@ class PythonParser(ast.NodeVisitor):
|
||||
raise UnsupportedSyntaxError(type_expr)
|
||||
|
||||
def _parse_frame_type(self, schema: ast.expr) -> FrameType:
|
||||
loc: Location = Location.from_ast(schema)
|
||||
columns: list[FrameColumn] = []
|
||||
|
||||
match schema:
|
||||
@@ -133,12 +156,17 @@ class PythonParser(ast.NodeVisitor):
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(schema)
|
||||
|
||||
return FrameType(columns=columns)
|
||||
return FrameType(location=loc, columns=columns)
|
||||
|
||||
def _parse_frame_column(self, column: ast.expr) -> FrameColumn:
|
||||
loc: Location = Location.from_ast(column)
|
||||
match column:
|
||||
case ast.Name():
|
||||
return FrameColumn(name=None, type=self._parse_type(column))
|
||||
return FrameColumn(
|
||||
location=loc,
|
||||
name=None,
|
||||
type=self._parse_type(column),
|
||||
)
|
||||
|
||||
case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
|
||||
if name == "_":
|
||||
@@ -154,7 +182,7 @@ class PythonParser(ast.NodeVisitor):
|
||||
type = self._parse_type(type_expr)
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(type_expr)
|
||||
return FrameColumn(name=name, type=type)
|
||||
return FrameColumn(location=loc, name=name, type=type)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(column)
|
||||
|
||||
Reference in New Issue
Block a user