Files
midas/midas/parser/python.py

189 lines
6.2 KiB
Python

import ast
from typing import Any, Optional
from midas.ast.python import (
BaseType,
ConstraintType,
FrameColumn,
FrameType,
Function,
FunctionArgument,
Location,
MidasType,
)
class InvalidSyntaxError(Exception):
pass
class UnsupportedSyntaxError(Exception):
def __init__(self, expr: ast.expr) -> None:
super().__init__(
f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}"
)
class PythonParser(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self.annotations: list[tuple[str, Optional[MidasType]]] = []
self.functions: list[Function] = []
def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
match node:
case ast.AnnAssign(
target=ast.Name(id=target), annotation=annotation, simple=1
):
self.annotations.append(
(target, self._parse_type(annotation, root=True))
)
case _:
print(f"Unsupported annotation: {ast.unparse(node)}")
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
self.functions.append(self._parse_function(node))
# Call visit on children to process body
# TODO: scope the resulting nodes to the function
self.generic_visit(node)
def _parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node)
match node:
case ast.FunctionDef(
name=name,
args=ast.arguments(
posonlyargs=posonlyargs,
args=args,
kwonlyargs=kwonlyargs,
),
returns=returns,
):
def parse_args(args_list: list[ast.arg]) -> list[FunctionArgument]:
return [self._parse_function_argument(arg) for arg in args_list]
return Function(
location=loc,
name=name,
posonlyargs=parse_args(posonlyargs),
args=parse_args(args),
kwonlyargs=parse_args(kwonlyargs),
returns=self._parse_type(returns) if returns is not None else None,
)
def _parse_function_argument(self, arg: ast.arg) -> FunctionArgument:
loc: Location = Location.from_ast(arg)
name: str = arg.arg
type: Optional[MidasType] = None
if arg.annotation is not None:
type = self._parse_type(arg.annotation)
return FunctionArgument(
location=loc,
name=name,
type=type,
)
def _parse_type(
self, type_expr: ast.expr, root: bool = False
) -> Optional[MidasType]:
loc: Location = Location.from_ast(type_expr)
match type_expr:
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
return self._parse_frame_type(schema)
case ast.Subscript(value=ast.Name(id=name), slice=param):
return BaseType(
location=loc,
base=name,
param=self._parse_type(param),
)
case ast.Name(id=name):
return BaseType(
location=loc,
base=name,
param=None,
)
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
left = self._parse_type(left_expr)
match left:
case None:
raise InvalidSyntaxError()
# If chained constraints, separate base type and rebuild constraint
case ConstraintType(type=left_type, constraint=left_constraint):
constraint = ast.BinOp(
left=left_constraint,
op=ast.Add(),
right=right_expr,
)
ast.copy_location(constraint, type_expr)
return ConstraintType(
location=loc,
type=left_type,
constraint=constraint,
)
case _:
return ConstraintType(
location=loc,
type=left,
constraint=right_expr,
)
case _:
if root:
return None
raise UnsupportedSyntaxError(type_expr)
def _parse_frame_type(self, schema: ast.expr) -> FrameType:
loc: Location = Location.from_ast(schema)
columns: list[FrameColumn] = []
match schema:
case ast.Tuple(elts=cols):
for col in cols:
columns.append(self._parse_frame_column(col))
case ast.Slice() | ast.Name():
columns.append(self._parse_frame_column(schema))
case _:
raise UnsupportedSyntaxError(schema)
return FrameType(location=loc, columns=columns)
def _parse_frame_column(self, column: ast.expr) -> FrameColumn:
loc: Location = Location.from_ast(column)
match column:
case ast.Name():
return FrameColumn(
location=loc,
name=None,
type=self._parse_type(column),
)
case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
if name == "_":
name = None
type: Optional[MidasType] = None
match type_expr:
case None:
raise InvalidSyntaxError("Missing column type")
case ast.Name(id="_"):
type = None
case ast.expr():
type = self._parse_type(type_expr)
case _:
raise UnsupportedSyntaxError(type_expr)
return FrameColumn(location=loc, name=name, type=type)
case _:
raise UnsupportedSyntaxError(column)