434 Commits

Author SHA1 Message Date
9764484fd9 docs: add docstrings to midas parser 2026-07-04 01:30:14 +02:00
5b9e322c91 docs: add some docstrings in lexer classes 2026-07-03 22:41:21 +02:00
c18d9c18de tests: update with new parameter spec 2026-07-03 19:31:17 +02:00
9229f00375 refactor: rebrand function parameters and unify spec
rename function arguments to parameters where it was wrong, and add ParamSpec for Python AST, like for Midas
2026-07-03 19:24:30 +02:00
6b7a682dc5 docs: add some docstrings 2026-07-03 17:36:45 +02:00
35b97fd17b refactor(ast): restructure printers 2026-07-03 17:26:28 +02:00
03bc32400b Merge pull request 'Frame / columns in manual' (#28) from feat/frame-columns-in-manual into main
Reviewed-on: #28
2026-07-03 14:38:44 +00:00
4a93ee45d9 docs: add section about Frame type annotations 2026-07-03 16:32:35 +02:00
8197131d8d docs: add Column and Frame to manual 2026-07-03 13:31:56 +02:00
cf91187b7a fix(checker): remove bool as subtype of int 2026-07-03 12:56:47 +02:00
1b2bdf0b79 docs: add alias statements to manual 2026-07-03 12:56:20 +02:00
c6cc38bfeb Merge pull request 'Frame / column operations' (#27) from feat/simple-frame-ops into main
Reviewed-on: #27
2026-07-03 10:29:32 +00:00
4d3e3f44a1 fix(checker): correctly check length of frame/column 2026-07-03 12:28:39 +02:00
ec80b1e92e feat(checker): add head/tail methods 2026-07-03 12:13:30 +02:00
4ea15519f3 feate(checker): add some frame/column attributes 2026-07-03 12:07:36 +02:00
7a6e01cff8 fix(checker): delegate frame aggregate methods to columns 2026-07-03 11:42:35 +02:00
733c8736b8 feat(checker): add aggregation ops on column groupby 2026-07-03 11:25:06 +02:00
20173a0b07 feat(tests): add colors and run all tests in base module 2026-07-03 10:58:28 +02:00
a143972ef1 feat(checker): add aggregation ops on frame groupby 2026-07-03 02:20:51 +02:00
0c70048b62 feat(checker): add statistical ops on columns 2026-07-03 01:34:58 +02:00
1c0c917873 feat(checker): add statistical ops on frames 2026-07-03 01:27:16 +02:00
1f6189daa4 feat(checker): add comparison binary ops on columns 2026-07-03 01:05:24 +02:00
66b585c3d6 fix(checker): recursively check builtin subtypes 2026-07-03 01:04:45 +02:00
819ab3c2bf tests: add dataframe operations test 2026-07-03 00:58:29 +02:00
d8c0b17512 feat(checker): add comparison binary ops on frames 2026-07-03 00:57:27 +02:00
6e06f9078e fix(checker): improve unknown method message 2026-07-03 00:57:10 +02:00
ece2e3a6a3 feat(checker): add arithmetic binary ops on columns 2026-07-03 00:42:00 +02:00
74c07c9afb feat(checker): add arithmetic binary ops on frames 2026-07-03 00:38:56 +02:00
be2fd4c837 feat(checker): delegate element operation to inner type
delegate element-wise binary operation on columns to their inner types
2026-07-03 00:05:40 +02:00
1bc4c704c3 feat(checker): delegate element operation to columns
delegate element-wise binary operation on frames to columns
2026-07-02 23:41:08 +02:00
0288a05901 feat(checker): handle assignment to multiple columns 2026-07-02 23:29:10 +02:00
b14f46d405 feat(checker): handle calls on group-bys 2026-07-02 19:53:58 +02:00
8e8ed62266 feat(checker): add add/mean/groupby on columns 2026-07-02 19:30:43 +02:00
2fce2f4bfc feat(checker): add column method registry 2026-07-02 19:23:23 +02:00
640f2d1771 feat(checker): support unification of frames and columns 2026-07-02 19:22:28 +02:00
b48dfe5301 refactor: make MethodRegistry generic on Call 2026-07-02 18:27:26 +02:00
0d5840a4ce refactor: restructure frame method registry in submodule 2026-07-02 18:20:10 +02:00
3c92f0867d feat(types): add ColumnGroupBy 2026-07-02 18:00:25 +02:00
b5acae4078 feat(types): add FrameGroupBy type 2026-07-02 17:45:18 +02:00
5d20f8ec3e docs: mention eager evaluation in manual 2026-07-02 17:22:28 +02:00
955c2233ed feat(checker): statically evaluate casts to Any and None 2026-07-02 17:14:30 +02:00
ff69b65171 feat(checker): add same length assertion on frames
safely adding two dataframes is only possible if the sizes are the same, or null values could be added dynamically to pad the shortest dataframe
2026-07-02 17:14:05 +02:00
8df01afd8c feat(gen): materialize assertions from collector 2026-07-02 17:10:27 +02:00
47b2dfdd73 feat(gen): add assertion collector to TypedAST 2026-07-02 17:09:50 +02:00
bd4d793ce0 feat(gen): add Assertion class 2026-07-02 17:08:43 +02:00
f7a36f61b6 fix(checker): pass AST expression to method registry 2026-07-01 22:34:02 +02:00
ad2fabf471 feat(checker): add assertion collector 2026-07-01 22:32:13 +02:00
a59a58d21a feat(gen): generate alias stubs 2026-07-01 14:43:30 +02:00
3260ae4a1e Merge pull request 'Call dispatcher' (#26) from feat/call-dispatcher into main
Reviewed-on: #26
2026-07-01 12:22:11 +00:00
bd1c9581c7 fix(checker): use dispatcher in frame method registry 2026-07-01 14:17:10 +02:00
663642ea6c fix(tests): serialize alias statements 2026-07-01 14:13:27 +02:00
e2abc04fe4 feat(checker): define min/max in preamble 2026-07-01 14:10:19 +02:00
a4016b55ce feat(checker): handle calls to AppliedType 2026-07-01 14:10:19 +02:00
1ea5da7024 feat(parser): parse binary operations in Midas 2026-07-01 14:10:18 +02:00
a017a8cf1f feat(checker): catch errors when evaluating constraint 2026-07-01 14:10:17 +02:00
8fc5ab623e feat(checker): evaluate literal cast to list/dict 2026-07-01 14:10:16 +02:00
14007db846 feat(checker): evaluate unary op on literals 2026-07-01 14:10:15 +02:00
6ad2ce4b68 feat(checker): improve function unwrapping 2026-07-01 14:10:15 +02:00
9a276c34c7 refactor: reuse CallDispatcher 2026-07-01 11:32:41 +02:00
6e717a3f9e refactor: use CallDispatcher in Midas typer 2026-07-01 11:24:09 +02:00
77aadfa264 refactor: extract function call methods to CallDispatcher 2026-07-01 11:14:08 +02:00
c81287df7f Merge pull request 'Initial dataframe implementation' (#25) from feat/dataframes into main
Reviewed-on: #25
2026-07-01 08:24:36 +00:00
ffccc1bedd feat(cli): generate stubs in build dir when compiling 2026-07-01 10:16:13 +02:00
d14f208897 feat(gen): add tuple expr to generator 2026-07-01 10:16:13 +02:00
293953a078 tests: update with multi-parameter generics 2026-07-01 10:16:12 +02:00
bccc96e4d0 fix: minor fixes 2026-07-01 10:16:11 +02:00
9db56adf56 feat: add Python tuple expression 2026-07-01 10:16:10 +02:00
3f99563ac8 feat: handle multi-parameter generic in Python 2026-07-01 10:16:10 +02:00
b36896cc7b feat(checker): add len() 2026-07-01 10:16:09 +02:00
cb75878ae9 fix(checker): allow some assignments to unknown 2026-07-01 10:16:08 +02:00
a5fe985eb2 feat(checker): add methods on str 2026-07-01 10:16:08 +02:00
e324f414e6 feat(checker): type check tuple instantiation in Midas 2026-07-01 10:16:07 +02:00
256536562f fix(parser): parse empty calls 2026-07-01 10:16:06 +02:00
64f4314f0d fix(gen): prevent empty loop for column asserts 2026-07-01 10:16:06 +02:00
6f6245d283 fix(checker): allow iterating on unknown 2026-07-01 10:16:05 +02:00
3392bc347d fix(checker): allow subtypes and unknown as if test 2026-07-01 10:16:04 +02:00
7e0319906a feat(gen): assertions for column values 2026-07-01 10:16:03 +02:00
75bd203d4a fix(checker): allow calling unknown method on dataframes 2026-07-01 10:15:16 +02:00
db40198357 feat(gen): generate asserts for dataframes and columns 2026-07-01 10:15:16 +02:00
d79e1dee18 fix(checker): change heterogeneous errors to warnings 2026-07-01 10:15:15 +02:00
4ea400265c feat(checker): add mean method on frames 2026-07-01 10:15:14 +02:00
24bffdabd4 fix(checker): type check None literal 2026-07-01 10:15:13 +02:00
d7bb6326de feat(checker): lookup dunders on dataframes 2026-07-01 10:15:12 +02:00
dbf6f9e2db tests: update with reordered argument typing 2026-07-01 10:15:12 +02:00
3cdc9031d3 refactor: use metaclass to collect frame methods 2026-07-01 10:15:11 +02:00
00e2ca8fe3 refactor: add MethodResolver class 2026-07-01 10:15:10 +02:00
4efb01285c feat: add dummy classes for typing frames and columns 2026-07-01 10:15:10 +02:00
f84a19159f fix(checker): improve heterogeneous error message 2026-07-01 10:15:09 +02:00
946b2e0d2e feat(checker): lookup dataframe methods 2026-07-01 10:15:08 +02:00
08dd7408ec feat(checker): defined add method of dataframes 2026-07-01 10:15:07 +02:00
b33fadf768 feat(checker): add structural subtyping rule for dataframes 2026-07-01 10:15:06 +02:00
7219109e5d feat(cli): print context for multiline diagnostics 2026-07-01 10:14:48 +02:00
cdf1725c26 feat(checker): process frame type definitions 2026-07-01 10:14:48 +02:00
7074b074bc feat(cli): add frame type to highlighter 2026-07-01 10:14:17 +02:00
ede7272c09 feat(parser): add frame type to midas syntax 2026-07-01 10:14:16 +02:00
87d5e286d2 feat(gen): add support for tuples and dataframes 2026-07-01 10:14:16 +02:00
c91b206791 feat(checker): handle setting dataframe column 2026-07-01 10:13:30 +02:00
a31d295eb1 feat(checker): type check subscript on dataframes 2026-07-01 10:13:30 +02:00
0d20993f02 feat(types): add TupleType 2026-07-01 10:13:28 +02:00
5357ca8e58 fix(types): add str methods to dataframe types 2026-07-01 10:13:28 +02:00
556765fd35 feat(types): add DataFrameType and ColumnType 2026-07-01 10:13:27 +02:00
d039a8e4b3 Merge pull request 'Type aliases vs. Derived types' (#24) from feat/subtypes-and-aliases into main
Reviewed-on: #24
2026-07-01 08:09:13 +00:00
c4533421eb feat(checker): process alias definitions 2026-07-01 09:59:58 +02:00
73769b42c1 feat(parser): add alias keyword and statement 2026-07-01 09:30:09 +02:00
087f6b4669 refactor(types): rename AliasType to DerivedType 2026-06-30 16:28:16 +02:00
d582df5927 Merge pull request 'User manual' (#23) from feat/manual into main
Reviewed-on: #23
2026-06-30 14:11:45 +00:00
6a0401833c feat(manual): add strings to midas syntax def 2026-06-30 14:10:32 +02:00
e15607b763 fix(manual): end syntax highlighting of extend body 2026-06-30 14:03:42 +02:00
e28f324a85 fix(manual): typos 2026-06-28 22:30:09 +02:00
31e696c938 feat(manual): add listings outline and tweak template 2026-06-28 22:28:13 +02:00
759b416bf3 feat(manual): wrap all code in figures 2026-06-28 22:20:15 +02:00
4b2b0fe476 feat(manual): document supported Python syntax 2026-06-28 21:41:39 +02:00
4c39504750 feat(manual): document predicate and constraints 2026-06-28 14:12:41 +02:00
f9f3ade6c7 feat(manual): document type statement 2026-06-28 12:37:44 +02:00
386018b956 feat(manual): add sublime syntax for Midas 2026-06-28 12:36:02 +02:00
bd47d33355 feat(manual): complete introduction and quick start 2026-06-26 17:52:54 +02:00
93ddb28802 docs: setup user manual 2026-06-24 15:53:52 +02:00
f7c43837b5 Merge pull request 'CLI tweaks' (#22) from fix/cli-tweaks into main
Reviewed-on: #22
2026-06-24 12:18:07 +00:00
32ed62a6f1 fix(cli): show summary of diagnostic counts 2026-06-24 14:11:39 +02:00
66f39acec0 fix(cli): show all diagnostics in types command
combine type checker diagnostics with judgements info diagnostics
2026-06-24 14:11:15 +02:00
6c04e2fee4 feat(cli): add compile option to ignore errors 2026-06-24 14:10:30 +02:00
2bb2e0a684 Merge pull request 'Unsafe cast' (#21) from feat/unsafe-cast into main
Reviewed-on: #21
2026-06-24 12:00:03 +00:00
5630320d21 chore: use unsafe_cast in demo script 2026-06-24 13:57:38 +02:00
9f05ba3224 feat: handle unsafe casts 2026-06-24 13:51:14 +02:00
5fbe965919 feat(checker): add typing submodule with cast functions 2026-06-24 13:40:23 +02:00
252a5abdfd Merge pull request 'Static evalution of casts on literals' (#20) from feat/literal-static-constraints into main
Reviewed-on: #20
2026-06-24 09:32:54 +00:00
55fba6a088 tests: update test without evaluated casts 2026-06-24 11:28:44 +02:00
70ce263ea2 feat(gen): skip assertions for evaluated casts
avoid generating a runtime assertion for a cast which has already been checked statically
2026-06-24 11:28:43 +02:00
e1d5eac8b8 feat(checker): evaluate constraints statically on literals 2026-06-24 11:10:09 +02:00
82666a4918 feat(checker): add evaluator
add an evaluator class to evaluate expressions using literal values
2026-06-24 11:08:15 +02:00
45f84a2f23 feat(checker): add debug diagnostics 2026-06-24 11:07:42 +02:00
dedfcb4dbb feat(checker): store builtin python functions in preamble 2026-06-24 11:05:36 +02:00
d9ea6365ea tests: update with cast expression judgement 2026-06-23 16:49:38 +02:00
9c7a93412c Merge pull request 'Fixes and small demo' (#19) from feat/demonstration into main
Reviewed-on: #19
2026-06-23 08:15:56 +00:00
d6b8fbfb60 chore: improve demo example 2026-06-23 10:03:24 +02:00
b290c59ac4 fix(gen): add bases for ConstraintType and TypeVar 2026-06-23 00:25:43 +02:00
093f2bc477 fix(checker): lookup member on typevar bound 2026-06-23 00:24:37 +02:00
7c771c4070 feat(checker): add input function to preamble 2026-06-23 00:22:38 +02:00
a50a207385 fix(gen): don't generate stubs for builtin types 2026-06-22 15:40:31 +02:00
7e5ea5e414 chore: add example to demonstrate some features 2026-06-22 15:29:39 +02:00
0ba0266bae fix(checker): check general subtype case for AppliedType
this adds the case where we check whether AppliedType <: Type, and delegates to the body

this may not be a legitimate rule, or may need to be refined
2026-06-22 15:27:06 +02:00
216c80f08c fix(checker): produce judgement for expression in cast 2026-06-22 15:24:51 +02:00
f75d7722a1 fix(checker): look up members on constraint type 2026-06-22 15:24:18 +02:00
2f29c47274 fix(gen): assert type var bound 2026-06-22 15:23:53 +02:00
80af2b9048 fix(checker): handle is_subtype of TypeVar 2026-06-22 14:44:51 +02:00
577454ee7e fix(checker): make UnknownType a top type for subtyping 2026-06-22 14:15:18 +02:00
878693383e feat(cli): add watch option to stubs command 2026-06-22 14:14:05 +02:00
0b91de75a8 feat(checker): handle type vars in python functions 2026-06-22 14:13:25 +02:00
739871c101 Merge pull request 'Generic call unification' (#18) from feat/unification into main
Reviewed-on: #18
2026-06-21 11:41:48 +00:00
4395e9339b fix(checker): abort unification on conflict 2026-06-21 13:36:07 +02:00
29e601128d tests: add unification test 2026-06-21 13:19:17 +02:00
b591f5508f fix(checker): make map definition generic 2026-06-21 13:17:35 +02:00
41d0c84bbe feat(checker): add unifier
add unifier class to infer type parameters from local call context
2026-06-21 13:12:27 +02:00
cccf2f8f9f Merge pull request 'Stubs generator' (#17) from feat/stubs-gen into main
Reviewed-on: #17
2026-06-20 15:44:34 +00:00
3f48c2138f chore: add stubs command to README 2026-06-20 17:44:15 +02:00
e4ab27673d fix(gen): handle TypeVar variance in stubs generator 2026-06-20 17:34:40 +02:00
b02ecc6326 fix(gen): handle ConstraintType in stubs generator 2026-06-20 17:34:22 +02:00
9e83079910 fix(cli): add missing methods to highlighter 2026-06-20 17:23:18 +02:00
ec468dd982 feat(cli): add stubs command 2026-06-20 17:10:25 +02:00
3edc25d778 feat(gen): add base for stubs generator 2026-06-20 17:10:24 +02:00
451e54b009 fix(checker): handle calls to AliasType 2026-06-20 17:10:24 +02:00
0dc14f67aa fix(checker): allow substitutyping type vars in GenericType and TopType 2026-06-20 17:10:23 +02:00
ff79f25628 fix(checker): store member kind in registry 2026-06-20 17:10:23 +02:00
12782dda1e Merge pull request 'Variance inference and subtyping' (#16) from feat/variance into main
Reviewed-on: #16
2026-06-20 14:55:01 +00:00
48a20b4aa0 tests: add tests for variance inference and subtyping 2026-06-20 16:48:19 +02:00
9467187313 feat(checker): use variance in subtype check 2026-06-20 16:30:30 +02:00
cd8f14153d feat(checker): infer type variables variance 2026-06-20 13:39:32 +02:00
6eea0c02e0 Merge pull request 'Constraint types' (#15) from feat/constraint-type into main
Reviewed-on: #15
2026-06-19 20:21:04 +00:00
3205e7b961 fix(checker): change back warning to errors 2026-06-19 22:13:10 +02:00
0aba134290 tests: add predicates and constraints test 2026-06-19 22:13:10 +02:00
1f0bcab2ca fix(checker) minor tweaks 2026-06-19 22:13:09 +02:00
db8d88ef35 feat(parser): parse strings in Midas files 2026-06-19 22:13:09 +02:00
7695d50537 fix(parser): correctly parse keyword arguments 2026-06-19 22:13:08 +02:00
8461d05fa6 fix(checker): handle all operations and calls in predicates 2026-06-19 22:13:08 +02:00
43d2118db7 fix(checker): lookup predicate variables in preamble 2026-06-19 22:13:07 +02:00
6a87b5396f feat(cli): print predicate with dump-registry 2026-06-19 22:13:07 +02:00
e6a581ba6e fix(checker): typo in docstring 2026-06-19 22:13:07 +02:00
2a7aac69ed fix(checker): change some diagnostics to warnings
temporarily change type errors in predicates to warnings until operations are fully type checked
2026-06-19 22:13:06 +02:00
eb5bf19c61 feat(gen): generate type hints for functions 2026-06-19 22:13:06 +02:00
657406ea01 feat(gen): handle predicate aliases
handle cases where a predicate is defined as an alias, i.e. without any parameters
2026-06-19 22:13:05 +02:00
2974386110 fix(parser): fix call expr location span 2026-06-19 22:13:05 +02:00
92ca6b6732 feat(types): detect constraint base subtyping 2026-06-19 22:13:04 +02:00
6aacdb98b7 feat(checker): type check predicate body 2026-06-19 22:13:04 +02:00
1b100b6ceb fix(gen): remove id from named predicate function 2026-06-19 22:13:03 +02:00
6b4c7d27bc fix(tests): update generator tester 2026-06-19 22:13:03 +02:00
2523d638f7 feat(gen): generate predicate functions 2026-06-19 22:13:02 +02:00
5fc7461e29 feat(gen): generate basic constraint assertion 2026-06-19 22:13:02 +02:00
c5154bde81 feat(types): add ConstraintType 2026-06-19 22:13:02 +02:00
d07e8ac0ca refactor: ensure exhaustiveness in some match/case 2026-06-19 22:13:01 +02:00
3380995082 tests: update with new predicate AST representation 2026-06-19 22:13:01 +02:00
7efc44c496 fix(tests): correctly serialize param name 2026-06-19 22:13:00 +02:00
ca94443699 feat(midas): generalize param spec of predicate and parse 2026-06-19 22:12:59 +02:00
c513a85cf2 feat(midas): add CallExpr 2026-06-19 22:12:59 +02:00
2a106c5d07 refactor: add param spec for FunctionType 2026-06-19 22:12:58 +02:00
9672dfd588 Merge pull request 'Update README' (#14) from fix/update-readme into main
Reviewed-on: #14
2026-06-19 13:25:09 +00:00
7639ccc94d chore: update README with new commands 2026-06-19 15:23:49 +02:00
a4a2ed5d64 Merge pull request 'Dictionaries' (#13) from feat/dictionaries into main
Reviewed-on: #13
2026-06-16 18:42:12 +00:00
e5cb90aff6 fix(checker): make builtin type constructor parameter optional 2026-06-16 20:40:48 +02:00
75f8e4af53 feat(checker): type check dictionaries 2026-06-16 20:40:10 +02:00
42c2d7a098 feat(parser): add dictionary expression 2026-06-16 20:35:39 +02:00
5ce3b4abed Merge pull request 'Cast assertions and generator tests' (#12) from feat/cast-assertions into main
Reviewed-on: #12
2026-06-16 12:57:49 +00:00
2a8b7d559c tests: add simple gen test 2026-06-16 14:56:59 +02:00
da38cad23d feat(tests): add generator tester 2026-06-16 14:56:59 +02:00
591012d059 fix(checker): allow calling AppliedType and UnknownType 2026-06-16 14:56:58 +02:00
4b1087d6b9 fix(cli): improve dump-registry command output 2026-06-16 14:56:57 +02:00
732f7b0796 feat(checker): add environment preamble
this adds some builtin functions such as the builtin type constructors
2026-06-16 14:56:56 +02:00
c4062c9595 fix(checker): allow inferred return to be subtype of hint 2026-06-16 14:56:47 +02:00
c3229b557c feat(gen): add basic cast assertions on base type 2026-06-16 12:49:36 +02:00
0a8e0fb6c2 feat(checker): handle raw expr/stmt 2026-06-16 10:39:26 +02:00
61514d036c feat(passer): add raw statements and expressions 2026-06-16 10:38:09 +02:00
2e5cf6f8a2 Merge pull request 'For loops' (#11) from feat/for-loops into main
Reviewed-on: #11
2026-06-15 22:51:12 +00:00
25fabdd6c3 refactor(checker): split type computation and judgement 2026-06-16 00:48:32 +02:00
af1aba41e7 feat(gen): handle for loops 2026-06-16 00:36:43 +02:00
48e13d3348 feat(checker): handle for loops 2026-06-16 00:36:03 +02:00
faa98ce0ef feat(parser): add for loop node 2026-06-16 00:35:05 +02:00
274e366561 feat(cli): add help messages to all commands 2026-06-15 18:55:23 +02:00
119c262da4 Merge pull request 'Simple code generator and CLI redesign' (#10) from feat/code-generator into main
Reviewed-on: #10
2026-06-15 12:29:22 +00:00
81181891c4 feat(gen): output compiled file in build dir 2026-06-15 14:20:17 +02:00
59c1a0c7b6 feat(cli): refactor CLI and add some commands 2026-06-15 14:17:54 +02:00
74f51f361a feat(checker): make checker return TypedAST 2026-06-15 14:16:10 +02:00
f25341b1e7 feat: add pass statements 2026-06-15 13:28:40 +02:00
3281324caf feat(gen): add simple generator 2026-06-15 02:10:22 +02:00
5b062b46e6 Merge pull request 'Refactor, generics, methods, overloads and more' (#9) from feat/generics into main
Reviewed-on: #9
2026-06-14 22:13:14 +00:00
635bf73531 feat(checker): add slice overloads on lists 2026-06-15 00:03:41 +02:00
bd0421b5d8 fix(checker): handle generic overloads 2026-06-15 00:03:40 +02:00
37a464d2bc feat(checker): type check slice expressions 2026-06-15 00:03:40 +02:00
1eedcff5aa feat(parser): add slice expression 2026-06-15 00:03:39 +02:00
35798e5752 tests: update with new subscript and call checks
invalid function calls now return UnknownType even if the function has a return type
2026-06-15 00:03:39 +02:00
0a35563aaf feat(checker): resolve overloads with subtypes
try to find the most specific overload if multiple matches are found
2026-06-15 00:03:38 +02:00
e1da87eaa0 doc(checker): add docstrings to new call checks 2026-06-15 00:03:38 +02:00
2a579c06b1 refactor(checker): unify call check for subscript 2026-06-15 00:03:37 +02:00
46a22797b6 chore: add examples for functions and overloads 2026-06-15 00:03:37 +02:00
7598681729 feat(checker): handle overloaded function calls 2026-06-15 00:03:36 +02:00
2df0380815 fix(types): remove unused operation structures 2026-06-15 00:03:36 +02:00
178e24cd02 feat(checker): type check subscripts 2026-06-15 00:03:35 +02:00
c92b6b5c18 feat(parser): add subscript expressions 2026-06-15 00:03:35 +02:00
6577241af9 feat(checker): handle unary operations 2026-06-15 00:03:34 +02:00
1c71badf24 fix(checker): report unsupported features 2026-06-15 00:03:34 +02:00
064702fe13 tests: update with newly reported judgements 2026-06-15 00:03:33 +02:00
890e2f035a refactor(checker): replace all accept calls
make visitor accept calls more explicit with type_of(), resolve_type_expr() and process_stmt()
2026-06-15 00:03:33 +02:00
0d0115534b tests: update tests 2026-06-15 00:03:33 +02:00
221b5ca926 fix(checker): adapt comparison to lookup method 2026-06-15 00:03:32 +02:00
9a227b6d4c fix(checker): remove in.to_bytes 2026-06-15 00:03:32 +02:00
df2e609c60 fix(checker): handle members on base type 2026-06-15 00:03:31 +02:00
3ee1161680 fix: remove unused op statement 2026-06-15 00:03:31 +02:00
eb223c6cb7 fix(checker): forward parsing errors as diagnostics 2026-06-15 00:03:30 +02:00
6f5d971c66 fix(checker): gravefully handle unknown type 2026-06-15 00:03:30 +02:00
109c8eb35a fix(parser): make name required for mixed and keyword args 2026-06-15 00:03:29 +02:00
99924ee6c2 feat(parser): add mixed arguments in midas functions 2026-06-15 00:03:29 +02:00
4c9cbd9faa feat(checker): add top type (Any) 2026-06-15 00:03:28 +02:00
84a5f41e62 fix: extend example of complex types 2026-06-15 00:03:27 +02:00
6d6bb66c54 feat(checker): define members on builtin types 2026-06-15 00:03:27 +02:00
50eaafc388 feat(tests): update serializer 2026-06-15 00:03:27 +02:00
2935c71366 fix(checker): give warning on unknown variable 2026-06-15 00:03:26 +02:00
52981f12f2 fix(checker): minor fix when using base type in generic 2026-06-15 00:03:26 +02:00
2e898ab1e9 fix(checker): update binary operation lookup 2026-06-15 00:03:25 +02:00
01ff5ca8d5 fix(checker): handle nested generic members 2026-06-15 00:03:25 +02:00
b5de28e291 feat(checker): implement lookup_member method 2026-06-15 00:03:24 +02:00
179b88bfed feat(checker): add members registry 2026-06-15 00:03:24 +02:00
b3665c6462 fix(cli): update highlighter 2026-06-15 00:03:23 +02:00
42284704de feat(parser): accept props and methods in extend 2026-06-15 00:03:23 +02:00
650f60e70c feat(cli): add option to show type judgements 2026-06-15 00:03:22 +02:00
efea1b29e7 fix(cli): show diagnostics from different files 2026-06-15 00:03:22 +02:00
ae0bd75f3b fix(checker): improve error for recursive type ref 2026-06-15 00:03:22 +02:00
d9100d8300 feat(checker): adapt typers to members and extension type 2026-06-15 00:03:21 +02:00
900be47d34 feat(parser): add new ast nodes to parser 2026-06-15 00:03:21 +02:00
3d5f97a0f4 feat(parser): add extension type and rename properties 2026-06-15 00:03:20 +02:00
9fde115016 feat: add function type to midas syntax 2026-06-15 00:03:20 +02:00
f8897dd075 feat(types): add type params to extend statement 2026-06-15 00:03:19 +02:00
380753ca7a refactor(types): extract TypeParams
also rename generic type params to type args (when calling a generic)
2026-06-15 00:03:19 +02:00
4715318913 feat(types): add human-friendly string rep
add `__str__` methods on type structures to improve readability of diagnostics
2026-06-15 00:03:18 +02:00
a78aee1639 fix(resolver): define variable on assignment
if a variable is not already defined when an assignment is visited, it is then defined in the current scope
2026-06-15 00:03:17 +02:00
3581b7600b fix(checker): use reduce_types to infer return type 2026-06-15 00:03:17 +02:00
32207c3d6f refactor(checker): extract reduce_types function 2026-06-15 00:03:16 +02:00
9474a7336a feat(types): WIP add AppliedType 2026-06-15 00:03:16 +02:00
5a6a279eaf feat(checker): WIP add lists 2026-06-15 00:03:15 +02:00
c1f95edc96 feat(types): add name to generic type 2026-06-15 00:03:15 +02:00
098bbc35c5 fix: avoid circular import in builtins.py 2026-06-15 00:03:15 +02:00
314d4d344b refactor(resolver): move resolver to checker module 2026-06-15 00:03:14 +02:00
7236749bd5 refactor(checker): unify builtins definitions 2026-06-15 00:03:14 +02:00
2ff1f27614 refactor(checker): restructure around shared registry
restructure the type checker with a shared TypesRegistry used by MidasTyper and PythonTyper

this commit also relocates some methods in more appropriate places, such as is_subtype and apply_generic (now in TypesRegistry)
2026-06-15 00:03:13 +02:00
111afe4dd4 feat(checker): add reporter class 2026-06-15 00:03:13 +02:00
c4c142482a feat(resolver): handle generic application 2026-06-15 00:03:12 +02:00
f9c15abaf4 refactor(checker): move is_subtype to resolver 2026-06-15 00:03:12 +02:00
d51d24f865 refactor(checker): move unfold_type to types.py 2026-06-15 00:03:11 +02:00
1d00875a8c feat(resolver): handle generics definition 2026-06-15 00:03:11 +02:00
f89722fad8 feat(checker): add generic type structure 2026-06-15 00:03:10 +02:00
27917496c1 Merge pull request 'Subtyping' (#8) from feat/subtyping into main
Reviewed-on: #8
2026-06-14 22:01:45 +00:00
e0179bc442 feat(checker): handle assignments to attributes 2026-06-07 17:50:56 +02:00
e665d03533 fix: remove unused SetExpr 2026-06-07 17:48:31 +02:00
b8cb2b4273 feat(checker): handle attribute getter 2026-06-07 15:07:24 +02:00
d278dc5f5b tests: update tests with operation overloads 2026-06-07 14:28:36 +02:00
59e73f0fd9 fix(checker): invert property subtype check 2026-06-07 14:00:02 +02:00
3e0dc60283 fix(checker): only unfold alias on subtype 2026-06-07 13:59:27 +02:00
c24eb5125e feat(checker): resolve operation overloads with subtypes 2026-06-07 13:43:43 +02:00
25bd895dde feat(cli): improve diagnostic printing 2026-06-07 13:42:15 +02:00
bccd75317e tests: add subtyping test 2026-06-06 16:59:49 +02:00
f0e3f7574f feat(tests): add judgements to test results
add type judgements to checker test results and update all tests (including the new subtyping rules)
2026-06-06 16:58:13 +02:00
5d44081847 feat(checker): implement function subtyping
the logic for checking function subtypes is a WIP and has not been fully tested, there may be some errors and unhandled edge cases
Claude helped lay out and verify the overall steps

Co-authored-by: Claude <noreply@anthropic.com>
2026-06-06 16:53:52 +02:00
2a2bb0aec7 feat(checker): store function param position 2026-06-06 16:50:42 +02:00
67c40a3909 feat(checker): add is_subtype method 2026-06-06 16:30:04 +02:00
1c30188122 feat(checker): record type judgements 2026-06-06 16:25:33 +02:00
82a0f13242 feat(cli): add verbose flag to compile 2026-06-05 14:17:24 +02:00
288d15a9bc Merge pull request 'Usage documentation' (#7) from feat/usage-documentation into main
Reviewed-on: #7
2026-06-05 10:29:42 +00:00
504703d0f7 fix(cli): remove print in main command 2026-06-05 12:26:09 +02:00
e48895d0af docs: add usage documentation in README 2026-06-05 12:25:02 +02:00
13d32d0d27 Merge pull request 'Basic type checker' (#6) from feat/basic-type-checker into main
Reviewed-on: #6
2026-06-05 09:31:53 +00:00
19b9fdd623 Merge pull request 'Improve syntax and types' (#5) from feat/improve-syntax-and-types into feat/basic-type-checker
Reviewed-on: #5
2026-06-05 09:20:56 +00:00
ddcaebb51a fix: remove outdated syntax definition 2026-06-05 11:19:29 +02:00
f182312cd2 fix: update midas syntax definitions 2026-06-05 11:14:53 +02:00
73b21789d5 fix(tests): remove custom imports 2026-06-05 10:48:46 +02:00
5d7c724bc8 fix(cli): add types files argument 2026-06-05 10:44:20 +02:00
74b297c89c feat(checker): remove custom midas import
remove custom import statement (`midas.using`) in favor of passing type definition files as arguments to the checker
2026-06-05 10:43:52 +02:00
822a74acce refactor(checker): rename methods
improve a couple methods names, namely evaluate → type_of and evaluate_block → process_block
2026-06-03 13:03:41 +02:00
9a934fabfd tests: remove union type 2026-06-02 17:22:19 +02:00
828ec9a3fa fix!: remove union type 2026-06-02 17:19:17 +02:00
63a43d79dd chore: update examples 2026-06-02 13:07:53 +02:00
029caf4526 fix(tests): update tests with new syntax 2026-06-02 13:05:38 +02:00
1c5c418f1c fix(tests): serialize ternary expressions 2026-06-02 13:05:06 +02:00
a4139d4652 feat(checker): handle logical expressions 2026-06-02 13:03:07 +02:00
2fd2071d40 feat(parser): parse pass statement and None 2026-06-02 13:02:45 +02:00
97b1ee8ab8 feat(cli): add format command 2026-06-02 13:00:43 +02:00
dee479def5 fix(checker): wrap type definitions in AliasType 2026-06-02 13:00:03 +02:00
c8536e20d2 feat(tests): update Midas serializer 2026-06-02 12:38:58 +02:00
d70137775f feat(cli): update highlighter with new nodes 2026-06-02 12:29:39 +02:00
35ceda99aa chore: tidy 2026-06-02 11:45:49 +02:00
7f3d74ee49 feat(checker)!: resolve new types 2026-06-02 11:44:31 +02:00
b9f378de6f feat(parser)!: update Midas parser with new nodes 2026-06-02 11:42:35 +02:00
ccb17c7290 feat(parser)!: add new Midas AST nodes 2026-06-02 11:41:53 +02:00
505779310a feat: add new midas syntax example 2026-06-02 11:40:42 +02:00
bea3f399ad feat(checker): handle ternary expression 2026-06-01 15:02:12 +02:00
55060bfecd feat(parser): add ternary statement 2026-06-01 15:00:21 +02:00
dd126f2559 fix(cli): improve diagnostic message popup 2026-06-01 14:48:24 +02:00
4151f5373d fix(checker): early define fully-typed function
to handle simple recursion cases where the function has an explicit return type hint, the function must be defined before evaluating its body
2026-06-01 14:40:42 +02:00
bd31713ab4 tests(checker): add control flow test 2026-06-01 14:22:03 +02:00
f4dc57cb96 chore: add control flow example 2026-06-01 14:15:10 +02:00
261fd47494 feat(cli): update highlighter 2026-06-01 14:14:10 +02:00
1b66a8553d fix(checker): handle paths with no returns in functions 2026-06-01 14:13:48 +02:00
65164abadb feat(checker): type check if statements 2026-06-01 14:13:17 +02:00
9d45163d9c feat(checker): handle comparisons 2026-06-01 14:12:22 +02:00
ab0fa1de1a feat(parser): add if statement 2026-06-01 14:11:12 +02:00
5d4df7978b fix(cli): ignore repeated visit of types 2026-06-01 14:10:07 +02:00
86ad348b99 feat(cli): add option to highlight diagnostics 2026-06-01 11:57:57 +02:00
29f691e38a fix: update vscode syntax 2026-06-01 11:30:56 +02:00
f2c61d24e2 refactor(checker): move builtins definition to separate file 2026-06-01 00:55:54 +02:00
112ed0e816 feat(parser): desugar AugAssign statements 2026-05-31 18:54:55 +02:00
7eb1e13b70 fix(cli): add cast visitor method to highlighter 2026-05-31 18:45:25 +02:00
893e1ba190 feat(cli): dump environment after compile 2026-05-31 18:44:41 +02:00
1a1b0e8e15 docs(checker): add documentation to checker, resolvers, etc. 2026-05-31 18:42:53 +02:00
4ddde364ed doc(checker): add documentation to checker methods 2026-05-31 12:56:20 +02:00
4a3363a3d6 feat(checker): add cast expression 2026-05-29 22:04:03 +02:00
0a3216e07d feat(parser):add cast expression 2026-05-29 22:03:39 +02:00
c29c0ed3ec tests: add tests for type checker 2026-05-29 19:08:58 +02:00
fa7e56cb77 tests: add checker tester 2026-05-29 19:08:41 +02:00
13c19db818 fix(checker): stabilize call error message
display missing arguments in a stable format, similar to how native Python does
2026-05-29 19:08:13 +02:00
95b218fbed tests: add tests for python parser 2026-05-29 18:45:06 +02:00
c3722c7438 tests: add python parser tester 2026-05-29 18:44:53 +02:00
9dd547d6c1 fix(tests): handle new tests with no snapshot 2026-05-29 18:44:28 +02:00
e2d5943517 tests: move midas parser tests in subfolder 2026-05-29 18:39:25 +02:00
86e4763a12 refactor(tests): make cases dir configurable in subclass 2026-05-29 17:52:10 +02:00
89ec63cb05 refactor(tests): extract tester base class 2026-05-29 17:42:39 +02:00
e6375f1aa9 chore: tidy 2026-05-29 17:25:12 +02:00
d16e192a3a feat(checker): map and check function call arguments 2026-05-29 15:49:51 +02:00
3f61f84e5a feat(parser): parse function param defaults and sinks 2026-05-29 15:47:19 +02:00
fd5399f50a feat(checker): evaluate function definitions 2026-05-29 12:10:09 +02:00
8906ac3db8 feat(parser): add return statements 2026-05-29 11:25:11 +02:00
022aebf55b fix(parser): prevent duplicate properties in complex types 2026-05-29 10:41:54 +02:00
5dc6903425 fix(cli): enable midas ast dump 2026-05-29 10:41:31 +02:00
1b078b832c chore: add some operations in the example 2026-05-28 18:32:35 +02:00
7515716864 feat(checker): add diagnostics 2026-05-28 18:32:35 +02:00
218b0c5b78 fix(parser): add location in all AST nodes 2026-05-28 18:32:34 +02:00
928901ef9c fix(checker): get literal types from context 2026-05-28 18:16:35 +02:00
4b62c78874 feat(cli): integrate checker in compile command 2026-05-28 17:35:38 +02:00
f882eebaf5 feat(checker): add basic checker
still very basic but lays out the structure and help methods
2026-05-28 17:35:00 +02:00
a872938405 feat(checker): add midas context resolver
this is still very basic and only handle a few expressions
notably, it doesn't support generics, option types, conditions, predicates nor complex types
2026-05-28 17:33:16 +02:00
146be72fd7 chore: add simple operation and type examples 2026-05-28 17:31:12 +02:00
6de54e1da1 feat(checker): add python scope resolver
adapted from Pebble
2026-05-28 17:30:16 +02:00
c82b41a4df feat(checker): add environment manager
adapted from Pebble
2026-05-28 17:29:37 +02:00
8304760fe0 fix(parser): add function body and all_args property 2026-05-28 15:26:53 +02:00
6bf91db757 feat(checker): create basic type and operation structs 2026-05-28 15:25:48 +02:00
3f6b650a4b docs: create architecture diagram 2026-05-28 15:25:12 +02:00
ec079f32ca Merge pull request 'Python parser' (#4) from feat/python-parser into main
Reviewed-on: #4
2026-05-26 08:28:42 +00:00
6524b3591a feat(cli): highlight midas keywords 2026-05-26 10:14:23 +02:00
170101aa37 fix(parser): add call keywords attribute in gen definition 2026-05-26 10:12:59 +02:00
0b3f33d7fe feat(parser): parse python expressions 2026-05-25 23:17:52 +02:00
8a9b4f3989 feat(parser): parse assignments 2026-05-25 22:43:38 +02:00
bbd0e3ae8d feat(cli): update highlighter with new nodes 2026-05-25 22:14:44 +02:00
4d23e8840e feat(parser): adapt AST printer with new nodes 2026-05-25 22:06:18 +02:00
c64d626d1c refactor(parser): remove inheritance from NodeVisitor
remove the parent NodeVisitor class from PythonParser and implement all custom recursive methods instead
2026-05-25 21:42:04 +02:00
ecab1b74a4 feat(parser): add Python AST nodes 2026-05-25 21:39:20 +02:00
0bbdf04621 feat(parser): generate python AST classes
use the generation script to create Python AST node classes, also distinguish between Midas type annotation nodes and statements
2026-05-25 20:53:36 +02:00
939e5af4ce refactor(parser): improve AST class generator
make the generation script more flexible
2026-05-25 20:38:38 +02:00
a735113466 fix(parser): update ast gen script 2026-05-25 12:46:04 +02:00
0e0a1b26f2 feat(cli): add midas highlighter 2026-05-25 12:14:55 +02:00
e94db2181f feat(parser): add location to midas AST nodes 2026-05-25 12:14:14 +02:00
9b59058881 feat(cli): add highlight command 2026-05-22 22:16:05 +02:00
d0c54db33a feat(parser): store locations in parsed nodes 2026-05-22 22:11:44 +02:00
5aedddfabb feat(parser): parse functions in python 2026-05-22 19:32:15 +02:00
8d7c115432 feat(parser): parse type constraints in python 2026-05-22 18:46:06 +02:00
832c350b61 fix: use generic Difference type in example 2026-05-22 17:38:13 +02:00
3d599b3462 feat(cli): add option to run python parser 2026-05-22 17:37:20 +02:00
4f799caaf5 feat(parser): add pretty-printer for python AST 2026-05-22 17:36:44 +02:00
f4d2be3b1b feat(parser): add simple Python parser 2026-05-22 17:36:22 +02:00
7ce2840f03 feat(parser): add AST nodes for python 2026-05-22 17:34:04 +02:00
e2f3cabe15 feat(cli): add compile command to read python AST 2026-05-22 14:06:28 +02:00
5a112332f2 chore: complete pyproject.toml 2026-05-22 11:15:46 +02:00
eb79cf6dc3 feat(cli): add basic CLI entrypoint 2026-05-22 11:09:54 +02:00
8a9bb6ef4e feat: add pyproject.toml 2026-05-22 11:09:24 +02:00
6e0190a378 refactor: move source files in subdirectory 2026-05-22 11:05:47 +02:00
b5969e9a2b Merge pull request 'Revise syntax' (#3) from feat/revise-syntax into main
Reviewed-on: #3
2026-05-22 08:00:59 +00:00
409d9f8fa6 fix(parser): update parser docstrings 2026-05-22 09:46:24 +02:00
12d762429d fix(parser): complete EBNF and railroad diagrams 2026-05-21 15:46:40 +02:00
53929ee514 test(parser): remove pytest tests 2026-05-21 15:07:19 +02:00
2f6e137f1a tests(parser): update snapshot with new syntax 2026-05-21 15:04:32 +02:00
5224e79d9f fix(parser): update pretty printer 2026-05-21 14:45:52 +02:00
bdcb12c58a fix(parser): update AST printer 2026-05-21 14:27:38 +02:00
5cb4d587e3 feat(parser)!: adapt parser for revised syntax 2026-05-21 13:57:38 +02:00
8f9ec8d73b feat(parser): add more nodes for constraint parsing 2026-05-21 13:54:58 +02:00
c1c50a448e fix(parser): allow underscores in identifier
modify the lexer to allow underscores in an identifier, but keep scanning single underscores as a specific underscore token
2026-05-21 13:54:19 +02:00
19229db0b1 feat(parser)!: adjust AST node classes for new syntax 2026-05-21 12:25:47 +02:00
f3b6bd146f tool: add AST class generator script 2026-05-21 12:24:43 +02:00
98c3510bd4 feat(parser): update lexer with new tokens 2026-05-21 09:15:14 +02:00
429d0d98fe feat: update railroad diagrams with revised syntax 2026-05-21 07:53:56 +02:00
db8fe5d3ff feat: update EBNF with revised syntax 2026-05-21 07:53:40 +02:00
7477ec8d70 fix: change syntax definition to W3C EBNF 2026-05-20 15:47:34 +02:00
adf7f4e7a2 tests(parser): use new MidasSyntaxError 2026-05-20 15:46:25 +02:00
abf6787946 fix(parser)!: remove annotation lexer and parser 2026-05-20 15:45:55 +02:00
e282b08597 fix: tweak syntax examples
- move operation definitions outside GeoLocation type
- add nullable type
- list syntax choices for complex refinement
2026-05-20 14:14:01 +02:00
0a02b9d3d9 feat: revise syntax (example)
improve the syntax to better fit the principle of least surprise and Python syntax
2026-05-20 13:20:53 +02:00
875ca589e4 Merge pull request 'Improve testing framework' (#2) from feat/test-framework into main
Reviewed-on: #2
2026-05-20 11:17:00 +00:00
88f92d6e1f tests(parser): add simple types snapshot test 2026-05-19 14:12:12 +02:00
db4ed74365 tests(parser): add snapshot test runner
the diff printing function was suggested by Gemini

Co-authored-by: Gemini <noreply@gemini.google.com>
2026-05-19 14:11:32 +02:00
7cbf4fdece feat(tests): add AST JSON serializer 2026-05-19 14:00:32 +02:00
1fa9a09bfe feat(parser): use custom syntax error class 2026-05-19 13:57:00 +02:00
152 changed files with 29444 additions and 1914 deletions

5
.gitignore vendored
View File

@@ -3,4 +3,7 @@ __pycache__
.env
venv
.venv
*.pyc
*.pyc
uv.lock
.python-version
/out

154
README.md
View File

@@ -1,7 +1,159 @@
# Midas
<h1>Midas</h1>
*Midas* is a type system to _Maintain Integrity of Data with Annotated Structures_. In Greek mythology, [Midas](https://en.wikipedia.org/wiki/Midas) was a Phrygian king who was blessed with the gift of turning everything he touched into gold.
*Midas* aims at providing Python developers with a simple annotation system to enable compile-time integrity and data type checks, as well as generating runtime assertions.
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
<details>
<summary><strong>Table of Contents</strong></summary>
- [Requirements](#requirements)
- [Installation](#installation)
- [Commands](#commands)
- [Type Checking](#type-checking)
- [Compiling](#compiling)
- [Formatting](#formatting)
- [Highlighting](#highlighting)
- [Dumping the AST](#dumping-the-ast)
- [Dumping the Registry](#dumping-the-registry)
- [Generating Stubs](#generating-stubs)
- [Showing Type Judgements](#showing-type-judgements)
- [Validating Definitions](#validating-definitions)
- [Tests](#tests)
</details>
## Requirements
- Python 3.11+
- [uv](https://docs.astral.sh/uv/getting-started/installation/)
## Installation
1. Clone the repository
```shell
git clone https://git.kb28.ch/HEL/midas.git
```
2. Go in the project directory
```shell
cd midas
```
3. Install the CLI as a user-wide tool
```shell
uv tool install .
```
4. You can now run the `midas` command from anywhere
```shell
midas --help
```
## Commands
<!--
check
compile
format
highlight
parse
dump_registry
types
validate
-->
### Type Checking
```shell
midas check -t types.midas source.py
```
This command parses the given files and run the type checkers against the Midas definitions and Python program. Diagnostics are then printed showing warnings and errors.
### Compiling
```shell
midas compile -t types.midas source.py
```
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
### Formatting
```shell
midas format types.midas
midas format types.midas -o formatted.midas
```
This command parses the given Midas file and outputs a pretty printed file from the AST.
### Highlighting
```shell
midas highlight source.py
midas highlight source.py -o highlighted.html
midas highlight types.midas
midas highlight types.midas -o highlighted.html
```
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
### Dumping the AST
```shell
midas parse source.py
midas parse types.midas
```
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `--raw` flags lets you toggle the custom AST parsing. With `--raw`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
### Dumping the Registry
```shell
midas dump-registry -t types.midas
```
This command processes the given Midas definitions and dumps the contents of the types registry.
### Generating Stubs
```shell
midas stubs types.midas -o stubs.pyi
```
This command generate Python stubs from a Midas definition file
### Showing Type Judgements
```shell
midas types -t types.midas source.py
```
This command type checks the given Python source file and logs all typing judgements made by the type checker.
### Validating Definitions
```shell
midas validate types.midas
```
This command lets you validate a Midas definition file by running the parser and type checker, verifying syntax and references.
## Tests
Several snapshot tests are available to assert the good behaviour of the parsers and type checker. They can be run as follows:
```shell
uv run -m tests.midas run -a
uv run -m tests.python run -a
uv run -m tests.checker run -a
uv run -m tests.generator run -a
```
**Available subcommands:**
- Run all tests: `run -a`
- Run specific tests: `run tests/cases/test1.py tests/cases/test2.py ...`
- Update all tests: `update -a`
- Update specific tests: `update tests/cases/test1.py tests/cases/test2.py ...`

117
assets/icon.svg Normal file
View File

@@ -0,0 +1,117 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!-- Created with Inkscape (http://www.inkscape.org/) -->
<svg
width="128"
height="128"
viewBox="0 0 128 128"
version="1.1"
id="svg1"
inkscape:export-filename="logo.png"
inkscape:export-xdpi="96"
inkscape:export-ydpi="96"
inkscape:version="1.4.4 (1:1.4.4+202605061436+dcaf3e7d9e)"
sodipodi:docname="logo.svg"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns:xlink="http://www.w3.org/1999/xlink"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:document-units="mm"
showgrid="true"
inkscape:zoom="1.9332778"
inkscape:cx="-8.2760999"
inkscape:cy="112.2446"
inkscape:window-width="2584"
inkscape:window-height="1028"
inkscape:window-x="0"
inkscape:window-y="24"
inkscape:window-maximized="1"
inkscape:current-layer="layer1">
<inkscape:grid
id="grid1"
units="px"
originx="0"
originy="0"
spacingx="4"
spacingy="4"
empcolor="#0099e5"
empopacity="0.30196078"
color="#0099e5"
opacity="0.14901961"
empspacing="4"
enabled="true"
visible="true" />
</sodipodi:namedview>
<defs
id="defs1">
<linearGradient
inkscape:collect="always"
xlink:href="#linearGradient4689"
id="linearGradient1478"
gradientUnits="userSpaceOnUse"
gradientTransform="matrix(0.562541,0,0,0.567972,-9.399749,-5.305317)"
x1="26.648937"
y1="20.603781"
x2="135.66525"
y2="114.39767" />
<linearGradient
id="linearGradient4689">
<stop
style="stop-color:#e1be1e;stop-opacity:1;"
offset="0"
id="stop4691" />
<stop
style="stop-color:#ffeb82;stop-opacity:1;"
offset="1"
id="stop4693" />
</linearGradient>
<linearGradient
inkscape:collect="always"
xlink:href="#linearGradient4671"
id="linearGradient1475"
gradientUnits="userSpaceOnUse"
gradientTransform="matrix(0.562541,0,0,0.567972,-9.399749,-5.305317)"
x1="150.96111"
y1="192.35176"
x2="112.03144"
y2="137.27299" />
<linearGradient
id="linearGradient4671">
<stop
style="stop-color:#ffdc21;stop-opacity:1;"
offset="0"
id="stop4673" />
<stop
style="stop-color:#ffeb82;stop-opacity:1;"
offset="1"
id="stop4675" />
</linearGradient>
</defs>
<g
inkscape:label="Calque 1"
inkscape:groupmode="layer"
id="layer1">
<g
id="g1"
transform="translate(2.911719,3.414527)">
<path
style="fill:url(#linearGradient1478);fill-opacity:1"
d="m 60.510156,6.3979729 c -4.583653,0.021298 -8.960939,0.4122177 -12.8125,1.09375 C 36.35144,9.4962267 34.291407,13.691825 34.291406,21.429223 v 10.21875 h 26.8125 v 3.40625 h -26.8125 -10.0625 c -7.792459,0 -14.6157592,4.683717 -16.7500002,13.59375 -2.46182,10.212966 -2.5710151,16.586023 0,27.25 1.9059283,7.937852 6.4575432,13.593748 14.2500002,13.59375 h 9.21875 v -12.25 c 0,-8.849902 7.657144,-16.656248 16.75,-16.65625 h 26.78125 c 7.454951,0 13.406253,-6.138164 13.40625,-13.625 v -25.53125 c 0,-7.266339 -6.12998,-12.7247775 -13.40625,-13.9375001 -4.605987,-0.7667253 -9.385097,-1.1150483 -13.96875,-1.09375 z m -14.5,8.2187501 c 2.769547,0 5.03125,2.298646 5.03125,5.125 -2e-6,2.816336 -2.261703,5.09375 -5.03125,5.09375 -2.779476,-1e-6 -5.03125,-2.277415 -5.03125,-5.09375 -1e-6,-2.826353 2.251774,-5.125 5.03125,-5.125 z"
id="path1948" />
<path
style="fill:url(#linearGradient1475);fill-opacity:1"
d="m 91.228906,35.054223 v 11.90625 c 0,9.230755 -7.825895,16.999999 -16.75,17 h -26.78125 c -7.335833,0 -13.406249,6.278483 -13.40625,13.625 v 25.531247 c 0,7.26634 6.318588,11.54032 13.40625,13.625 8.487331,2.49561 16.626237,2.94663 26.78125,0 6.750155,-1.95439 13.406253,-5.88761 13.40625,-13.625 V 92.897973 h -26.78125 v -3.40625 h 26.78125 13.406254 c 7.79246,0 10.69625,-5.435408 13.40624,-13.59375 2.79933,-8.398886 2.68022,-16.475776 0,-27.25 -1.92578,-7.757441 -5.60387,-13.59375 -13.40624,-13.59375 z m -15.0625,64.65625 c 2.779478,3e-6 5.03125,2.277417 5.03125,5.093747 -2e-6,2.82635 -2.251775,5.125 -5.03125,5.125 -2.76955,0 -5.03125,-2.29865 -5.03125,-5.125 2e-6,-2.81633 2.261697,-5.093747 5.03125,-5.093747 z"
id="path1950" />
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 4.7 KiB

View File

@@ -1,107 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from lexer.token import Token
T = TypeVar("T")
@dataclass(frozen=True)
class Stmt(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> T: ...
@dataclass(frozen=True)
class AnnotationStmt(Stmt):
name: Token
schema: Optional[SchemaExpr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_annotation_stmt(self)
@dataclass(frozen=True)
class Expr(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
@abstractmethod
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
@abstractmethod
def visit_type_expr(self, expr: TypeExpr) -> T: ...
@abstractmethod
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
@abstractmethod
def visit_schema_expr(self, expr: SchemaExpr) -> T: ...
@abstractmethod
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> T: ...
@dataclass(frozen=True)
class WildcardExpr(Expr):
token: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_wildcard_expr(self)
@dataclass(frozen=True)
class LiteralExpr(Expr):
value: Any
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_literal_expr(self)
@dataclass(frozen=True)
class TypeExpr(Expr):
name: Token
constraints: list[ConstraintExpr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_type_expr(self)
@dataclass(frozen=True)
class ConstraintExpr(Expr):
left: Expr
op: Token
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_constraint_expr(self)
@dataclass(frozen=True)
class SchemaExpr(Expr):
left: Token
elements: list[Expr]
right: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_schema_expr(self)
@dataclass(frozen=True)
class SchemaElementExpr(Expr):
name: Optional[Token]
type: Optional[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_schema_element_expr(self)

View File

@@ -1,138 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from lexer.token import Token
T = TypeVar("T")
# Statements
@dataclass(frozen=True)
class Stmt(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
@abstractmethod
def visit_op_stmt(self, stmt: OpStmt) -> T: ...
@abstractmethod
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> T: ...
@dataclass(frozen=True)
class TypeStmt(Stmt):
name: Token
bases: list[TypeExpr]
body: Optional[TypeBodyExpr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_type_stmt(self)
@dataclass(frozen=True)
class PropertyStmt(Stmt):
name: Token
type: TypeExpr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_property_stmt(self)
@dataclass(frozen=True)
class OpStmt(Stmt):
left: TypeExpr
op: Token
right: TypeExpr
result: TypeExpr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_op_stmt(self)
@dataclass(frozen=True)
class ConstraintStmt(Stmt):
name: Token
constraint: ConstraintExpr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_constraint_stmt(self)
# Expressions
@dataclass(frozen=True)
class Expr(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
@abstractmethod
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
@abstractmethod
def visit_type_expr(self, expr: TypeExpr) -> T: ...
@abstractmethod
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
@abstractmethod
def visit_type_body_expr(self, expr: TypeBodyExpr) -> T: ...
@dataclass(frozen=True)
class WildcardExpr(Expr):
token: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_wildcard_expr(self)
@dataclass(frozen=True)
class LiteralExpr(Expr):
value: Any
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_literal_expr(self)
@dataclass(frozen=True)
class TypeExpr(Expr):
name: Token
constraints: list[ConstraintExpr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_type_expr(self)
@dataclass(frozen=True)
class ConstraintExpr(Expr):
left: Expr
op: Token
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_constraint_expr(self)
@dataclass(frozen=True)
class TypeBodyExpr(Expr):
properties: list[PropertyStmt]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_type_body_expr(self)

View File

@@ -1,360 +0,0 @@
from __future__ import annotations
from contextlib import contextmanager
from enum import Enum, auto
import io
from typing import Generator, Generic, Optional, Protocol, TypeVar
import core.ast.annotations as a
import core.ast.midas as m
class _Level(Enum):
EMPTY = auto()
ACTIVE = auto()
LAST = auto()
class Expr(Protocol):
def accept(self, printer: AstPrinter) -> None: ...
T = TypeVar("T", bound=Expr)
class AstPrinter(Generic[T]):
LAST_CHILD = "└── "
CHILD = "├── "
VERTICAL = ""
EMPTY = " "
def __init__(self):
self._levels: list[_Level] = []
self._idx: Optional[int] = None
self._buf: io.StringIO = io.StringIO()
def print(self, expr: T):
self._buf = io.StringIO()
expr.accept(self)
return self._buf.getvalue()
@contextmanager
def _child_level(self, last: bool = False) -> Generator[None, None, None]:
self._levels.append(_Level.LAST if last else _Level.ACTIVE)
try:
yield
finally:
self._levels.pop()
def _mark_last(self):
if self._levels:
self._levels[-1] = _Level.LAST
def _write_line(self, text: str, *, last: bool = False):
if last:
self._mark_last()
indent: str = self._build_indent()
if self._idx is not None:
text = f"[{self._idx}] {text}"
self._idx = None
self._buf.write(indent + text + "\n")
def _build_indent(self) -> str:
parts: list[str] = []
for level in self._levels[:-1]:
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
if self._levels:
if self._levels[-1] == _Level.LAST:
parts.append(self.LAST_CHILD)
self._levels[-1] = _Level.EMPTY
else:
parts.append(self.CHILD)
return "".join(parts)
def _write_optional_child(
self, label: str, child: Optional[T], *, last: bool = False
):
if last:
self._mark_last()
if child is None:
self._write_line(f"{label}: None")
else:
self._write_line(label)
with self._child_level(last=True):
child.accept(self)
class AnnotationAstPrinter(AstPrinter, a.Expr.Visitor[None], a.Stmt.Visitor[None]):
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> None:
self._write_line("AnnotationStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("schema", stmt.schema, last=True)
def visit_type_expr(self, expr: a.TypeExpr):
self._write_line("TypeExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_line("constraints", last=True)
with self._child_level():
for i, constraint in enumerate(expr.constraints):
self._idx = i
if i == len(expr.constraints) - 1:
self._mark_last()
constraint.accept(self)
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> None:
self._write_line("ConstraintExpr")
with self._child_level():
self._write_line("left")
with self._child_level():
self._mark_last()
expr.left.accept(self)
self._write_line(f"operator: {expr.op.lexeme}")
self._write_line("right", last=True)
with self._child_level():
self._mark_last()
expr.right.accept(self)
def visit_schema_expr(self, expr: a.SchemaExpr):
self._write_line("SchemaExpr")
with self._child_level():
for i, elmt in enumerate(expr.elements):
self._idx = i
if i == len(expr.elements) - 1:
self._mark_last()
elmt.accept(self)
def visit_schema_element_expr(self, expr: a.SchemaElementExpr):
self._write_line("SchemaElementExpr")
with self._child_level():
name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"'
self._write_line(f"name: {name_text}")
self._write_optional_child("type", expr.type, last=True)
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> None:
self._write_line("WildcardExpr")
def visit_literal_expr(self, expr: a.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level():
self._write_line(f"value: {expr.value}", last=True)
class AnnotationPrinter(a.Expr.Visitor[str], a.Stmt.Visitor[str]):
def print(self, expr: a.Expr | a.Stmt):
return expr.accept(self)
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> str:
schema: str = ""
if stmt.schema is not None:
schema = stmt.schema.accept(self)
return f"{stmt.name.lexeme}{schema}"
def visit_type_expr(self, expr: a.TypeExpr) -> str:
parts: list[str] = [expr.name.lexeme]
for constraint in expr.constraints:
parts.append("(" + constraint.accept(self) + ")")
return " + ".join(parts)
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> str:
parts: list[str] = [
expr.left.accept(self),
expr.op.lexeme,
expr.right.accept(self),
]
return " ".join(parts)
def visit_schema_expr(self, expr: a.SchemaExpr) -> str:
res: str = expr.left.lexeme
res += ", ".join(elmt.accept(self) for elmt in expr.elements)
res += expr.right.lexeme
return res
def visit_schema_element_expr(self, expr: a.SchemaElementExpr) -> str:
parts: list[str] = []
if expr.name is not None:
parts.append(expr.name.lexeme)
if expr.type is None:
parts.append("_")
else:
parts.append(expr.type.accept(self))
return ": ".join(parts)
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> str:
return "_"
def visit_literal_expr(self, expr: a.LiteralExpr) -> str:
return str(expr.value)
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
def visit_type_stmt(self, stmt: m.TypeStmt):
self._write_line("TypeStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("bases")
with self._child_level():
for i, base in enumerate(stmt.bases):
self._idx = i
if i == len(stmt.bases) - 1:
self._mark_last()
base.accept(self)
self._write_optional_child("body", stmt.body, last=True)
def visit_property_stmt(self, stmt: m.PropertyStmt):
self._write_line("PropertyStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level():
self._mark_last()
stmt.type.accept(self)
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
self._write_line("OpStmt")
with self._child_level():
self._write_line("left")
with self._child_level():
self._mark_last()
stmt.left.accept(self)
self._write_line(f'op: "{stmt.op.lexeme}"')
self._write_line("right")
with self._child_level():
self._mark_last()
stmt.right.accept(self)
self._write_line("result", last=True)
with self._child_level():
self._mark_last()
stmt.result.accept(self)
def visit_constraint_stmt(self, stmt: m.ConstraintStmt):
self._write_line("ConstraintStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("constraint", last=True)
with self._child_level():
self._mark_last()
stmt.constraint.accept(self)
def visit_type_expr(self, expr: m.TypeExpr):
self._write_line("TypeExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_line("constraints", last=True)
with self._child_level():
for i, constraint in enumerate(expr.constraints):
self._idx = i
if i == len(expr.constraints) - 1:
self._mark_last()
constraint.accept(self)
def visit_constraint_expr(self, expr: m.ConstraintExpr):
self._write_line("ConstraintExpr")
with self._child_level():
self._write_line("left")
with self._child_level():
self._mark_last()
expr.left.accept(self)
self._write_line(f"operator: {expr.op.lexeme}")
self._write_line("right", last=True)
with self._child_level():
self._mark_last()
expr.right.accept(self)
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
self._write_line("TypeBodyExpr")
with self._child_level():
self._write_line("properties", last=True)
with self._child_level():
for i, property in enumerate(expr.properties):
self._idx = i
if i == len(expr.properties) - 1:
self._mark_last()
property.accept(self)
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self._write_line("WildcardExpr")
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level():
self._write_line(f"value: {expr.value}", last=True)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def __init__(self, indent: int = 4):
self.indent: int = indent
self.level: int = 0
def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text
def print(self, expr: m.Expr | m.Stmt):
self.level = 0
return expr.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt):
bases: list[str] = [
b.accept(self)
for b in stmt.bases
]
res: str = self.indented(f"type {stmt.name.lexeme}<{', '.join(bases)}>")
if stmt.body is not None:
res += " {\n"
self.level += 1
res += stmt.body.accept(self)
self.level -= 1
res += "\n" + self.indented("}")
return res
def visit_property_stmt(self, stmt: m.PropertyStmt):
return f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
def visit_op_stmt(self, stmt: m.OpStmt):
left: str = stmt.left.accept(self)
op: str = stmt.op.lexeme
right: str = stmt.right.accept(self)
result: str = stmt.result.accept(self)
return self.indented(f"op <{left}> {op} <{right}> = <{result}>")
def visit_constraint_stmt(self, stmt: m.ConstraintStmt):
name: str = stmt.name.lexeme
constraint: str = stmt.constraint.accept(self)
return self.indented(f"constraint {name} = {constraint}")
def visit_type_expr(self, expr: m.TypeExpr):
parts: list[str] = [expr.name.lexeme]
for constraint in expr.constraints:
parts.append("(" + constraint.accept(self) + ")")
return " + ".join(parts)
def visit_constraint_expr(self, expr: m.ConstraintExpr):
parts: list[str] = [
expr.left.accept(self),
expr.op.lexeme,
expr.right.accept(self),
]
return " ".join(parts)
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
properties: list[str] = [
self.indented(prop.accept(self))
for prop in expr.properties
]
return "\n".join(properties)
def visit_wildcard_expr(self, expr: m.WildcardExpr):
return "_"
def visit_literal_expr(self, expr: m.LiteralExpr):
return str(expr.value)

150
docs/architecture.typ Normal file
View File

@@ -0,0 +1,150 @@
#import "@preview/cetz:0.5.2": canvas, draw
#let diagram-only = false
#set document(
title: [Midas Architecture],
//author: "Louis Heredero",
)
#set text(
font: "Source Sans 3",
)
#let diagram = canvas({
let framed = draw.content.with(
padding: (x: .8em, y: 1em),
frame: "rect",
stroke: black,
)
let arrow = draw.line.with(mark: (end: ">", fill: black))
framed(
(0, 0),
name: "python-parser",
)[Python parser]
draw.content(
(rel: (0, 1), to: "python-parser.north"),
padding: 5pt,
anchor: "south",
name: "source-py",
)[_`source.py`_]
arrow("source-py", "python-parser")
framed(
(rel: (3, 0), to: "python-parser.east"),
anchor: "west",
name: "custom-parser",
align(center)[Custom python\ parser],
)
arrow("python-parser", "custom-parser", name: "arrow-python-ast")
draw.content(
"arrow-python-ast",
anchor: "south",
padding: 5pt,
)[`ast.Module`]
framed(
(rel: (-3, -2), to: "custom-parser.south"),
anchor: "east",
name: "python-resolver",
)[Python Resolver]
arrow(
"custom-parser",
((), "|-", "python-resolver.east"),
"python-resolver",
name: "arrow-python-custom-ast",
)
draw.content(
(rel: (1.5, 0), to: "arrow-python-custom-ast.end"),
padding: 5pt,
anchor: "south",
)[P-AST#footnote[#strong[P]ython *AST*]<fn-past>]
draw.content(
"python-resolver.west",
padding: 5pt,
anchor: "south-east",
)[Resolved P-AST@fn-past]
draw.circle(
(rel: (1, -2), to: "custom-parser.south-east"),
radius: .4,
name: "midas-loader",
)
arrow(
"custom-parser",
"midas-loader",
name: "arrow-load-midas",
mark: (end: (symbol: ">", fill: black), start: "o"),
)
draw.content(
"arrow-load-midas",
anchor: "west",
padding: 5pt,
)[```python midas.using("types.midas")```]
framed(
(rel: (0, -2), to: "midas-loader.south"),
name: "midas-parser",
)[Midas lexer/parser]
arrow("midas-loader", "midas-parser", name: "arrow-midas-source")
draw.content(
"arrow-midas-source",
anchor: "west",
padding: 5pt,
)[_`types.midas`_]
framed(
(rel: (-2, 0), to: "midas-parser.west"),
anchor: "east",
name: "midas-resolver",
)[Midas Resolver]
arrow("midas-parser", "midas-resolver", name: "arrow-midas-ast")
draw.content(
"arrow-midas-ast",
anchor: "south",
padding: 5pt,
)[M-AST#footnote[#strong[M]idas *AST*]<fn-mast>]
framed(
(rel: (-3, 0), to: "midas-resolver.west"),
anchor: "east",
name: "checker",
)[Checker]
arrow("midas-resolver", "checker", name: "arrow-type-ctx")
arrow(
"python-resolver",
((), "-|", "checker.north"),
"checker",
)
draw.content(
"arrow-type-ctx",
anchor: "south",
padding: 5pt,
)[Types context]
})
#show: doc => if diagram-only {
set page(width: auto, height: auto, margin: .5cm)
diagram
} else { doc }
#align(center, title())
#v(1cm)
#figure(
diagram,
caption: [Midas type-checker architecture],
)
== Components
- *Python parser*: builtin Python AST parser, extracts abstract syntax from the raw Python source (```python ast.parse(...)```)
- *Custom python parser*: converts the raw Python AST into custom, more suitable constructs, especially for type annotations
- *Python resolver*: resolves bindings and references, tracks binding scopes
- *Midas lexer/parser*: parses a Midas type definition file and extracts its AST
- *Midas resolver*: walks the AST and fills the environment with the defined types and operations
- *Checker*: evaluates expressions and checks type coherence

809
docs/manual.typ Normal file
View File

@@ -0,0 +1,809 @@
//#import "@preview/codly:1.3.0": codly, codly-init
// Fix unaligned highlights in v0.15.0 ()
// See https://github.com/Dherse/codly/pull/132
#import "@local/codly:1.3.1": codly, codly-init
#import "@preview/codly-languages:0.1.10": codly-languages
#import "template.typ": TODO, project
#import "@preview/gentle-clues:1.3.1" as gc
#let midas-version = toml("../pyproject.toml").project.version
#let head-ref = read("../.git/HEAD").split(":").at(1).trim()
#let commit-hash = read("../.git/" + head-ref).slice(0, 8)
#show: project.with(
title: [Midas User Manual],
author: "Louis Heredero",
version: midas-version,
hash: commit-hash,
icon-path: path("../assets/icon.svg"),
)
#show: codly-init
#codly(
languages: codly-languages
+ (
midas: (
name: "Midas",
color: rgb("#eedd47"),
icon: box(
image(
"../assets/icon.svg",
height: 130%,
fit: "contain",
),
),
),
),
)
= Introduction
Python is a very popular programming language, especially in data sciences.
However, it has been designed for simplicity, distancing itself from typed languages such as Java or C to embrace dynamic typing.
What this means is that in Python, type checks are deferred to runtime when operations are concretely executed.
For developers, it might seem like a great way of simplifying the language and making it very flexible, but it does come with a cost.
Indeed, type errors are very easy to make in Python. While passing an integer where a string is expected might not be an issue in some cases, these are the sort of thing that can cause crashes or incorrect results without a clear diagnostic to help the user fix it.
Fortunately, developers using IDEs or properly configured text editors can benefit from external type checkers such as MyPy which will perform static type analysis of their Python code. Some can also be configured to be very strict, forcing the user to make the whole code typeable statically, thus avoiding any runtime type errors.
This is not the end of the problem though. Some parts of a program, especially in data related fields, may not be available at "compile-time". For example, a dataset can be loaded from an external file, or data can be fetched from an API, with no guarantees of having the expected format when analyzing the code statically.
In turn, that can cause a range of loud and silent errors at runtime. A malformed number will probably crash the program when trying to convert it, but a NaN in a series of value might just produce wrong results without any exception. Combine this with often long-running data-processing pipelines and this is how developers can waste hours of precious computation time.
Midas is a type system which can be used on top of Python to provide better type checking capabilities and gradual typing.
It aims at providing optional but strict type annotations and casting operations which can produce runtime assertions. It also allows the user to define dependent types with value constraints that are translated into runtime checks.
= Installation
Midas comes as a very light Python package that you can install on your system in a few simple steps.
== Requirements
Here below are the requirements for installing Midas. All Python dependencies will be installed by `uv` in the installation process described in @install-steps.
- Python 3.11+
- `uv`
== Steps <install-steps>
1. Clone the repository
```bash
git clone https://git.kb28.ch/HEL/midas.git
```
2. Navigate inside the directory
```bash
cd midas
```
3. Install Midas as a tool in your local user space
```bash
uv tool install .
```
And that's it ! You can now use Midas commands anywhere, like this:
```bash
midas --help
```
= Quick Start
This chapter will give you the keys to quickly start using Midas in your project.
== Defining custom types
To begin with, you might want to define some custom types for your project, to avoid handling anonymous float values everywhere. To do so, create a `*.midas` file in your project, and write some definitions for your types. See @midas-ref for more information on syntax and features.
@qs-midas shows a simple example of what it might look like.
#codly(header: [types.midas])
#figure(
```midas
type Meter = float
extend Meter {
def __add__: fn(Meter, /) -> Meter
def __sub__: fn(Meter, /) -> Meter
}
type Coordinate = object
extend Coordinate {
prop x: Meter
prop y: Meter
}
```,
caption: [Example Midas type definitions],
) <qs-midas>
You can check for any syntax error using the following command:
```bash
midas validate types.midas
```
When you are happy with your definitions, you can generate Python stubs to use in your source code. This allows other type checkers like MyPy to recognize your custom types and avoid reporting them as undefined. It can also help catch some type errors in your IDE.
```bash
midas stubs types.midas -o stubs.pyi
```
This command will generate a file as shown in @qs-stubs, providing stub classes to represent the type lattice including methods and properties.
#codly(header: [stubs.pyi])
#figure(
```pyi
from __future__ import annotations
class Meter(float):
def __add__(self, _0: Meter, /) -> Meter: ...
def __sub__(self, _0: Meter, /) -> Meter: ...
class Coordinate(object):
x: Meter
y: Meter
```,
caption: [Generated stubs from example definitions of @qs-midas],
) <qs-stubs>
== Using Midas in Python
You can now write your Python program as you would normally. You can import your custom types from the generated stubs file and use them in type annotations.
You can also import the `cast` and `unsafe_cast` functions from `midas.typing` to explicitly cast a value to a specific type (see @cast for more information).
An example Python script is shown in @qs-python, demonstrating how you can use custom types in type annotations. Notice the comments describing errors that will be caught by the type checker in @qs-type-checking.
#codly(header: [script.py])
#figure(
```python
from lib import load_coordinate
from midas.typing import cast
from stubs import Coordinate, Meter
p1 = cast(Coordinate, load_coordinate(0))
p2 = cast(Coordinate, load_coordinate(1))
diff_x = p2.x - p1.x
diff_y = p2.y - p1.y
dist = diff_x + diff_y
p2.x += cast(Meter, 1)
p2.y = True # invalid, wrong type
p2.z = 3 # invalid, no property 'z' on Coordinate
p2.x.a = 3 # invalid, no properties on Meter
```,
caption: [Example Python script],
) <qs-python>
== Type checking <qs-type-checking>
Now that you have defined some types and written a script, you can run the type checker with the following command. You can also skip this step and directly run the compilation command in @qs-compilation.
```bash
midas check -t types.midas script.py
```
== Compiling <qs-compilation>
The final step is to compile your code. This step will produce a runnable Python script, including runtime assertions generated by `cast` expressions.
```bash
midas compile -t types.midas script.py
python3 build/midas/script.py
```
= Midas Language Reference <midas-ref>
In this chapter, you will find a complete reference for the Midas definition language.
A `*.midas` file contains a number of statements, which can be:
- *`alias`* statements (see @alias-stmt): to define a new type alias
- *`type`* statements (see @type-stmt): to define a new type
- *`extend`* statements (see @extend-stmt): to define member of a type
- *`predicate`* statements (see @predicate-stmt): to define named predicates that can be used in constraint types
== Alias Statement <alias-stmt>
An *`alias`* statement lets you define a new type alias. It requires a unique name and base type.
While a `type` statement (see @type-stmt) allows generic definitions, aliases are purely a for givin an alternative name to a type.
#figure(
```midas
alias MyType = float
```,
caption: [Simple `alias` statement declaring a new type "`MyType`" equivalent to `float`],
) <midas-simple-alias>
This statement defines a new type called `MyType` which is equivalent to `float`. `MyType` and `float` can be used interchangeably.
== Type Statement <type-stmt>
A *`type`* statement lets you define a new type. It requires a unique name and base type.
The simplest form of a *`type`* statement is:
#figure(
```midas
type MyType = float
```,
caption: [Simple `type` statement declaring a new type "`MyType`" as a subtype of `float`],
) <midas-simple-type>
This statement defines a new type called `MyType` which is a subtype of `float`. `MyType` is a `float` but a `float` is not necessarily `MyType`.
=== Builtin / base types
A number of base types are provided out of the box, which can be used to derive other types.
They correspond to Python's builtin types:
```py object```,
```py str```,
```py float```,
```py int```,
```py bool```,
```py list```,
```py dict```,
```py None```.
Some differences are to be noted however.
1. ```py bool``` is not a subtype of ```py int```
2. ```py list``` are homogeneous, i.e. all items must be of the same type
3. ```py dict``` keys and values are homogeneous, i.e. all keys must be of the same type and all values must be of the same type (can be different from keys).
=== Function types
A function type is written in a similar notation to Python function definitions:
#figure(
```midas
type Repeater = fn(text: str, count: int) -> str
```,
caption: [Simple function type definition],
)
Midas supports positional-only, keyword-only and mixed arguments (using the `/` and `*` separators). You may omit the name of positional-only arguments. The return type is required.
Optional parameters can be indicated by adding a question mark (`?`) after their type:
#figure(
```midas
type Repeater = fn(text: str, count: int, *, sep: str?) -> str
```,
caption: [Function type definition with an optional keyword-only parameter],
)
#gc.warning[
Sink arguments (`*args`, `**kwargs`) are not currently supported.
]
=== Constraint types
A useful feature provided by Midas is the possibility to combine types with custom value constraints. For example, you might want to define a type for positive amounts of money:
#figure(
```midas
type Money = float
type Income = Money where _ >= 0
```,
caption: [Simple constraint type definition],
)
Constraints can be combined with any type using the `where` keyword, followed by a constraint expression (see @constraint-expr).
=== Generic types
For more complex types, you might want to use type parameters. For example, to define a container, we might write:
#figure(
```midas
type Container[T] = object
```,
caption: [Simple generic container type definition],
)
To better refine a generic type, you can also bound type parameters using the following syntax:
#figure(
```midas
type Container[T <: float] = object
```,
caption: [Generic container type definition with a bound],
)
This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`".\
You can use a generic type, i.e. instantiate it, by using a similar syntax with concrete type as arguments:
#figure(
```midas
type MyContainer = Container[MyType]
```,
caption: [Application of a generic type],
)
Generic types can also take multiple parameters, which are then separated by commas:
#figure(
```midas
type ZipCodeRegistry = dict[int, str]
```,
caption: [Application of a multi-parameter generic type],
)
The _body_ of a generic type, i.e. the right-hand side of the definition, can contain or even be equal to any number of its parameters.#footnote[The latter is not something that is expressible in standard Python, yet it brings a semantic distinction on top of structurally equivalent values.] For example, the following is a valid type statement:
#figure(
```midas
type Price[T <: Currency] = T where _ > 0
```,
caption: [Type parameters in a generic type's body],
)
=== `Column` / `Frame` types
To provide useful type-checking for data engineers, Midas offers two special types: `Column` and `Frame`.
Their goal is to help type check Pandas' `Series` and `DataFrame` respectively.
==== `Column`
The `Column` type is a generic type used to represent a `pandas.Series` object.
You can use it like any other generic type and it will provide type checking for some common methods and attributes offered by Pandas.
#figure(
```midas
type Temperature = float
alias Temperatures = Column[Temperature]
```,
caption: [Simple column type definition],
)
==== `Frame` <frame-type>
The `Frame` type is a super-powered generic type used to represent a `pandas.DataFrame` object.
In place of type arguments, `Frame` accepts a schema, i.e. a series of column definitions.
@simple-frame show how you can define a simple frame type with 3 columns:
- `name`: a column of `Name` values
- `age`: a column of `int` values
- `height`: a column of `float where _ >= 0` values
Notice that you don't need to specify `Column` types.
#figure(
```midas
type Name = str where len(_) != 0
alias Data = Frame[
name: Name,
age: int,
height: float where _ >= 0
]
```,
) <simple-frame>
#pagebreak()
== Extend Statement <extend-stmt>
Type statements allow you to define new types, kind of like type aliases. However, a type might have properties or methods of its own. These might override those of the parent type or be brand new members.
This is where the `extend` statement comes into play. It allows defining members on a given type. Members can either be properties (`prop`) or methods (`def`). The only difference between the two is that methods must be functions and can be overloaded.
Here is a simple example showing how to define a property and a method on a custom type:
#figure(
```midas
type MyType = float
extend MyType {
prop norm: float
def double: fn() -> MyType
}
```,
caption: [Simple `extend` statement defining a property and a method],
)
An `extend` statement can appear anywhere after the type it extends has been defined.
You may want to override Python's dunder methods to implement type checking for some basic operators, like `__add__` for the `+` operator.
#figure(
```midas
type Money = float
extend Money {
def __add__(Money, /) -> Money
def __mul__(float, /) -> Money
}
```,
caption: [Simple `extend` statement overriding some dunder methods],
)
When extending generic type, you must specify the whole type, including its parameter(s):
#figure(
```midas
type Container[T <: float] = object
extend Container[T <: float] {
prop content: T
def set_content: fn(content: T) -> None
}
```,
caption: [Generic `extend` statement using type parameters in the declared members],
)
#pagebreak()
== Predicate Statement <predicate-stmt>
A *`predicate`* statement lets you define a named constraint expression, like a function, which can then be used in other constraint expressions (either in other predicate statements or in constraint types). See @constraint-expr for more information about the syntax of constraint expressions.
The left-hand side of a predicate statement is written as a function signature, without a return type. The right-hand side is a constraint expression. For example:
#figure(
```midas
predicate is_positive(v: float) = v >= 0
```,
caption: [Simple `predicate` statement defining an `is_positive` predicate],
)
The left-hand side can also be curried to allow partial application. For example:
#figure(
```midas
predicate in_range(mn: float, mx: float)(v: float) = mn <= v & v <= mx
predicate is_ratio = in_range(0.0, 1.0)
```,
caption: [Curried `predicate` statement and partial application],
) <midas-predicate-partial>
Notice that the second predicate statement doesn't take any parameters. This is simply a partial application of another predicate, kind of like an alias. You can use it in other expressions to finalize the call:
#figure(
```midas
type Efficiency = float where is_ratio(_)
```,
caption: [Constraint type definition using the partially applied predicate from @midas-predicate-partial],
)
Of course you can also directly call `in_range`:
#figure(
```midas
type Efficiency = float where in_range(0.0, 1.0)(_)
```,
caption: [Full call of curried predicate from @midas-predicate-partial],
)
When compiled, named predicates are translated to Python functions which are used in runtime assertions. Only predicates that are referenced are compiled.
#pagebreak()
== Constraint Expressions <constraint-expr>
*Constraint expressions* are Python-like expressions which can appear in *`predicate`* statements or in constraint types.
They can contain comparisons, simple computations, logical operations and must evaluate to a boolean value.
Context is quite restricted inside these expressions. You can only reference some builtin functions, such as type constructors (`float(...)`, `str(...)`, etc.), parameters of predicate statements, and named predicates. In constraint type, the special variable `_` can be used to reference the value targeted by the type. For example:
#figure(
```midas
predicate not_nan(v: float) = v != float("nan")
type RealFloat = float where not_nan(_)
```,
caption: [Example constraint expressions],
) <ex-constraint-expr>
In the predicate statement (@ex-constraint-expr:1), we reference the parameter `v` and the builtin `float` function.
In the constraint type definition (@ex-constraint-expr:2), we then reference the named predicate `not_nan`, passing the value targeted by the type itself ( `_` )
= Supported Python Syntax <python-ref>
Midas integrates naturally in Python via type annotations. Through generated stubs, even other type checker can detect your custom types (see @cmd-stubs).
It has been designed to leave the user free of typing any amount of their code but be strict about the parts that are annotated. By default, any untyped Python expression is assigned `UnknownType`.
Any operation is permitted on `UnknownType` and will result in `UnknownType` values.
The moment an expression can be typed, that be thanks to an annotation or a literal value, the type checker kicks in and will validate your statements.
Because Python is very flexible language with many features, some expressions and statements might be more complex to properly type check, thus only a subset of the Python language is fully supported. This chapter lists all supported features of Python and how they affect type checking.
Some examples are presented in the following sections in the form of code blocks. Highlights in the code blocks indicate the type assigned to each expression by the type checker. Some types may be omitted for readability. For example:
#codly(
highlights: (
(
line: 1,
start: 5,
fill: green,
tag: [_int_],
),
(
line: 2,
start: 7,
end: 7,
fill: green,
tag: [_int_],
),
),
)
```python
v = 3
print(v)
```
== Literals
Literal Python values are type checked using builtin types. Lists and dictionaries of literals are also typed liked literals. This does not include comprehension lists/dicts (```py [. for . in .]```), nor formatted strings (```py f"..."```). @supported-literals shows the list of supported literal values and their type.
#let supported-literals = table(
columns: 2,
table.header[*Example value*][*Judged Type*],
```py 42```, ```py int```,
```py 3.14```, ```py float```,
```py True```, ```py bool```,
```py "Midas"```, ```py str```,
```py None```, ```py None```,
```py [1, 2, 3]```, ```py list[int]```,
```py {1: "One", 2: "Two"}```, ```py dict[int, str]```,
```py ("1", 1, True)```, ```py tuple[str, int, bool]```,
)
#figure(
supported-literals,
caption: [Supported literal values and their judged types],
) <supported-literals>
== Assignments
Variable assignments allow assigning a new value to a variable. For the type checker, this implies two things:
1. If the variable was not already declared in the current scope, it is declared at that point with the type of the right-hand side expression
2. If the variable was already declared, the type of the right-hand side expression is checked against the declared type of the variable. Only a subtype of the variable's type can be assigned to it
Once a variable has been given a type, it cannot be changed in the same scope.
The walrus operator (```py :=```) is not currently supported.
A simple annotation declaration, without assigning a value, is enough to declare a variable. For example:
#figure(
```python
var: float
```,
caption: [Bare Python variable annotation without assignment],
)
Because unpacking is not supported, assigning to multiple values is also not handled by the type checker.
For more information about type annotations, see @type-annotations
== Arithmetic
- All basic binary operators are supported, through dunder methods.
- All comparison operators except ```py in``` are supported.
- All unary operators are supported (`+`, `-`, `~`).
- All logical operators are supported (```py and```, ```py or```, ```py not```).
== Ternary operator
The ternary operator ```py . if . else .``` is supported. As for `if` statements (see @if-else), the test expression must be a boolean. Additionally, both branches must be of the same type.
For example:
#codly(
highlights: (
(
line: 1,
start: 10,
end: 44,
tag: [_str_],
fill: blue,
),
(
line: 1,
start: 11,
end: 16,
tag: [_str_],
fill: green,
),
(
line: 1,
start: 39,
end: 43,
tag: [_str_],
fill: green,
),
(
line: 1,
start: 21,
end: 32,
tag: [_bool_],
fill: green,
),
),
)
#figure(
```python
parity = ("even" if num % 2 == 0 else "odd")
```,
caption: [Typing of ternary operator],
)
== Control flow
Some control flow features are supported. For the limited code of this project, not all constructs are supported. The following are those currently handled and typ checked by Midas.
=== `if` / `elif` / `else` <if-else>
Conditional statements are checked relatively strictly by Midas. The test expression, i.e. what comes after the ```py if``` keyword, must be a boolean. While Python allows introducing and leaking new variables from inside an ```py if``` statement, Midas will strictly forbid leaks by restraining bindings to the scope they are defined in. For example, the following Python code will not compile with Midas:
#figure(
```python
age = 22
if age >= 18:
msg = "You're an adult"
else:
msg = "You're still a child"
print(msg) # -> unknown variable 'msg'
```,
caption: [`if`/`else` statement cannot leak variables],
)
=== `for` loops
Simple forms of `for` loops can be used, that is using a single variable and iterating over an object implementing the `__getitem__` method. Like above in @if-else, leaking variables from inside the loop is ignored.
`for`-`else` statements are not supported. `while` loops are also not supported.
== Functions
You can define functions as usual and the type checker will do its best to type it. Apart from argument sinks (`*args`, `**kwargs`), all forms of parameter specifications are supported (positional-only, keyword-only, mixed, optional).
As for the rest of your code, type annotations are optional, but recommended. If you omit the return type hint, the type checker will try to infer it from the function body and its return statements. If you did specify a return type, all return paths must return values that are subtypes of the type hint.
#codly(
highlights: (
(
line: 2,
start: 12,
end: 16,
tag: [_float_],
fill: green,
),
(
line: 2,
start: 12,
tag: [_float_],
fill: blue,
),
(
line: 3,
start: 10,
end: 15,
tag: [_(value: float) -> float_],
fill: green,
),
(
line: 3,
start: 17,
end: 19,
tag: [_float_],
fill: green,
),
(
line: 3,
start: 10,
tag: [_float_],
fill: blue,
),
),
)
#figure(
```python
def double(value: float) -> float:
return value * 2
result = double(4.0)
```,
caption: [Typing of function's body and call],
)
Anonymous functions (```py lambda```) are not yet supported
== Casts <cast>
#gc.info[
The functions discussed in this section are provided by the `midas.typing` submodule. You can import them in your script like so:
#figure(
```python
from midas.typing import cast, unsafe_cast
```,
caption: [Importing cast functions],
)
]
Sometimes, you may want to use a value whose type is not known to the type checker in a place where it expects a particular type. In that case, if you do know that the runtime type will correspond to what is expected, you can use a `cast` expression.
Similar to the `cast` function from the `typing` package of Python's Standard Library, it allows telling the type checker that a value has a given type. While `typing`'s function doesn't have any runtime side-effect, Midas' will generate runtime assertions, ensuring that your statement is true when running the code. What cannot be checked statically is checked at runtime.
In the following example, a runtime check would be generated to ensure that the value is indeed a `float` and that it satisfies the type's constraint (i.e. `>= 0`):
#codly(
highlights: (
(
line: 1,
start: 35,
end: 47,
tag: [_UnknownType_],
fill: red,
),
(
line: 2,
start: 7,
end: 17,
tag: [_PositiveFloat_],
fill: green,
),
),
)
#figure(
```python
typed_value = cast(PositiveFloat, unknown_value)
print(typed_value)
```,
caption: [Typing of `cast` expression],
)
#gc.warning[
Assertions are statements inserted just before a statement using a `cast` expression. This means that the expression is evaluated _before_ its actual intended usage location, which might cause issues if you rely on logical operator short-circuiting. See @eager-eval for more information.
]
There may be some cases where the cost of checking a value at runtime is simply not worth the safety, for example when dealing with a big dataset. If do wish so, you can use `unsafe_cast` which will only tell the type checker the type of the value, without generating a runtime assertion. This maps to the default behavior of `typing`'s own `cast` function.
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
== Annotations / Type Hints <type-annotations>
Vanilla Python already lets you use type hints to specify the type of variables and function parameters.
Midas use them to type check your code. Additionally, it allows you to use a special syntax to define a `Frame` types directly in these annotations.
Because these annotations are not interpretable by Python, your integrated type checker might complain loudly about them being invalid.
A workaround is to silence it by adding a type comment at the end of the line, as shown in @silence-errors.
#figure(
```python
var: Frame[name: str, age: float] # type: ignore # noqa: F821
```,
caption: [MyPy's and Pylance's complaints about custom type annotation can be silenced with type comments],
) <silence-errors>
=== Frame type annotation
The syntax is similar to how you can define frame types in the Midas language (see @frame-type). The only difference is that types can only be name references; you cannot inline constraint types.
The example of @python-frame-type shows how you can annotate a dataframe with some columns directly in Python.
#figure(
```python
df: Frame[name: Name, age: float, height: Length[Meter]] = ...
```,
caption: [Frame type annotation in Python],
) <python-frame-type>
= Commands <commands>
#TODO
== Type Checking (`check`) <cmd-check>
== Compiling (`compile`) <cmd-compile>
== Formatting (`format`) <cmd-format>
== Highlighting (`highlight`) <cmd-highlight>
== Dumping the AST (`parse`) <cmd-parse>
== Dumping the Registry (`dump-registry`) <cmd-registry>
== Generating Stubs (`stubs`) <cmd-stubs>
== Showing Type Judgements (`types`) <cmd-types>
== Validating Definitions (`validate`) <cmd-validate>
= Known limitations <limitations>
== Eager evaluation in runtime assertions <eager-eval>
The process of generating assertions to ensure safety at runtime, mainly for `cast` expressions, leads to the creation of aliases for the expressions being casted. These alias definitions eagerly evaluate before the assertion, and most importantly before the real usage location. This means that you should avoid using `cast` expressions inside logical expressions like `and` or `or`, because the normal "short-circuit" behavior will be irrelevant to the evaluations of the operands.
For example:
#figure(
```py
def foo():
print("Foo")
return True
def bar():
print("Bar")
return True
result = foo() or bar()
# Foo
# Bar
```,
caption: [Runtime assertions may eagerly evaluate expressions and bypass logical operator's short-circuit],
)

211
docs/midas.sublime-syntax Normal file
View File

@@ -0,0 +1,211 @@
%YAML 1.2
---
name: Midas
file_extensions:
- midas
scope: source.midas
variables:
identifier: "[a-zA-Z_][a-zA-Z0-9_]*"
contexts:
prototype:
- include: comments
main:
- include: keywords
- include: types
comments:
- match: "//"
scope: punctuation.definition.comment.midas
push:
- meta_scope: comment.line.midas
- match: $
pop: true
- match: /\*
scope: punctuation.definition.comment.midas
push:
- meta_scope: comment.block.midas
- match: \*/
pop: true
string:
- meta_include_prototype: false
- meta_scope: string.quoted.double.c
- match: '"'
pop: true
keywords:
- match: \balias\b
scope: keyword.declaration.midas
push: alias-stmt
- match: \btype\b
scope: keyword.declaration.midas
push: type-stmt
- match: \bextend\b
scope: keyword.declaration.midas
push: extend-stmt
- match: \bpredicate\b
scope: keyword.declaration.midas
push: predicate-stmt
alias-stmt:
- match: "{{identifier}}"
scope: entity.name.type
- match: "="
scope: keyword.operator.equal.midas
push: type-expr
- match: $
pop: true
type-stmt:
- match: "{{identifier}}"
scope: entity.name.type
- match: \[
push: type-params
- match: "="
scope: keyword.operator.equal.midas
push: type-expr
- match: $
pop: true
type-expr:
- match: \b(fn)\s*(\()
captures:
1: keyword.other.midas
2: punctuation.section.group.begin
push: fn-params
- match: \b(where)\b
scope: keyword.other.midas
set: constraint
- match: "Frame"
scope: entity.name.type
push:
- match: \[
push: frame-schema
- match: $
pop: true
- match: "{{identifier}}"
scope: entity.name.type
- match: $
pop: 2
fn-params:
- match: "({{identifier}})(:)"
captures:
1: variable.parameter.midas
2: punctuation.separator.annotation.midas
push:
- include: type-expr
- match: \?
scope: keyword.operator.qmark.midas
- match: "(?=,)"
scope: punctuation.separator.midas
pop: true
- match: '(?=\))'
pop: true
- include: type-expr
- match: '\)'
set:
- match: "->"
scope: keyword.operator.arrow.midas
set: type-expr
constraint:
- match: $
pop: 2
- match: \d+(\.\d+)?
scope: constant.numeric.midas
- match: \b(true|false|none)\b
scope: constant.language.midas
- match: '"'
push: string
- match: (<=|>=|<|>|==|!=|&)
scope: keyword.operator
- match: _
scope: variable.language.midas
- match: '{{identifier}}(?=\s*\()'
scope: variable.function.midas
- match: "{{identifier}}"
scope: variable.other.readwrite.midas
type-params:
- match: "<:"
scope: keyword.operator.subtype.midas
- match: "[a-zA-Z][a-zA-Z_0-9]*"
scope: entity.name.type
- match: "]"
pop: true
extend-stmt:
- match: "{{identifier}}"
scope: entity.name.type
- match: \[
push: type-params
- match: \{
scope: punctuation.section.block.begin
set: extend-body
extend-body:
- include: member-stmt
- match: \}
scope: punctuation.section.block.end
pop: true
member-stmt:
- match: \b(prop|def)\b
scope: keyword.other.midas
push:
- match: "{{identifier}}"
scope: variable.other.member
- match: ":"
push: type-expr
- match: $
pop: true
predicate-stmt:
- match: "{{identifier}}"
scope: entity.name.function.midas
- match: '\('
push: predicate-params
- match: "="
scope: keyword.operator.equal.midas
set: constraint
- match: $
pop: true
predicate-params:
- match: "({{identifier}})(:)"
captures:
1: variable.parameter.midas
2: punctuation.separator.annotation.midas
push:
- include: type-expr
- match: "(?=,)"
scope: punctuation.separator.midas
pop: true
- match: '(?=\))'
pop: true
- match: '\)'
pop: true
frame-schema:
- include: frame-column
- match: \]
# scope: punctuation.section.block.end
pop: true
frame-column:
- match: "{{identifier}}"
scope: variable.other.member
- match: ":"
push: type-expr

143
docs/template.typ Normal file
View File

@@ -0,0 +1,143 @@
#import "@preview/modpattern:0.2.0": modpattern
#let TODO = block(
width: 6em,
height: 3em,
stroke: red,
fill: modpattern(
size: (10pt, 10pt),
line(
start: (0%, 0%),
end: (100%, 100%),
stroke: gray.transparentize(60%) + 2pt,
),
),
align(
center + horizon,
text(fill: red, size: 1.5em)[*TODO*],
),
)
#let _render-header(version, hash) = {
let last-heading = query(heading.where(level: 1).before(here())).last(default: none)
let next-heading = query(heading.where(level: 1).after(here())).first(default: none)
let current-heading = if next-heading != none and next-heading.location().page() == here().page() {
next-heading
} else if last-heading != none {
last-heading
} else { none }
let chapter = if current-heading != none {
let body = current-heading.body
if current-heading.numbering != none {
let num = counter(heading).display(current-heading.numbering, at: current-heading.location())
body = [#num #body]
}
body
} else []
grid(
columns: (1fr, auto, 1fr),
align: (left, center, right),
document.title, [v#version - #hash], chapter,
)
}
#let _unshift-prefix(prefix, content) = context {
pad(left: -measure(prefix).width, prefix + content)
}
#let project(
title: none,
author: none,
version: "0.0.1",
hash: "abcdefgh",
icon-path: none,
doc,
) = {
assert(title != none, message: "Please provide a title")
set document(
title: title,
author: author,
)
set text(
font: "Source Sans 3",
)
set raw(syntaxes: path("midas.sublime-syntax"))
let front-page() = {
align(center)[
#{
set text(size: 1.5em)
std.title()
}
v#version - #hash
#if icon-path != none {
v(1cm)
image(icon-path)
}
]
pagebreak()
}
let outlines() = {
outline()
pagebreak()
outline(
title: [List of Listings],
target: figure.where(kind: raw),
)
outline(
title: [List of Tables],
target: figure.where(kind: table),
)
}
let main() = {
// Adapted from https://github.com/hei-templates/hei-synd-thesis/blob/7d2b941197babae0bf3afd4e5914754e09a64001/lib/template-thesis.typ#L242-L261
show heading.where(level: 1): it => {
pagebreak()
set text(size: 1.5em)
set block(above: 1.2em, below: 1.2em)
if it.numbering != none {
let num = numbering(it.numbering, ..counter(heading).at(it.location()))
let prefix = num + h(1em)
_unshift-prefix(prefix, it.body)
} else {
it
}
}
show heading.where(level: 2): it => {
if it.numbering != none {
let num = numbering(it.numbering, ..counter(heading).at(it.location()))
_unshift-prefix(num + h(0.8em), it.body)
} else {
it
}
}
set page(
header: context _render-header(version, hash),
footer: context if page.numbering != none {
align(center, counter(page).display(page.numbering, both: true))
},
numbering: "1 / 1",
)
show heading: set heading(numbering: "I.1.")
counter(page).update(1)
doc
}
front-page()
outlines()
main()
}

View File

@@ -2,10 +2,6 @@
# ruff: disable[F821]
from __future__ import annotations
# Prototype of custom type import to use valid Python syntax
import midas
midas.using("02_custom_types.midas")
# A data-frame using a custom type
df: Frame[
location: GeoLocation
@@ -21,7 +17,7 @@ lat + lon # Invalid operation
# Registered operations are permitted
lat1: Latitude = lat[0]
lat2: Latitude = lat[1]
lat_diff: LatitudeDiff = lat2 - lat1 # Valid operation
lat_diff: Difference[Latitude] = lat2 - lat1 # Valid operation
# In addition to the type, a column can have one or more constraints, either defined inline or in a separate file
df2: Frame[

View File

@@ -0,0 +1,73 @@
// Simple custom type derived from float
type Custom(float)
// Simple custom types with constraints
type Latitude(float) where (-90 <= _ <= 90)
type Longitude(float) where (-180 <= _ <= 180)
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
type Difference[T](T)
// Complex custom type, containing two values accessible through properties
type GeoLocation {
lat: Latitude
lon: Longitude
}
// Define operations on our custom type
extend GeoLocation {
// This type is compatible with the `-` operation with another GeoLocation
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
// in a Difference of GeoLocations
op __sub__(GeoLocation) -> Difference[GeoLocation]
}
// For complex generics, you need to specify how the genericity the properties
// are handled
type Difference[GeoLocation] {
lat: Difference[Latitude]
lon: Difference[Longitude]
}
// Simple operation defined on our custom types
extend Latitude {
op __sub__(Latitude) -> Difference[Latitude]
}
extend Longitude {
op __sub__(Longitude) -> Difference[Longitude]
}
// Predefined custom predicates that can be referenced in other definitions
predicate Positive(v: float) = v >= 0
predicate StrictlyPositive(v: float) = v > 0
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
type Person {
name: str
// Property with an inline constraint
age: int? where (0 <= _ < 150)
// Property referencing a predicate
height: float where StrictlyPositive
home: GeoLocation
}
// Custom complex type derived from another complex type, with a constraint
// on a property
// Multiple proposed syntaxes, not yet defined
// Explicit, but new keyword
type EquatorialPerson refines Person where Equatorial(_.home)
// Explicit with existing keyword, might be confusing if expectations regarding 'is'
type EquatorialPerson is Person where Equatorial(_.home)
// Consistent and Python-friendly but can be confused with structural extension
type EquatorialPerson(Person) where Equatorial(_.home)
// Allow new properties, probably not useful
type EquatorialPerson extends Person where Equatorial(_.home)

View File

@@ -0,0 +1,15 @@
# type: ignore
# ruff: disable[F821]
from __future__ import annotations
def func(
col1: Column[float + (0 <= _ <= 1)],
col2: Column[float + (0 <= _ <= 1)],
) -> Column[float + (0 <= _ <= 2)]:
result: Column[float + (0 <= _ <= 2)] = col1 + col2
return result
def func2(a: int, /, b: float, *, c: str):
pass

View File

@@ -0,0 +1,33 @@
type Foo1 = float
type Foo2 = float where (_ > 3)
type Foo3 = int | float
type Foo4 = int where (_ > 3) | float where (_ > 3)
type Foo5 = (int | float) where (_ > 3)
type Foo6 = {
foo: float
bar: float where (_ > 3)
}
type Foo7[T] = T where (_ > 3)
type Foo8[A, B<:int] = {
a: A
b: B
}
type Complex = {
a: int
b: int
}
type Complex2 = Complex where (_.a > 3 & _.b < 5)
predicate Positive(n: int) = n >= 0
extend Foo1 {
op __add__(Foo1) -> Foo1
}
extend Foo7[T] {
op __add__(Foo7[T]) -> Foo7[T]
}
type Optional[T] = None | T

View File

@@ -0,0 +1,13 @@
a: int = 3
b: int = 4
c = a + b # -> int
c = "invalid" # -> can't assign str to int variable
d = True
e = d + d
f: float = a
f = -f

View File

@@ -0,0 +1,14 @@
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
def __add__: fn(Meter, /) -> Meter
def __sub__: fn(Meter, /) -> Meter
def __truediv__: fn(Second, /) -> MeterPerSecond
}
extend Second {
def __add__: fn(Second, /) -> Second
def __sub__: fn(Second, /) -> Second
}

View File

@@ -0,0 +1,6 @@
# type: ignore
# ruff: disable [F821]
distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7)
speed = distance / time

View File

@@ -0,0 +1,23 @@
def minimum(x: int, y: int):
if x < y:
return x
else:
return y
a = 15
b = 72
c = minimum(a, b)
def factorial(n: int) -> int:
if n <= 1:
return 1
return n * factorial(n - 1)
category = "Category 1" if a < 10 else "Category 2"
def foo() -> None:
pass

View File

@@ -0,0 +1,21 @@
type Meter = float
extend Meter {
def __add__: fn(Meter, /) -> Meter
def __sub__: fn(Meter, /) -> Meter
}
type Coordinate = object
extend Coordinate {
prop x: Meter
prop y: Meter
}
type Difference[T <: float] = T
type MeterDifference = Difference[Meter]
type CompDiff[T <: float] = {
prop d1: Difference[T]
prop d2: Difference[T]
}

View File

@@ -0,0 +1,37 @@
# type: ignore
# ruff: disable [F821]
p1: Coordinate
p2: Coordinate
diff_x = p2.x - p1.x
diff_y = p2.y - p1.y
dist = diff_x + diff_y
p2.x += cast(Meter, 1)
p2.y = True # invalid, wrong type
p2.z = 3 # invalid, no property 'z' on Coordinate
p2.x.a = 3 # invalid, no properties on Meter
foo: list[float] = []
append = foo.append
foo.append("") # invalid, must be float
foo.append(2)
append(True) # invalid, must be float
append(2)
bar: list[list[Meter]]
bar.append([p2.x])
foo2 = foo + foo
a = foo[0]
b = bar[0][1]
c = bar[0][1][2] # invalid, not method __getitem__ on Meter
c = bar[""] # invalid, wrong index type
d = foo[1:2]

View File

@@ -0,0 +1,28 @@
def incr(value: int):
return value + 1
def decr(value: int):
return value - 1
def foo(a: int, /, b: float, *, c: str):
return True
r1 = foo() # foo() missing 2 required positional arguments: 'a' and 'b'
r2 = foo(1) # foo() missing 1 required positional argument: 'b'
r3 = foo(1, 2.0) # foo() missing 1 required keyword-only argument: 'c'
r4 = foo(1, b=2.0) # foo() missing 1 required keyword-only argument: 'c'
r5 = foo(1, 2.0, "test") # foo() takes 2 positional arguments but 3 were given
r6 = foo(1, 2.0, b=3.0) # foo() got multiple values for argument 'b'
r7 = foo(
a=1
) # foo() got some positional-only arguments passed as keyword arguments: 'a'
r8 = foo(g="test") # foo() got an unexpected keyword argument 'g'
r9a = foo(1, 2.0, c="test")
r9b = foo(1, b=2.0, c="test")
r9c = foo(1, c="test", b=2.0)
r10 = foo("a", 3, c=False) # wrong argument types

View File

@@ -0,0 +1,10 @@
type T1 = object
type T2 = object
type Foo = object
type T2b = T2
extend Foo {
def bar: fn(T1, /) -> int
def bar: fn(T2, /) -> float
def bar: fn(T2b, /) -> int
}

View File

@@ -0,0 +1,18 @@
# type: ignore
# ruff: disable [F821]
foo: Foo
t1: T1
t2: T2
a = foo.bar(t1)
b = foo.bar(t2)
func = foo.bar
c = func(t1)
d = func(t2)
t2b: T2b
e = foo.bar(t2b)

View File

@@ -0,0 +1,15 @@
predicate in_range(min: float, max: float)(v: float) = min <= v & v <= max
predicate is_ratio = in_range(0, 1)
type Currency = float
type Price[T <: Currency] = T where _ >= 0
extend Price[T <: Currency] {
def __add__: fn(Price[T], /) -> Price[T]
}
type EUR = Currency
type USD = Currency
type CHF = Currency
type Discount = float where is_ratio(_)

View File

@@ -0,0 +1,35 @@
from typing import TypeVar
from demo_stubs import CHF, EUR, USD, Currency, Discount, Price
from midas.typing import cast, unsafe_cast
T = TypeVar("T", bound=Currency)
def apply_discount(amount: Price[T], discount: Discount) -> Price[T]:
return cast(Price[T], (1.0 - discount) * amount)
a1 = cast(Price[EUR], 3.2)
a2 = cast(Price[USD], 10.4)
r1 = cast(Discount, 0.2)
print(apply_discount(a1, r1))
print(apply_discount(a2, r1))
a3 = a1 + a1
a4 = a1 + a2 # cannot add euros and dollars
a3 = a2 # cannot change variable type
dyn_price = float(input("Price (CHF): "))
dyn_discount = float(input("Discount (0.0-1.0): "))
discounted = apply_discount(
cast(Price[CHF], dyn_price),
cast(Discount, dyn_discount),
)
print(f"Discounted: CHF {discounted}")
large_data = [i * 10 for i in range(100)]
prices = unsafe_cast(list[Price[EUR]], large_data)

View File

@@ -0,0 +1,14 @@
from __future__ import annotations
from typing import Generic, TypeVar
class Currency(float): ...
_T0 = TypeVar("_T0", bound=Currency, covariant=True)
class Price(Currency, Generic[_T0]):
def __add__(self, _0: Price[_T0], /) -> Price[_T0]: ...
class EUR(Currency): ...
class USD(Currency): ...
class CHF(Currency): ...
class Discount(float): ...

165
gen/gen.py Normal file
View File

@@ -0,0 +1,165 @@
"""
Helper script to generate AST nodes for Midas and Python.
Takes in simple templates and generates full dataclasses and a visitor interface
"""
import re
from pathlib import Path
HEADER = '''"""
This file was generated by a script. Any manual changes might be overwritten.
Please modify {defs_path} instead and run {gen_path}
"""'''
SECTION_TEMPLATE = """{banner}
@dataclass(frozen=True, kw_only=True)
class {base}(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{visitor_methods}
{classes}"""
TEMPLATE = """{header}
from __future__ import annotations
{imports}
T = TypeVar("T")
{preamble}
{sections}
"""
VISITOR_METHOD_TEMPLATE = """
@abstractmethod
def visit_{func_name}(self, {param}: {cls}) -> T: ...
"""
CLASS_TEMPLATE = """
@dataclass(frozen=True)
class {cls}({base}):
{body}
def accept(self, visitor: {base}.Visitor[T]) -> T:
return visitor.visit_{func_name}(self)
"""
SECTION_REGEX = re.compile(
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)(\s*\|\s*(?P<param>[^\n]*?))?\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
IMPORTS_REGEX = re.compile(
r"^###>\s*Imports\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
PREAMBLE_REGEX = re.compile(
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
def make_visitor_method(cls: str, param: str):
method: str = VISITOR_METHOD_TEMPLATE.format(
func_name=snake_case(cls), param=param, cls=cls
)
return method.strip("\n")
def make_class(name: str, cls: str, base: str):
body: str = cls.split("\n", 1)[1]
func_name: str = snake_case(name)
cls_def: str = CLASS_TEMPLATE.format(
cls=name,
base=base,
body=body,
func_name=func_name,
)
return cls_def.strip("\n")
def make_banner(text: str) -> str:
middle: str = f"# {text} #"
rule: str = "#" * len(middle)
return "\n".join((rule, middle, rule))
def make_section(full_name: str, base: str, param: str, body: str) -> str:
print(f" Generating {full_name}")
visitor_methods: list[str] = []
classes: list[str] = []
definitions: list[str] = body.strip("\n").split("\n\n\n")
for cls in definitions:
cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f" Processing {name}")
visitor_methods.append(make_visitor_method(name, param))
classes.append(make_class(name, cls, base))
return SECTION_TEMPLATE.format(
banner=make_banner(full_name),
base=base,
visitor_methods="\n\n".join(visitor_methods),
classes="\n\n\n".join(classes),
)
def generate(definitions_path: Path, out_path: Path):
print(f"Processing generating {out_path} from {definitions_path}")
root_dir: Path = Path(__file__).parent.parent
rel_path: Path = definitions_path.relative_to(root_dir)
src: str = definitions_path.read_text()
sections: list[str] = []
imports: str = ""
if m := IMPORTS_REGEX.search(src):
imports = m.group("body").strip("\n")
preamble: str = ""
if m := PREAMBLE_REGEX.search(src):
preamble = m.group("body")
for section_m in SECTION_REGEX.finditer(src):
full_name: str = section_m.group("name")
base: str = section_m.group("base")
param: str = section_m.group("param") or base.lower()
body: str = section_m.group("body")
sections.append(make_section(full_name, base, param, body))
result: str = TEMPLATE.format(
header=HEADER.format(
defs_path=rel_path,
gen_path=Path(__file__).relative_to(root_dir),
),
imports=imports,
preamble=preamble,
sections="\n\n\n".join(sections),
)
out_path.write_text(result)
def main():
root: Path = Path(__file__).parent.parent
defs_dir: Path = root / "gen"
ast_dir: Path = root / "midas" / "ast"
generate(defs_dir / "midas.py", ast_dir / "midas.py")
generate(defs_dir / "python.py", ast_dir / "python.py")
if __name__ == "__main__":
main()

170
gen/midas.py Normal file
View File

@@ -0,0 +1,170 @@
# type: ignore
# ruff: disable[F821, F401]
###> Imports
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
from midas.lexer.token import Token
###<
###> Preamble
@dataclass(frozen=True, kw_only=True)
class TypeParam:
location: Location
name: Token
bound: Optional[Type]
class MemberKind(Enum):
PROPERTY = auto()
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Parameter]
mixed: list[FunctionType.Parameter]
kw: list[FunctionType.Parameter]
###<
###> Stmt | Statements
class TypeStmt:
name: Token
params: list[TypeParam]
type: Type
class AliasStmt:
name: Token
type: Type
class MemberStmt:
name: Token
type: Type
kind: MemberKind
class ExtendStmt:
name: Token
params: list[TypeParam]
members: list[MemberStmt]
class PredicateStmt:
name: Token
params: list[ParamSpec]
body: Expr
###<
###> Expr | Expressions
class LogicalExpr:
left: Expr
operator: Token
right: Expr
class BinaryExpr:
left: Expr
operator: Token
right: Expr
class UnaryExpr:
operator: Token
right: Expr
class CallExpr:
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
class GetExpr:
expr: Expr
name: Token
class VariableExpr:
name: Token
class GroupingExpr:
expr: Expr
class LiteralExpr:
value: Any
class WildcardExpr:
token: Token
###<
###> Type | Types
class NamedType:
name: Token
class GenericType:
type: Type
args: list[Type]
class ConstraintType:
type: Type
constraint: Expr
class ComplexType:
members: list[MemberStmt]
class ExtensionType:
base: Type
extension: ComplexType
class FunctionType:
params: ParamSpec
returns: Type
@dataclass(frozen=True, kw_only=True)
class Parameter:
location: Optional[Location] = None
name: Optional[Token]
type: Type
required: bool
class FrameType:
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
location: Optional[Location] = None
name: Token
type: Type
###<

192
gen/python.py Normal file
View File

@@ -0,0 +1,192 @@
# type: ignore
# ruff: disable[F821, F401]
###> Imports
import ast
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
###<
###> Preamble
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
pos: list[Function.Parameter]
mixed: list[Function.Parameter]
kw: list[Function.Parameter]
@property
def all(self) -> list[Function.Parameter]:
return self.pos + self.mixed + self.kw
###<
###> MidasType | Type annotations | node
class BaseType:
base: str
args: tuple[MidasType, ...]
class ConstraintType:
type: MidasType
constraint: ast.expr
class FrameColumn:
name: Optional[str]
type: Optional[MidasType]
class FrameType:
columns: list[FrameColumn]
###<
###> Stmt | Statements
class ExpressionStmt:
expr: Expr
class Function:
name: str
params: ParamSpec
returns: Optional[MidasType]
body: list[Stmt]
@dataclass(frozen=True, kw_only=True)
class Parameter:
location: Optional[Location] = None
name: str
type: Optional[MidasType]
default: Optional[Expr]
class TypeAssign:
name: str
type: MidasType
class AssignStmt:
targets: list[Expr]
value: Expr
class ReturnStmt:
value: Optional[Expr]
class IfStmt:
test: Expr
body: list[Stmt]
orelse: list[Stmt]
class Pass:
pass
class ForStmt:
target: Expr
iterator: Expr
body: list[Stmt]
class RawStmt:
stmt: ast.stmt
###<
###> Expr | Expressions
class BinaryExpr:
left: Expr
operator: ast.operator
right: Expr
class CompareExpr:
left: Expr
operator: ast.cmpop
right: Expr
class UnaryExpr:
operator: ast.unaryop
right: Expr
class CallExpr:
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
class GetExpr:
object: Expr
name: str
class LiteralExpr:
value: Any
class VariableExpr:
name: str
class LogicalExpr:
left: Expr
operator: ast.boolop
right: Expr
class CastExpr:
type: MidasType
expr: Expr
unsafe: bool
class TernaryExpr:
test: Expr
if_true: Expr
if_false: Expr
class ListExpr:
items: list[Expr]
class DictExpr:
keys: list[Optional[Expr]]
values: list[Expr]
class SubscriptExpr:
object: Expr
index: Expr
class SliceExpr:
lower: Optional[Expr]
upper: Optional[Expr]
step: Optional[Expr]
class TupleExpr:
items: tuple[Expr, ...]
class RawExpr:
expr: ast.expr
###<

View File

@@ -1,102 +0,0 @@
from lexer.base import Lexer
from lexer.keyword import ANNOTATION_KEYWORDS
from lexer.token import TokenType
class AnnotationLexer(Lexer):
def scan_token(self) -> None:
char: str = self.advance()
match char:
case "(":
self.add_token(TokenType.LEFT_PAREN)
case ")":
self.add_token(TokenType.RIGHT_PAREN)
case "[":
self.add_token(TokenType.LEFT_BRACKET)
case "]":
self.add_token(TokenType.RIGHT_BRACKET)
case "<":
self.add_token(
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
)
case ">":
self.add_token(
TokenType.GREATER_EQUAL if self.match("=") else TokenType.GREATER
)
case "=":
self.add_token(
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
)
case "!":
if self.match("="):
self.add_token(TokenType.BANG_EQUAL)
else:
self.error("Unexpected single bang. Did you mean '!=' ?")
case ":":
self.add_token(TokenType.COLON)
case ",":
self.add_token(TokenType.COMMA)
case "_":
self.add_token(TokenType.UNDERSCORE)
case "+":
self.add_token(TokenType.PLUS)
case "#":
self.scan_comment()
case "\n":
self.add_token(TokenType.NEWLINE)
case " " | "\r" | "\t":
# Consume all whitespace characters until EOL or EOF
while (
self.peek().isspace()
and self.peek() != "\n"
and not self.is_at_end()
):
self.advance()
self.add_token(TokenType.WHITESPACE)
case _:
if char.isdigit():
self.scan_number()
elif char.isalpha():
self.scan_identifier()
else:
self.error("Unexpected character")
return None
def scan_number(self):
"""Scan the rest of number and add it as a token
This method handles both simple integers and floats. Scientific notation
and base prefixes (0x, 0b, 0o) are not supported
"""
while self.peek().isdigit():
self.advance()
if self.peek() == "." and self.peek_next().isdigit():
self.advance()
while self.peek().isdigit():
self.advance()
value: float = float(self.source[self.start : self.idx])
self.add_token(TokenType.NUMBER, value)
def scan_identifier(self):
"""Scan the rest of an identifier and add it as a token
An identifier starts with a letter, followed by any number of
alphanumerical characters or underscores
"""
while self.peek().isalnum() or self.peek() == "_":
self.advance()
lexeme: str = self.source[self.start : self.idx]
token_type: TokenType = ANNOTATION_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
self.add_token(token_type)
def scan_comment(self):
"""Scan the rest of a comment and add it as a token
A comment starts with a `#` character and ends at the EOL/EOF
"""
while self.peek() != "\n" and not self.is_at_end():
self.advance()
self.add_token(TokenType.COMMENT)

View File

@@ -1,16 +0,0 @@
from lexer.token import TokenType
ANNOTATION_KEYWORDS: dict[str, TokenType] = {
"True": TokenType.TRUE,
"False": TokenType.FALSE,
"None": TokenType.NONE,
}
MIDAS_KEYWORDS: dict[str, TokenType] = {
"type": TokenType.TYPE,
"op": TokenType.OP,
"constraint": TokenType.CONSTRAINT,
"true": TokenType.TRUE,
"false": TokenType.FALSE,
"none": TokenType.NONE,
}

View File

@@ -1,59 +0,0 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any
from lexer.position import Position
class TokenType(Enum):
# Punctuation
LEFT_PAREN = auto()
RIGHT_PAREN = auto()
LEFT_BRACKET = auto()
RIGHT_BRACKET = auto()
LEFT_BRACE = auto()
RIGHT_BRACE = auto()
COLON = auto()
COMMA = auto()
UNDERSCORE = auto()
# Operators
PLUS = auto()
MINUS = auto()
STAR = auto()
SLASH = auto()
GREATER = auto()
GREATER_EQUAL = auto()
LESS = auto()
LESS_EQUAL = auto()
EQUAL = auto()
EQUAL_EQUAL = auto()
BANG_EQUAL = auto()
# Literals
IDENTIFIER = auto()
NUMBER = auto()
TRUE = auto()
FALSE = auto()
NONE = auto()
# Keywords
TYPE = auto()
OP = auto()
CONSTRAINT = auto()
# Misc
COMMENT = auto()
WHITESPACE = auto()
EOF = auto()
NEWLINE = auto()
@dataclass(frozen=True)
class Token:
"""A scanned token"""
type: TokenType
lexeme: str
value: Any
position: Position

49
midas/ast/location.py Normal file
View File

@@ -0,0 +1,49 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Protocol
class HasLocation(Protocol):
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@dataclass(frozen=True, kw_only=True)
class Location:
"""Information about the location of an AST node"""
lineno: int
col_offset: int
end_lineno: Optional[int]
end_col_offset: Optional[int]
@staticmethod
def from_ast(obj: HasLocation) -> Location:
return Location(
lineno=obj.lineno,
col_offset=obj.col_offset,
end_lineno=obj.end_lineno,
end_col_offset=obj.end_col_offset,
)
@staticmethod
def span(start: Location, end: Location) -> Location:
"""Create a new location spanning from one location to another
Args:
start (Location): the starting location
end (Location): the end location
Returns:
Location: a new location spanning from the start of `start`
to the end of `end`
"""
return Location(
lineno=start.lineno,
col_offset=start.col_offset,
end_lineno=end.lineno,
end_col_offset=end.end_col_offset,
)

342
midas/ast/midas.py Normal file
View File

@@ -0,0 +1,342 @@
"""
This file was generated by a script. Any manual changes might be overwritten.
Please modify gen/midas.py instead and run gen/gen.py
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
from midas.lexer.token import Token
T = TypeVar("T")
@dataclass(frozen=True, kw_only=True)
class TypeParam:
location: Location
name: Token
bound: Optional[Type]
class MemberKind(Enum):
PROPERTY = auto()
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Parameter]
mixed: list[FunctionType.Parameter]
kw: list[FunctionType.Parameter]
##############
# Statements #
##############
@dataclass(frozen=True, kw_only=True)
class Stmt(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
def visit_alias_stmt(self, stmt: AliasStmt) -> T: ...
@abstractmethod
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
@abstractmethod
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
@abstractmethod
def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ...
@dataclass(frozen=True)
class TypeStmt(Stmt):
name: Token
params: list[TypeParam]
type: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_type_stmt(self)
@dataclass(frozen=True)
class AliasStmt(Stmt):
name: Token
type: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_alias_stmt(self)
@dataclass(frozen=True)
class MemberStmt(Stmt):
name: Token
type: Type
kind: MemberKind
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_member_stmt(self)
@dataclass(frozen=True)
class ExtendStmt(Stmt):
name: Token
params: list[TypeParam]
members: list[MemberStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_extend_stmt(self)
@dataclass(frozen=True)
class PredicateStmt(Stmt):
name: Token
params: list[ParamSpec]
body: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_predicate_stmt(self)
###############
# Expressions #
###############
@dataclass(frozen=True, kw_only=True)
class Expr(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@abstractmethod
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
@abstractmethod
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
@abstractmethod
def visit_call_expr(self, expr: CallExpr) -> T: ...
@abstractmethod
def visit_get_expr(self, expr: GetExpr) -> T: ...
@abstractmethod
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
@abstractmethod
def visit_grouping_expr(self, expr: GroupingExpr) -> T: ...
@abstractmethod
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
@abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
@dataclass(frozen=True)
class LogicalExpr(Expr):
left: Expr
operator: Token
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_logical_expr(self)
@dataclass(frozen=True)
class BinaryExpr(Expr):
left: Expr
operator: Token
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_binary_expr(self)
@dataclass(frozen=True)
class UnaryExpr(Expr):
operator: Token
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_unary_expr(self)
@dataclass(frozen=True)
class CallExpr(Expr):
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_call_expr(self)
@dataclass(frozen=True)
class GetExpr(Expr):
expr: Expr
name: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_get_expr(self)
@dataclass(frozen=True)
class VariableExpr(Expr):
name: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_variable_expr(self)
@dataclass(frozen=True)
class GroupingExpr(Expr):
expr: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_grouping_expr(self)
@dataclass(frozen=True)
class LiteralExpr(Expr):
value: Any
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_literal_expr(self)
@dataclass(frozen=True)
class WildcardExpr(Expr):
token: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_wildcard_expr(self)
#########
# Types #
#########
@dataclass(frozen=True, kw_only=True)
class Type(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_named_type(self, type: NamedType) -> T: ...
@abstractmethod
def visit_generic_type(self, type: GenericType) -> T: ...
@abstractmethod
def visit_constraint_type(self, type: ConstraintType) -> T: ...
@abstractmethod
def visit_complex_type(self, type: ComplexType) -> T: ...
@abstractmethod
def visit_extension_type(self, type: ExtensionType) -> T: ...
@abstractmethod
def visit_function_type(self, type: FunctionType) -> T: ...
@abstractmethod
def visit_frame_type(self, type: FrameType) -> T: ...
@dataclass(frozen=True)
class NamedType(Type):
name: Token
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_named_type(self)
@dataclass(frozen=True)
class GenericType(Type):
type: Type
args: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_generic_type(self)
@dataclass(frozen=True)
class ConstraintType(Type):
type: Type
constraint: Expr
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_constraint_type(self)
@dataclass(frozen=True)
class ComplexType(Type):
members: list[MemberStmt]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_complex_type(self)
@dataclass(frozen=True)
class ExtensionType(Type):
base: Type
extension: ComplexType
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_extension_type(self)
@dataclass(frozen=True)
class FunctionType(Type):
params: ParamSpec
returns: Type
@dataclass(frozen=True, kw_only=True)
class Parameter:
location: Optional[Location] = None
name: Optional[Token]
type: Type
required: bool
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_function_type(self)
@dataclass(frozen=True)
class FrameType(Type):
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
location: Optional[Location] = None
name: Token
type: Type
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_frame_type(self)

View File

@@ -0,0 +1,3 @@
from .midas import MidasPrinter as MidasPrinter
from .midas_ast import MidasAstPrinter as MidasAstPrinter
from .python_ast import PythonAstPrinter as PythonAstPrinter

103
midas/ast/printer/base.py Normal file
View File

@@ -0,0 +1,103 @@
from __future__ import annotations
import io
from contextlib import contextmanager
from enum import Enum, auto
from typing import Callable, Generator, Generic, Optional, Protocol, Sequence, TypeVar
class _Level(Enum):
EMPTY = auto()
ACTIVE = auto()
LAST = auto()
class Expr(Protocol):
def accept(self, printer: AstPrinter) -> None: ...
T = TypeVar("T", bound=Expr)
class AstPrinter(Generic[T]):
LAST_CHILD = "└── "
CHILD = "├── "
VERTICAL = ""
EMPTY = " "
def __init__(self):
self._levels: list[_Level] = []
self._idx: Optional[int] = None
self._buf: io.StringIO = io.StringIO()
def print(self, expr: T):
self._buf = io.StringIO()
expr.accept(self)
return self._buf.getvalue()
@contextmanager
def _child_level(self, single: bool = False) -> Generator[None, None, None]:
self._levels.append(_Level.LAST if single else _Level.ACTIVE)
try:
yield
finally:
self._levels.pop()
def _mark_last(self):
if self._levels:
self._levels[-1] = _Level.LAST
def _write_line(self, text: str, *, last: bool = False):
if last:
self._mark_last()
indent: str = self._build_indent()
if self._idx is not None:
text = f"[{self._idx}] {text}"
self._idx = None
self._buf.write(indent + text + "\n")
def _build_indent(self) -> str:
parts: list[str] = []
for level in self._levels[:-1]:
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
if self._levels:
if self._levels[-1] == _Level.LAST:
parts.append(self.LAST_CHILD)
self._levels[-1] = _Level.EMPTY
else:
parts.append(self.CHILD)
return "".join(parts)
def _write_optional_child(
self, label: str, child: Optional[T], *, last: bool = False
):
if last:
self._mark_last()
if child is None:
self._write_line(f"{label}: None")
else:
self._write_line(label)
with self._child_level(single=True):
child.accept(self)
def _write_sequence(
self,
label: str,
list_: Sequence[T],
*,
last: bool = False,
print_func: Optional[Callable[[T], None]] = None,
):
if last:
self._mark_last()
self._write_line(label)
with self._child_level():
for i, item in enumerate(list_):
self._idx = i
if i == len(list_) - 1:
self._mark_last()
if print_func is not None:
print_func(item)
else:
item.accept(self)

183
midas/ast/printer/midas.py Normal file
View File

@@ -0,0 +1,183 @@
import midas.ast.midas as m
class MidasPrinter(
m.Expr.Visitor[str],
m.Stmt.Visitor[str],
m.Type.Visitor[str],
):
def __init__(self, indent: int = 4):
self.indent: int = indent
self.level: int = 0
def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
self.level = 0
return expr.accept(self)
# Statements
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
template: str = ""
if len(stmt.params) != 0:
params: list[str] = [self._print_type_param(param) for param in stmt.params]
template = f"[{', '.join(params)}]"
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
def _print_type_param(self, param: m.TypeParam) -> str:
res: str = param.name.lexeme
if param.bound is not None:
res += "<:" + param.bound.accept(self)
return res
def visit_member_stmt(self, stmt: m.MemberStmt):
keyword: str = {
m.MemberKind.PROPERTY: "prop",
m.MemberKind.METHOD: "def",
}.get(stmt.kind, "")
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
return self.indented(res)
def visit_extend_stmt(self, stmt: m.ExtendStmt):
template: str = ""
if len(stmt.params) != 0:
params: list[str] = [self._print_type_param(param) for param in stmt.params]
template = f"[{', '.join(params)}]"
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
res += " {\n"
self.level += 1
for member in stmt.members:
res += member.accept(self) + "\n"
self.level -= 1
res += self.indented("}")
return res
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
body: str = stmt.body.accept(self)
return self.indented(f"predicate {name}{sig} = {body}")
# Expressions
def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self)
operator: str = expr.operator.lexeme
right: str = expr.right.accept(self)
return f"{left} {operator} {right}"
def visit_binary_expr(self, expr: m.BinaryExpr):
left: str = expr.left.accept(self)
operator: str = expr.operator.lexeme
right: str = expr.right.accept(self)
return f"{left} {operator} {right}"
def visit_unary_expr(self, expr: m.UnaryExpr):
operator: str = expr.operator.lexeme
right: str = expr.right.accept(self)
return f"{operator}{right}"
def visit_call_expr(self, expr: m.CallExpr) -> str:
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
]
return f"{expr.callee.accept(self)}({', '.join(args)})"
def visit_get_expr(self, expr: m.GetExpr):
expr_: str = expr.expr.accept(self)
name: str = expr.name.lexeme
return f"{expr_}.{name}"
def visit_variable_expr(self, expr: m.VariableExpr):
return expr.name.lexeme
def visit_grouping_expr(self, expr: m.GroupingExpr):
expr_: str = expr.expr.accept(self)
return f"({expr_})"
def visit_literal_expr(self, expr: m.LiteralExpr):
return str(expr.value)
def visit_wildcard_expr(self, expr: m.WildcardExpr):
return "_"
# Types
def visit_named_type(self, type: m.NamedType) -> str:
return type.name.lexeme
def visit_generic_type(self, type: m.GenericType) -> str:
res: str = type.type.accept(self)
if len(type.args) != 0:
args: list[str] = [param.accept(self) for param in type.args]
res += f"[{', '.join(args)}]"
return res
def visit_constraint_type(self, type: m.ConstraintType) -> str:
res: str = type.type.accept(self)
res += " where " + type.constraint.accept(self)
return res
def visit_complex_type(self, type: m.ComplexType) -> str:
res: str = "{\n"
self.level += 1
for member in type.members:
res += member.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res
def visit_extension_type(self, type: m.ExtensionType) -> str:
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
def visit_function_type(self, type: m.FunctionType) -> str:
spec: str = self._visit_param_spec(type.params)
return f"fn {spec} -> {type.returns.accept(self)}"
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
pos: list[str] = [self._print_param(param) for param in spec.pos]
mixed: list[str] = [self._print_param(param) for param in spec.mixed]
kw: list[str] = [self._print_param(param) for param in spec.kw]
params: list[str] = pos
if len(pos) != 0:
params.append("/")
params += mixed
if len(kw) != 0:
params.append("*")
params += kw
return f"({', '.join(params)})"
def _print_param(self, param: m.FunctionType.Parameter) -> str:
res: str = ""
if param.name is not None:
res += param.name.lexeme
res += ": "
res += param.type.accept(self)
if not param.required:
res += "?"
return res
def visit_frame_type(self, type: m.FrameType) -> str:
res: str = self.indented("Frame[")
if len(type.columns) != 0:
res += "\n"
self.level += 1
columns: list[str] = []
for column in type.columns:
columns.append(self.indented(self._print_frame_column(column)))
res += ",\n".join(columns)
self.level -= 1
res += "\n"
res += "]"
return res
def _print_frame_column(self, column: m.FrameType.Column) -> str:
return f"{column.name.lexeme}: {column.type.accept(self)}"

View File

@@ -0,0 +1,253 @@
import midas.ast.midas as m
from midas.ast.printer.base import AstPrinter
class MidasAstPrinter(
AstPrinter,
m.Expr.Visitor[None],
m.Stmt.Visitor[None],
m.Type.Visitor[None],
):
# Statements
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self._write_line("TypeStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_sequence(
"params",
stmt.params,
print_func=self._print_type_param,
)
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
self._write_line("AliasStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_param(self, param: m.TypeParam) -> None:
self._write_line("Param")
with self._child_level():
self._write_line(f'name: "{param.name.lexeme}"')
self._write_optional_child("bound", param.bound, last=True)
def visit_member_stmt(self, stmt: m.MemberStmt):
self._write_line("MemberStmt")
with self._child_level():
self._write_line(f"kind: {stmt.kind.name}")
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._write_line("ExtendStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_sequence(
"params",
stmt.params,
print_func=self._print_type_param,
)
self._write_sequence("members", stmt.members, last=True)
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
self._write_line("PredicateStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_sequence(
"params",
stmt.params,
print_func=self._visit_param_spec,
)
self._write_line("body", last=True)
with self._child_level(single=True):
stmt.body.accept(self)
# Expressions
def visit_logical_expr(self, expr: m.LogicalExpr):
self._write_line("LogicalExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_binary_expr(self, expr: m.BinaryExpr):
self._write_line("BinaryExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_unary_expr(self, expr: m.UnaryExpr):
self._write_line("UnaryExpr")
with self._child_level():
self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_sequence("arguments", expr.arguments)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr):
self._write_line("GetExpr")
with self._child_level():
self._write_line("expr")
with self._child_level(single=True):
expr.expr.accept(self)
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
def visit_variable_expr(self, expr: m.VariableExpr):
self._write_line("VariableExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
def visit_grouping_expr(self, expr: m.GroupingExpr):
self._write_line("GroupingExpr")
with self._child_level():
self._write_line("expr", last=True)
with self._child_level(single=True):
expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level():
self._write_line(f"value: {expr.value}", last=True)
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self._write_line("WildcardExpr")
# Types
def visit_named_type(self, type: m.NamedType) -> None:
self._write_line("NamedType")
with self._child_level():
self._write_line(f'name: "{type.name.lexeme}"', last=True)
def visit_generic_type(self, type: m.GenericType) -> None:
self._write_line("GenericType")
with self._child_level():
self._write_line("type")
with self._child_level():
type.type.accept(self)
self._write_sequence("args", type.args, last=True)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
type.type.accept(self)
self._write_line("constraint", last=True)
with self._child_level(single=True):
type.constraint.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None:
self._write_line("ComplexType")
with self._child_level():
self._write_sequence("members", type.members, last=True)
def visit_extension_type(self, type: m.ExtensionType) -> None:
self._write_line("ExtensionType")
with self._child_level():
self._write_line("base")
with self._child_level(single=True):
type.base.accept(self)
self._write_line("extension", last=True)
with self._child_level(single=True):
type.extension.accept(self)
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
with self._child_level():
self._write_line("params")
with self._child_level(single=True):
self._visit_param_spec(type.params)
self._write_line("returns", last=True)
with self._child_level(single=True):
type.returns.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
self._write_line("ParamSpec")
with self._child_level():
self._write_sequence(
"pos",
spec.pos,
print_func=self._print_param,
)
self._write_sequence(
"mixed",
spec.mixed,
print_func=self._print_param,
)
self._write_sequence(
"kw",
spec.kw,
print_func=self._print_param,
last=True,
)
def _print_param(self, param: m.FunctionType.Parameter) -> None:
self._write_line("Parameter")
with self._child_level():
name: str = "None"
if param.name is not None:
name = f'"{param.name.lexeme}"'
self._write_line(f"name: {name}")
self._write_line("type")
with self._child_level(single=True):
param.type.accept(self)
self._write_line(f"required: {param.required}", last=True)
def visit_frame_type(self, type: m.FrameType) -> None:
self._write_line("FrameType")
with self._child_level(single=True):
self._write_sequence(
"columns",
type.columns,
print_func=self._print_frame_column,
)
def _print_frame_column(self, column: m.FrameType.Column) -> None:
self._write_line("Column")
with self._child_level():
self._write_line(f'name: "{column.name.lexeme}"')
self._write_line("type")
with self._child_level(single=True):
column.type.accept(self)

View File

@@ -0,0 +1,285 @@
import ast
import midas.ast.python as p
from midas.ast.printer.base import AstPrinter
class PythonAstPrinter(
AstPrinter,
p.MidasType.Visitor[None],
p.Stmt.Visitor[None],
p.Expr.Visitor[None],
):
# Types
def visit_base_type(self, node: p.BaseType) -> None:
self._write_line("BaseType")
with self._child_level():
self._write_line(f"base: {node.base}")
self._write_sequence("args", node.args, last=True)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
node.type.accept(self)
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
def visit_frame_column(self, node: p.FrameColumn) -> None:
self._write_line("FrameColumn")
with self._child_level():
self._write_line(f"name: {node.name}")
self._write_optional_child("type", node.type, last=True)
def visit_frame_type(self, node: p.FrameType) -> None:
self._write_line("FrameType")
with self._child_level(single=True):
self._write_sequence("columns", node.columns)
# Statements
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
stmt.expr.accept(self)
def visit_function(self, stmt: p.Function) -> None:
self._write_line("Function")
with self._child_level():
self._write_line(f"name: {stmt.name}")
self._write_line("params")
with self._child_level():
self._print_param_spec(stmt.params)
self._write_optional_child("returns", stmt.returns)
self._write_sequence("body", stmt.body, last=True)
def _print_param_spec(self, spec: p.ParamSpec) -> None:
self._write_line("ParamSpec")
with self._child_level():
self._write_sequence(
"pos",
spec.pos,
print_func=self._print_param,
)
self._write_sequence(
"mixed",
spec.mixed,
print_func=self._print_param,
)
self._write_sequence(
"kw",
spec.kw,
print_func=self._print_param,
last=True,
)
def _print_param(self, param: p.Function.Parameter) -> None:
self._write_line("Parameter")
with self._child_level():
self._write_line(f"name: {param.name}")
self._write_optional_child("type", param.type, last=True)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
self._write_line("TypeAssign")
with self._child_level():
self._write_line(f"name: {stmt.name}")
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
self._write_line("AssignStmt")
with self._child_level():
self._write_sequence("targets", stmt.targets)
self._write_line("value", last=True)
with self._child_level(single=True):
stmt.value.accept(self)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
self._write_line("ReturnStmt")
with self._child_level():
self._write_optional_child("value", stmt.value, last=True)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self._write_line("IfStmt")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
stmt.test.accept(self)
self._write_sequence("body", stmt.body)
self._write_sequence("orelse", stmt.orelse, last=True)
def visit_pass(self, stmt: p.Pass) -> None:
self._write_line("Pass")
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self._write_line("ForStmt")
with self._child_level():
self._write_line("target")
with self._child_level(single=True):
stmt.target.accept(self)
self._write_line("iterator")
with self._child_level(single=True):
stmt.iterator.accept(self)
self._write_sequence("body", stmt.body, last=True)
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
self._write_line("RawStmt")
with self._child_level(single=True):
self._write_line(f"stmt: {ast.unparse(stmt.stmt)}")
# Expressions
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
self._write_line("CompareExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
self._write_line("UnaryExpr")
with self._child_level():
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_call_expr(self, expr: p.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_sequence("arguments", expr.arguments)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: p.GetExpr) -> None:
self._write_line("GetExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line(f"name: {expr.name}", last=True)
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level(single=True):
self._write_line(f"value: {expr.value!r}")
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
self._write_line("VariableExpr")
with self._child_level(single=True):
self._write_line(f"name: {expr.name}")
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
self._write_line("LogicalExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self._write_line("CastExpr")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
expr.type.accept(self)
self._write_line("expr")
with self._child_level(single=True):
expr.expr.accept(self)
self._write_line(f"unsafe: {expr.unsafe}", last=True)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self._write_line("TernaryExpr")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
expr.test.accept(self)
self._write_line("if_true")
with self._child_level(single=True):
expr.if_true.accept(self)
self._write_line("if_false", last=True)
with self._child_level(single=True):
expr.if_false.accept(self)
def visit_list_expr(self, expr: p.ListExpr) -> None:
self._write_line("ListExpr")
with self._child_level():
self._write_sequence("items", expr.items, last=True)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
self._write_line("DictExpr")
with self._child_level():
self._write_sequence(
"keys",
expr.keys,
print_func=lambda k: (
self._write_line("None") if k is None else k.accept(self)
),
)
self._write_sequence("values", expr.values, last=True)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
self._write_line("SubscriptExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line("index", last=True)
with self._child_level(single=True):
expr.index.accept(self)
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
self._write_line("SliceExpr")
with self._child_level():
self._write_optional_child("lower", expr.lower)
self._write_optional_child("upper", expr.upper)
self._write_optional_child("step", expr.step, last=True)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
self._write_line("TupleExpr")
with self._child_level():
self._write_sequence("items", expr.items, last=True)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
self._write_line("RawExpr")
with self._child_level(single=True):
self._write_line(f"expr: {ast.unparse(expr.expr)}")

423
midas/ast/python.py Normal file
View File

@@ -0,0 +1,423 @@
"""
This file was generated by a script. Any manual changes might be overwritten.
Please modify gen/python.py instead and run gen/gen.py
"""
from __future__ import annotations
import ast
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
T = TypeVar("T")
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
pos: list[Function.Parameter]
mixed: list[Function.Parameter]
kw: list[Function.Parameter]
@property
def all(self) -> list[Function.Parameter]:
return self.pos + self.mixed + self.kw
####################
# Type annotations #
####################
@dataclass(frozen=True, kw_only=True)
class MidasType(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_base_type(self, node: BaseType) -> T: ...
@abstractmethod
def visit_constraint_type(self, node: ConstraintType) -> T: ...
@abstractmethod
def visit_frame_column(self, node: FrameColumn) -> T: ...
@abstractmethod
def visit_frame_type(self, node: FrameType) -> T: ...
@dataclass(frozen=True)
class BaseType(MidasType):
base: str
args: tuple[MidasType, ...]
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_base_type(self)
@dataclass(frozen=True)
class ConstraintType(MidasType):
type: MidasType
constraint: ast.expr
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_constraint_type(self)
@dataclass(frozen=True)
class FrameColumn(MidasType):
name: Optional[str]
type: Optional[MidasType]
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_frame_column(self)
@dataclass(frozen=True)
class FrameType(MidasType):
columns: list[FrameColumn]
def accept(self, visitor: MidasType.Visitor[T]) -> T:
return visitor.visit_frame_type(self)
##############
# Statements #
##############
@dataclass(frozen=True, kw_only=True)
class Stmt(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T: ...
@abstractmethod
def visit_function(self, stmt: Function) -> T: ...
@abstractmethod
def visit_type_assign(self, stmt: TypeAssign) -> T: ...
@abstractmethod
def visit_assign_stmt(self, stmt: AssignStmt) -> T: ...
@abstractmethod
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
@abstractmethod
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
@abstractmethod
def visit_pass(self, stmt: Pass) -> T: ...
@abstractmethod
def visit_for_stmt(self, stmt: ForStmt) -> T: ...
@abstractmethod
def visit_raw_stmt(self, stmt: RawStmt) -> T: ...
@dataclass(frozen=True)
class ExpressionStmt(Stmt):
expr: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_expression_stmt(self)
@dataclass(frozen=True)
class Function(Stmt):
name: str
params: ParamSpec
returns: Optional[MidasType]
body: list[Stmt]
@dataclass(frozen=True, kw_only=True)
class Parameter:
location: Optional[Location] = None
name: str
type: Optional[MidasType]
default: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_function(self)
@dataclass(frozen=True)
class TypeAssign(Stmt):
name: str
type: MidasType
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_type_assign(self)
@dataclass(frozen=True)
class AssignStmt(Stmt):
targets: list[Expr]
value: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_assign_stmt(self)
@dataclass(frozen=True)
class ReturnStmt(Stmt):
value: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_return_stmt(self)
@dataclass(frozen=True)
class IfStmt(Stmt):
test: Expr
body: list[Stmt]
orelse: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_if_stmt(self)
@dataclass(frozen=True)
class Pass(Stmt):
pass
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_pass(self)
@dataclass(frozen=True)
class ForStmt(Stmt):
target: Expr
iterator: Expr
body: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_for_stmt(self)
@dataclass(frozen=True)
class RawStmt(Stmt):
stmt: ast.stmt
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_raw_stmt(self)
###############
# Expressions #
###############
@dataclass(frozen=True, kw_only=True)
class Expr(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
@abstractmethod
def visit_compare_expr(self, expr: CompareExpr) -> T: ...
@abstractmethod
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
@abstractmethod
def visit_call_expr(self, expr: CallExpr) -> T: ...
@abstractmethod
def visit_get_expr(self, expr: GetExpr) -> T: ...
@abstractmethod
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
@abstractmethod
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
@abstractmethod
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@abstractmethod
def visit_cast_expr(self, expr: CastExpr) -> T: ...
@abstractmethod
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
@abstractmethod
def visit_list_expr(self, expr: ListExpr) -> T: ...
@abstractmethod
def visit_dict_expr(self, expr: DictExpr) -> T: ...
@abstractmethod
def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ...
@abstractmethod
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
@abstractmethod
def visit_tuple_expr(self, expr: TupleExpr) -> T: ...
@abstractmethod
def visit_raw_expr(self, expr: RawExpr) -> T: ...
@dataclass(frozen=True)
class BinaryExpr(Expr):
left: Expr
operator: ast.operator
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_binary_expr(self)
@dataclass(frozen=True)
class CompareExpr(Expr):
left: Expr
operator: ast.cmpop
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_compare_expr(self)
@dataclass(frozen=True)
class UnaryExpr(Expr):
operator: ast.unaryop
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_unary_expr(self)
@dataclass(frozen=True)
class CallExpr(Expr):
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_call_expr(self)
@dataclass(frozen=True)
class GetExpr(Expr):
object: Expr
name: str
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_get_expr(self)
@dataclass(frozen=True)
class LiteralExpr(Expr):
value: Any
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_literal_expr(self)
@dataclass(frozen=True)
class VariableExpr(Expr):
name: str
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_variable_expr(self)
@dataclass(frozen=True)
class LogicalExpr(Expr):
left: Expr
operator: ast.boolop
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_logical_expr(self)
@dataclass(frozen=True)
class CastExpr(Expr):
type: MidasType
expr: Expr
unsafe: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_cast_expr(self)
@dataclass(frozen=True)
class TernaryExpr(Expr):
test: Expr
if_true: Expr
if_false: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_ternary_expr(self)
@dataclass(frozen=True)
class ListExpr(Expr):
items: list[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_list_expr(self)
@dataclass(frozen=True)
class DictExpr(Expr):
keys: list[Optional[Expr]]
values: list[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_dict_expr(self)
@dataclass(frozen=True)
class SubscriptExpr(Expr):
object: Expr
index: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_subscript_expr(self)
@dataclass(frozen=True)
class SliceExpr(Expr):
lower: Optional[Expr]
upper: Optional[Expr]
step: Optional[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_slice_expr(self)
@dataclass(frozen=True)
class TupleExpr(Expr):
items: tuple[Expr, ...]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_tuple_expr(self)
@dataclass(frozen=True)
class RawExpr(Expr):
expr: ast.expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_raw_expr(self)

View File

@@ -0,0 +1,277 @@
extend float {
def hex: fn() -> str
def is_integer: fn() -> bool
prop real: float
prop imag: float
def conjugate: fn() -> float
def __add__: fn(value: float, /) -> float
def __sub__: fn(value: float, /) -> float
def __mul__: fn(value: float, /) -> float
def __floordiv__: fn(value: float, /) -> float
def __truediv__: fn(value: float, /) -> float
def __mod__: fn(value: float, /) -> float
// def __divmod__: fn(value: float, /) -> tuple[float, float]
def __pow__: fn(value: int, /) -> float
// positive __value -> float; negative __value -> complex
// return type must be Any as `float | complex` causes too many false-positive errors
def __pow__: fn(value: float, /) -> Any
def __radd__: fn(value: float, /) -> float
def __rsub__: fn(value: float, /) -> float
def __rmul__: fn(value: float, /) -> float
def __rfloordiv__: fn(value: float, /) -> float
def __rtruediv__: fn(value: float, /) -> float
def __rmod__: fn(value: float, /) -> float
// def __rdivmod__: fn(value: float, /) -> tuple[float, float]
// def __rpow__: fn(value: _PositiveInteger, mod: None = None, /) -> float
// def __rpow__: fn(value: _NegativeInteger, mod: None = None, /) -> complex
// Returning `complex` for the general case gives too many false-positive errors.
// def __rpow__: fn(value: float, mod: None = None, /) -> Any
// def __getnewargs__: fn() -> tuple[float]
def __trunc__: fn() -> int
def __ceil__: fn() -> int
def __floor__: fn() -> int
def __round__: fn(ndigits: None?, /) -> int
def __round__: fn(ndigits: int, /) -> float
def __eq__: fn(value: object, /) -> bool
def __ne__: fn(value: object, /) -> bool
def __lt__: fn(value: float, /) -> bool
def __le__: fn(value: float, /) -> bool
def __gt__: fn(value: float, /) -> bool
def __ge__: fn(value: float, /) -> bool
def __neg__: fn() -> float
def __pos__: fn() -> float
def __int__: fn() -> int
def __float__: fn() -> float
def __abs__: fn() -> float
def __hash__: fn() -> int
def __bool__: fn() -> bool
def __format__: fn(format_spec: str, /) -> str
}
extend int {
prop real: int
prop imag: int
prop numerator: int
prop denominator: int
def conjugate: fn() -> int
def bit_length: fn() -> int
def bit_count: fn() -> int
// def to_bytes: fn(length: int?, byteorder: str?, *, signed: bool?) -> bytes
def __add__: fn(value: int, /) -> int
def __sub__: fn(value: int, /) -> int
def __mul__: fn(value: int, /) -> int
def __floordiv__: fn(value: int, /) -> int
def __truediv__: fn(value: int, /) -> float
def __mod__: fn(value: int, /) -> int
// def __divmod__: fn(value: int, /) -> tuple[int, int]
def __radd__: fn(value: int, /) -> int
def __rsub__: fn(value: int, /) -> int
def __rmul__: fn(value: int, /) -> int
def __rfloordiv__: fn(value: int, /) -> int
def __rtruediv__: fn(value: int, /) -> float
def __rmod__: fn(value: int, /) -> int
// def __rdivmod__: fn(value: int, /) -> tuple[int, int]
def __pow__: fn(value: int, /) -> int
// def __pow__: fn(value: _PositiveInteger, mod: None = None, /) -> int
// def __pow__: fn(value: _NegativeInteger, mod: None = None, /) -> float
// positive __value -> int; negative __value -> float
// return type must be Any as `int | float` causes too many false-positive errors
// def __pow__: fn(value: int, mod: None = None, /) -> Any
// def __pow__: fn(value: int, mod: int, /) -> int
def __rpow__: fn(value: int, /) -> Any
def __and__: fn(value: int, /) -> int
def __or__: fn(value: int, /) -> int
def __xor__: fn(value: int, /) -> int
def __lshift__: fn(value: int, /) -> int
def __rshift__: fn(value: int, /) -> int
def __rand__: fn(value: int, /) -> int
def __ror__: fn(value: int, /) -> int
def __rxor__: fn(value: int, /) -> int
def __rlshift__: fn(value: int, /) -> int
def __rrshift__: fn(value: int, /) -> int
def __neg__: fn() -> int
def __pos__: fn() -> int
def __invert__: fn() -> int
def __trunc__: fn() -> int
def __ceil__: fn() -> int
def __floor__: fn() -> int
def __round__: fn(ndigits: None?, /) -> int
def __round__: fn(ndigits: int, /) -> int
// def __getnewargs__: fn() -> tuple[int]
def __eq__: fn(value: object, /) -> bool
def __ne__: fn(value: object, /) -> bool
def __lt__: fn(value: int, /) -> bool
def __le__: fn(value: int, /) -> bool
def __gt__: fn(value: int, /) -> bool
def __ge__: fn(value: int, /) -> bool
def __float__: fn() -> float
def __int__: fn() -> int
def __abs__: fn() -> int
def __hash__: fn() -> int
def __bool__: fn() -> bool
def __index__: fn() -> int
def __format__: fn(format_spec: str, /) -> str
}
extend list[T] {
def copy: fn () -> list[T]
def append: fn (object: T, /) -> None
def extend: fn (iterable: list[T], /) -> None
def pop: fn (index: int?, /) -> T
def index: fn (value: T, start: int?, stop: int?, /) -> int
def count: fn (value: T, /) -> int
def insert: fn (index: int, object: T, /) -> None
def remove: fn (value: T, /) -> None
def sort: fn (*, reverse: bool?) -> None
def __len__: fn () -> int
// def __iter__: fn () -> Iterator[T]
def __getitem__: fn (i: int, /) -> T
def __getitem__: fn (s: slice, /) -> list[T]
def __setitem__: fn (key: int, value: T, /) -> None
def __setitem__: fn (key: slice, value: list[T], /) -> None
def __delitem__: fn (key: int, /) -> None
def __delitem__: fn (key: slice, /) -> None
// def __add__: fn[S <: T] (value: list[S], /) -> list[T]
def __add__: fn (value: list[T], /) -> list[T]
def __iadd__: fn (value: list[T], /) -> list[T]
def __mul__: fn (value: int, /) -> list[T]
def __rmul__: fn (value: int, /) -> list[T]
def __imul__: fn (value: int, /) -> list[T]
def __contains__: fn (key: object, /) -> bool
// def __reversed__: fn (self) -> Iterator[_T]
def __gt__: fn (value: list[T], /) -> bool
def __ge__: fn (value: list[T], /) -> bool
def __lt__: fn (value: list[T], /) -> bool
def __le__: fn (value: list[T], /) -> bool
def __eq__: fn (value: object, /) -> bool
prop __doc__: str
}
extend dict[K, V] {
def copy: fn() -> dict[K, V]
def keys: fn() -> list[K] // TODO: use builtin types
def values: fn() -> list[V] // TODO: use builtin types
// def items: fn() -> list[tuple[K, V]] // TODO: use builtin types
// def get: fn(key: K, default: None = None, /) -> V | None
def get: fn(key: K, default: V, /) -> V
// def get: fn[T](key: K, default: T, /) -> V | T
def pop: fn(key: K, /) -> V
def pop: fn(key: K, default: V, /) -> V
// def pop: fn[T](key: K, default: T, /) -> V | T
def __len__: fn() -> int
def __getitem__: fn(key: K, /) -> V
def __setitem__: fn(key: K, value: V, /) -> None
def __delitem__: fn(key: K, /) -> None
// def __iter__: fn() -> Iterator[K]
def __eq__: fn(value: object, /) -> bool
// def __reversed__: fn() -> Iterator[K]
def __or__: fn(value: dict[K, V], /) -> dict[K, V]
// def __or__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2]
def __ror__: fn(value: dict[K, V], /) -> dict[K, V]
// def __ror__: fn[K2, V2](value: dict[K2, V2], /) -> dict[K | K2, V | V2]
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
}
extend str {
def capitalize: fn() -> str
def casefold: fn() -> str
def center: fn(width: int, fillchar: str?, /) -> str
def count: fn(sub: str, start: None?, end: None?, /) -> int
def count: fn(sub: str, start: int, end: None?, /) -> int
def count: fn(sub: str, start: None, end: int, /) -> int
def count: fn(sub: str, start: int, end: int, /) -> int
def encode: fn(encoding: str?, errors: str?) -> bytes
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
def expandtabs: fn(tabsize: int?) -> str
def find: fn(sub: str, start: None?, end: None?, /) -> int
def find: fn(sub: str, start: int, end: None?, /) -> int
def find: fn(sub: str, start: None, end: int, /) -> int
def find: fn(sub: str, start: int, end: int, /) -> int
// def format: fn(*args: object, **kwargs: object) -> str
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
def index: fn(sub: str, start: None?, end: None?, /) -> int
def index: fn(sub: str, start: int, end: None?, /) -> int
def index: fn(sub: str, start: None, end: int, /) -> int
def index: fn(sub: str, start: int, end: int, /) -> int
def isalnum: fn() -> bool
def isalpha: fn() -> bool
def isascii: fn() -> bool
def isdecimal: fn() -> bool
def isdigit: fn() -> bool
def isidentifier: fn() -> bool
def islower: fn() -> bool
def isnumeric: fn() -> bool
def isprintable: fn() -> bool
def isspace: fn() -> bool
def istitle: fn() -> bool
def isupper: fn() -> bool
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
def ljust: fn(width: int, fillchar: str?, /) -> str
def lower: fn() -> str
def lstrip: fn(chars: None?, /) -> str
def lstrip: fn(chars: str, /) -> str
def partition: fn(sep: str, /) -> tuple[str, str, str]
def replace: fn(old: str, new: str, count: int?, /) -> str
def removeprefix: fn(prefix: str, /) -> str
def removesuffix: fn(suffix: str, /) -> str
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
def rfind: fn(sub: str, start: int, end: None?, /) -> int
def rfind: fn(sub: str, start: None, end: int, /) -> int
def rfind: fn(sub: str, start: int, end: int, /) -> int
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
def rindex: fn(sub: str, start: int, end: None?, /) -> int
def rindex: fn(sub: str, start: None, end: int, /) -> int
def rindex: fn(sub: str, start: int, end: int, /) -> int
def rjust: fn(width: int, fillchar: str?, /) -> str
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
def rstrip: fn(chars: None?, /) -> str
def rstrip: fn(chars: str, /) -> str
def split: fn(sep: None?, maxsplit: int?) -> list[str]
def split: fn(sep: str, maxsplit: int?) -> list[str]
def splitlines: fn(keepends: bool?) -> list[str]
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
def strip: fn(chars: None?, /) -> str
def strip: fn(chars: str, /) -> str
def swapcase: fn() -> str
def title: fn() -> str
// def translate: fn(table: _TranslateTable, /) -> str
def upper: fn() -> str
def zfill: fn(width: int, /) -> str
def __add__: fn(value: str, /) -> str
// Incompatible with Sequence.__contains__
def __contains__: fn(key: str, /) -> bool
def __eq__: fn(value: object, /) -> bool
def __ge__: fn(value: str, /) -> bool
def __getitem__: fn(key: slice, /) -> str
def __getitem__: fn(key: int, /) -> str
def __gt__: fn(value: str, /) -> bool
def __hash__: fn() -> int
// def __iter__: fn() -> Iterator[str]
def __le__: fn(value: str, /) -> bool
def __len__: fn() -> int
def __lt__: fn(value: str, /) -> bool
def __mod__: fn(value: Any, /) -> str
def __mul__: fn(value: int, /) -> str
def __ne__: fn(value: object, /) -> bool
def __rmul__: fn(value: int, /) -> str
def __getnewargs__: fn() -> tuple[str]
def __format__: fn(format_spec: str, /) -> str
}

60
midas/checker/builtins.py Normal file
View File

@@ -0,0 +1,60 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from midas.checker.types import (
BaseType,
GenericType,
TopType,
TypeVar,
UnitType,
)
if TYPE_CHECKING:
from midas.checker.registry import TypesRegistry
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
"float": {"int"},
}
"""
Hard-coded subtype relationships between builtin types
Circular dependencies and diamond inheritance MUST be avoided
"""
def define_builtins(reg: TypesRegistry):
"""Define builtin types and operations"""
any = reg.define_type("Any", TopType())
unit = reg.define_type("None", UnitType())
object = reg.define_type("object", BaseType(name="object"))
bytes = reg.define_type("bytes", BaseType(name="bytes"))
bool = reg.define_type("bool", BaseType(name="bool"))
int = reg.define_type("int", BaseType(name="int"))
float = reg.define_type("float", BaseType(name="float"))
str = reg.define_type("str", BaseType(name="str"))
slice = reg.define_type("slice", BaseType(name="slice"))
tuple = reg.define_type("tuple", BaseType(name="tuple"))
list = reg.define_type(
"list",
GenericType(
name="list",
params=[TypeVar(name="T", bound=None)],
body=BaseType(name="list"),
),
)
dict = reg.define_type(
"dict",
GenericType(
name="dict",
params=[
TypeVar(name="K", bound=None),
TypeVar(name="V", bound=None),
],
body=BaseType(name="dict"),
),
)

41
midas/checker/checker.py Normal file
View File

@@ -0,0 +1,41 @@
from pathlib import Path
from typing import Optional
from midas.checker.diagnostic import Diagnostic
from midas.checker.midas import MidasTyper
from midas.checker.python import PythonTyper
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import Reporter
from midas.utils import TypedAST
class TypeChecker:
"""Type checking dispatcher
Contains a typer for Midas and one for Python, as well as the types registry
"""
def __init__(self):
self.types: TypesRegistry = TypesRegistry()
self.reporter: Reporter = Reporter()
self.midas_typer = MidasTyper(self.types, self.reporter)
self.python_typer = PythonTyper(self.types, self.reporter)
def import_midas(self, path: Path):
source: str = path.read_text()
return self.import_midas_source(source, path=str(path))
def import_midas_source(self, source: str, path: Optional[str] = None):
self.midas_typer.process(source, path)
def type_check(self, path: Path) -> TypedAST:
source: str = path.read_text()
return self.type_check_source(source, path=str(path))
def type_check_source(self, source: str, path: Optional[str] = None) -> TypedAST:
return self.python_typer.process(source, path)
@property
def diagnostics(self) -> list[Diagnostic]:
return self.reporter.diagnostics

View File

@@ -0,0 +1,64 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Optional
from midas.ast.location import Location
class DiagnosticType(StrEnum):
ERROR = "Error"
WARNING = "Warning"
INFO = "Info"
DEBUG = "Debug"
@dataclass(frozen=True)
class Diagnostic:
"""Information about a diagnostic (warning, errors, etc.)
Holds a location, a diagnostic type and a message.
Optionally bound to a file path
Returns:
_type_: _description_
"""
file_path: Optional[str]
location: Location
type: DiagnosticType
message: str
@property
def location_str(self) -> str:
"""The diagnostic type and location as a human readable string
The location is formatted as "<Type> in <file> from L<start_line>:<start_col> to <end_line>:<end_col>",
for example: "Error in /home/user/Desktop/script.py from L12:5 to L12:8"
If the file is `None`, the "in ..." section is excluded from the result.<br>
If the location's end is not specified, the formulation "at L<start_line>:<start_col>" is used.
Returns:
str: _description_
"""
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
end_loc: Optional[str] = ""
if (
self.location.end_lineno is not None
and self.location.end_col_offset is not None
):
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
loc: str = ""
if self.file_path is not None:
loc += f" in {self.file_path}"
if end_loc is None:
loc += f" at {start_loc}"
else:
loc += f" from {start_loc} to {end_loc}"
return f"{self.type}{loc}"
def __str__(self) -> str:
return f"{self.location_str}: {self.message}"

486
midas/checker/dispatcher.py Normal file
View File

@@ -0,0 +1,486 @@
import logging
from dataclasses import dataclass
from enum import StrEnum
from typing import Generic, Optional, Protocol, TypeVar, Union
from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
AppliedType,
DerivedType,
Function,
GenericType,
OverloadedFunction,
Type,
UnknownType,
)
from midas.checker.unifier import Unifier
class HasLocation(Protocol):
@property
def location(self) -> Location: ...
E = TypeVar("E", bound=HasLocation)
TypedExpr = tuple[E, Type]
@dataclass(frozen=True, kw_only=True)
class MappedArgument(Generic[E]):
arg_expr: E
arg_type: Type
parameter: Function.Parameter
@dataclass(frozen=True, kw_only=True)
class OverloadCandidate:
function: Function
mapped: list[MappedArgument]
class CallError(StrEnum):
INVALID_ARGS = "Invalid arguments"
NO_MATCHING_OVERLOAD = "No matching overload"
IMPOSSIBLE_UNIFICATION = "Parameters unification failed"
NOT_CALLABLE = "Not callable"
@dataclass(frozen=True, kw_only=True)
class CallResult:
error: Optional[CallError] = None
result: Type = UnknownType()
message: Optional[str] = None
@property
def is_valid(self) -> bool:
return self.error is None
@property
def error_message(self) -> str:
if self.message is not None:
return self.message
if self.error is not None:
return str(self.error)
return ""
class CallDispatcher(Generic[E]):
def __init__(self, types: TypesRegistry, reporter: FileReporter) -> None:
self.types: TypesRegistry = types
self.reporter: FileReporter = reporter
self.logger: logging.Logger = logging.getLogger("CallDispatcher")
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
def get_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
report_errors: bool = True,
) -> CallResult:
"""Get the result type of a function call
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:
location (Location): the call location
callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely
"""
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument[E]]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid:
return CallResult(error=CallError.INVALID_ARGS)
return CallResult(result=function.returns)
case OverloadedFunction(overloads=overloads):
res = self._match_overload(
overloads, location, positional, keywords, report_errors
)
if res[0] is None:
return CallResult(
error=CallError.NO_MATCHING_OVERLOAD,
message=res[1],
)
return CallResult(result=res[0].returns)
case AppliedType(body=body):
return self.get_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return CallResult(result=UnknownType())
case DerivedType(type=base):
return self.get_result(
location, base, positional, keywords, report_errors
)
case GenericType():
unifier: Unifier = Unifier(self.types)
pos: list[Type] = [a[1] for a in positional]
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
if unified is None:
pos_str: str = ", ".join(str(t) for t in pos)
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
message: str = (
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}"
)
if report_errors:
self.reporter.error(location, message)
return CallResult(
error=CallError.IMPOSSIBLE_UNIFICATION,
message=message,
)
return self.get_result(
location,
unified,
positional,
keywords,
report_errors,
)
case _:
message: str = f"{callee} ({callee.__class__.__name__}) is not callable"
if report_errors:
self.reporter.error(location, message)
return CallResult(
error=CallError.NOT_CALLABLE,
message=message,
)
def _unwrap_function(
self,
callee: Type,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
) -> Union[tuple[Function, None], tuple[None, CallError]]:
match callee:
case DerivedType(type=base):
return self._unwrap_function(base, positional, keywords)
case GenericType():
unifier: Unifier = Unifier(self.types)
unified: Optional[Type] = unifier.unify_call(
callee,
[a[1] for a in positional],
{k: v[1] for k, v in keywords.items()},
)
if unified is None:
return None, CallError.IMPOSSIBLE_UNIFICATION
return self._unwrap_function(unified, positional, keywords)
case Function():
return callee, None
case AppliedType(body=body):
return self._unwrap_function(body, positional, keywords)
case _:
return None, CallError.NOT_CALLABLE
def _are_arguments_valid(
self,
arguments: list[MappedArgument[E]],
report_errors: bool = True,
) -> bool:
"""Check whether the passed argument types correspond to their matched parameter definitions
Args:
arguments (list[MappedArgument]): the list of argument/parameter pairs
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
bool: True if all arguments fit the matching parameter definitions, False otherwise
"""
valid: bool = True
for arg in arguments:
if not self.types.is_subtype(arg.arg_type, arg.parameter.type):
if report_errors:
self.reporter.error(
arg.arg_expr.location,
f"Wrong type for argument '{arg.parameter.name}', expected {arg.parameter.type}, got {arg.arg_type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
report_errors: bool = True,
) -> Union[tuple[Function, None], tuple[None, str]]:
"""Try and resolve the appropriate overload for the given arguments
Args:
overloads (list[Type]): the list of possible overloads
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambiguously, or `None`.
"""
candidates: list[OverloadCandidate] = []
errors: list[CallError] = []
for overload in overloads:
function, unwrap_error = self._unwrap_function(
overload, positional, keywords
)
if function is None:
errors.append(unwrap_error) # type: ignore
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(
OverloadCandidate(
function=function,
mapped=mapped,
)
)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
n_candidates: int = len(candidates)
# Exactly 1 match -> return it
if n_candidates == 1:
return candidates[0].function, None
# No match -> invalid call
if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads))
errors_str: str = ", ".join(errors)
message: str = (
f"No matching overload in [{overloads_str}] {for_args} (errors: {errors_str})"
)
if report_errors:
self.reporter.error(location, message)
return None, message
# Multiple matches -> see if one <: all others (more specific)
for i1, c1 in enumerate(candidates):
mapped1: list[MappedArgument[E]] = c1.mapped
best_match: bool = True
for i2, c2 in enumerate(candidates):
if i1 == i2:
continue
mapped2: list[MappedArgument[E]] = c2.mapped
if not self._are_mapped_subtypes(mapped1, mapped2):
best_match = False
break
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
if best_match:
return c1.function, None
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
message: str = f"Multiple matching overloads {for_args}: {candidates_str}"
if report_errors:
self.reporter.error(location, message)
return None, message
def map_call_arguments(
self,
function: Function,
location: Location,
positional: list[TypedExpr[E]],
keywords: dict[str, TypedExpr[E]],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to a function's parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic,
unless `report_errors` is set to `False`
Args:
function (Function): the function definition
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
tuple[bool, list[MappedArgument]]: a boolean reporting whether
the call is valid and the list of mapped arguments
"""
set_params: set[str] = set()
required_positional: list[str] = [
param.name
for param in function.params.pos + function.params.mixed
if param.required
]
required_keyword: list[str] = [
param.name for param in function.params.kw if param.required
]
mapped: list[MappedArgument[E]] = []
pos_params: list[Function.Parameter] = list(function.params.pos)
mixed_params: list[Function.Parameter] = list(function.params.mixed)
kw_params: dict[str, Function.Parameter] = {
param.name: param for param in function.params.kw
}
valid_call: bool = True
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Parameter
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_params.add(name)
mapped.append(
MappedArgument(
arg_expr=arg[0],
arg_type=arg[1],
parameter=param,
)
)
kw_params.update({param.name: param for param in mixed_params})
for name, arg in keywords.items():
param: Function.Parameter
if name not in kw_params:
if report_errors:
if name in set_params:
self.reporter.error(
arg[0].location, f"Multiple values for parameter '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword parameter '{name}'"
)
valid_call = False
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_params.add(name)
mapped.append(
MappedArgument(
arg_expr=arg[0],
arg_type=arg[1],
parameter=param,
)
)
def join_params(params: list[str]) -> str:
params = list(map(lambda p: f"'{p}'", params))
if len(params) == 0:
return ""
if len(params) == 1:
return params[0]
return ", ".join(params[:-1]) + " and " + params[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
params: str = join_params(required_positional)
if report_errors:
self.reporter.error(
location,
f"Missing required positional argument{plural}: {params}",
)
valid_call = False
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
params: str = join_params(required_keyword)
if report_errors:
self.reporter.error(
location,
f"Missing required keyword argument{plural}: {params}",
)
valid_call = False
return valid_call, mapped
def _are_mapped_subtypes(
self, mapped1: list[MappedArgument[E]], mapped2: list[MappedArgument[E]]
) -> bool:
"""Check whether the given argument mappings are subtype/supertype of one another
This function checks whether the argument mappings `mapped1` are subtypes
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
of the corresponding parameter in `mapped2`, `False` is returned.
This is used to check whether a given overload is
a more specific function/ a subtype of another.
Args:
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
Returns:
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
"""
by_expr: dict[E, Type] = {}
for arg in mapped1:
by_expr[arg.arg_expr] = arg.parameter.type
for arg in mapped2:
type2: Type = arg.parameter.type
type1: Type = by_expr[arg.arg_expr]
if not self.types.is_subtype(type1, type2):
return False
return True

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
from typing import Optional
from midas.checker.types import Type
class Environment:
"""
A scoped environment in which variables are defined
Each environment can inherit from a parent/enclosing environment.
"""
def __init__(self, enclosing: Optional[Environment] = None) -> None:
self.enclosing: Optional[Environment] = enclosing
self.values: dict[str, Type] = {}
self.return_types: list[Type] = []
self._children: list[Environment] = []
if enclosing is not None:
enclosing._children.append(self)
def define(self, name: str, value: Type) -> None:
"""Define a variable in this environment
Args:
name (str): the name of the variable
value (Type): the value
"""
self.values[name] = value
def get(self, name: str) -> Optional[Type]:
"""Get a variable in the closest environment which has a definition for it
Args:
name (str): the name of the variable
Returns:
Optional[Type]: the value of the variable, or None if it was not found
"""
if name in self.values:
return self.values[name]
if self.enclosing is not None:
return self.enclosing.get(name)
# raise NameError(f"Undefined variable '{name}'")
return None
def assign(self, name: str, value: Type) -> bool:
"""Assign a new value to a variable in the environment it was defined in
Args:
name (str): the name of the variable
value (Type): the new value
Returns:
bool: True if the variable was assigned in this environment or an ancestor, False otherwise
"""
if name not in self.values:
if self.enclosing is None:
return False
if self.enclosing.assign(name, value):
return True
self.values[name] = value
return True
def clear(self):
"""Clear all definitions in this environment"""
self.values = {}
def get_at(self, distance: int, name: str) -> Optional[Type]:
"""Get the value of a variable at a given distance
A distance of 0 looks up in this environment, 1 in the parent environment, etc.
This methods expects `distance` to be valid. An error will be raised if
the stack does not extend far enough to reach `distance`
Args:
distance (int): the scope distance
name (str): the name of the variable
Returns:
Optional[Type]: the value at the given distance, or None if it is not defined in that environment
Raises:
AssertionError: if the stack does not extend far enough to reach `distance`
"""
return self.ancestor(distance).values.get(name)
def assign_at(self, distance: int, name: str, value: Type) -> None:
"""Assign a new value to a variable at a given distance
A distance of 0 assigns in this environment, 1 in the parent environment, etc.
Args:
distance (int): the scope distance
name (str): the name of the variable
value (Type): the new value
Raises:
AssertionError: if the stack does not extend far enough to reach `distance`
"""
self.ancestor(distance).values[name] = value
def ancestor(self, distance: int) -> Environment:
"""Get the ancestor at a given distance
A distance of 0 references this environment, 1 the parent environment, etc.
Args:
distance (int): the scope distance
Returns:
Environment: the environment
Raises:
AssertionError: if the stack does not extend far enough to reach `distance`
"""
env: Environment = self
for _ in range(distance):
assert env.enclosing is not None
env = env.enclosing
return env
def flat_dict(self) -> dict[str, Type]:
"""Get the current environment including definitions in its ancestor as a flat dictionary
This method recursively combines this environment definitions with its ancestor's
Returns:
dict: the combined environment
"""
if self.enclosing is None:
return self.values
return self.enclosing.flat_dict() | self.values
def dump(self) -> dict:
return {
"values": self.values,
"return_types": self.return_types,
"children": [child.dump() for child in self._children],
}

174
midas/checker/evaluator.py Normal file
View File

@@ -0,0 +1,174 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional
import midas.ast.midas as m
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import Function, Predicate
from midas.lexer.token import TokenType
@dataclass(frozen=True, kw_only=True)
class PartialPredicate(Predicate):
scope: dict[str, Any]
class Evaluator(m.Expr.Visitor[Any]):
def __init__(self, types: TypesRegistry, reporter: Optional[FileReporter] = None):
self.types: TypesRegistry = types
self.reporter: Optional[FileReporter] = reporter
self.preamble: Preamble = Preamble(self.types)
self.scopes: list[dict[str, Any]] = [{}]
def evaluate(self, expr: m.Expr) -> Any:
value: Any = expr.accept(self)
if self.reporter is not None:
self.reporter.debug(expr.location, f"Value: {value}")
return value
def get_value(self, name: str) -> Any:
scope: dict[str, Any] = self.scopes[-1]
return scope[name]
def set_value(self, name: str, value: Any, force_declare: bool = False):
if not force_declare:
for scope in reversed(self.scopes):
if name in scope:
scope[name] = value
return
self.scopes[-1][name] = value
def visit_logical_expr(self, expr: m.LogicalExpr) -> Any:
def left():
return self.evaluate(expr.left)
def right():
return self.evaluate(expr.right)
match expr.operator.type:
case TokenType.AND:
return left() and right()
case _:
raise NotImplementedError
def visit_binary_expr(self, expr: m.BinaryExpr) -> Any:
left: Any = self.evaluate(expr.left)
right: Any = self.evaluate(expr.right)
match expr.operator.type:
case TokenType.MINUS:
return left - right
case TokenType.STAR:
return left * right
case TokenType.SLASH:
return left / right
case TokenType.GREATER:
return left > right
case TokenType.GREATER_EQUAL:
return left >= right
case TokenType.LESS:
return left < right
case TokenType.LESS_EQUAL:
return left <= right
case TokenType.EQUAL_EQUAL:
return left == right
case TokenType.BANG_EQUAL:
return left != right
case _:
raise NotImplementedError
def visit_unary_expr(self, expr: m.UnaryExpr) -> Any:
right: Any = self.evaluate(expr.right)
match expr.operator.type:
case TokenType.MINUS:
return -right
case _:
raise NotImplementedError
def visit_call_expr(self, expr: m.CallExpr) -> Any:
callee: Any = self.evaluate(expr.callee)
args: list[Any] = [self.evaluate(arg) for arg in expr.arguments]
kwargs: dict[str, Any] = {
name: self.evaluate(arg) for name, arg in expr.keywords.items()
}
match callee:
case Predicate():
return self._evaluate_predicate(callee, args, kwargs)
case _ if callable(callee):
return callee(*args, **kwargs)
case _:
return NotImplementedError
def visit_get_expr(self, expr: m.GetExpr) -> Any:
obj: Any = self.evaluate(expr.expr)
return getattr(obj, expr.name.lexeme)
def visit_variable_expr(self, expr: m.VariableExpr) -> Any:
name: str = expr.name.lexeme
for scope in reversed(self.scopes):
if name in scope:
return scope[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is not None:
if predicate.alias:
return self.evaluate(predicate.body)
return predicate
glob: Optional[Callable] = self.preamble.get_py_func(name)
if glob is not None:
return glob
raise NameError(f"Unknown variable '{name}'")
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Any:
return self.evaluate(expr.expr)
def visit_literal_expr(self, expr: m.LiteralExpr) -> Any:
return expr.value
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Any:
return self.get_value("_")
def _evaluate_predicate(
self, predicate: Predicate, args: list[Any], kwargs: dict[str, Any]
) -> Any:
res: Any = None
if isinstance(predicate, PartialPredicate):
self.scopes.append(predicate.scope)
else:
self.scopes.append({})
match predicate.type:
case Function(returns=Function() as inner):
self._map_args(predicate.type, args, kwargs)
res = PartialPredicate(
type=inner,
body=predicate.body,
alias=False,
scope=self.scopes[-1],
)
case Function():
self._map_args(predicate.type, args, kwargs)
res = self.evaluate(predicate.body)
case _:
raise NotImplementedError
self.scopes.pop()
return res
def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]):
positional: list[Function.Parameter] = (
function.params.pos + function.params.mixed
)
keywords: dict[str, Function.Parameter] = {
param.name: param for param in function.params.mixed + function.params.kw
}
for i, arg in enumerate(args):
param: Function.Parameter = positional[i]
self.set_value(param.name, arg)
for name, arg in kwargs.items():
param: Function.Parameter = keywords[name]
self.set_value(param.name, arg)

View File

@@ -0,0 +1,210 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import (
ColumnGroupBy,
ColumnType,
Function,
ParamSpec,
TopType,
Type,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
groupby: ColumnGroupBy
groupby_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.groupby_expr, self.groupby)
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
NAMED_ARGS: dict[str, str] = {
"numeric_only": "bool",
"skipna": "bool",
"engine": "str",
"engine_kwargs": "dict",
}
def _aggregate(
self,
call: Call,
params: list[str | tuple[str, str, bool]] = [],
*,
preserve_inner_type: bool = False,
) -> Type:
real_params: list[Function.Parameter] = []
for i, param in enumerate(params):
match param:
case str() as name:
param = Function.Parameter(
pos=i,
name=name,
type=self.types.get_type(self.NAMED_ARGS[name]),
required=False,
)
case (name, type, required):
param = Function.Parameter(
pos=i,
name=name,
type=self.types.get_type(type),
required=required,
)
real_params.append(param)
signature = Function(
params=ParamSpec(mixed=real_params),
returns=(
call.groupby.column
if preserve_inner_type
else ColumnType(type=TopType())
),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(
call,
["skipna", "numeric_only"],
)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
preserve_inner_type=True,
)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna", "engine", "engine_kwargs"],
)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna"],
preserve_inner_type=True,
)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
preserve_inner_type=True,
)
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
],
)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"ddof",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"var",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
from midas.checker.registry import TypesRegistry
from midas.checker.types import ColumnGroupBy, ColumnType, Type
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
class ColumnManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer)
self.groupby_method_resolver: ColumnGroupByMethodRegistry = (
ColumnGroupByMethodRegistry(self.typer)
)
def call(
self,
method: str,
location: Location,
call_expr: p.Expr,
column: ColumnType,
column_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
call_expr=call_expr,
column=column,
column_expr=column_expr,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)
def groupby_call(
self,
method: str,
location: Location,
call_expr: p.Expr,
groupby: ColumnGroupBy,
groupby_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: GroupByCall = GroupByCall(
location=location,
call_expr=call_expr,
groupby=groupby,
groupby_expr=groupby_expr,
positional=positional,
keywords=keywords,
)
return self.groupby_method_resolver.call(method, call)
def get_attribute(self, column: ColumnType, name: str) -> Optional[Type]:
types: TypesRegistry = self.typer.types
match name:
case "ndim" | "size":
return types.get_type("int")
case "shape":
return types.tuple_of("int")
case "T":
return column
case _:
return None

View File

@@ -0,0 +1,400 @@
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import (
ColumnGroupBy,
ColumnType,
Function,
GenericType,
ParamSpec,
TopType,
Type,
TypeVar,
UnknownType,
unfold_type,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
column: ColumnType
column_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.column_expr, self.column)
class ColumnMethodRegistry(MethodRegistry[Call]):
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
"""Compute the result of an element-wise binary operation
This function delegates to the inner types for computing the resulting
type.
Args:
call (Call): the call that triggered this resolution
method (str): the method name
Returns:
ColumnType: the resulting column type
"""
column2: Optional[ColumnType] = None
col_type1: Type = call.column.type
new_column: Type = ColumnType(type=UnknownType())
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
if isinstance(unfolded_other, ColumnType):
column2 = unfolded_other
col_type2: Type = column2.type
new_inner_type = self.typer.result_of_binary_op(
location=call.location,
expr=call.call_expr,
left=(call.column_expr, col_type1),
right=(call.positional[0][0], col_type2),
method=method,
)
new_column = ColumnType(type=new_inner_type)
return new_column
def _element_wise(self, call: Call, method: str) -> Type:
# TODO: support add with scalar
# Build signature with new column type and generic operand
param_type: TypeVar = TypeVar(name="T", bound=None)
signature = GenericType(
name="add",
params=[param_type],
body=Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="other",
type=ColumnType(type=param_type),
required=True,
),
],
),
returns=self._element_binary_op(call, method),
),
)
# Map arguments and compute result type
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
if result.is_valid:
self._assert_same_length(
call.call_expr, call.column_expr, call.positional[0][0]
)
return result.result
@method("add", "__add__")
def add(self, call: Call) -> Type:
return self._element_wise(call, "__add__")
@method("sub", "__sub__")
def sub(self, call: Call) -> Type:
return self._element_wise(call, "__sub__")
@method("mul", "__mul__")
def mul(self, call: Call) -> Type:
return self._element_wise(call, "__mul__")
@method("div", "truediv", "__truediv__")
def truediv(self, call: Call) -> Type:
return self._element_wise(call, "__truediv__")
@method("floordiv", "__floordiv__")
def floordiv(self, call: Call) -> Type:
return self._element_wise(call, "__floordiv__")
@method("mod", "__mod__")
def mod(self, call: Call) -> Type:
return self._element_wise(call, "__mod__")
@method("pow", "__pow__")
def pow(self, call: Call) -> Type:
return self._element_wise(call, "__pow__")
@method("lt", "__lt__")
def lt(self, call: Call) -> Type:
return self._element_wise(call, "__lt__")
@method("gt", "__gt__")
def gt(self, call: Call) -> Type:
return self._element_wise(call, "__gt__")
@method("le", "__le__")
def le(self, call: Call) -> Type:
return self._element_wise(call, "__le__")
@method("ge", "__ge__")
def ge(self, call: Call) -> Type:
return self._element_wise(call, "__ge__")
@method("ne", "__ne__")
def ne(self, call: Call) -> Type:
return self._element_wise(call, "__ne__")
@method("eq", "__eq__")
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _aggregate(
self,
call: Call,
kwargs: list[Function.Parameter] = [],
*,
preserve_inner_type: bool = False,
) -> Type:
signature = Function(
params=ParamSpec(
kw=[
Function.Parameter(
pos=0,
name="axis",
type=TopType(),
required=False,
),
*kwargs,
],
),
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def mode(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Parameter(
pos=1,
name="ddof",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Parameter(
pos=1,
name="var",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def head(self, call: Call) -> Type:
signature = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
),
returns=call.column,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def tail(self, call: Call) -> Type:
signature = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
),
returns=call.column,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool")
function: Function = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="by",
type=TopType(),
required=False,
),
Function.Parameter(
pos=1,
name="level",
type=TopType(),
required=False,
),
],
kw=[
Function.Parameter(
pos=i + 2,
name=name,
type=bool_,
required=False,
)
for i, name in enumerate(
["as_index", "sort", "group_keys", "observed", "dropna"]
)
],
),
returns=ColumnGroupBy(column=call.column),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=function,
positional=call.positional,
keywords=call.keywords,
)
return result.result
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
func_name: str = "__midas_column_same_length__"
# Efficiently compute length
# https://stackoverflow.com/a/15943975/11109181
def len_of_col(col: ast.expr) -> ast.expr:
return ast.Call(
func=ast.Name(id="len"),
args=[
ast.Attribute(
value=col,
attr="index",
)
],
keywords=[],
)
self.assertions.define(
func_name,
ast.FunctionDef(
name=func_name,
args=ast.arguments(
posonlyargs=[],
args=[
ast.arg(arg="column1"),
ast.arg(arg="column2"),
],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Return(
value=ast.Compare(
left=len_of_col(ast.Name(id="column1")),
ops=[ast.Eq()],
comparators=[
len_of_col(ast.Name(id="column2")),
],
)
)
],
decorator_list=[],
),
)
self.assertions.add(
bound_expr=call_expr,
inputs=[column1, column2],
builder=lambda c1, c2: ast.Call(
func=ast.Name(id=func_name),
args=[c1, c2],
keywords=[],
),
message="Columns must have the same length",
)

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import (
ColumnGroupBy,
ColumnType,
DataFrameType,
FrameGroupBy,
Type,
UnknownType,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
groupby: FrameGroupBy
groupby_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.groupby_expr, self.groupby)
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
NAMED_ARGS: dict[str, str] = {
"numeric_only": "bool",
"skipna": "bool",
"engine": "str",
"engine_kwargs": "dict",
}
def _aggregate(self, call: Call, method: str) -> Type:
new_columns: list[DataFrameType.Column] = []
for column in call.groupby.frame.columns:
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
result_type: Type = self.typer.call_method(
location=call.location,
call_expr=call.call_expr,
obj=(call.groupby_expr, column_groupby),
method_name=method,
positional=call.positional,
keywords=call.keywords,
)
if not isinstance(result_type, ColumnType):
result_type = ColumnType(type=UnknownType())
new_columns.append(
DataFrameType.Column(
index=column.index,
name=column.name,
type=result_type,
)
)
return DataFrameType(columns=new_columns)
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(call, "kurt")
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call, "max")
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call, "mean")
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call, "median")
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call, "min")
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(call, "prod")
@method()
def std(self, call: Call) -> Type:
return self._aggregate(call, "std")
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call, "sum")
@method()
def var(self, call: Call) -> Type:
return self._aggregate(call, "var")

View File

@@ -0,0 +1,255 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
ColumnGroupBy,
ColumnType,
DataFrameType,
FrameGroupBy,
TupleType,
Type,
UnknownType,
)
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
class FrameManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
self.groupby_method_resolver: FrameGroupByMethodRegistry = (
FrameGroupByMethodRegistry(self.typer)
)
def assign(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
value_type: Type,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
return self.assign_column(reporter, location, frame, name, value_type)
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(index.value, str) for index in indices
):
names: list[str] = [cast(str, index.value) for index in indices]
if not isinstance(value_type, TupleType):
reporter.error(
location,
f"Cannot assign {type} to dataframe columns. Must be a tuple of columns",
)
return UnknownType()
if len(names) != len(value_type.items):
reporter.error(
location,
f"Wrong number of columns. Cannot assign {len(value_type.items)} to {len(names)} targets",
)
return UnknownType()
new_frame: Type = frame
for name, value in zip(names, value_type.items):
new_frame = self.assign_column(
reporter,
location,
new_frame,
name,
value,
)
if not isinstance(new_frame, DataFrameType):
return new_frame
return new_frame
case _:
reporter.error(
location, f"Invalid index type {index} on {frame} (assignment)"
)
return UnknownType()
def assign_column(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
name: str,
type: Type,
) -> Type:
if not isinstance(type, ColumnType):
reporter.error(
location,
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
)
return frame
return self._set_column(frame, name, type)
def get(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
return column
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(index.value, str) for index in indices
):
names: list[str] = [cast(str, index.value) for index in indices]
columns: list[ColumnType] = []
for name in names:
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
columns.append(column)
return TupleType(items=tuple(columns))
case _:
reporter.error(
location, f"Invalid index type {index} on {frame} (access)"
)
return UnknownType()
def groupby_get(
self,
reporter: FileReporter,
location: Location,
groupby: FrameGroupBy,
index: p.Expr,
) -> Type:
result: Type = self.get(reporter, location, groupby.frame, index)
match result:
case ColumnType():
result = ColumnGroupBy(column=result)
case TupleType(items=columns):
result = TupleType(
items=tuple(
ColumnGroupBy(column=cast(ColumnType, column))
for column in columns
)
)
return result
@classmethod
def _set_column(
cls, frame: DataFrameType, name: str, column: ColumnType
) -> DataFrameType:
new_columns: list[DataFrameType.Column] = []
index: int = len(frame.columns)
replace: bool = False
for i, col in enumerate(frame.columns):
if col.name == name:
index = i
replace = True
# TODO: check column type here to prevent changing it
new_columns.append(col)
new_col: DataFrameType.Column = DataFrameType.Column(
index=index,
name=name,
type=column,
)
if replace:
new_columns[index] = new_col
else:
new_columns.append(new_col)
return DataFrameType(columns=new_columns)
@classmethod
def _set_columns(
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
) -> DataFrameType:
for name, col in zip(names, columns):
frame = cls._set_column(frame, name, col)
return frame
@classmethod
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
for col in frame.columns:
if col.name == name:
return col.type
return None
@classmethod
def _get_columns(
cls, frame: DataFrameType, names: list[str]
) -> list[Optional[ColumnType]]:
return [cls._get_column(frame, name) for name in names]
def call(
self,
method: str,
location: Location,
call_expr: p.Expr,
frame: DataFrameType,
frame_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
call_expr=call_expr,
frame=frame,
frame_expr=frame_expr,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)
def groupby_call(
self,
method: str,
location: Location,
call_expr: p.Expr,
groupby: FrameGroupBy,
groupby_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: GroupByCall = GroupByCall(
location=location,
call_expr=call_expr,
groupby=groupby,
groupby_expr=groupby_expr,
positional=positional,
keywords=keywords,
)
return self.groupby_method_resolver.call(method, call)
def get_attribute(self, frame: DataFrameType, name: str) -> Optional[Type]:
types: TypesRegistry = self.typer.types
match name:
case "ndim" | "size":
return types.get_type("int")
case "shape":
return types.tuple_of("int", "int")
case _:
return None

View File

@@ -0,0 +1,479 @@
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import (
ColumnType,
DataFrameType,
FrameGroupBy,
Function,
OverloadedFunction,
ParamSpec,
TopType,
Type,
UnknownType,
unfold_type,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
frame: DataFrameType
frame_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.frame_expr, self.frame)
class FrameMethodRegistry(MethodRegistry[Call]):
def _get_method_result(
self,
call: Call,
column1: ColumnType,
column2: ColumnType,
method: str,
) -> ColumnType:
"""Get the result of calling a method on a column, passing a second
This function delegates to the main typer the resolution of the method
member, as well as computing the result type. Because we don't have any
AST expression for the individual columns, the frame expressions are
used instead.
Args:
call (Call): the call that triggered this resolution
column1 (ColumnType): the first column, i.e. left operand
column2 (ColumnType): the second column, i.e. right operand
method (str): the method name
Returns:
ColumnType: the resulting column.
If the operation is invalid / doesn't exist,
`ColumnType(type=UnknownType())` is returned
"""
result: Type = self.typer.result_of_binary_op(
location=call.location,
expr=call.call_expr,
left=(call.frame_expr, column1),
right=(call.positional[0][0], column2),
method=method,
)
if not isinstance(result, ColumnType):
return ColumnType(type=UnknownType())
return result
def _element_binary_op(self, call: Call, method: str) -> DataFrameType:
"""Compute the result of an element-wise binary operation
This function delegates to the matching columns for computing resulting
types. Any column only present in one of the frames is forwarded as a
generic `ColumnType(type=UnknownType())`. Columns only in the second
frame are append at the end of the schema.
Args:
call (Call): the call that triggered this resolution
method (str): the method name
Returns:
DataFrameType: the resulting frame type
"""
new_columns: list[DataFrameType.Column] = []
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
# Get map of operand's columns by name, if there is at least 1 operand, which is a dataframe
if len(call.positional) != 0:
operand: TypedExpr = call.positional[0]
unfolded_other: Type = unfold_type(operand[1])
if isinstance(unfolded_other, DataFrameType):
frame2 = unfolded_other
by_name = {
col.name: col for col in frame2.columns if col.name is not None
}
# Compute new schema:
# Step 1: for all columns in frame1:
# - if present in frame2 -> delegate operation to columns
# - if not -> add to schema as unknown
in_frame1: set[str] = set()
for column in call.frame.columns:
if column.name is not None:
in_frame1.add(column.name)
col_type1: ColumnType = column.type
col_type: ColumnType = ColumnType(type=UnknownType())
if column.name in by_name:
column2 = by_name[column.name]
col_type2: ColumnType = column2.type
col_type = self._get_method_result(call, col_type1, col_type2, method)
new_column = DataFrameType.Column(
index=column.index,
name=column.name,
type=col_type,
)
new_columns.append(new_column)
# Step 2: for all columns in frame2
# - if not in frame1 -> add to schema as unknown
if frame2 is not None:
for column in frame2.columns:
if column.name in in_frame1:
continue
new_columns.append(
DataFrameType.Column(
index=len(new_columns),
name=column.name,
type=ColumnType(type=UnknownType()),
)
)
return DataFrameType(columns=new_columns)
def _element_wise(self, call: Call, method: str) -> Type:
# TODO: support scalar, sequence, Series, dict operand
# Build signature with new schema and generic operand
signature = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="other",
type=DataFrameType(columns=[]),
required=True,
),
],
),
returns=self._element_binary_op(call, method),
)
# Map arguments and compute result type
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
if result.is_valid:
self._assert_same_length(
call.call_expr, call.frame_expr, call.positional[0][0]
)
return result.result
@method("add", "__add__")
def add(self, call: Call) -> Type:
return self._element_wise(call, "__add__")
@method("sub", "__sub__")
def sub(self, call: Call) -> Type:
return self._element_wise(call, "__sub__")
@method("mul", "__mul__")
def mul(self, call: Call) -> Type:
return self._element_wise(call, "__mul__")
@method("div", "truediv", "__truediv__")
def truediv(self, call: Call) -> Type:
return self._element_wise(call, "__truediv__")
@method("floordiv", "__floordiv__")
def floordiv(self, call: Call) -> Type:
return self._element_wise(call, "__floordiv__")
@method("mod", "__mod__")
def mod(self, call: Call) -> Type:
return self._element_wise(call, "__mod__")
@method("pow", "__pow__")
def pow(self, call: Call) -> Type:
return self._element_wise(call, "__pow__")
@method("lt", "__lt__")
def lt(self, call: Call) -> Type:
return self._element_wise(call, "__lt__")
@method("gt", "__gt__")
def gt(self, call: Call) -> Type:
return self._element_wise(call, "__gt__")
@method("le", "__le__")
def le(self, call: Call) -> Type:
return self._element_wise(call, "__le__")
@method("ge", "__ge__")
def ge(self, call: Call) -> Type:
return self._element_wise(call, "__ge__")
@method("ne", "__ne__")
def ne(self, call: Call) -> Type:
return self._element_wise(call, "__ne__")
@method("eq", "__eq__")
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _aggregate(self, call: Call, kwargs: list[Function.Parameter] = []) -> Type:
with_axis = Function(
params=ParamSpec(
kw=[
Function.Parameter(
pos=0,
name="axis",
type=self.types.get_type("int"),
required=False,
),
*kwargs,
],
),
returns=ColumnType(type=TopType()),
)
without_axis = Function(
params=ParamSpec(
kw=[
Function.Parameter(
pos=0,
name="axis",
type=self.types.get_type("None"),
required=True,
),
*kwargs,
],
),
returns=TopType(),
)
overload = OverloadedFunction(
overloads=[
with_axis,
without_axis,
]
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=overload,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def mode(self, call: Call) -> Type:
return self._aggregate(call)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Parameter(
pos=1,
name="ddof",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Parameter(
pos=1,
name="var",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def head(self, call: Call) -> Type:
signature = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
),
returns=call.frame,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def tail(self, call: Call) -> Type:
signature = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
),
returns=call.frame,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool")
function: Function = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="by",
type=TopType(),
required=False,
),
Function.Parameter(
pos=1,
name="level",
type=TopType(),
required=False,
),
],
kw=[
Function.Parameter(
pos=i + 2,
name=name,
type=bool_,
required=False,
)
for i, name in enumerate(
["as_index", "sort", "group_keys", "observed", "dropna"]
)
],
),
returns=FrameGroupBy(frame=call.frame),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=function,
positional=call.positional,
keywords=call.keywords,
)
return result.result
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
func_name: str = "__midas_frame_same_length__"
# Efficiently compute length
# https://stackoverflow.com/a/15943975/11109181
def len_of_df(df: ast.expr) -> ast.expr:
return ast.Call(
func=ast.Name(id="len"),
args=[
ast.Attribute(
value=df,
attr="index",
)
],
keywords=[],
)
self.assertions.define(
func_name,
ast.FunctionDef(
name=func_name,
args=ast.arguments(
posonlyargs=[],
args=[
ast.arg(arg="frame1"),
ast.arg(arg="frame2"),
],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Return(
value=ast.Compare(
left=len_of_df(ast.Name(id="frame1")),
ops=[ast.Eq()],
comparators=[len_of_df(ast.Name(id="frame2"))],
)
)
],
decorator_list=[],
),
)
self.assertions.add(
bound_expr=call_expr,
inputs=[frame1, frame2],
builder=lambda f1, f2: ast.Call(
func=ast.Name(id=func_name),
args=[f1, f2],
keywords=[],
),
message="DataFrames must have the same length",
)

View File

@@ -0,0 +1,100 @@
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
Protocol,
Self,
TypeVar,
)
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallDispatcher
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import Type, UnknownType
from midas.generator.collector import AssertionCollector
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
class _MethodRegistryMeta(type):
_methods: dict[str, Callable[..., Type]] = {}
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
):
new_class = super().__new__(cls, name, bases, namespace)
new_class._methods = {}
for attr in namespace.values():
if callable(attr) and hasattr(attr, "__method_names__"):
for name in attr.__method_names__: # type: ignore
new_class._methods[name] = attr # type: ignore
return new_class
class MethodCall(Protocol):
@property
def location(self) -> Location: ...
@property
def call_expr(self) -> p.Expr: ...
@property
def subject(self) -> TypedExpr: ...
T = TypeVar("T", bound=MethodCall)
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@property
def reporter(self) -> FileReporter:
return self.typer.reporter
@property
def types(self) -> TypesRegistry:
return self.typer.types
@property
def dispatcher(self) -> CallDispatcher[p.Expr]:
return self.typer.dispatcher
@property
def assertions(self) -> AssertionCollector:
return self.typer.assertions
def call(self, method: str, call: T) -> Type:
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
if func is None:
self.reporter.warning(
call.location, f"Unknown method {method} on {call.subject[1]}"
)
return UnknownType()
return func(self, call)
_Self = TypeVar("_Self", bound=MethodRegistry[Any])
Method = Callable[[_Self, T], Type]
def method(*names: str) -> Callable[[Method[_Self, T]], Method[_Self, T]]:
def wrapper(func: Method[_Self, T]) -> Method[_Self, T]:
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper

431
midas/checker/midas.py Normal file
View File

@@ -0,0 +1,431 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import midas.ast.midas as m
from midas.ast.location import Location
from midas.checker.builtins import define_builtins
from midas.checker.dispatcher import CallDispatcher, CallResult
from midas.checker.environment import Environment
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
Function,
GenericType,
ParamSpec,
Predicate,
Type,
TypeVar,
UnknownType,
)
from midas.checker.variance import VarianceInferrer
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
class ReturnException(Exception):
pass
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: m.Expr
type: Type
argument: Function.Parameter
@dataclass(frozen=True, kw_only=True)
class OverloadCandidate:
function: Function
mapped: list[MappedArgument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
self.logger: logging.Logger = logging.getLogger("MidasTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr](
self.types, self.reporter
)
self._local_variables: dict[str, TypeVar] = {}
self._predicate_params: dict[str, Type] = {}
self._current_name: Optional[str] = None
define_builtins(self.types)
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
self.process(builtins_path.read_text(), str(builtins_path))
self._bool: Type = self.get_type("bool")
self._preamble: Environment = Preamble(self.types)
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
self.dispatcher.set_reporter(reporter)
def process(self, source: str, path: Optional[str]):
reporter: FileReporter = self.reporter.for_file(path)
self.set_reporter(reporter)
lexer: MidasLexer = MidasLexer(source)
tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for error in parser.errors:
self.reporter.error(error.token.get_location(), error.message)
self.resolve(stmts)
def type_of(self, expr: m.Expr) -> Type:
type: Type = expr.accept(self)
return type
def get_type(self, name: str) -> Type:
"""Get a type from its name
Args:
name (str): the name of the type
Raises:
NameError: if the type is not defined
Returns:
Type: the type
"""
if name in self._local_variables:
return self._local_variables[name]
return self.types.get_type(name)
def get_variable(self, name: str) -> Type:
if name in self._predicate_params:
return self._predicate_params[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is not None:
return predicate.type
global_: Optional[Type] = self._preamble.get(name)
if global_ is not None:
return global_
raise NameError(f"Unknown variable '{name}'")
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements
Args:
stmts (list[m.Stmt]): the statements
"""
for stmt in stmts:
stmt.accept(self)
for name, type in self.types._types.items():
if isinstance(type, GenericType):
inferrer = VarianceInferrer(self.types)
self.types._types[name] = inferrer.infer(type)
def assert_bool(self, expr: m.Expr):
type: Type = self.type_of(expr)
if not self.types.is_subtype(type, self._bool):
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
params: list[TypeVar] = self._resolve_type_params(stmt.params)
type: Type = stmt.type.accept(self)
if len(params) != 0:
type = GenericType(name=name, params=params, body=type)
else:
type = DerivedType(name=name, type=type)
self.types.define_type(name, type)
self._local_variables.clear()
self._current_name = None
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
type: Type = stmt.type.accept(self)
self.types.define_type(name, type)
self._current_name = None
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._resolve_type_params(stmt.params)
base_name: str = stmt.name.lexeme
try:
_ = self.get_type(base_name)
except NameError:
self.reporter.error(stmt.name.get_location(), f"Unknown type '{base_name}'")
for member in stmt.members:
member_type: Type = member.type.accept(self)
self.types.define_member(
base_name,
member.name.lexeme,
member_type,
member.kind,
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
for spec in stmt.params:
for param in spec.mixed:
assert param.name is not None
self._predicate_params[param.name.lexeme] = param.type.accept(self)
type: Type = self.type_of(stmt.body)
params: list[ParamSpec] = [self._visit_param_spec(spec) for spec in stmt.params]
if not self._is_valid_predicate(type):
self.reporter.error(
stmt.body.location,
f"Predicate function body must evaluate to a boolean, got {type}",
)
if len(params) != 0:
type = self._bool
for spec in reversed(params):
type = Function(
params=spec,
returns=type,
)
self._predicate_params = {}
self.types.define_predicate(
stmt.name.lexeme,
Predicate(
type=type,
body=stmt.body,
alias=len(params) == 0,
),
)
def _is_valid_predicate(self, body: Type) -> bool:
match body:
case Function(returns=returns):
return self._is_valid_predicate(returns)
case _ if self.types.is_subtype(body, self._bool):
return True
case _:
return False
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
self.assert_bool(expr.left)
self.assert_bool(expr.right)
return self._bool
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator.lexeme}"
)
return UnknownType()
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def _visit_binary_expr(
self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str
) -> Type:
left: Type = self.type_of(left_expr)
right: Type = self.type_of(right_expr)
operation: Optional[Type] = self.types.lookup_member(left, method)
if operation is None:
self.reporter.error(
location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
result: CallResult = self.dispatcher.get_result(
location=location,
callee=operation,
positional=[(right_expr, right)],
keywords={},
)
return result.result
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator.lexeme}"
)
return UnknownType()
operand: Type = self.type_of(expr.right)
operation: Optional[Type] = self.types.lookup_member(operand, method)
if operation is None:
self.reporter.error(
expr.location,
f"Undefined operation {method} for {operand}",
)
return UnknownType()
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=operation,
positional=[],
keywords={},
)
return result.result
def visit_call_expr(self, expr: m.CallExpr) -> Type:
callee: Type = expr.callee.accept(self)
positional: list[tuple[m.Expr, Type]] = [
(arg, self.type_of(arg)) for arg in expr.arguments
]
keywords: dict[str, tuple[m.Expr, Type]] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=callee,
positional=positional,
keywords=keywords,
)
return result.result
def visit_get_expr(self, expr: m.GetExpr) -> Type:
object: Type = expr.expr.accept(self)
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
if member is None:
self.reporter.error(
expr.location, f"Unknown member '{expr.name.lexeme}' of {object}"
)
return UnknownType()
return member
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
return self.get_variable(expr.name.lexeme)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
match expr.value:
case bool(): # Must be before int
return self.types.get_type("bool")
case int():
return self.types.get_type("int")
case float():
return self.types.get_type("float")
case str():
return self.types.get_type("str")
case _:
self.reporter.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
return self.get_variable("_")
def visit_named_type(self, type: m.NamedType) -> Type:
name: str = type.name.lexeme
try:
return self.get_type(name)
except NameError:
msg: str = f"Undefined type {name}"
if self._current_name == name:
msg += ". Recursive types are not supported, use an extend block"
self.reporter.error(type.name.get_location(), msg)
return UnknownType()
def visit_generic_type(self, type: m.GenericType) -> Type:
type_: Type = type.type.accept(self)
args: list[Type] = [arg.accept(self) for arg in type.args]
try:
return self.types.apply_generic(type_, args)
except Exception as e:
self.reporter.error(type.location, f"Cannot apply generic type: {e}")
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
return ConstraintType(
type=type.type.accept(self),
constraint=type.constraint,
)
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
return ComplexType(
members={
member.name.lexeme: member.type.accept(self) for member in type.members
}
)
def visit_extension_type(self, type: m.ExtensionType) -> Type:
return ExtensionType(
base=type.base.accept(self),
extension=self.visit_complex_type(type.extension),
)
def visit_function_type(self, type: m.FunctionType) -> Type:
return Function(
params=self._visit_param_spec(type.params),
returns=type.returns.accept(self),
)
def _visit_param_spec(self, spec: m.ParamSpec) -> ParamSpec:
n_pos: int = len(spec.pos)
n_mixed: int = len(spec.mixed)
def process_param(
param: m.FunctionType.Parameter, i: int
) -> Function.Parameter:
return Function.Parameter(
pos=i,
name=param.name.lexeme if param.name is not None else str(i),
type=param.type.accept(self),
required=param.required,
)
return ParamSpec(
pos=[process_param(param, i) for i, param in enumerate(spec.pos)],
mixed=[
process_param(param, i + n_pos) for i, param in enumerate(spec.mixed)
],
kw=[
process_param(param, i + n_pos + n_mixed)
for i, param in enumerate(spec.kw)
],
)
def visit_frame_type(self, type: m.FrameType) -> Type:
def process_column(i: int, col: m.FrameType.Column) -> DataFrameType.Column:
return DataFrameType.Column(
index=i,
name=col.name.lexeme,
type=ColumnType(type=col.type.accept(self)),
)
return DataFrameType(
columns=[process_column(i, col) for i, col in enumerate(type.columns)]
)
def _resolve_type_params(self, params: list[m.TypeParam]):
vars: list[TypeVar] = []
for param in params:
name: str = param.name.lexeme
bound: Optional[Type] = None
if param.bound is not None:
bound = param.bound.accept(self)
var = TypeVar(name=name, bound=bound)
self._local_variables[name] = var
vars.append(var)
return vars

View File

@@ -0,0 +1,71 @@
import ast
from typing import Type
from midas.lexer.token import TokenType
PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = {
ast.Add: "__add__",
ast.Sub: "__sub__",
ast.Mult: "__mul__",
ast.MatMult: "__matmul__",
ast.Div: "__truediv__",
ast.Mod: "__mod__",
ast.Pow: "__pow__",
ast.LShift: "__lshift__",
ast.RShift: "__rshift__",
ast.BitOr: "__or__",
ast.BitXor: "__xor__",
ast.BitAnd: "__and__",
ast.FloorDiv: "__floordiv__",
}
PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
ast.Eq: "__eq__",
ast.NotEq: "__eq__",
ast.Lt: "__lt__",
ast.LtE: "__le__",
ast.Gt: "__gt__",
ast.GtE: "__ge__",
# ast.Is: "__is__",
# ast.IsNot: "__isnot__",
# ast.In: "__in__",
# ast.NotIn: "__notin__",
}
PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
ast.Invert: "__invert__",
# ast.Not: "",
ast.UAdd: "__pos__",
ast.USub: "__neg__",
}
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
TokenType.PLUS: "__add__",
TokenType.MINUS: "__sub__",
TokenType.STAR: "__mul__",
TokenType.SLASH: "__truediv__",
# TokenType.MODULO: "__mod__",
# TokenType.POW: "__pow__",
# ast.BitOr: "__or__",
# ast.BitXor: "__xor__",
# ast.BitAnd: "__and__",
# ast.FloorDiv: "__floordiv__",
TokenType.EQUAL_EQUAL: "__eq__",
TokenType.BANG_EQUAL: "__eq__",
TokenType.LESS: "__lt__",
TokenType.LESS_EQUAL: "__le__",
TokenType.GREATER: "__gt__",
TokenType.GREATER_EQUAL: "__ge__",
# ast.Is: "__is__",
# ast.IsNot: "__isnot__",
# ast.In: "__in__",
# ast.NotIn: "__notin__",
}
MIDAS_UNARY_METHODS: dict[TokenType, str] = {
# ast.Invert: "__invert__",
# ast.Not: "",
# TokenType.PLUS: "__pos__",
TokenType.MINUS: "__neg__",
}

213
midas/checker/preamble.py Normal file
View File

@@ -0,0 +1,213 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional
from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
Function,
GenericType,
OverloadedFunction,
ParamSpec,
TopType,
Type,
TypeVar,
UnitType,
)
@dataclass(frozen=True)
class Param:
name: str
type: Type
required: bool = True
class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None:
super().__init__()
self._types: TypesRegistry = types
self._python_funcs: dict[str, Callable[..., Any]] = {}
self._def_type_constructor("object", object)
self._def_type_constructor("float", float)
self._def_type_constructor("int", int)
self._def_type_constructor("bool", bool)
self._def_type_constructor("str", str)
self._def_function(
name="list",
pos=[Param("object", TopType())],
returns=self._list_of(TopType()),
py_function=list,
)
# TODO: use sink
self._def_function(
name="print",
pos=[Param("object", TopType(), required=False)],
returns=UnitType(),
py_function=print,
)
map_in = TypeVar(name="T", bound=None)
map_out = TypeVar(name="U", bound=None)
mapper = self._make_function(
name="MapTransform",
pos=[Param("v", map_in)],
returns=map_out,
)
self._def_function(
name="map",
pos=[
Param("transform", mapper),
Param(
"iterable",
self._list_of(map_in), # TODO: replace with Iterable[T]
),
],
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
type_vars=[map_in, map_out],
py_function=map,
)
self._def_function(
name="input",
pos=[Param("prompt", TopType(), required=False)],
returns=self._types.get_type("str"),
)
self._def_function(
name="len",
pos=[Param("object", TopType())],
returns=self._types.get_type("int"),
)
T = TypeVar(name="T", bound=None)
self._def_overloads(
name="max",
py_function=max,
signatures=[
(
[Param("arg1", T), Param("arg2", T)],
[],
[],
T,
[T],
),
([Param("iterable", self._list_of(T))], [], [], T, [T]),
],
)
self._def_overloads(
name="min",
py_function=min,
signatures=[
(
[Param("arg1", T), Param("arg2", T)],
[],
[],
T,
[T],
),
([Param("iterable", self._list_of(T))], [], [], T, [T]),
],
)
def _list_of(self, item_type: str | Type) -> Type:
return self._types.list_of(item_type)
def _def_type_constructor(
self, name: str, py_function: Optional[Callable[..., Any]] = None
):
# TODO: more specific arg types
self._def_function(
name=name,
pos=[Param("object", TopType(), required=False)],
returns=self._types.get_type(name),
py_function=py_function,
)
def _make_function(
self,
*,
name: str,
pos: list[Param] = [],
mixed: list[Param] = [],
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
) -> Type:
def map_params(params: list[Param], offset: int) -> list[Function.Parameter]:
return [
Function.Parameter(
pos=i + offset,
name=param.name,
type=param.type,
required=param.required,
)
for i, param in enumerate(params)
]
function = Function(
params=ParamSpec(
pos=map_params(pos, 0),
mixed=map_params(mixed, len(pos)),
kw=map_params(kw, len(pos) + len(mixed)),
),
returns=returns,
)
if len(type_vars) != 0:
function = GenericType(
name=name,
params=type_vars,
body=function,
)
return function
def _def_function(
self,
*,
name: str,
pos: list[Param] = [],
mixed: list[Param] = [],
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
py_function: Optional[Callable[..., Any]] = None,
):
function: Type = self._make_function(
name=name,
pos=pos,
mixed=mixed,
kw=kw,
returns=returns,
type_vars=type_vars,
)
self.define(name, function)
if py_function is not None:
self._python_funcs[name] = py_function
def _def_overloads(
self,
*,
name: str,
signatures: list[
tuple[list[Param], list[Param], list[Param], Type, list[TypeVar]]
],
py_function: Optional[Callable[..., Any]] = None,
):
overloads: list[Type] = []
for pos, mixed, kw, returns, type_vars in signatures:
overloads.append(
self._make_function(
name=name,
pos=pos,
mixed=mixed,
kw=kw,
returns=returns,
type_vars=type_vars,
)
)
function: Type = OverloadedFunction(overloads=overloads)
self.define(name, function)
if py_function is not None:
self._python_funcs[name] = py_function
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
return self._python_funcs.get(name)

1153
midas/checker/python.py Normal file

File diff suppressed because it is too large Load Diff

488
midas/checker/registry.py Normal file
View File

@@ -0,0 +1,488 @@
import logging
from dataclasses import dataclass
from typing import Optional
from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import (
AppliedType,
BaseType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
Predicate,
TopType,
TupleType,
Type,
TypeVar,
UnknownType,
Variance,
substitute_typevars,
)
@dataclass
class Member:
kind: MemberKind
type: Type
class TypesRegistry:
def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Member]] = {}
self._predicates: dict[str, Predicate] = {}
def get_type(self, name: str) -> Type:
"""Get a type from its name
Args:
name (str): the name of the type
Raises:
NameError: if the type is not defined
Returns:
Type: the type
"""
if name in self._types:
return self._types[name]
raise NameError(f"Undefined type {name}")
def define_type(self, name: str, type: Type) -> Type:
"""Define a type in the registry
Args:
name (str): the name of the type
type (Type): the type to define
Raises:
ValueError: if a type is already defined with that name
Returns:
Type: the defined type
"""
if name in self._types:
raise ValueError(f"Type {name} already defined")
self._types[name] = type
return type
def define_member(
self,
type_name: str,
member_name: str,
member_type: Type,
kind: MemberKind,
):
members: dict[str, Member] = self._members.setdefault(type_name, {})
if member_name in members:
current: Member = members[member_name]
if current.kind != kind:
self.logger.error(
f"Member '{member_name}' is already defined as a {current.kind},"
+ f" cannot define a {kind} with the same name"
)
return
if kind != MemberKind.METHOD:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name},"
+ " only methods can be overloaded"
)
return
combined: Type
match current.type:
case OverloadedFunction(overloads=overloads):
combined = OverloadedFunction(overloads=overloads + [member_type])
case _:
combined = OverloadedFunction(overloads=[current.type, member_type])
members[member_name] = Member(kind=current.kind, type=combined)
else:
members[member_name] = Member(kind=kind, type=member_type)
def define_predicate(self, name: str, predicate: Predicate):
if name in self._predicates:
raise ValueError(f"Predicate {name} already defined")
self._predicates[name] = predicate
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
if name1 in subtypes:
return True
for subtype in subtypes:
if self.is_builtin_subtype(name1, subtype):
return True
return False
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
For more details on the rules checked here, see TAPL Chap. 15-16-17
Args:
type1 (Type): the potential subtype
type2 (Type): the potential supertype
Returns:
bool: whether `type1` is a subtype of `type2`
"""
if type1 == type2:
return True
match (type1, type2):
case (_, TopType()):
return True
case (_, UnknownType()):
return True
case (TypeVar(bound=bound), _):
if bound is None:
return False
return self.is_subtype(bound, type2)
case (_, TypeVar(bound=bound)):
if bound is None:
return True
return self.is_subtype(type1, bound)
case (DerivedType(type=base1), _):
return self.is_subtype(base1, type2)
case (BaseType(name=name1), BaseType(name=name2)):
return self.is_builtin_subtype(name1, name2)
case (ComplexType(properties=props1), ComplexType(properties=props2)):
for k, t in props2.items():
if k not in props1:
return False
if not self.is_subtype(props1[k], t):
return False
return True
case (DataFrameType(columns=columns1), DataFrameType(columns=columns2)):
# TODO: check order?
by_name1: dict[str, DataFrameType.Column] = {
col.name: col for col in columns1 if col.name is not None
}
for col2 in columns2:
if col2.name not in by_name1:
return False
if not self.is_subtype(by_name1[col2.name].type, col2.type):
return False
return True
case (ColumnType(type=inner1), ColumnType(type=inner2)):
# TODO: invariant, replace ColumnType with simple GenericType
if not self.are_equivalent(inner1, inner2):
return False
return True
case (Function(), Function()):
return self.is_func_subtype(type1, type2)
case (ConstraintType(type=base1), _):
return self.is_subtype(base1, type2)
case (
AppliedType(name=name1, args=args1),
AppliedType(name=name2, args=args2),
) if (
name1 == name2
):
generic: Type = self.get_type(name1)
assert isinstance(generic, GenericType)
for param, arg1, arg2 in zip(generic.params, args1, args2):
variance: Variance = param.variance
if variance in {Variance.INVARIANT, Variance.COVARIANT}:
if not self.is_subtype(arg1, arg2):
return False
if variance in {Variance.INVARIANT, Variance.CONTRAVARIANT}:
if not self.is_subtype(arg2, arg1):
return False
return True
# TODO: verify legitimacy
case (AppliedType(body=body), _):
return self.is_subtype(body, type2)
return False
def are_equivalent(self, type1: Type, type2: Type) -> bool:
return self.is_subtype(type1, type2) and self.is_subtype(type2, type1)
# TODO: verify the logic in here
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
"""Check whether a function is a subtype of another
Args:
func1 (Function): the potential function subtype
func2 (Function): the potential function supertype
Returns:
bool: whether `func1` is a subtype of `func2`
"""
if not self.is_subtype(func1.returns, func2.returns):
return False
pos1: list[Function.Parameter] = func1.params.pos
mixed1: list[Function.Parameter] = func1.params.mixed
kw1: dict[str, Function.Parameter] = {
param.name: param for param in func1.params.kw
}
pos2: list[Function.Parameter] = func2.params.pos
mixed2: list[Function.Parameter] = func2.params.mixed
kw2: dict[str, Function.Parameter] = {
param.name: param for param in func2.params.kw
}
mixed_by_pos: dict[int, Function.Parameter] = {
param.pos: param for param in mixed2
}
mixed_by_name: dict[str, Function.Parameter] = {
param.name: param for param in mixed2
}
def is_arg_subtype(sub: Function.Parameter, sup: Function.Parameter) -> bool:
if not self.is_subtype(sub.type, sup.type):
return False
if not sup.required and sub.required:
return False
return True
for param1 in pos1:
param2: Function.Parameter
if param1.pos < len(pos2):
param2 = pos2[param1.pos]
elif param1.pos in mixed_by_pos:
param2 = mixed_by_pos[param1.pos]
elif not param1.required:
continue
else:
return False
if not is_arg_subtype(param2, param1):
return False
for name, param1 in kw1.items():
param2: Function.Parameter
if name in kw2:
param2 = kw2[name]
elif name in mixed_by_name:
param2 = mixed_by_name[name]
elif not param1.required:
continue
else:
return False
if not is_arg_subtype(param2, param1):
return False
for param1 in mixed1:
pos_param2: Optional[Function.Parameter] = None
kw_param2: Optional[Function.Parameter] = None
if param1.name in kw2:
kw_param2 = kw2[param1.name]
elif param1.name in mixed_by_name:
kw_param2 = mixed_by_name[param1.name]
if param1.pos < len(pos2):
pos_param2 = pos2[param1.pos]
elif param1.pos in mixed_by_pos:
pos_param2 = mixed_by_pos[param1.pos]
# No match in func2 and arg is required
if pos_param2 is None and kw_param2 is None and param1.required:
return False
# Matching keyword argument
if kw_param2 is not None and not is_arg_subtype(kw_param2, param1):
return False
# Matching positional argument
if pos_param2 is not None and not is_arg_subtype(pos_param2, param1):
return False
mixed_positions: set[int] = {param.pos for param in mixed1}
mixed_names: set[str] = {param.name for param in mixed1}
for param2 in pos2:
if not param2.required:
continue
if param2.pos >= len(pos1) and param2.pos not in mixed_positions:
return False
for name, param2 in kw2.items():
if not param2.required:
continue
if name not in kw1 and name not in mixed_names:
return False
for param2 in mixed2:
if param2.required:
continue
pos_match: bool = param2.pos < len(pos1) or param2.pos in mixed_positions
kw_match: bool = param2.name in kw1 or param2.name in mixed_names
if not pos_match or not kw_match:
return False
return True
def apply_generic(self, type: Type, args: list[Type]) -> Type:
match type:
case DerivedType(name=name, type=base):
return DerivedType(name=name, type=self.apply_generic(base, args))
case GenericType(name=name, params=type_vars, body=body):
n_args: int = len(args)
n_type_vars: int = len(type_vars)
if n_args < n_type_vars:
raise ValueError(
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
)
if n_args > n_type_vars:
raise ValueError(
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
)
substitutions: dict[str, Type] = {}
for arg, type_var in zip(args, type_vars):
if type_var.bound is not None and not self.is_subtype(
arg, type_var.bound
):
raise ValueError(
f"Type argument {arg} is not a subtype of {type_var.bound}"
)
substitutions[type_var.name] = arg
return AppliedType(
name=name,
args=args,
body=substitute_typevars(body, substitutions),
)
case BaseType(name="tuple"):
return TupleType(items=tuple(args))
case _:
raise ValueError(f"{type} is not a generic type")
def reduce_types(self, types: list[Type]) -> list[Type]:
"""Reduce a list of types to remove subtypes and only keep the highest types
Args:
types (list[Type]): the types to reduce
Returns:
list[Type]: the reduced list of types
"""
reduced: bool = True
keep: list[int] = list(range(len(types)))
while reduced:
reduced = False
for i, i1 in enumerate(keep):
type1: Type = types[i1]
for i2 in keep[i + 1 :]:
type2 = types[i2]
if self.is_subtype(type1, type2):
keep.remove(i1)
elif self.is_subtype(type2, type1):
keep.remove(i2)
else:
continue
reduced = True
break
return [types[i] for i in keep]
def lookup_member(self, type: Type, member_name: str) -> Optional[Type]:
match type:
case BaseType(name=name):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name].type
return None
case DerivedType(name=name, type=base):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name].type
return self.lookup_member(base, member_name)
case AppliedType(name=name, body=body, args=args):
generic: Type = self.get_type(name)
if not isinstance(generic, GenericType):
raise ValueError("AppliedType not derived from a GenericType")
substitutions = {
type_var.name: arg for arg, type_var in zip(args, generic.params)
}
if name in self._members:
if member_name in self._members[name]:
member_type: Type = self._members[name][member_name].type
return substitute_typevars(member_type, substitutions)
member_type2: Optional[Type] = self.lookup_member(body, member_name)
if member_type2 is not None:
member_type2 = substitute_typevars(member_type2, substitutions)
return member_type2
case ComplexType(members=members):
if member_name in members:
return members[member_name]
self.logger.debug(f"No member '{member_name}' in {type}")
return None
case ExtensionType(base=base, extension=ComplexType(members=members)):
if member_name in members:
return members[member_name]
self.logger.debug(
f"No member '{member_name}' on {type}, looking up in base"
)
return self.lookup_member(base, member_name)
case ConstraintType(type=base):
return self.lookup_member(base, member_name)
case TypeVar(bound=bound) if bound is not None:
return self.lookup_member(bound, member_name)
case UnknownType():
return UnknownType()
case _:
self.logger.debug(f"Can't get member on {type}")
return None
def lookup_predicate(self, name: str) -> Optional[Predicate]:
return self._predicates.get(name)
def _by_name_or_type(self, name_or_type: str | Type) -> Type:
if isinstance(name_or_type, str):
return self.get_type(name_or_type)
return name_or_type
def list_of(self, item_type: str | Type) -> Type:
list_ = self.get_type("list")
return self.apply_generic(list_, [self._by_name_or_type(item_type)])
def tuple_of(self, *item_types: str | Type) -> Type:
tuple_ = self.get_type("tuple")
return self.apply_generic(
tuple_,
[self._by_name_or_type(item_type) for item_type in item_types],
)
def dict_of(self, key_type: str | Type, value_type: str | Type) -> Type:
dict_ = self.get_type("dict")
return self.apply_generic(
dict_,
[
self._by_name_or_type(key_type),
self._by_name_or_type(value_type),
],
)

70
midas/checker/reporter.py Normal file
View File

@@ -0,0 +1,70 @@
from __future__ import annotations
from typing import Optional
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
class Reporter:
def __init__(self):
self.diagnostics: list[Diagnostic] = []
def report(
self,
path: Optional[str],
type: DiagnosticType,
location: Location,
message: str,
):
self.diagnostics.append(
Diagnostic(
file_path=path,
location=location,
type=type,
message=message,
)
)
def for_file(self, path: Optional[str]) -> FileReporter:
return FileReporter(self, path)
class FileReporter:
def __init__(self, base_reporter: Reporter, path: Optional[str]) -> None:
self.base_reporter: Reporter = base_reporter
self.path: Optional[str] = path
def for_file(self, path: Optional[str]) -> FileReporter:
return FileReporter(self.base_reporter, path)
def report(self, type: DiagnosticType, location: Location, message: str):
self.base_reporter.report(self.path, type, location, message)
def error(self, location: Location, message: str):
self.report(
type=DiagnosticType.ERROR,
location=location,
message=message,
)
def warning(self, location: Location, message: str):
self.report(
type=DiagnosticType.WARNING,
location=location,
message=message,
)
def info(self, location: Location, message: str):
self.report(
type=DiagnosticType.INFO,
location=location,
message=message,
)
def debug(self, location: Location, message: str):
self.report(
type=DiagnosticType.DEBUG,
location=location,
message=message,
)

244
midas/checker/resolver.py Normal file
View File

@@ -0,0 +1,244 @@
import midas.ast.python as p
class ResolverError(Exception): ...
class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
"""A variable assignment and reference resolver
This class keeps track of which scope a variable is defined in and which
scope is referred to when a variable is referenced
"""
def __init__(self):
self.locals: dict[p.Expr, int] = {}
self.scopes: list[dict[str, bool]] = [{}]
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
"""Resolve the given statements or expressions"""
for obj in objects:
obj.accept(self)
def begin_scope(self):
"""Begin a new scope inside the current one"""
self.scopes.append({})
def end_scope(self):
"""Close the current scope"""
self.scopes.pop()
def declare(self, name: str) -> None:
"""Declare a variable in the current scope
This method must be called *before* evaluating the variable initializer
Args:
name (str): the name of the variable
Raises:
ResolverError: if the variable has already been declared in the current scope
"""
if len(self.scopes) == 0:
return
scope: dict[str, bool] = self.scopes[-1]
if name in scope:
raise ResolverError(
f"A variable with the name {name} is already declared in this scope"
)
scope[name] = False
def define(self, name: str) -> None:
"""Define a variable in the current scope
This method must be called *after* evaluating the variable initializer
Args:
name (str): the name of the variable
"""
if len(self.scopes) == 0:
return
self.scopes[-1][name] = True
def resolve_local(self, expr: p.Expr, name: str) -> None:
"""Resolve a variable reference and store the scope distance
This method associates to the variable expression a number representing
the "distance" of the variable declaration, i.e. the number of scope
levels to go "up" to find the closest declaration for that variable.
Args:
expr (p.Expr): the variable expression
name (str): the name of the variable
"""
for i, scope in enumerate(reversed(self.scopes)):
if name in scope:
self.locals[expr] = i
return
def is_defined(self, name: str) -> bool:
for scope in self.scopes:
if name in scope:
return True
return False
def resolve_function(self, function: p.Function) -> None:
"""Resolve a function definition
This method creates a new scope for the function, resolves all the
parameter declarations and then the body.
Args:
function (p.Function): the function to resolve
"""
self.begin_scope()
for param in function.params.all:
self.declare(param.name)
self.define(param.name)
self.resolve(*function.body)
self.end_scope()
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
stmt.expr.accept(self)
def visit_function(self, stmt: p.Function) -> None:
# Declare before resolving body to allow recursion
self.declare(stmt.name)
self.define(stmt.name)
self.resolve_function(stmt)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
self.declare(stmt.name)
# NOTE: resolve type here?
self.define(stmt.name)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
self.resolve(stmt.value)
for target in stmt.targets:
self._visit_assign(target)
def _visit_assign(self, target: p.Expr):
match target:
case p.VariableExpr(name=name):
if not self.is_defined(name):
self.declare(name)
self.define(name)
target.accept(self)
case p.GetExpr():
target.accept(self)
case p.SubscriptExpr():
target.accept(self)
case _:
raise Exception(f"Unsupported assignment to {target}")
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
if stmt.value is not None:
self.resolve(stmt.value)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not resolved in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
self.resolve(stmt.test)
# Body
self.begin_scope()
self.resolve(*stmt.body)
self.end_scope()
# Else
self.begin_scope()
self.resolve(*stmt.orelse)
self.end_scope()
def visit_pass(self, stmt: p.Pass) -> None:
pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self.resolve(stmt.iterator)
self._visit_assign(stmt.target)
self.begin_scope()
self.resolve(*stmt.body)
self.end_scope()
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
pass
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
self.resolve(expr.right)
def visit_call_expr(self, expr: p.CallExpr) -> None:
self.resolve(expr.callee)
for arg in expr.arguments:
self.resolve(arg)
for arg in expr.keywords.values():
self.resolve(arg)
def visit_get_expr(self, expr: p.GetExpr) -> None:
self.resolve(expr.object)
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
pass
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
if len(self.scopes) != 0 and self.scopes[-1].get(expr.name) is False:
raise ResolverError(
f"Cannot use local variable '{expr.name}' in its own initializer"
) # aka. UnboundLocalError
self.resolve_local(expr, expr.name)
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self.resolve(expr.expr)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self.resolve(expr.test)
self.resolve(expr.if_true)
self.resolve(expr.if_false)
def visit_list_expr(self, expr: p.ListExpr) -> None:
for item in expr.items:
self.resolve(item)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
for key in expr.keys:
if key is not None:
self.resolve(key)
for value in expr.values:
self.resolve(value)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
self.resolve(expr.object)
self.resolve(expr.index)
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
if expr.lower is not None:
self.resolve(expr.lower)
if expr.upper is not None:
self.resolve(expr.upper)
if expr.step is not None:
self.resolve(expr.step)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
for item in expr.items:
self.resolve(item)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
pass

457
midas/checker/types.py Normal file
View File

@@ -0,0 +1,457 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Optional, assert_never, cast
import midas.ast.midas as m
from midas.ast.printer import MidasPrinter
@dataclass(frozen=True, kw_only=True)
class TopType:
def __str__(self) -> str:
return "Any"
@dataclass(frozen=True, kw_only=True)
class BaseType:
name: str
def __str__(self) -> str:
return self.name
@dataclass(frozen=True, kw_only=True)
class DerivedType:
name: str
type: Type
def __str__(self) -> str:
return self.name
@dataclass(frozen=True, kw_only=True)
class UnknownType:
def __str__(self) -> str:
return "<Unknown>"
@dataclass(frozen=True, kw_only=True)
class UnitType:
def __str__(self) -> str:
return "None"
@dataclass(frozen=True, kw_only=True)
class Function:
params: ParamSpec
returns: Type
def __str__(self) -> str:
return f"{self.params} -> {self.returns}"
@dataclass(frozen=True, kw_only=True)
class Parameter:
pos: int
name: str
type: Type
required: bool
def __str__(self) -> str:
opt: str = "" if self.required else "?"
return f"{self.name}: {self.type}{opt}"
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
pos: list[Function.Parameter] = field(default_factory=list)
mixed: list[Function.Parameter] = field(default_factory=list)
kw: list[Function.Parameter] = field(default_factory=list)
def __str__(self) -> str:
params: list[str] = []
if len(self.pos) != 0:
params += list(map(str, self.pos))
params.append("/")
if len(self.mixed) != 0:
params += list(map(str, self.mixed))
if len(self.kw) != 0:
params.append("*")
params += list(map(str, self.kw))
return f"({', '.join(params)})"
@dataclass(frozen=True, kw_only=True)
class OverloadedFunction:
overloads: list[Type]
def __str__(self) -> str:
return "<overloaded function>"
@dataclass(frozen=True, kw_only=True)
class ComplexType:
members: dict[str, Type]
def __str__(self) -> str:
props: list[str] = [f"{name}: {type}" for name, type in self.members.items()]
return f"{{{', '.join(props)}}}"
@dataclass(frozen=True, kw_only=True)
class ExtensionType:
base: Type
extension: ComplexType
def __str__(self) -> str:
return f"{self.base} & {self.extension}"
class Variance(StrEnum):
INVARIANT = "INVARIANT"
COVARIANT = "COVARIANT"
CONTRAVARIANT = "CONTRAVARIANT"
@dataclass(frozen=True, kw_only=True)
class TypeVar:
name: str
bound: Optional[Type]
variance: Variance = Variance.INVARIANT
def __str__(self) -> str:
variance: str = {
Variance.COVARIANT: "+",
Variance.CONTRAVARIANT: "-",
}.get(self.variance, "")
res: str = f"{variance}{self.name}"
if self.bound is not None:
res = f"{res} <: {self.bound}"
return res
@dataclass(frozen=True, kw_only=True)
class GenericType:
name: str
params: list[TypeVar]
body: Type
def __str__(self) -> str:
return f"{self.name}[{', '.join(map(str, self.params))}]"
@dataclass(frozen=True, kw_only=True)
class AppliedType:
name: str
args: list[Type]
body: Type
def __str__(self) -> str:
return f"{self.name}[{', '.join(map(str, self.args))}]"
@dataclass(frozen=True, kw_only=True)
class ConstraintType:
type: Type
constraint: m.Expr
def __str__(self) -> str:
printer = MidasPrinter()
return f"{self.type} where {printer.print(self.constraint)}"
@dataclass(frozen=True, kw_only=True)
class TupleType:
items: tuple[Type, ...]
def __str__(self) -> str:
return f"({', '.join(map(str, self.items))})"
@dataclass(frozen=True, kw_only=True)
class ColumnType:
type: Type
def __str__(self) -> str:
return f"Column[{self.type}]"
@dataclass(frozen=True, kw_only=True)
class DataFrameType:
columns: list[Column]
def __str__(self) -> str:
schema: list[str] = [f"{col.name}: {col.type}" for col in self.columns]
return f"Frame[{', '.join(schema)}]"
@dataclass(frozen=True, kw_only=True)
class Column:
index: int
name: Optional[str]
type: ColumnType
@dataclass(frozen=True, kw_only=True)
class FrameGroupBy:
frame: DataFrameType
def __str__(self) -> str:
return f"FrameGroupBy[{self.frame}]"
@dataclass(frozen=True, kw_only=True)
class ColumnGroupBy:
column: ColumnType
def __str__(self) -> str:
return f"ColumnGroupBy[{self.column}]"
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_parameter(param: Function.Parameter):
return Function.Parameter(
pos=param.pos,
name=param.name,
type=substitute_typevars(param.type, substitutions),
required=param.required,
)
def sub_param_spec(spec: ParamSpec):
return ParamSpec(
pos=list(map(sub_parameter, spec.pos)),
mixed=list(map(sub_parameter, spec.mixed)),
kw=list(map(sub_parameter, spec.kw)),
)
def sub_column(col: DataFrameType.Column):
return DataFrameType.Column(
index=col.index,
name=col.name,
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
)
match type:
case TopType():
return type
case BaseType(name=name) if name in substitutions:
return substitutions[name]
case BaseType():
return type
case DerivedType(name=name, type=type2):
return DerivedType(
name=name, type=substitute_typevars(type2, substitutions)
)
case Function(
params=params,
returns=returns,
):
return Function(
params=sub_param_spec(params),
returns=substitute_typevars(returns, substitutions),
)
case OverloadedFunction(overloads=overloads):
return OverloadedFunction(
overloads=[
substitute_typevars(overload, substitutions)
for overload in overloads
]
)
case ComplexType(members=members):
members2: dict[str, Type] = {
name: substitute_typevars(prop, substitutions)
for name, prop in members.items()
}
return ComplexType(members=members2)
case ExtensionType(base=base, extension=ComplexType(members=members)):
return ExtensionType(
base=substitute_typevars(base, substitutions),
extension=ComplexType(
members={
name: substitute_typevars(prop, substitutions)
for name, prop in members.items()
}
),
)
case AppliedType(name=name, args=args, body=body):
return AppliedType(
name=name,
args=[substitute_typevars(arg, substitutions) for arg in args],
body=substitute_typevars(body, substitutions),
)
case ConstraintType():
return ConstraintType(
type=substitute_typevars(type.type, substitutions),
constraint=type.constraint,
)
case TypeVar(name=name):
if name in substitutions:
return substitutions[name]
raise ValueError(f"Missing TypeVar substitution for {name}")
case GenericType(name=name, params=params, body=body):
params2: list[TypeVar] = []
for param in params:
param2: Type = substitute_typevars(param, substitutions)
if not isinstance(param2, TypeVar):
raise ValueError(
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
)
params2.append(param2)
return GenericType(
name=name,
params=params2,
body=substitute_typevars(body, substitutions),
)
case TupleType(items=items):
return TupleType(
items=tuple(substitute_typevars(item, substitutions) for item in items),
)
case ColumnType(type=items_type):
return ColumnType(
type=substitute_typevars(items_type, substitutions),
)
case DataFrameType(columns=columns):
return DataFrameType(
columns=list(map(sub_column, columns)),
)
case FrameGroupBy(frame=frame):
return FrameGroupBy(
frame=cast(DataFrameType, substitute_typevars(frame, substitutions))
)
case ColumnGroupBy(column=column):
return ColumnGroupBy(
column=cast(ColumnType, substitute_typevars(column, substitutions))
)
case UnknownType() | UnitType():
return type
case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
case _:
assert_never(type)
def unfold_type(type: Type) -> Type:
match type:
case DerivedType(type=ref_type):
return unfold_type(ref_type)
case _:
return type
def to_annotation(type: Type) -> str:
def _params_annotation(spec: ParamSpec) -> str:
if len(spec.kw) != 0:
return "..."
params: str = ", ".join(
to_annotation(param.type) for param in spec.pos + spec.mixed
)
return f"[{params}]"
match type:
case TopType():
return "Any"
case BaseType(name=name):
return name
case DerivedType(name=name):
return name
case UnknownType():
return "Any"
case UnitType():
return "None"
case Function(params=params, returns=returns):
params_annot: str = _params_annotation(params)
return f"Callable[{params_annot}, {to_annotation(returns)}]"
case OverloadedFunction():
return "Callable"
case ComplexType() | ExtensionType():
raise NotImplementedError
case TypeVar(name=name):
return name
case GenericType(name=name, params=params):
return f"{name}[{', '.join(map(to_annotation, params))}]"
case AppliedType(name=name, args=args):
return f"{name}[{', '.join(map(to_annotation, args))}]"
case ConstraintType():
return str(type)
case TupleType(items=items):
return f"Tuple[{', '.join(map(to_annotation, items))}]"
case ColumnType():
return "pd.Series"
case DataFrameType():
return "pd.DataFrame"
case FrameGroupBy():
return "pd.api.typing.DataFrameGroupBy"
case ColumnGroupBy():
return "pd.api.typing.SeriesGroupBy"
case _:
assert_never(type)
@dataclass(frozen=True, kw_only=True)
class Predicate:
type: Type
body: m.Expr
alias: bool
Type = (
TopType
| BaseType
| DerivedType
| UnknownType
| UnitType
| Function
| OverloadedFunction
| ComplexType
| ExtensionType
| TypeVar
| GenericType
| AppliedType
| ConstraintType
| TupleType
| ColumnType
| DataFrameType
| FrameGroupBy
| ColumnGroupBy
)

201
midas/checker/unifier.py Normal file
View File

@@ -0,0 +1,201 @@
import logging
from typing import Optional
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AppliedType,
ColumnType,
DataFrameType,
Function,
GenericType,
ParamSpec,
TopType,
Type,
TypeVar,
)
class UnificationError(Exception): ...
class Unifier:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.logger: logging.Logger = logging.getLogger("Unifier")
def unify_call(
self,
type: GenericType,
positional: list[Type],
keywords: dict[str, Type],
) -> Optional[Type]:
concrete_func: Function = Function(
params=ParamSpec(
pos=[
Function.Parameter(
pos=i,
name=str(i),
type=arg,
required=True,
)
for i, arg in enumerate(positional)
],
kw=[
Function.Parameter(
pos=len(positional) + i,
name=name,
type=arg,
required=True,
)
for i, (name, arg) in enumerate(keywords.items())
],
),
returns=TopType(), # TODO: use expected type
)
return self.unify_generic(type, concrete_func, match_return=False)
def unify_generic(
self,
template: GenericType,
concrete: Type,
match_return: bool = True,
) -> Optional[Type]:
substitutions: dict[str, Type]
try:
substitutions = self.match(template.body, concrete, match_return)
except UnificationError:
return None
args: list[Type] = []
for param in template.params:
if param.name not in substitutions:
return None
args.append(substitutions[param.name])
applied: Type = self.types.apply_generic(template, args)
return applied
def match(
self,
template: Type,
concrete: Type,
match_return: bool = True,
) -> dict[str, Type]:
# TODO: if concrete is Generic, record bound TypeVar. Then when merging
# substitutions, check that the constraint is respected
match (template, concrete):
case (TypeVar(name=name), _):
return {name: concrete}
case (
AppliedType(name=template_name, args=template_args),
AppliedType(name=concrete_name, args=concrete_args),
) if template_name == concrete_name and len(template_args) == len(
concrete_args
):
substitutions: dict[str, Type] = {}
for template_arg, concrete_arg in zip(template_args, concrete_args):
new_substistutions: dict[str, Type] = self.match(
template_arg, concrete_arg
)
substitutions = self.merge(substitutions, new_substistutions)
return substitutions
case (
DataFrameType(columns=template_columns),
DataFrameType(columns=concrete_columns),
) if len(template_columns) == len(concrete_columns):
substitutions: dict[str, Type] = {}
for template_column, concrete_column in zip(
template_columns, concrete_columns
):
if template_column.index != concrete_column or (
template_column.name != concrete_column.name
):
self.logger.debug(
f"Column mismatch: template={template_column}, concrete={concrete_column}"
)
raise UnificationError
new_substistutions: dict[str, Type] = self.match(
template_column.type, concrete_column.type
)
substitutions = self.merge(substitutions, new_substistutions)
return substitutions
case (ColumnType(type=template_column), ColumnType(type=concrete_column)):
return self.match(template_column, concrete_column)
case (Function(), Function()):
mapped: list[tuple[Function.Parameter, Function.Parameter]] = (
self.map_params(template, concrete)
)
substitutions: dict[str, Type] = {}
for template_arg, concrete_arg in mapped:
arg_subs: dict[str, Type] = self.match(
template_arg.type, concrete_arg.type
)
substitutions = self.merge(substitutions, arg_subs)
if match_return:
return_subs: dict[str, Type] = self.match(
template.returns, concrete.returns
)
substitutions = self.merge(substitutions, return_subs)
return substitutions
case _:
self.logger.debug(f"Can't match {concrete!r} with {template!r}")
return {}
def merge(self, subs1: dict[str, Type], subs2: dict[str, Type]) -> dict[str, Type]:
merged: dict[str, Type] = subs1.copy()
for k, v in subs2.items():
if k in merged and merged[k] != v:
self.logger.debug(
f"Substitution already defined for {k} with type {merged[k]}, got {v}"
)
raise UnificationError
merged[k] = v
return merged
def map_params(
self, func1: Function, func2: Function
) -> list[tuple[Function.Parameter, Function.Parameter]]:
pos1: list[Function.Parameter] = func1.params.pos
mixed1: list[Function.Parameter] = func1.params.mixed
kw1: list[Function.Parameter] = func1.params.kw
pos2: list[Function.Parameter] = func2.params.pos
mixed2: list[Function.Parameter] = func2.params.mixed
kw2: list[Function.Parameter] = func2.params.kw
mapped: list[tuple[Function.Parameter, Function.Parameter]] = []
by_pos2: dict[int, Function.Parameter] = {
param.pos: param for param in pos2 + mixed2
}
by_name2: dict[str, Function.Parameter] = {
param.name: param for param in mixed2 + kw2
}
for arg1 in pos1:
if (arg2 := by_pos2.get(arg1.pos)) is not None:
mapped.append((arg1, arg2))
for arg1 in mixed1:
# Match both positionally and by name, conflicts are caught
# when merging substitutions
if (arg2 := by_pos2.get(arg1.pos)) is not None:
mapped.append((arg1, arg2))
if (arg2 := by_name2.get(arg1.name)) is not None:
mapped.append((arg1, arg2))
for arg1 in kw1:
if (arg2 := by_name2.get(arg1.name)) is not None:
mapped.append((arg1, arg2))
return mapped

129
midas/checker/variance.py Normal file
View File

@@ -0,0 +1,129 @@
from typing import Literal, Optional, cast
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AppliedType,
ConstraintType,
Function,
GenericType,
OverloadedFunction,
Type,
TypeVar,
Variance,
)
Polarity = Literal[-1, 0, 1]
class Tracker:
def __init__(self, vars: list[TypeVar]) -> None:
self.vars: list[TypeVar] = vars
self.refs: dict[str, set[Polarity]] = {var.name: set() for var in self.vars}
def record(self, var: TypeVar, polarity: Polarity):
self.refs[var.name].add(polarity)
def get_updated_vars(self) -> list[TypeVar]:
return [
TypeVar(
name=var.name, bound=var.bound, variance=self.get_variance(var.name)
)
for var in self.vars
]
def get_variance(self, name: str) -> Variance:
refs: set[Polarity] = self.refs[name]
if refs == {-1}:
return Variance.CONTRAVARIANT
if refs == {1}:
return Variance.COVARIANT
return Variance.INVARIANT
def __contains__(self, item: TypeVar | str):
if isinstance(item, TypeVar):
return item.name in self
return item in self.refs
class VarianceInferrer:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.tracker: Tracker = Tracker([])
def infer(self, type: GenericType) -> GenericType:
self.tracker = Tracker(type.params)
self.walk(type.body, 1, type.name)
members: dict[str, Member] = self.types._members.get(type.name, {})
for name, member in members.items():
self.walk(member.type, 1, type.name, [f"member:'{name}'"])
return GenericType(
name=type.name,
params=self.tracker.get_updated_vars(),
body=type.body,
)
def walk(
self,
type: Type,
polarity: Polarity,
base_name: str,
path: Optional[list[str]] = None,
):
if path is None:
path = []
match type:
# Arguments are negative positions -> flip polarity
# Return is positive position -> keep polarity
case Function(params=spec):
all_params: list[Function.Parameter] = spec.pos + spec.mixed + spec.kw
for param in all_params:
self.walk(
param.type,
-polarity,
base_name,
path + [f"param:'{param.name}'"],
)
self.walk(type.returns, polarity, base_name, path + ["return"])
# Walk all overloads
case OverloadedFunction(overloads=overloads):
for overload in overloads:
self.walk(overload, polarity, base_name, path)
# If same name as root generic -> skip
# Get inferred variance of parameters and multiply with current
# polarity to recurse through arguments
case AppliedType(name=name, args=args):
# TODO: handle mutually recursive types
if name == base_name:
return
generic: Type = self.types.get_type(name)
assert isinstance(generic, GenericType)
params: list[TypeVar] = generic.params
polarities: dict[Variance, Polarity] = {
Variance.INVARIANT: 0,
Variance.COVARIANT: 1,
Variance.CONTRAVARIANT: -1,
}
for param, param in zip(args, params):
param_polarity: Polarity = polarities[param.variance]
self.walk(
param,
cast(Polarity, polarity * param_polarity),
base_name,
path + [f"applied:'{name}'"],
)
# Walk base type
case ConstraintType(type=base):
self.walk(base, polarity, base_name, path + ["constraint"])
# Reached end
# If tracked, record polarity
case TypeVar():
if type in self.tracker:
self.tracker.record(type, polarity)

0
midas/cli/__init__.py Normal file
View File

41
midas/cli/ansi.py Normal file
View File

@@ -0,0 +1,41 @@
class Ansi:
CTRL = "\x1b["
RESET = CTRL + "0m"
BOLD = CTRL + "1m"
DIM = CTRL + "2m"
ITALIC = CTRL + "3m"
UNDERLINE = CTRL + "4m"
BLACK = 0
RED = 1
GREEN = 2
YELLOW = 3
BLUE = 4
MAGENTA = 5
CYAN = 6
WHITE = 7
BRIGHT_BLACK = 60
BRIGHT_RED = 61
BRIGHT_GREEN = 62
BRIGHT_YELLOW = 63
BRIGHT_BLUE = 64
BRIGHT_MAGENTA = 65
BRIGHT_CYAN = 66
BRIGHT_WHITE = 67
@classmethod
def FG(cls, col: int) -> str:
return f"{cls.CTRL}{30 + col}m"
@classmethod
def BG(cls, col: int) -> str:
return f"{cls.CTRL}{40 + col}m"
@classmethod
def FG_RGB(cls, r: int, g: int, b: int) -> str:
return f"{cls.CTRL}38;2;{r};{g};{b}m"
@classmethod
def BG_RGB(cls, r: int, g: int, b: int) -> str:
return f"{cls.CTRL}48;2;{r};{g};{b}m"

View File

@@ -0,0 +1,9 @@
from .check import check as check
from .compile import compile as compile
from .format import format as format
from .highlight import highlight as highlight
from .parse import parse as parse
from .registry import dump_registry as dump_registry
from .stubs import stubs as stubs
from .types import types as types
from .validate import validate as validate

View File

@@ -0,0 +1,41 @@
# **Run type checker and report diagnostics**
# ```shell
# midas check <file.py> [--types <file.midas>]
# ```
from pathlib import Path
from typing import Optional, TextIO
import click
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic
from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter
@click.command(help="Run type checker and report diagnostics")
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-l", "--highlight", type=click.File("w"))
def check(
file: TextIO,
types: tuple[TextIO],
highlight: Optional[TextIO],
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
checker.type_check(source_path)
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
printer = DiagnosticPrinter()
printer.print_all(diagnostics)
if highlight is not None:
source: str = file.read()
highlighter = DiagnosticsHighlighter(source)
highlighter.highlight(diagnostics)
highlighter.dump(highlight)

View File

@@ -0,0 +1,51 @@
# **Compile source**
# ```shell
# midas compile <file.py> [--types <file.midas>] [-o <output>] [--assertions|--strict|--no-checks]
# ```
import sys
from pathlib import Path
from typing import Optional, TextIO
import click
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.cli.utils import DiagnosticPrinter
from midas.generator.generator import Generator
from midas.utils import TypedAST
@click.command(help="Compile source")
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-s", "--stubs", type=str, multiple=True)
@click.option("--ignore-errors", is_flag=True)
def compile(
file: TextIO,
types: tuple[TextIO],
stubs: tuple[str],
ignore_errors: bool,
):
source: str = file.read()
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
type_files: list[tuple[Path, Optional[str]]] = []
for i, types_file in enumerate(types):
in_path: Path = Path(types_file.name).resolve()
checker.import_midas(in_path)
type_files.append((in_path, stubs[i] if i < len(stubs) else None))
typed_ast: TypedAST = checker.type_check_source(source, str(source_path))
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
printer = DiagnosticPrinter()
printer.print_all(diagnostics)
if not ignore_errors and any(
map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)
):
sys.exit(1)
generator = Generator(workdir=source_path.parent, types=checker.types)
generator.generate(typed_ast, source_path, type_files=type_files)

View File

@@ -0,0 +1,25 @@
from typing import TextIO
import click
import midas.ast.midas as m
from midas.ast.printer import MidasPrinter
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
@click.command(help="Parse and pretty print a Midas file")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def format(file: TextIO, output: TextIO):
source: str = file.read()
printer = MidasPrinter()
lexer = MidasLexer(source, file=file.name)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
output.write(printer.print(stmt) + "\n")

View File

@@ -0,0 +1,66 @@
import ast
from typing import TextIO
import click
import midas.ast.midas as m
import midas.ast.python as p
from midas.cli.highlighter import (
Highlighter,
LocatableToken,
MidasHighlighter,
PythonHighlighter,
)
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token, TokenType
from midas.parser.midas import MidasParser
from midas.parser.python import PythonParser
def highlight_python(source: str, path: str) -> Highlighter:
tree: ast.Module = ast.parse(source, filename=path)
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
highlighter = PythonHighlighter(source)
for stmt in stmts:
highlighter.highlight(stmt)
return highlighter
def highlight_midas(source: str, path: str) -> Highlighter:
lexer = MidasLexer(source, file=path)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
highlighter = MidasHighlighter(source)
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
highlighter.highlight(stmt)
for token in tokens:
if token.type == TokenType.COMMENT:
highlighter.wrap(LocatableToken(token), "comment")
elif token.is_keyword:
highlighter.wrap(LocatableToken(token), "keyword")
return highlighter
@click.command(
help="Parse a Python or Midas file and produce a highlighted version showing AST node types inline",
short_help="Parse and highlight a Python or Midas file",
)
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def highlight(output: TextIO, file: TextIO):
source: str = file.read()
highlighter: Highlighter
if file.name.endswith(".py"):
highlighter = highlight_python(source, file.name)
elif file.name.endswith(".midas"):
highlighter = highlight_midas(source, file.name)
else:
raise ValueError("Unsupported file type")
highlighter.dump(output)

View File

@@ -0,0 +1,66 @@
# **Parse and pretty-print AST**
# ```shell
# midas parse <file.midas / file.py>
# ```
import ast
from typing import TextIO
import click
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
from midas.parser.python import PythonParser
def dump_python_ast(tree: ast.Module) -> str:
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
printer = PythonAstPrinter()
dump: str = ""
for stmt in stmts:
dump += printer.print(stmt)
dump += "\n"
return dump
def dump_midas_ast(source: str, filename: str) -> str:
lexer = MidasLexer(source, file=filename)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
if len(parser.errors) != 0:
for err in parser.errors:
print(err.get_report())
raise RuntimeError("A parsing error occurred")
printer = MidasAstPrinter()
dump: str = ""
for stmt in stmts:
dump += printer.print(stmt)
dump += "\n"
return dump
@click.command(help="Parse a Python or Midas file and pretty-print its AST")
@click.argument("file", type=click.File("r"))
@click.option("--raw", is_flag=True)
def parse(file: TextIO, raw: bool):
source: str = file.read()
dump: str
if file.name.endswith(".py"):
tree: ast.Module = ast.parse(source, filename=file.name)
if raw:
dump = ast.dump(tree, indent=4)
else:
dump = dump_python_ast(tree)
elif file.name.endswith(".midas"):
dump = dump_midas_ast(source, file.name)
else:
raise ValueError("Unsupported file type")
click.echo(dump)

View File

@@ -0,0 +1,66 @@
# **Dump types registry**
# ```shell
# midas dump-registry [--types <file.midas>]
# ```
from pathlib import Path
from typing import TextIO
import click
from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.registry import Member
from midas.checker.types import AppliedType, BaseType, DerivedType, GenericType, Type
def base_type(type: Type) -> Type:
match type:
case BaseType():
return type
case DerivedType(type=base):
return base
case AppliedType(body=body):
return body
case GenericType(body=body):
return body
case _:
return type
@click.command(help="Dump types registry")
@click.option("-t", "--types", type=click.File("r"), multiple=True)
def dump_registry(
types: tuple[TextIO],
):
checker = TypeChecker()
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
print("##### Types #####")
for name, type in checker.types._types.items():
members: dict[str, Member] = checker.types._members.get(name, {})
params: str = ""
if isinstance(type, GenericType):
params = ", ".join(map(str, type.params))
params = f"[{params}]"
print(f"{name}{params} = {base_type(type)}")
if len(members) != 0:
print(" " * 4 + "Members:")
for member_name, member in members.items():
kind: str = member.kind.name
print(" " * 8 + f"({kind:8}) {member_name}: {member.type}")
print("##### Predicates #####")
printer = MidasPrinter()
for name, predicate in checker.types._predicates.items():
body: str = printer.print(predicate.body)
if predicate.alias:
print(f"{name}: {predicate.type} = {body}")
else:
print(f"{name}{predicate.type}:")
body = "\n".join(
" " + ("return " if i == 0 else "") + line
for i, line in enumerate(body.split("\n"))
)
print(body)

View File

@@ -0,0 +1,66 @@
import ast
import time
from pathlib import Path
from typing import Optional, TextIO
import black
import click
from watchdog.events import DirModifiedEvent, FileModifiedEvent, FileSystemEventHandler
from watchdog.observers import Observer
from midas.checker.checker import TypeChecker
from midas.generator.stubs import StubsGenerator
def generate_stubs(in_path: Path, out_path: Path):
checker = TypeChecker()
checker.import_midas(in_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output: str = ast.unparse(module)
output = black.format_str(output, mode=black.Mode(is_pyi=True))
out_path.write_text(output)
class Handler(FileSystemEventHandler):
def __init__(self, in_path: Path, out_path: Path) -> None:
super().__init__()
self.in_path: Path = in_path
self.out_path: Path = out_path
def on_modified(self, event: DirModifiedEvent | FileModifiedEvent) -> None:
generate_stubs(self.in_path, self.out_path)
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"))
@click.option("-w", "--watch", is_flag=True)
def stubs(
file: TextIO,
output: Optional[TextIO],
watch: bool,
):
source_path: Path = Path(file.name).resolve()
out_path: Path = source_path.with_suffix(".pyi")
if output is not None:
out_path = Path(output.name).resolve()
generate_stubs(source_path, out_path)
if watch:
print(f"Watching {source_path}...")
print("Press CTRL+C to stop")
handler = Handler(source_path, out_path)
observer = Observer()
observer.schedule(handler, str(source_path))
observer.start()
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
observer.stop()
observer.join()

View File

@@ -0,0 +1,52 @@
# **Print judgements**
# ```shell
# midas types <file.py> [--types <file.midas>]
# ```
from pathlib import Path
from typing import Optional, TextIO
import click
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter
@click.command(help="Print typing judgements")
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-l", "--highlight", type=click.File("w"))
def types(
file: TextIO,
types: tuple[TextIO],
highlight: Optional[TextIO],
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
checker.type_check(source_path)
diagnostics: list[Diagnostic] = []
for expr, type in checker.python_typer.judgements:
diagnostics.append(
Diagnostic(
file_path=str(source_path),
location=expr.location,
type=DiagnosticType.INFO,
message=f"Type: {type}",
)
)
diagnostics.extend(checker.diagnostics)
printer = DiagnosticPrinter()
printer.print_all(diagnostics)
if highlight is not None:
source: str = file.read()
highlighter = DiagnosticsHighlighter(source)
highlighter.highlight(diagnostics)
highlighter.dump(highlight)

View File

@@ -0,0 +1,37 @@
# **Validate midas definitions**
# ```shell
# midas validate <file.midas>
# ```
from pathlib import Path
from typing import Optional, TextIO
import click
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic
from midas.cli.highlighter import DiagnosticsHighlighter
from midas.cli.utils import DiagnosticPrinter
@click.command(help="Validate Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-l", "--highlight", type=click.File("w"))
def validate(
file: TextIO,
highlight: Optional[TextIO],
):
source_path: Path = Path(file.name).resolve()
checker = TypeChecker()
checker.import_midas(source_path)
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
printer = DiagnosticPrinter()
printer.print_all(diagnostics)
if highlight is not None:
source: str = file.read()
highlighter = DiagnosticsHighlighter(source)
highlighter.highlight(diagnostics)
highlighter.dump(highlight)

58
midas/cli/highlight.css Normal file
View File

@@ -0,0 +1,58 @@
html,
body {
margin: 0;
font-size: 14pt;
}
* {
box-sizing: border-box;
}
#code {
display: flex;
flex-direction: column;
font-family: monospace;
white-space: pre-wrap;
}
.line {
display: flex;
&:nth-child(odd) {
background-color: rgb(247, 247, 247);
}
.no {
width: 4em;
text-align: right;
padding: 0.2em 0.4em;
border-right: solid black 1px;
flex-shrink: 0;
}
.txt {
flex-grow: 1;
padding: 0.2em 0.8em;
}
}
span {
--col: transparent;
--opacity: 0.1;
--border: 0px;
background-color: rgba(var(--col), var(--opacity));
outline: solid rgb(var(--col)) var(--border);
outline-offset: 2px;
border-radius: 2px;
&:hover:not(:has(*:hover)) {
--opacity: 0.8;
--border: 2px;
z-index: 10;
}
&.keyword {
color: rgb(211, 72, 9);
pointer-events: none;
}
}

374
midas/cli/highlighter.py Normal file
View File

@@ -0,0 +1,374 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Generic, Optional, Protocol, TextIO, TypeVar
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic
from midas.lexer.token import Token
H = TypeVar("H", bound="Highlighter", contravariant=True)
class Highlightable(Protocol, Generic[H]):
def accept(self, visitor: H): ...
class Locatable(Protocol):
@property
@abstractmethod
def location(self) -> Optional[Location]: ...
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
class Highlighter(ABC):
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
EXTRA_CSS_PATH: Optional[Path] = None
def __init__(self, source: str) -> None:
self.source: str = source
self.lines: list[str] = self.source.splitlines()
self.openings: dict[tuple[int, int], list[str]] = {}
self.closings: dict[tuple[int, int], list[str]] = {}
def format_css(self, path: Path) -> list[str]:
css: str = path.read_text()
css = "\n".join((" " + line).rstrip() for line in css.splitlines())
return [
" <style>",
css,
" </style>",
]
def dump(self, buf: TextIO):
base_css: list[str] = self.format_css(self.BASE_CSS_PATH)
extra_css: list[str] = (
self.format_css(self.EXTRA_CSS_PATH)
if self.EXTRA_CSS_PATH is not None
else []
)
lines: list[str] = [
"<!DOCTYPE html>",
'<html lang="en">',
"<head>",
' <meta charset="UTF-8">',
' <meta name="viewport" content="width=device-width, initial-scale=1.0">',
" <title>Highlighted file</title>",
*base_css,
*extra_css,
"</head>",
"<body>",
' <div id="code">',
]
for l, line in enumerate(self.lines):
lineno: int = l + 1
line_buf: str = (
f'<div class="line" id="l{lineno}"><div class="no">{lineno}</div><div class="txt">'
)
for c, char in enumerate(line):
pos: tuple[int, int] = (lineno, c)
closings: list[str] = self.closings.get(pos, [])
openings: list[str] = self.openings.get(pos, [])
line_buf += "".join(closings + openings)
line_buf += char
line_buf += "".join(self.closings.get((lineno, len(line)), []))
line_buf += "</div></div>"
lines.append(" " + line_buf)
lines.extend(
[
" </div>",
"</body>",
"</html>",
]
)
buf.write("\n".join(lines))
def wrap(self, node: Locatable, cls: str, message: Optional[str] = None):
if node.location is None:
return
if node.location.end_lineno is None or node.location.end_col_offset is None:
return
start_pos: tuple[int, int] = (node.location.lineno, node.location.col_offset)
end_pos: tuple[int, int] = (
node.location.end_lineno,
node.location.end_col_offset,
)
opening: str = f'<span class="{cls}" title="{cls}">'
closing: str = "</span>"
if message is not None:
opening = f'<span class="with-msg">{opening}'
closing = f'{closing}<span class="message">{message}</span></span>'
self.openings.setdefault(start_pos, []).append(opening)
self.closings.setdefault(end_pos, []).insert(0, closing)
if start_pos[0] != end_pos[0]:
for l in range(start_pos[0], end_pos[0]):
c: int = len(self.lines[l - 1])
self.closings.setdefault((l, c), []).insert(0, closing)
self.openings.setdefault((l + 1, 0), []).append(opening)
class PythonHighlighter(
Highlighter,
p.MidasType.Visitor[None],
p.Stmt.Visitor[None],
p.Expr.Visitor[None],
):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_python.css"
def highlight(self, node: Highlightable[PythonHighlighter]):
node.accept(self)
def visit_base_type(self, node: p.BaseType) -> None:
self.wrap(node, "base-type")
for arg in node.args:
self.wrap(arg, "arg")
arg.accept(self)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self.wrap(node, "constraint-type")
node.type.accept(self)
def visit_frame_column(self, node: p.FrameColumn) -> None:
self.wrap(node, "frame-column")
if node.type is not None:
node.type.accept(self)
def visit_frame_type(self, node: p.FrameType) -> None:
self.wrap(node, "frame-type")
for column in node.columns:
column.accept(self)
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
stmt.expr.accept(self)
def visit_function(self, stmt: p.Function) -> None:
self.wrap(stmt, "function")
self._highlight_param_spec(stmt.params)
for body_stmt in stmt.body:
body_stmt.accept(self)
def _highlight_param_spec(self, spec: p.ParamSpec) -> None:
for param in spec.all:
self._highlight_function_param(param)
def _highlight_function_param(self, param: p.Function.Parameter) -> None:
self.wrap(param, "parameter")
if param.type is not None:
param.type.accept(self)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
stmt.type.accept(self)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
for target in stmt.targets:
target.accept(self)
stmt.value.accept(self)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
self.wrap(stmt, "return")
if stmt.value is not None:
stmt.value.accept(self)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self.wrap(stmt, "if")
stmt.test.accept(self)
for body_stmt in stmt.body:
body_stmt.accept(self)
for else_stmt in stmt.orelse:
else_stmt.accept(self)
def visit_pass(self, stmt: p.Pass) -> None:
pass
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self.wrap(stmt, "for")
stmt.iterator.accept(self)
stmt.target.accept(self)
for body_stmt in stmt.body:
body_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
def visit_call_expr(self, expr: p.CallExpr) -> None:
self.wrap(expr, "call")
expr.callee.accept(self)
for arg in expr.arguments:
arg.accept(self)
for arg in expr.keywords.values():
arg.accept(self)
def visit_get_expr(self, expr: p.GetExpr) -> None: ...
def visit_literal_expr(self, expr: p.LiteralExpr) -> None: ...
def visit_variable_expr(self, expr: p.VariableExpr) -> None: ...
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
def visit_list_expr(self, expr: p.ListExpr) -> None:
for item in expr.items:
item.accept(self)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
for key in expr.keys:
if key is not None:
key.accept(self)
for value in expr.values:
value.accept(self)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
expr.object.accept(self)
expr.index.accept(self)
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
if expr.lower is not None:
expr.lower.accept(self)
if expr.upper is not None:
expr.upper.accept(self)
if expr.step is not None:
expr.step.accept(self)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
for item in expr.items:
item.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
def highlight(self, node: Highlightable[MidasHighlighter]):
node.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self.wrap(stmt, "type-stmt")
self.wrap(LocatableToken(stmt.name), "type-name")
stmt.type.accept(self)
def visit_member_stmt(self, stmt: m.MemberStmt) -> None:
self.wrap(stmt, "member")
stmt.type.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self.wrap(stmt, "extend")
for member in stmt.members:
member.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name")
for spec in stmt.params:
self._visit_param_spec(spec)
stmt.body.accept(self)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr")
expr.left.accept(self)
expr.right.accept(self)
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
self.wrap(expr, "binary-expr")
expr.left.accept(self)
expr.right.accept(self)
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
self.wrap(expr, "unary-expr")
expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self.wrap(expr, "call-expr")
expr.callee.accept(self)
for arg in expr.arguments:
arg.accept(self)
for arg in expr.keywords.values():
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.wrap(expr, "get-expr")
expr.expr.accept(self)
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
self.wrap(expr, "variable")
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_named_type(self, type: m.NamedType) -> None:
self.wrap(type, "named-type")
def visit_generic_type(self, type: m.GenericType) -> None:
self.wrap(type, "generic-type")
type.type.accept(self)
for arg in type.args:
arg.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self.wrap(type, "constraint-type")
type.type.accept(self)
type.constraint.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None:
self.wrap(type, "complex-type")
for member in type.members:
member.accept(self)
def visit_function_type(self, type: m.FunctionType) -> None:
self.wrap(type, "function")
self._visit_param_spec(type.params)
type.returns.accept(self)
def visit_extension_type(self, type: m.ExtensionType) -> None:
self.wrap(type, "extension")
type.base.accept(self)
type.extension.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
for param in spec.pos + spec.mixed + spec.kw:
param.type.accept(self)
def visit_frame_type(self, type: m.FrameType) -> None:
self.wrap(type, "frame")
for column in type.columns:
self._visit_frame_column(column)
def _visit_frame_column(self, column: m.FrameType.Column) -> None:
self.wrap(column, "column")
class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
def highlight(self, diagnostics: list[Diagnostic]):
for diagnostic in diagnostics:
self.wrap(diagnostic, str(diagnostic.type).lower(), diagnostic.message)

View File

@@ -0,0 +1,39 @@
span {
--opacity: 0.4;
&.error {
--col: 255, 0, 0;
}
&.warning {
--col: 250, 160, 0;
}
&.info {
--col: 150, 190, 250;
}
&.with-msg {
position: relative;
.message {
display: none;
}
&:hover:not(:has(.with-msg:hover)) {
.message {
display: inline-block;
}
}
.message {
position: absolute;
top: calc(100% + 0.2em);
left: -.2em;
background-color: black;
color: white;
padding: 0.2em 0.4em;
border-radius: .2em;
z-index: 10;
width: 300%;
}
}
}

52
midas/cli/hl_midas.css Normal file
View File

@@ -0,0 +1,52 @@
span {
&.comment {
--col: 200, 200, 200;
color: rgb(110, 110, 110);
font-style: italic;
}
&.named-type,
&.generic-type,
&.constraint-type,
&.complex-type {
--col: 150, 150, 150;
}
&.constraint {
--col: 233, 108, 108;
}
&.property {
--col: 233, 108, 176;
}
&.extend {
--col: 108, 197, 233;
}
&.op {
--col: 108, 148, 233;
}
&.predicate {
--col: 193, 108, 233;
}
&.logical-expr,
&.binary-expr,
&.unary-expr,
&.get-expr {
--col: 123, 215, 193;
}
&.template {
--col: 163, 117, 71;
}
&.type-name,
&.op-name,
&.predicate-name {
--col: 200, 200, 200;
font-weight: bold;
}
}

29
midas/cli/hl_python.css Normal file
View File

@@ -0,0 +1,29 @@
span {
&.base-type {
--col: 108, 233, 108;
}
&.arg {
--col: 103, 192, 224;
}
&.constraint-type {
--col: 174, 200, 195;
}
&.frame-column {
--col: 216, 231, 81;
}
&.frame-type {
--col: 231, 46, 40;
}
&.function {
--col: 215, 103, 224;
}
&.parameter {
--col: 103, 192, 224;
}
}

26
midas/cli/main.py Normal file
View File

@@ -0,0 +1,26 @@
import logging
import click
from midas.cli import commands
@click.group()
@click.option("-v", "--verbose", is_flag=True)
def midas(verbose: bool):
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
midas.add_command(commands.check)
midas.add_command(commands.compile)
midas.add_command(commands.format)
midas.add_command(commands.highlight)
midas.add_command(commands.parse)
midas.add_command(commands.dump_registry)
midas.add_command(commands.types)
midas.add_command(commands.stubs)
midas.add_command(commands.validate)
if __name__ == "__main__":
midas()

121
midas/cli/utils.py Normal file
View File

@@ -0,0 +1,121 @@
from collections import defaultdict
from pathlib import Path
from typing import Optional
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.cli.ansi import Ansi
class DiagnosticPrinter:
COLORS: dict[DiagnosticType, int] = {
DiagnosticType.ERROR: Ansi.RED,
DiagnosticType.WARNING: Ansi.YELLOW,
DiagnosticType.INFO: Ansi.CYAN,
DiagnosticType.DEBUG: Ansi.MAGENTA,
}
def __init__(self) -> None:
self.files: dict[Optional[str], list[str]] = {}
def get_lines(self, filename: Optional[str]) -> list[str]:
if filename is None:
return []
if filename not in self.files:
path: Path = Path(filename)
if path.exists() and path.is_file():
self.files[filename] = path.read_text().split("\n")
else:
self.files[filename] = []
return self.files[filename]
def print_all(self, diagnostics: list[Diagnostic], indent: int = 4):
by_type: dict[DiagnosticType, int] = defaultdict(int)
for diagnostic in diagnostics:
filename: Optional[str] = diagnostic.file_path
lines = self.get_lines(filename)
self.print(lines, diagnostic, indent=indent)
by_type[diagnostic.type] += 1
if len(diagnostics) == 0:
return
counts: list[str] = []
for type in DiagnosticType:
if type not in by_type:
continue
count: int = by_type[type]
color: int = self.COLORS.get(type, Ansi.WHITE)
counts.append(f"{Ansi.FG(color)}{type.value}s{Ansi.RESET}: {count}")
print(" ".join(counts))
def print(self, lines: list[str], diagnostic: Diagnostic, indent: int = 4):
"""Pretty-print a diagnostic, showing some context if possible
If the diagnostic concerns a specific part of one line, the line is shown
with the affected part highlighted. The message is clearly printed under the
line with an underline further indicating the target expression.
If multiple lines are concerned, no context is shown, only the
diagnostic type, location and message
Args:
lines (list[str]): source code lines
diagnostic (Diagnostic): the diagnostic to print
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
"""
loc: Location = diagnostic.location
if loc.lineno != loc.end_lineno:
self.print_multiline(lines, diagnostic, indent)
return
start_offset: int = loc.col_offset
end_offset: int = loc.end_col_offset or (start_offset + 1)
line: str = lines[loc.lineno - 1]
before: str = line[:start_offset]
after: str = line[end_offset:]
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
cursor: str = (
" " * start_offset
+ Ansi.FG(color)
+ "~" * (end_offset - start_offset)
+ "> "
+ diagnostic.message
+ Ansi.RESET
)
indent_str: str = " " * indent
print(diagnostic.location_str + ":")
print(indent_str + before + subject + after)
print(indent_str + cursor)
print()
def print_multiline(
self, all_lines: list[str], diagnostic: Diagnostic, indent: int = 4
):
loc: Location = diagnostic.location
lines: list[str] = all_lines[loc.lineno - 1 : loc.end_lineno]
start_offset: int = loc.col_offset
end_offset: int = loc.end_col_offset or (start_offset + 1)
indent_str: str = " " * indent
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
res: str = indent_str + lines[0][:start_offset]
res += Ansi.FG(color) + lines[0][start_offset:]
for line in lines[1:-1]:
res += "\n" + indent_str + line
res += "\n" + indent_str + lines[-1][:end_offset]
res += Ansi.RESET + lines[-1][end_offset:]
print(diagnostic.location_str + ":")
print(res)
print()
print(Ansi.FG(color) + diagnostic.message + Ansi.RESET)
print()

View File

@@ -0,0 +1,59 @@
import ast
from dataclasses import dataclass
from typing import Callable
import midas.ast.python as p
AssertionBuilder = Callable[..., ast.expr]
@dataclass
class Assertion:
bound_expr: p.Expr
inputs: list[p.Expr]
builder: AssertionBuilder
message: str
def is_bound_to(self, expr: p.Expr) -> bool:
return expr == self.bound_expr
class AssertionCollector:
def __init__(self):
self.assertions: list[Assertion] = []
self.definitions: dict[str, ast.stmt] = {}
def add(
self,
bound_expr: p.Expr,
inputs: list[p.Expr],
builder: AssertionBuilder,
message: str,
):
self.assertions.append(
Assertion(
bound_expr=bound_expr,
inputs=inputs,
builder=builder,
message=message,
)
)
def remove(self, assertion: Assertion):
try:
self.assertions.remove(assertion)
except ValueError:
pass
def define(self, name: str, stmt: ast.stmt):
if name not in self.definitions:
self.definitions[name] = stmt
def get_definitions(self) -> list[ast.stmt]:
return list(self.definitions.values())
def get_assertions(self) -> list[Assertion]:
return self.assertions
def get_assertions_for(self, expr: p.Expr) -> list[Assertion]:
return list(filter(lambda a: a.is_bound_to(expr), self.assertions))

View File

@@ -0,0 +1,225 @@
import ast
from typing import Optional
import midas.ast.midas as m
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
Function,
ParamSpec,
Predicate,
Type,
to_annotation,
)
from midas.lexer.token import TokenType
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
TokenType.AND: ast.And,
# TokenType.OR: ast.Or,
}
BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = {
# TokenType.PLUS: ast.Add,
TokenType.MINUS: ast.Sub,
TokenType.STAR: ast.Mult,
TokenType.SLASH: ast.Div,
}
UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = {
# TokenType.PLUS: ast.UAdd,
TokenType.MINUS: ast.USub,
}
COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
TokenType.GREATER: ast.Gt,
TokenType.GREATER_EQUAL: ast.GtE,
TokenType.LESS: ast.Lt,
TokenType.LESS_EQUAL: ast.LtE,
TokenType.EQUAL_EQUAL: ast.Eq,
TokenType.BANG_EQUAL: ast.NotEq,
}
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
def __init__(self, types: TypesRegistry):
self.types: TypesRegistry = types
self._id: int = 0
self._definitions: list[ast.stmt] = []
self._aliases: dict[str, str] = {}
def get_definitions(self) -> list[ast.stmt]:
return self._definitions
def generate(self, expr: m.Expr) -> ast.expr:
match expr:
case m.VariableExpr():
return expr.accept(self)
case _:
func = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="_",
type=self.types.get_type("Any"),
required=True,
)
],
),
returns=self.types.get_type("bool"),
)
alias: str = self.make_alias(None)
definition: ast.stmt = self.make_definition(
alias, Predicate(type=func, body=expr, alias=False)
)
self._definitions.append(definition)
return ast.Name(id=alias)
def make_alias(self, name: Optional[str]) -> str:
suffix: str
if name is None:
suffix = f"p{self._id}"
self._id += 1
else:
suffix = name
alias: str = f"__midas_{suffix}__"
return alias
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
body: ast.expr = predicate.body.accept(self)
if predicate.alias:
return ast.Assign(
targets=[
ast.Name(id=name),
],
value=body,
)
return self.make_func(name, [ast.Return(value=body)], predicate.type)
def make_args(self, params: ParamSpec) -> ast.arguments:
return ast.arguments(
posonlyargs=[
ast.arg(
arg=param.name,
annotation=ast.Constant(value=to_annotation(param.type)),
)
for param in params.pos
],
args=[
ast.arg(
arg=param.name,
annotation=ast.Constant(value=to_annotation(param.type)),
)
for param in params.mixed
],
kwonlyargs=[
ast.arg(
arg=param.name,
annotation=ast.Constant(value=to_annotation(param.type)),
)
for param in params.kw
],
defaults=[],
kw_defaults=[],
)
def make_func(
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
) -> ast.stmt:
match type:
case Function(params=params, returns=Function()):
inner_name: str = f"inner{level}"
return ast.FunctionDef(
name=name,
args=self.make_args(params),
body=[
self.make_func(inner_name, inner_body, type.returns, level + 1),
ast.Return(value=ast.Name(id=inner_name)),
],
returns=ast.Constant(value=to_annotation(type.returns)),
decorator_list=[],
)
case Function(params=params):
return ast.FunctionDef(
name=name,
args=self.make_args(params),
body=inner_body,
returns=ast.Constant(value=to_annotation(type.returns)),
decorator_list=[],
)
case _:
raise ValueError(f"Expected function, got {type!r}")
def get_predicate(self, name: str) -> Optional[ast.expr]:
if name not in self._aliases:
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is None:
return None
alias: str = self.make_alias(name)
self._aliases[name] = alias
self._definitions.append(self.make_definition(alias, predicate))
return ast.Name(id=self._aliases[name])
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
return ast.BoolOp(
op=LOGICAL_OPERATORS[expr.operator.type](),
values=[
expr.left.accept(self),
expr.right.accept(self),
],
)
def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr:
op: TokenType = expr.operator.type
if op in BINARY_OPERATORS:
return ast.BinOp(
left=expr.left.accept(self),
op=BINARY_OPERATORS[op](),
right=expr.right.accept(self),
)
if op in COMPARISON_OPERATORS:
return ast.Compare(
left=expr.left.accept(self),
ops=[COMPARISON_OPERATORS[op]()],
comparators=[expr.right.accept(self)],
)
raise ValueError(f"Unexpected binary operator {op}")
def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr:
return ast.UnaryOp(
op=UNARY_OPERATORS[expr.operator.type](),
operand=expr.right.accept(self),
)
def visit_call_expr(self, expr: m.CallExpr) -> ast.expr:
return ast.Call(
func=expr.callee.accept(self),
args=[arg.accept(self) for arg in expr.arguments],
keywords=[
ast.keyword(arg=name, value=arg.accept(self))
for name, arg in expr.keywords.items()
],
)
def visit_get_expr(self, expr: m.GetExpr) -> ast.expr:
return ast.Attribute(
value=expr.expr.accept(self),
attr=expr.name.lexeme,
)
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
name: str = expr.name.lexeme
if (p := self.get_predicate(name)) is not None:
return p
return ast.Name(id=name)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
return expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr:
return ast.Constant(value=expr.value)
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr:
return ast.Name(id="_")

View File

@@ -0,0 +1,690 @@
import ast
import logging
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, assert_never
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AppliedType,
BaseType,
ColumnGroupBy,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
FrameGroupBy,
Function,
GenericType,
OverloadedFunction,
TopType,
TupleType,
Type,
TypeVar,
UnitType,
UnknownType,
)
from midas.generator.collector import Assertion, AssertionCollector
from midas.generator.constraints import ConstraintGenerator
from midas.generator.stubs import StubsGenerator
from midas.utils import TypedAST
@dataclass
class Scope:
pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
aliases: list[str] = field(default_factory=list[str])
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
IS_COLUMN_FUNC = "__midas_is_column__"
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
self.rel_src_path: Path = Path()
self.logger: logging.Logger = logging.getLogger("Generator")
self._typed_ast: TypedAST = TypedAST(
stmts=[],
judgements=[],
evaluated_casts=[],
assertions=AssertionCollector(),
)
self._alias_count: int = 0
self._predicate_count: int = 0
self._scopes: list[Scope] = []
self._aliases: list[tuple[p.Expr, ast.expr]] = []
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = []
self.define_is_dataframe: bool = False
self.define_is_column: bool = False
def set_src_path(self, path: Path):
self.rel_src_path = path.resolve().relative_to(self.workdir)
def generate_ast(self, typed_ast: TypedAST) -> ast.AST:
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts, can_be_empty=True)
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
body = predicates + body
if self.define_is_dataframe:
body = [self._is_dataframe_definition()] + body
if self.define_is_column:
body = [self._is_column_definition()] + body
module = ast.Module(body=body, type_ignores=[])
module = ast.fix_missing_locations(module)
return module
def generate(
self,
typed_ast: TypedAST,
src_path: Path,
out_path: Optional[Path] = None,
type_files: Optional[list[tuple[Path, Optional[str]]]] = None,
) -> Path:
self.set_src_path(src_path)
if out_path is None:
if self.build_dir.exists():
shutil.rmtree(self.build_dir)
self.build_dir.mkdir(parents=True, exist_ok=True)
out_path = (self.build_dir / self.rel_src_path).resolve()
try:
_ = out_path.relative_to(self.build_dir)
except ValueError:
raise ValueError(
f"Directory traversal, {self.rel_src_path} points outside of parent directory"
)
out_dir: Path = out_path.parent
out_dir.parent.mkdir(parents=True, exist_ok=True)
if type_files is not None:
for in_path, out_name in type_files:
if out_name is None:
out_name = in_path.stem
self.generate_stubs(in_path, out_dir / f"{out_name}.py")
module: ast.AST = self.generate_ast(typed_ast)
compiled: str = ast.unparse(module)
out_path.write_text(compiled)
return out_path
def generate_stubs(self, in_path: Path, out_path: Path):
checker = TypeChecker()
checker.import_midas(in_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output: str = ast.unparse(module)
out_path.write_text(output)
def convert(self, expr: p.Expr) -> ast.expr:
for expr2, alias in self._aliases:
if expr2 == expr:
return alias
assertions = self._typed_ast.assertions.get_assertions_for(expr)
if len(assertions) != 0:
return self._apply_assertions(expr, assertions)
return expr.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> ast.expr:
return ast.BinOp(
left=self.convert(expr.left),
op=expr.operator,
right=self.convert(expr.right),
)
def visit_compare_expr(self, expr: p.CompareExpr) -> ast.expr:
return ast.Compare(
left=self.convert(expr.left),
ops=[expr.operator],
comparators=[self.convert(expr.right)],
)
def visit_unary_expr(self, expr: p.UnaryExpr) -> ast.expr:
return ast.UnaryOp(
op=expr.operator,
operand=self.convert(expr.right),
)
def visit_call_expr(self, expr: p.CallExpr) -> ast.expr:
return ast.Call(
func=self.convert(expr.callee),
args=[self.convert(arg) for arg in expr.arguments],
keywords=[
ast.keyword(arg=name, value=self.convert(arg))
for name, arg in expr.keywords.items()
],
)
def visit_get_expr(self, expr: p.GetExpr) -> ast.expr:
return ast.Attribute(
value=self.convert(expr.object),
attr=expr.name,
)
def visit_literal_expr(self, expr: p.LiteralExpr) -> ast.expr:
return ast.Constant(value=expr.value)
def visit_variable_expr(self, expr: p.VariableExpr) -> ast.expr:
return ast.Name(id=expr.name)
def visit_logical_expr(self, expr: p.LogicalExpr) -> ast.expr:
return ast.BoolOp(
op=expr.operator,
values=[self.convert(expr.left), self.convert(expr.right)],
)
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
expr2: ast.expr = self.convert(expr.expr)
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
return expr2
alias: ast.expr = self._make_alias(expr.expr, expr2)
type: Type = self._get_expr_type(expr)
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
for assert_ in asserts:
self._add_assert(assert_)
return alias
def visit_ternary_expr(self, expr: p.TernaryExpr) -> ast.expr:
return ast.IfExp(
test=self.convert(expr.test),
body=self.convert(expr.if_true),
orelse=self.convert(expr.if_false),
)
def visit_list_expr(self, expr: p.ListExpr) -> ast.expr:
return ast.List(
elts=[self.convert(item) for item in expr.items],
)
def visit_dict_expr(self, expr: p.DictExpr) -> ast.expr:
return ast.Dict(
keys=[self.convert(key) if key is not None else None for key in expr.keys],
values=[self.convert(value) for value in expr.values],
)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> ast.expr:
return ast.Subscript(
value=self.convert(expr.object),
slice=self.convert(expr.index),
)
def visit_slice_expr(self, expr: p.SliceExpr) -> ast.expr:
return ast.Slice(
lower=self.convert(expr.lower) if expr.lower is not None else None,
upper=self.convert(expr.upper) if expr.upper is not None else None,
step=self.convert(expr.step) if expr.step is not None else None,
)
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
return ast.Tuple(
elts=[self.convert(item) for item in expr.items],
)
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
return expr.expr
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> ast.stmt:
return ast.Expr(
value=self.convert(stmt.expr),
)
def make_args(self, params: p.ParamSpec) -> ast.arguments:
return ast.arguments(
posonlyargs=[ast.arg(arg=param.name) for param in params.pos],
args=[ast.arg(arg=param.name) for param in params.mixed],
kwonlyargs=[ast.arg(arg=param.name) for param in params.kw],
defaults=[
self.convert(param.default)
for param in params.pos + params.mixed
if param.default is not None
],
kw_defaults=[
self.convert(param.default) if param.default is not None else None
for param in params.kw
],
)
def visit_function(self, stmt: p.Function) -> ast.stmt:
return ast.FunctionDef(
name=stmt.name,
args=self.make_args(stmt.params),
body=self._visit_body(stmt.body),
decorator_list=[],
)
def visit_type_assign(self, stmt: p.TypeAssign) -> ast.stmt:
# TODO: is that ok?
return ast.Pass()
def visit_assign_stmt(self, stmt: p.AssignStmt) -> ast.stmt:
return ast.Assign(
targets=[self.convert(target) for target in stmt.targets],
value=self.convert(stmt.value),
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> ast.stmt:
return ast.Return(
value=self.convert(stmt.value) if stmt.value is not None else None,
)
def visit_if_stmt(self, stmt: p.IfStmt) -> ast.stmt:
return ast.If(
test=self.convert(stmt.test),
body=self._visit_body(stmt.body),
orelse=self._visit_body(stmt.orelse, can_be_empty=True),
)
def visit_pass(self, stmt: p.Pass) -> ast.stmt:
return ast.Pass()
def visit_for_stmt(self, stmt: p.ForStmt) -> ast.stmt:
return ast.For(
target=self.convert(stmt.target),
iter=self.convert(stmt.iterator),
body=self._visit_body(stmt.body),
orelse=[],
)
def visit_raw_stmt(self, stmt: p.RawStmt) -> ast.stmt:
return stmt.stmt
def _visit_body(
self, stmts: list[p.Stmt], can_be_empty: bool = False
) -> list[ast.stmt]:
generated: list[ast.stmt] = []
for stmt in stmts:
scope = Scope()
self._scopes.append(scope)
stmt2 = stmt.accept(self)
generated.extend(scope.pre_assertions)
generated.append(stmt2)
if len(scope.aliases) != 0:
generated.append(
ast.Delete(targets=[ast.Name(id=alias) for alias in scope.aliases])
)
self._scopes.pop()
# Remove redundant pass statements
if len(generated) > 1:
generated = [stmt for stmt in generated if not isinstance(stmt, ast.Pass)]
if len(generated) == 0 and not can_be_empty:
generated = [ast.Pass()]
return generated
def _make_alias(self, node: p.Expr, expr: ast.expr) -> ast.expr:
name: str = f"__midas_a{self._alias_count}__"
alias = ast.Name(id=name)
self._alias_count += 1
self._scopes[-1].aliases.append(name)
self._scopes[-1].pre_assertions.append(
ast.Assign(
targets=[alias],
value=expr,
)
)
self._aliases.append((node, alias))
return alias
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
if isinstance(message, str):
message = ast.Constant(value=message)
return ast.Assert(
test=expr,
msg=message,
)
def _add_assert(self, assertion: ast.stmt):
self._scopes[-1].pre_assertions.append(assertion)
def _get_expr_type(self, query: p.Expr) -> Type:
for expr, type in self._typed_ast.judgements:
if expr == query:
return type
raise RuntimeError(f"Cannot get type judgement for {query}")
def _make_cast_asserts(
self, src_location: Location, expr: ast.expr, type: Type
) -> list[ast.stmt]:
match type:
case UnknownType() | TopType():
return []
case BaseType(name=name):
return [
self._build_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id=name)],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
]
case DerivedType(type=base):
return self._make_cast_asserts(src_location, expr, base)
case UnitType():
return [
self._build_assert(
ast.Compare(
left=expr,
ops=[ast.Is()],
comparators=[
ast.Constant(value=None),
],
),
self._make_cast_assert_message(src_location, expr, type),
),
]
case AppliedType(body=body):
return self._make_cast_asserts(src_location, expr, body)
case ConstraintType(type=base, constraint=constraint):
asserts: list[ast.stmt] = self._make_cast_asserts(
src_location, expr, base
)
asserts.append(
self._make_constraint_assert(src_location, expr, constraint)
)
return asserts
case TypeVar(bound=bound):
# TODO: check with type from arguments / use call-site context
if bound is None:
return []
return self._make_cast_asserts(src_location, expr, bound)
case TupleType(items=items):
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
),
]
assert isinstance(expr, ast.Tuple)
for item, item_type in zip(expr.elts, items):
asserts.extend(
self._make_cast_asserts(src_location, item, item_type)
)
return asserts
case DataFrameType(columns=columns):
self.define_is_dataframe = True
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a dataframe"
),
),
]
for column in columns:
asserts.append(
self._build_assert(
ast.Compare(
left=ast.Constant(value=column.name),
ops=[ast.In()],
comparators=[expr],
),
self._make_cast_assert_message(
src_location,
expr,
type,
f": Missing column {column.name}",
),
)
)
asserts.extend(
self._make_cast_asserts(
src_location,
ast.Subscript(
value=expr, slice=ast.Constant(value=column.name)
),
column.type,
)
)
return asserts
case ColumnType():
self.define_is_column = True
asserts: list[ast.stmt] = [
self._build_assert(
ast.Call(
func=ast.Name(id=self.IS_COLUMN_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a column"
),
),
]
inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
src_location, expr, type
)
if inner_assert is not None:
asserts.append(inner_assert)
return asserts
case (
Function()
| OverloadedFunction()
| ComplexType()
| ExtensionType()
| GenericType()
| FrameGroupBy()
| ColumnGroupBy()
):
self.logger.warning(f"Can't make assertion for type {type}")
return []
# Ensure exhaustiveness
case _:
assert_never(type)
def _make_cast_assert_message(
self,
location: Location,
expr: ast.expr,
type: Type,
extra: Optional[str] = None,
) -> ast.expr:
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
return ast.JoinedStr(
values=[
ast.Constant(f"{loc_str}: CastError: Cannot cast "),
ast.FormattedValue(
value=ast.Attribute(
value=ast.Call(
func=ast.Name(id="type"),
args=[expr],
keywords=[],
),
attr="__name__",
),
conversion=-1,
),
ast.Constant(f" to {type}{extra or ''}"),
]
)
def _make_constraint_assert(
self, src_location: Location, expr: ast.expr, constraint: m.Expr
) -> ast.stmt:
test_func: ast.expr = self._get_constraint(constraint)
return self._build_assert(
ast.Call(
func=test_func,
args=[expr],
keywords=[],
),
self._make_constraint_assert_message(src_location, expr, constraint),
)
def _make_constraint_assert_message(
self, location: Location, expr: ast.expr, constraint: m.Expr
) -> ast.expr:
printer = MidasPrinter()
constraint_str: str = printer.print(constraint)
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'"
return ast.Constant(
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
)
def _get_constraint(self, expr: m.Expr) -> ast.expr:
for expr2, constraint in self._constraints:
if expr2 == expr:
return constraint
constraint: ast.expr = self._constraint_generator.generate(expr)
self._constraints.append((expr, constraint))
return constraint
def _is_dataframe_definition(self) -> ast.stmt:
"""
def IS_DATAFRAME_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.DataFrame)
"""
return ast.FunctionDef(
name=self.IS_DATAFRAME_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _is_column_definition(self) -> ast.stmt:
"""
def IS_COLUMN_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.Series)
"""
return ast.FunctionDef(
name=self.IS_COLUMN_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _make_column_inner_assert(
self, src_location: Location, column: ast.expr, type: ColumnType
) -> Optional[ast.stmt]:
# TODO: improve message, maybe chain contexts
col: ast.expr = ast.Name(id="col")
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
if len(body) == 0:
return None
return ast.For(
target=col,
iter=column,
body=body,
orelse=[],
)
def _convert_assertion(self, assertion: Assertion) -> ast.stmt:
inputs: list[ast.expr] = []
for input in assertion.inputs:
converted: ast.expr = self.convert(input)
alias: ast.expr = self._make_alias(input, converted)
inputs.append(alias)
test: ast.expr = assertion.builder(*inputs)
location: Location = assertion.bound_expr.location
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
return self._build_assert(
test, f"{loc_str}: AssertionError: {assertion.message}"
)
def _apply_assertions(self, expr: p.Expr, assertions: list[Assertion]) -> ast.expr:
for assertion in assertions:
assert_stmt: ast.stmt
assert_stmt = self._convert_assertion(assertion)
self._add_assert(assert_stmt)
# Mutating list in frozen dataclass
# Not ideal but easiest way to avoid duplicate assertions
self._typed_ast.assertions.remove(assertion)
return expr.accept(self)

480
midas/generator/stubs.py Normal file
View File

@@ -0,0 +1,480 @@
import ast
from typing import Optional, assert_never
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AppliedType,
BaseType,
ColumnGroupBy,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
DerivedType,
ExtensionType,
FrameGroupBy,
Function,
GenericType,
OverloadedFunction,
ParamSpec,
TopType,
TupleType,
Type,
TypeVar,
UnitType,
UnknownType,
Variance,
substitute_typevars,
)
Empty = ast.Constant(value=...)
class StubsGenerator:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set()
self.import_pandas: bool = False
self.protocol_idx: int = 0
self.stub_idx: int = 0
self.type_var_idx: int = 0
self.substitutions: dict[str, dict[str, Type]] = {}
def generate_stubs(self) -> ast.Module:
self.stubs = []
self.typing_imports = set()
self.import_pandas = False
for name, type in self.types._types.items():
# Skip builtin types, not just based on name so the user can override
# TODO: check if added members on builtin type
match type:
case BaseType(name=name_) if name == name_:
continue
case GenericType(
name=name1,
body=BaseType(name=name2),
) if (
name == name1 == name2
):
continue
self.generate_stub(name, type)
imports: list[ast.stmt] = [
ast.ImportFrom(
module="__future__",
names=[ast.alias(name="annotations")],
level=0,
)
]
if len(self.typing_imports) != 0:
imports.append(
ast.ImportFrom(
module="typing",
names=[
ast.alias(name=name) for name in sorted(self.typing_imports)
],
level=0,
)
)
if self.import_pandas:
imports.append(
ast.Import(
names=[
ast.alias(
name="pandas",
asname="pd",
)
],
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type):
base_type: Type = type
# TODO: improve
match type:
case DerivedType(name=name_) | GenericType(name=name_) if name_ == name:
pass
case UnitType() if name == "None":
pass
case TopType() if name == "Any":
pass
case _:
alias = ast.Assign(
targets=[ast.Name(id=name)], value=self.dump_type(type)
)
self.add_stub(alias)
return
members: dict[str, Member] = self.types._members.get(name, {})
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
return
bases: list[ast.expr] = []
substitutions: dict[str, Type] = {}
bases, substitutions = self.get_bases(type)
self.substitutions[name] = substitutions
body = self.generate_body(members, substitutions)
stub = ast.ClassDef(
name=name,
bases=bases,
body=body,
keywords=[],
decorator_list=[],
)
self.add_stub(stub)
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
match type:
case DerivedType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
self.add_typing_import("Generic")
type_vars: ast.expr
params2: list[TypeVar] = self.define_type_vars(params)
if len(params) == 1:
type_vars = ast.Name(id=params2[0].name)
else:
type_vars = ast.Tuple(
elts=[ast.Name(id=param.name) for param in params2]
)
substitutions: dict[str, TypeVar] = {
param.name: param2 for param, param2 in zip(params, params2)
}
body_bases, body_subsitutions = self.get_bases(body)
return (
body_bases
+ [
ast.Subscript(
value=ast.Name(id="Generic"),
slice=type_vars,
)
],
body_subsitutions | substitutions,
)
case ConstraintType(type=base):
return self.get_bases(base)
case TypeVar(bound=bound) if bound is not None:
return [self.dump_type(bound)], {}
case _:
return [], {}
def generate_body(
self, members: dict[str, Member], substitutions: dict[str, Type]
) -> list[ast.stmt]:
if len(members) == 0:
return [ast.Expr(value=Empty)]
body: list[ast.stmt] = []
for name, member in members.items():
type: Type = member.type
type = substitute_typevars(type, substitutions)
match member.kind:
case m.MemberKind.PROPERTY:
body.append(
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(type),
simple=1,
)
)
case m.MemberKind.METHOD:
body.extend(self.dump_method(name, type))
return body
def dump_type(self, type: Type) -> ast.expr:
match type:
case DerivedType(name=name) | GenericType(name=name) if (
name in self.substitutions
):
type = substitute_typevars(type, self.substitutions[name])
match type:
case TopType() | UnknownType():
self.add_typing_import("Any")
return ast.Name(id="Any")
case BaseType(name=name):
return ast.Name(id=name)
case DerivedType(name=name):
return ast.Name(id=name)
case UnitType():
return ast.Constant(value=None)
case Function():
name: str = self.define_protocol(type)
return ast.Name(id=name)
case OverloadedFunction(overloads=overloads):
if len(overloads) == 1:
return self.dump_type(overloads[0])
return ast.BinOp(
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
op=ast.BitOr(),
right=self.dump_type(overloads[-1]),
)
case ComplexType():
name: str = self.new_stub_name()
self.generate_stub(name, type)
return ast.Name(id=name)
case ExtensionType():
raise NotImplementedError
case TypeVar():
return ast.Name(id=type.name)
case GenericType(name=name):
params: ast.expr
if len(type.params) == 1:
params = self.dump_type(type.params[0])
else:
params = ast.Tuple(
elts=[self.dump_type(param) for param in type.params]
)
return ast.Subscript(
value=ast.Name(id=type.name),
slice=params,
)
case AppliedType():
args: ast.expr
if len(type.args) == 1:
args = self.dump_type(type.args[0])
else:
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
return ast.Subscript(
value=ast.Name(id=type.name),
slice=args,
)
case ConstraintType():
return self.dump_type(type.type)
case TupleType(items=items):
return ast.Subscript(
value=ast.Name(id="tuple"),
slice=ast.Tuple(
elts=[self.dump_type(item) for item in items],
),
)
case ColumnType(type=inner):
self.import_pandas = True
return ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
slice=self.dump_type(inner),
)
case DataFrameType():
self.import_pandas = True
return ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
)
case FrameGroupBy():
self.import_pandas = True
return ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="api",
),
attr="typing",
),
attr="DataFrameGroupBy",
)
case ColumnGroupBy():
self.import_pandas = True
return ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="api",
),
attr="typing",
),
attr="SeriesGroupBy",
)
case _:
assert_never(type)
def dump_method(
self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]:
match method:
case Function():
if overloaded:
self.add_typing_import("overload")
return [
ast.FunctionDef(
name=name,
args=self.dump_params(method.params, with_self=True),
returns=self.dump_type(method.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[ast.Name(id="overload")] if overloaded else [],
)
]
case OverloadedFunction(overloads=overloads):
stmts: list[ast.stmt] = []
for overload in overloads:
stmts.extend(self.dump_method(name, overload, True))
return stmts
case _:
return [
ast.AnnAssign(
target=ast.Name(id=name),
annotation=self.dump_type(method),
simple=1,
)
]
def dump_params(self, params: ParamSpec, with_self: bool = False) -> ast.arguments:
pos: list[ast.arg] = [
ast.arg(
arg=f"_{param.pos}",
annotation=self.dump_type(param.type),
)
for param in params.pos
]
mixed: list[ast.arg] = [
ast.arg(
arg=param.name,
annotation=self.dump_type(param.type),
)
for param in params.mixed
]
kw: list[ast.arg] = [
ast.arg(
arg=param.name,
annotation=self.dump_type(param.type),
)
for param in params.kw
]
defaults: list[ast.expr] = [
Empty for param in params.pos + params.mixed if not param.required
]
kw_defaults: list[Optional[ast.expr]] = [
None if param.required else Empty for param in params.kw
]
if with_self:
arg = ast.arg(arg="self", annotation=None)
if len(pos) != 0:
pos.insert(0, arg)
else:
mixed.insert(0, arg)
return ast.arguments(
posonlyargs=pos,
args=mixed,
kwonlyargs=kw,
defaults=defaults,
kw_defaults=kw_defaults,
)
def define_protocol(self, func: Function) -> str:
self.add_typing_import("Protocol")
name: str = self.new_protocol_name()
protocol = ast.ClassDef(
name=name,
bases=[ast.Name(id="Protocol")],
keywords=[],
body=[
ast.FunctionDef(
name="__call__",
args=self.dump_params(func.params, with_self=True),
returns=self.dump_type(func.returns),
body=[ast.Expr(value=Empty)],
decorator_list=[],
),
],
decorator_list=[],
)
self.add_stub(protocol)
return name
def new_protocol_name(self) -> str:
name: str = f"_Protocol{self.protocol_idx}"
self.protocol_idx += 1
return name
def new_stub_name(self) -> str:
name: str = f"_Stub_{self.stub_idx}"
self.stub_idx += 1
return name
def new_type_var_name(self) -> str:
name: str = f"_T{self.type_var_idx}"
self.type_var_idx += 1
return name
def add_stub(self, stub: ast.stmt):
self.stubs.append(stub)
def add_typing_import(self, name: str):
self.typing_imports.add(name)
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
vars2: list[TypeVar] = []
for var in vars:
vars2.append(self.define_type_var(var))
return vars2
def define_type_var(self, var: TypeVar) -> TypeVar:
name: str = self.new_type_var_name()
self.add_typing_import("TypeVar")
kwargs: list[ast.keyword] = []
if var.bound is not None:
kwargs.append(
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
)
if var.variance == Variance.COVARIANT:
kwargs.append(
ast.keyword(
arg="covariant",
value=ast.Constant(value=True),
)
)
elif var.variance == Variance.CONTRAVARIANT:
kwargs.append(
ast.keyword(
arg="contravariant",
value=ast.Constant(value=True),
)
)
self.add_stub(
ast.Assign(
targets=[ast.Name(id=name)],
value=ast.Call(
func=ast.Name(id="TypeVar"),
args=[
ast.Constant(value=name),
],
keywords=kwargs,
),
)
)
return TypeVar(name=name, bound=None)

0
midas/lexer/__init__.py Normal file
View File

View File

@@ -1,17 +1,25 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional
from lexer.position import Position
from lexer.token import Token, TokenType
from midas.lexer.position import Position
from midas.lexer.token import Token, TokenType
class MidasSyntaxError(Exception):
def __init__(self, pos: Position, message: str):
super().__init__(f"[ERROR] Error at {pos}: {message}")
self.pos: Position = pos
self.message: str = message
class Lexer(ABC):
"""An abstract lexer which provides methods to easily extend it into a concrete one
This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom,
more specifically on my [previous Python implementation](https://git.kb28.ch/HEL/pebble)
more specifically on my [previous Python implementation][2]
[1]: https://craftinginterpreters.com/
[2]: https://git.kb28.ch/HEL/pebble
"""
def __init__(self, source: str, file: Optional[str] = None) -> None:
@@ -38,9 +46,9 @@ class Lexer(ABC):
msg (str): the error message
Raises:
SyntaxError
MidasSyntaxError
"""
raise SyntaxError(f"[ERROR] Error at {self.start_pos}: {msg}")
raise MidasSyntaxError(self.start_pos, msg)
def process(self) -> list[Token]:
"""Scan tokens out of the source text
@@ -49,7 +57,7 @@ class Lexer(ABC):
list[Token]: all the tokens that could be scanned
Raises:
SyntaxError: if a syntax error is found
MidasSyntaxError: if a syntax error is found
"""
self.scan_tokens()
self.tokens.append(Token(TokenType.EOF, "", None, self.get_position()))
@@ -161,6 +169,6 @@ class Lexer(ABC):
def scan_token(self) -> None:
"""Scan a token
This function should (at least) consume the current character and produce the appropriate token(s), using `add_token`
This function should (at least) consume the current character and produce the appropriate token(s), using :func:`add_token`
"""
pass

View File

@@ -1,6 +1,5 @@
from lexer.base import Lexer
from lexer.keyword import MIDAS_KEYWORDS
from lexer.token import TokenType
from midas.lexer.base import Lexer
from midas.lexer.token import KEYWORDS, TokenType
class MidasLexer(Lexer):
@@ -31,30 +30,34 @@ class MidasLexer(Lexer):
self.add_token(
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
)
case "!":
if self.match("="):
self.add_token(TokenType.BANG_EQUAL)
else:
self.error("Unexpected single bang. Did you mean '!=' ?")
case "!" if self.match("="):
self.add_token(TokenType.BANG_EQUAL)
case ":":
self.add_token(TokenType.COLON)
case ".":
self.add_token(TokenType.DOT)
case "&":
self.add_token(TokenType.AND)
case "?":
self.add_token(TokenType.QMARK)
case ",":
self.add_token(TokenType.COMMA)
case "_":
case "_" if not self.is_identifier_char(self.peek_next(), start=False):
self.add_token(TokenType.UNDERSCORE)
case "-" if self.match(">"):
self.add_token(TokenType.ARROW)
case "+":
self.add_token(TokenType.PLUS)
case "-":
self.add_token(TokenType.MINUS)
case "*":
self.add_token(TokenType.STAR)
case "/" if self.match("/"):
self.scan_comment()
case "/" if self.match("*"):
self.scan_comment_multiline()
case "/":
if self.match("/"):
self.scan_comment()
elif self.match("*"):
self.scan_comment_multiline()
else:
self.add_token(TokenType.SLASH)
self.add_token(TokenType.SLASH)
case "\n":
self.add_token(TokenType.NEWLINE)
case " " | "\r" | "\t":
@@ -66,15 +69,34 @@ class MidasLexer(Lexer):
):
self.advance()
self.add_token(TokenType.WHITESPACE)
case '"' | "'":
self.scan_string(char)
case _:
if char.isdigit():
self.scan_number()
elif char.isalpha():
elif self.is_identifier_char(char, start=True):
self.scan_identifier()
else:
self.error("Unexpected character")
return None
def scan_string(self, opening: str):
"""Scan the rest of a string and add it as a token
Args:
opening (str): the opening quote or double quote, to be matched
at the end of the string
"""
while self.peek() != opening and not self.is_at_end():
self.advance()
if self.is_at_end():
self.error("Unterminated string")
self.advance()
value: str = self.source[self.start + 1 : self.idx - 1]
self.add_token(TokenType.STRING, value)
def scan_number(self):
"""Scan the rest of number and add it as a token
@@ -98,11 +120,11 @@ class MidasLexer(Lexer):
An identifier starts with a letter, followed by any number of
alphanumerical characters or underscores
"""
while self.peek().isalnum() or self.peek() == "_":
while self.is_identifier_char(self.peek(), start=False):
self.advance()
lexeme: str = self.source[self.start : self.idx]
token_type: TokenType = MIDAS_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
token_type: TokenType = KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
self.add_token(token_type)
def scan_comment(self):
@@ -129,3 +151,24 @@ class MidasLexer(Lexer):
if not self.is_at_end():
self.advance()
self.add_token(TokenType.COMMENT)
def is_identifier_char(self, char: str, *, start: bool) -> bool:
"""Check whether a character is a valid as part of an identifier
Identifiers can contain any alphanumerical character or underscore.
They cannot start with a digit.
Args:
char (str): the character to check
start (bool): whether this is the first character of the identifier
Returns:
bool: `True` if the character is valid, `False` otherwise
"""
if char == "_":
return True
if char.isalpha():
return True
if not start and char.isdigit():
return True
return False

View File

@@ -5,6 +5,7 @@ from typing import Optional
@dataclass(frozen=True)
class Position:
"""A simple structure to store the position of a token"""
file: Optional[str]
line: int
column: int

120
midas/lexer/token.py Normal file
View File

@@ -0,0 +1,120 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any
from midas.ast.location import Location
from midas.lexer.position import Position
class TokenType(Enum):
# Punctuation
LEFT_PAREN = auto()
RIGHT_PAREN = auto()
LEFT_BRACKET = auto()
RIGHT_BRACKET = auto()
LEFT_BRACE = auto()
RIGHT_BRACE = auto()
COLON = auto()
COMMA = auto()
UNDERSCORE = auto()
ARROW = auto()
AND = auto()
QMARK = auto()
DOT = auto()
# Operators
PLUS = auto()
MINUS = auto()
STAR = auto()
SLASH = auto()
GREATER = auto()
GREATER_EQUAL = auto()
LESS = auto()
LESS_EQUAL = auto()
EQUAL = auto()
EQUAL_EQUAL = auto()
BANG_EQUAL = auto()
# Literals
IDENTIFIER = auto()
NUMBER = auto()
TRUE = auto()
FALSE = auto()
NONE = auto()
STRING = auto()
# Keywords
TYPE = auto()
ALIAS = auto()
PREDICATE = auto()
EXTEND = auto()
WHERE = auto()
PROP = auto()
DEF = auto()
FUNC = auto()
# Misc
COMMENT = auto()
WHITESPACE = auto()
EOF = auto()
NEWLINE = auto()
KEYWORDS: dict[str, TokenType] = {
"type": TokenType.TYPE,
"alias": TokenType.ALIAS,
"predicate": TokenType.PREDICATE,
"extend": TokenType.EXTEND,
"where": TokenType.WHERE,
"true": TokenType.TRUE,
"false": TokenType.FALSE,
"none": TokenType.NONE,
"prop": TokenType.PROP,
"def": TokenType.DEF,
"fn": TokenType.FUNC,
}
@dataclass(frozen=True)
class Token:
"""A scanned token"""
type: TokenType
lexeme: str
value: Any
position: Position
def get_location(self) -> Location:
lineno: int = self.position.line
col_offset: int = self.position.column - 1
end_lineno = lineno
end_col_offset = col_offset
for c in self.lexeme:
end_col_offset += 1
if c == "\n":
end_lineno += 1
end_col_offset = 0
return Location(
lineno=lineno,
col_offset=col_offset,
end_lineno=end_lineno,
end_col_offset=end_col_offset,
)
def location_to(self, to: Token) -> Location:
"""Create a new :class:`Location` spanning from this token to another
Args:
to (Token): the end token
Returns:
Location: a new :class:`Location` starting at this token and ending
at `to`, both included
"""
return Location.span(self.get_location(), to.get_location())
@property
def is_keyword(self) -> bool:
return self.lexeme in KEYWORDS

View File

@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar
from lexer.token import Token, TokenType
from parser.errors import ParsingError
from midas.lexer.token import Token, TokenType
from midas.parser.errors import ParsingError
@dataclass(frozen=True)
@@ -16,6 +16,9 @@ class TokenError:
def get_report(self) -> str:
"""Get a detailed error message
The error message is formatted as "(<position>) Error at <token>: <message>".
For example: "(L2:5) Error at '3': Expected ')' after arguments."
Returns:
str: the complete error message
"""
@@ -32,9 +35,10 @@ class Parser(ABC, Generic[T]):
"""An abstract parser which provides methods to easily extend it into a concrete one
This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom,
more specifically on my [previous Python implementation](https://git.kb28.ch/HEL/pebble)
more specifically on my [previous Python implementation][2]
[1]: https://craftinginterpreters.com/
[2]: https://git.kb28.ch/HEL/pebble
"""
IGNORE: set[TokenType] = {
@@ -173,7 +177,7 @@ class Parser(ABC, Generic[T]):
error_msg (str): the error message if the token doesn't match
Raises:
SyntaxError: if the current token doesn't match the given type
ParsingError: if the current token doesn't match the given type
Returns:
Token: the current token which matched the given type

874
midas/parser/midas.py Normal file
View File

@@ -0,0 +1,874 @@
from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import (
AliasStmt,
BinaryExpr,
CallExpr,
ComplexType,
ConstraintType,
Expr,
ExtendStmt,
ExtensionType,
FrameType,
FunctionType,
GenericType,
GetExpr,
GroupingExpr,
LiteralExpr,
LogicalExpr,
MemberKind,
MemberStmt,
NamedType,
ParamSpec,
PredicateStmt,
Stmt,
Type,
TypeParam,
TypeStmt,
UnaryExpr,
VariableExpr,
WildcardExpr,
)
from midas.lexer.token import KEYWORDS, Token, TokenType
from midas.parser.base import Parser
from midas.parser.errors import ParsingError
class MidasParser(Parser[list[Stmt]]):
"""A simple parser for midas type definitions"""
SYNC_BOUNDARY: set[TokenType] = {
TokenType.ALIAS,
TokenType.TYPE,
TokenType.EXTEND,
TokenType.PREDICATE,
TokenType.PROP,
TokenType.FUNC,
}
def parse(self) -> list[Stmt]:
statements: list[Stmt] = []
while not self.is_at_end():
stmt: Optional[Stmt] = self.declaration()
if stmt is None:
print("Early stop")
break
statements.append(stmt)
return statements
def synchronize(self):
"""Skip tokens until a synchronization boundary is found
This method allows gracefully recovering from a parse error
to a safe place and continue parsing
"""
self.advance()
while not self.is_at_end():
if self.previous().type == TokenType.NEWLINE:
return
if self.peek().type in self.SYNC_BOUNDARY:
return
self.advance()
def declaration(self) -> Optional[Stmt]:
"""Try and parse a declaration
Any parsing error is caught and `None` is returned
Returns:
Optional[Stmt]: the parsed Midas statement, or `None` if a ParsingError was raised
"""
try:
if self.match(TokenType.TYPE):
return self.type_declaration()
if self.match(TokenType.ALIAS):
return self.alias_declaration()
if self.match(TokenType.EXTEND):
return self.extend_declaration()
if self.match(TokenType.PREDICATE):
return self.predicate_declaration()
raise self.error(self.peek(), "Unexpected token")
except ParsingError:
self.synchronize()
return None
def type_declaration(self) -> TypeStmt:
"""Parse a type declaration
A type declaration creates a named subtype of a type expression.
It can have an optional template expression after its name, wrapped in brackets, to handle type parameters.
A type statement consists of:
- the `type` keyword
- a name (identifier)
- (optional) type parameters
- a body, a type expression (see :func:`type_expr`)
Returns:
TypeStmt: the parsed type declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected type name")
params: list[TypeParam] = self.type_params()
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
type: Type = self.type_expr()
return TypeStmt(
location=keyword.location_to(self.previous()),
name=name,
params=params,
type=type,
)
def type_params(self) -> list[TypeParam]:
"""Parse a list of type parameters
Type parameters are a comma-separated list of type variables wrapped in brackets.
Each type variable is either a simple variable, or a bounded variable written `S <: T`
Returns:
list[TypeParam]: the list of type parameters, if any, or an empty list
"""
if not self.match(TokenType.LEFT_BRACKET):
return []
params: list[TypeParam] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
name: Token = self.consume_identifier("Expected type variable")
bound: Optional[Type] = None
if self.match(TokenType.LESS):
self.consume(TokenType.COLON, "Expected ':' after '<'")
bound = self.type_expr()
params.append(
TypeParam(
location=name.location_to(self.previous()),
name=name,
bound=bound,
)
)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
return params
def alias_declaration(self) -> AliasStmt:
"""Parse an alias declaration
An alias statement consists of:
- the `alias` keyword
- a name (identifier)
- a body, a type expression (see :func:`type_expr`)
Returns:
AliasStmt: the parsed alias declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected alias name")
self.consume(TokenType.EQUAL, "Expected '=' before alias definition")
type: Type = self.type_expr()
return AliasStmt(
location=keyword.location_to(self.previous()),
name=name,
type=type,
)
def type_expr(self) -> Type:
"""Parse a type expression
A type expression can either be a function type (see :func:`function`)
or a constraint type (see :func:`constraint_type`)
Returns:
TypeExpr: the parsed type expression
"""
base: Type
if self.match(TokenType.FUNC):
base = self.function()
else:
base = self.constraint_type()
if self.match(TokenType.AND):
extension: ComplexType = self.complex_type()
return ExtensionType(
location=Location.span(base.location, extension.location),
base=base,
extension=extension,
)
return base
def constraint_type(self) -> Type:
"""Parse a constraint type expression
A constraint type consists of a base type (see :func:`base_type`),
optionally followed by the `where` keyword and a constraint
expression (see :func:`constraint`)
Returns:
Type: the parsed constraint type expression
"""
type: Type = self.base_type()
if self.match(TokenType.WHERE):
constraint: Expr = self.constraint()
return ConstraintType(
location=Location.span(type.location, constraint.location),
type=type,
constraint=constraint,
)
return type
def base_type(self) -> Type:
"""Parse a base type expression
A base type is either a parenthesized type expression (see :func:`type_expr`)
or a generic type (see :func:`generic_type`)
Returns:
Type: the parsed base type expression
"""
if self.match(TokenType.LEFT_PAREN):
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return type
if self.check(TokenType.LEFT_BRACE):
return self.complex_type()
return self.generic_type()
def generic_type(self) -> Type:
"""Parse a generic type expression
A generic type consists of a named type (see :func:`named_type`),
optionally followed by type arguments in brackets.
The special `Frame` type accepts a frame schema instead of type
arguments (see :func:`frame_type`).
Returns:
Type: the parsed generic type
"""
type: NamedType = self.named_type()
if self.check(TokenType.LEFT_BRACKET):
if type.name.lexeme == "Frame":
return self.frame_type()
args: list[Type] = self.type_args()
return GenericType(
location=Location.span(type.location, self.previous().get_location()),
type=type,
args=args,
)
return type
def type_args(self) -> list[Type]:
"""Parse a list of type arguments
Type arguments are a comma-separated list of type expression wrapped in brackets.
Returns:
list[Type]: the list of type arguments, if any, or an empty list
"""
args: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
args.append(self.type_expr())
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
return args
def named_type(self) -> NamedType:
"""Parse a named type expression
A named type is an identifier token
Returns:
NamedType: the parsed named type expression
"""
name: Token = self.consume_identifier("Expected type name")
return NamedType(
location=name.get_location(),
name=name,
)
def complex_type(self) -> ComplexType:
"""Parse a complex type expression
A complex type consists of zero or more member statements enclosed in
curly braces
Returns:
ComplexType: the parsed complex type expression
"""
left: Token = self.consume(
TokenType.LEFT_BRACE, "Expected '{' to start type body"
)
members: list[MemberStmt] = []
# TODO: add keyword to differentiate properties and methods,
# and allow multiple methods with the same name but not properties
names: set[str] = set()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
member: MemberStmt = self.member_stmt()
# if member.name.lexeme in names:
# raise self.error(member.name, "Duplicate property")
# names.add(member.name.lexeme)
members.append(member)
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return ComplexType(
location=left.location_to(right),
members=members,
)
def frame_type(self) -> FrameType:
"""Parse a frame type expression
A frame type consists of:
- the `Frame` identifier
- an opening bracket `[`
- a list of comma-separated column expression consisting of:
- a name (token)
- a colon `:`
- a type expression (see :func:`type_expr`)
- a closing bracket `]`
Returns:
FrameType: the parsed frame type
"""
keyword: Token = self.previous()
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
columns: list[FrameType.Column] = []
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
name: Token = self.advance()
self.consume(TokenType.COLON, "Expected ':' between column name and type")
type: Type = self.type_expr()
columns.append(
FrameType.Column(
location=name.location_to(self.previous()),
name=name,
type=type,
)
)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Unclosed frame schema")
return FrameType(
location=keyword.location_to(self.previous()),
columns=columns,
)
def constraint(self) -> Expr:
"""Parse a constraint expression
A constraint is an expression (see :func:`expression`)
Returns:
Expr: the parsed constraint expression
"""
return self.expression()
def expression(self) -> Expr:
"""Parse an expression
An expression consists of a logical AND expression (see :func:`and_`)
Returns:
Expr: the parsed expression
"""
return self.and_()
def and_(self) -> Expr:
"""Parse a logical AND expression
An AND consists of one or more equality expressions (see :func:`equality`)
separated by logical AND operators (`&`)
Returns:
Expr: the parsed expression
"""
expr: Expr = self.equality()
while self.match(TokenType.AND):
operator: Token = self.previous()
right: Expr = self.equality()
location: Location = Location.span(expr.location, right.location)
expr = LogicalExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def equality(self) -> Expr:
"""Parse an equality expression
An equality consists of one or more comparison expressions (see :func:`comparison`)
separated by equality operators (`==`, `!=`)
Returns:
Expr: the parsed expression
"""
expr: Expr = self.comparison()
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
operator: Token = self.previous()
right: Expr = self.comparison()
location: Location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def comparison(self) -> Expr:
"""Parse a comparison expression
A comparison consists of one or more term expressions (see :func:`term`)
separated by comparison operators (`<`, `<=`, `>`, `>=`)
Returns:
Expr: the parsed expression
"""
expr: Expr = self.term()
while self.match(
TokenType.LESS,
TokenType.LESS_EQUAL,
TokenType.GREATER,
TokenType.GREATER_EQUAL,
):
operator: Token = self.previous()
right: Expr = self.term()
location: Location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def term(self) -> Expr:
"""Parse a term expression
A term consists of one or more factor expressions (see :func:`factor`)
separated by weak arithmetic operators (`+`, `-`)
Returns:
Expr: the parsed expression
"""
expr: Expr = self.factor()
while self.match(TokenType.PLUS, TokenType.MINUS):
operator: Token = self.previous()
right: Expr = self.factor()
location: Location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def factor(self) -> Expr:
"""Parse a factor expression
A factor consists of one or more unary expressions (see :func:`unary`)
separated by strong arithmetic operators (`*`, `/`)
Returns:
Expr: the parsed expression
"""
expr: Expr = self.unary()
while self.match(TokenType.STAR, TokenType.SLASH):
operator: Token = self.previous()
right: Expr = self.unary()
location: Location = Location.span(expr.location, right.location)
expr = BinaryExpr(
location=location, left=expr, operator=operator, right=right
)
return expr
def unary(self) -> Expr:
"""Parse a unary expression
A unary consists of a call expression (see :func:`call`) optionally
preceded by zero or more unary operators (`+`, `-`)
Returns:
Expr: the parsed expression
"""
if self.match(TokenType.PLUS, TokenType.MINUS):
operator: Token = self.previous()
right: Expr = self.unary()
location: Location = Location.span(operator.get_location(), right.location)
return UnaryExpr(location=location, operator=operator, right=right)
return self.call()
def call(self) -> Expr:
"""Parse a call expression
A call consists of a reference expression (see :func:`reference`)
optionally followed by zero or more argument groups.
Argument groups are parenthesize, comma-separated list of arguments (see :func:`finish_call`)
Returns:
Expr: the parsed expression
"""
expr: Expr = self.reference()
while self.match(TokenType.LEFT_PAREN):
expr = self.finish_call(expr)
return expr
def finish_call(self, callee: Expr) -> Expr:
"""Parse an argument group, i.e. the arguments of a call
Arguments are either passed positionally or by name (keyword argument).
All positional arguments must come before any keyword argument and
vice-versa. Arguments are separated by commas.
A positional argument simply consists of an expression (see :func:`expression`)
A keyword argument consists of and identifier, followed by the equal `=`
token and an expression (see :func:`expression`).
Args:
callee (Expr): the callee expression
Raises:
ParsingError: if a positional argument is passed after a keyword
argument or if a keyword argument's name is invalid (i.e. not
an identifier)
Returns:
Expr: the parsed call expression
"""
pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {}
keywords: bool = False
while not self.check(TokenType.RIGHT_PAREN):
if self.check_identifier() and self.check_next(TokenType.EQUAL):
keywords = True
keyword: Token = self.advance()
self.advance()
value: Expr = self.expression()
name: str = keyword.lexeme
if name in kw_args:
self.error(
self.peek(),
f"Multiple values passed for '{name}', only the last occurrence will be used",
)
kw_args[name] = value
else:
value = self.expression()
if self.check(TokenType.EQUAL):
error_msg: str
if keywords:
error_msg = "Invalid keyword argument name"
else:
error_msg = (
"Cannot pass positional arguments after a keyword argument"
)
raise self.error(self.peek(), error_msg)
pos_args.append(value)
if not self.match(TokenType.COMMA):
break
r_paren: Token = self.consume(
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
)
return CallExpr(
location=Location.span(callee.location, r_paren.get_location()),
callee=callee,
arguments=pos_args,
keywords=kw_args,
)
def reference(self) -> Expr:
"""Parse a reference expression
A reference consists of a primary expression (see :func:`primary`)
optionally followed by zero or more attribute accesses.
An attribute access consists of a dot `.` token followed by an identifier
Returns:
Expr: the parsed expression
"""
expr: Expr = self.primary()
while self.match(TokenType.DOT):
name: Token = self.consume_identifier("Expected property name after '.'")
location: Location = Location.span(expr.location, name.get_location())
expr = GetExpr(location=location, expr=expr, name=name)
return expr
def primary(self) -> Expr:
"""Parse a primary expression
This includes literals (booleans, numbers, etc.), wildcards, identifiers
and grouped expressions
Raises:
ParsingError: if a primary expressions cannot be parsed from the
following tokens
Returns:
Expr: the parsed expression
"""
token: Token = self.peek()
if self.match(TokenType.FALSE):
return LiteralExpr(location=token.get_location(), value=False)
if self.match(TokenType.TRUE):
return LiteralExpr(location=token.get_location(), value=True)
if self.match(TokenType.NONE):
return LiteralExpr(location=token.get_location(), value=None)
if self.match(TokenType.NUMBER):
return LiteralExpr(location=token.get_location(), value=token.value)
if self.match(TokenType.STRING):
return LiteralExpr(location=token.get_location(), value=token.value)
if self.match_identifier():
return VariableExpr(location=token.get_location(), name=token)
if self.match(TokenType.UNDERSCORE):
return WildcardExpr(location=token.get_location(), token=token)
if self.match(TokenType.LEFT_PAREN):
expr: Expr = self.constraint()
right: Token = self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return GroupingExpr(location=token.location_to(right), expr=expr)
raise self.error(self.peek(), "Expected expression")
def consume_identifier(self, message: str = "Expected identifier") -> Token:
"""Consume the current token if it is a valid identifier or raise an error (see :func:`check_identifier`)
If the current token is not a valid identifier, an error is raised
with the provided message
Args:
message (str, optional): the error message. Defaults to "Expected identifier".
Raises:
ParsingError: if the current token is not a valid identifier
Returns:
Token: the current token which is a valid identifier
"""
if not self.match_identifier():
raise self.error(self.peek(), message)
return self.previous()
def match_identifier(self) -> bool:
"""Consume the next token if it is a valid identifier (see :func:`check_identifier`)
Returns:
bool: whether a token was matched and consumed
"""
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
def check_identifier(self) -> bool:
"""Check whether the current token is a valid identifier
A valid identifier is either an identifier token or a keyword token.
This function always returns False if the parser is at the EOF token
Returns:
bool: True if the current token is a valid identifier and not EOF
"""
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
if self.check(tt):
return True
return False
def member_stmt(self) -> MemberStmt:
"""Parse a member statement
A member statement is written consists of:
- the `prop` (for a property) or `def` (for a method) keyword
- an name (identifier)
- a colon `:`
- a type expression (see :func:`type_expr`)
Raises:
ParsingError: if the first token is neither `prop` nor `def`
Returns:
MemberStmt: the parsed member statement
"""
kind: MemberKind
if self.match(TokenType.PROP):
kind = MemberKind.PROPERTY
elif self.match(TokenType.DEF):
kind = MemberKind.METHOD
else:
raise self.error(self.peek(), "Expected 'prop' or 'def'")
name: Token = self.consume_identifier("Expected member name")
self.consume(TokenType.COLON, "Expected ':' after member name")
type: Type = self.type_expr()
return MemberStmt(
location=name.location_to(self.previous()),
name=name,
type=type,
kind=kind,
)
def extend_declaration(self) -> ExtendStmt:
"""Parse an extension definition
An extension statement consists of:
- the `extend` keyword
- a type name (identifier)
- (optional) type parameters (see :func:`type_params`)
- an opening brace `{`
- zero or more member statements (see :func:`member_stmt`)
- a closing brace `}`
Returns:
ExtendStmt: the parsed extension statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected type name")
params: list[TypeParam] = self.type_params()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
members: list[MemberStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
members.append(self.member_stmt())
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
location: Location = keyword.location_to(self.previous())
return ExtendStmt(
location=location,
name=name,
params=params,
members=members,
)
def predicate_declaration(self) -> PredicateStmt:
"""Parse a predicate declaration
A predicate statement consists of:
- the `predicate` keyword
- a name (identifier)
- (optional) zero or more parameter specs (see :func:`function_params`)
- an equal sign `=`
- a body, a constraint expression (see :func:`constraint`)
Returns:
PredicateStmt: the parsed predicate declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected predicate name")
params: list[ParamSpec] = []
while self.check(TokenType.LEFT_PAREN):
params.append(self.function_params())
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
body: Expr = self.constraint()
return PredicateStmt(
location=keyword.location_to(self.previous()),
name=name,
params=params,
body=body,
)
def function(self) -> FunctionType:
"""Parse a function type expression
A function consists of:
- the `fn` keyword
- a parameter spec (see :func:`function_params`)
- the arrow keyword `->`
- a result type expression (see :func:`type_expr`)
Returns:
FunctionType: the parsed function type expression
"""
params: ParamSpec = self.function_params()
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=params.l_paren.location_to(self.previous()),
params=params,
returns=result,
)
def function_params(self) -> ParamSpec:
"""Parse a parameter spec
A parameter spec consists of zero or more comma-separated parameters,
wrapped in parentheses.
Like in Python, it can contain positional-only, mixed and keyword-only
parameters (separated by `/` and `*`).
Each parameter has a type (see :func:`type_expr`),
preceded by a name (identifier) and a colon `:` (not required for
positional-only parameters).
Returns:
ParamSpec: the parsed parameter spec
"""
l_paren: Token = self.consume(
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
)
pos: list[FunctionType.Parameter] = []
mixed: list[FunctionType.Parameter] = []
kw: list[FunctionType.Parameter] = []
mixed_first_tokens: list[Token] = []
section: int = 0
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
match section:
case 0 if self.match(TokenType.SLASH):
pos = mixed
mixed = []
mixed_first_tokens = []
section = 1
case 0 | 1 if self.match(TokenType.STAR):
section = 2
case _:
# Record first token of mixed parameters for errors if unnamed
if section != 2:
mixed_first_tokens.append(self.peek())
name: Optional[Token] = None
if section == 2:
name = self.consume_identifier(
"Expected keyword parameter name"
)
self.consume(
TokenType.COLON, "Expected ':' after parameter name"
)
elif self.check_identifier() and self.check_next(TokenType.COLON):
name = self.advance()
self.advance()
type: Type = self.type_expr()
optional: bool = self.match(TokenType.QMARK)
param = FunctionType.Parameter(
location=None,
name=name,
type=type,
required=not optional,
)
if section == 2:
kw.append(param)
else:
mixed.append(param)
if not self.match(TokenType.COMMA):
break
for param, token in zip(mixed, mixed_first_tokens):
if param.name is None:
# Not raised because we can keep parsing
self.error(token, "Unnamed mixed parameter")
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
return ParamSpec(l_paren=l_paren, pos=pos, mixed=mixed, kw=kw)

566
midas/parser/python.py Normal file
View File

@@ -0,0 +1,566 @@
import ast
from typing import Optional
from midas.ast.location import Location
from midas.ast.python import (
AssignStmt,
BaseType,
BinaryExpr,
CallExpr,
CastExpr,
CompareExpr,
ConstraintType,
DictExpr,
Expr,
ExpressionStmt,
ForStmt,
FrameColumn,
FrameType,
Function,
GetExpr,
IfStmt,
ListExpr,
LiteralExpr,
LogicalExpr,
MidasType,
ParamSpec,
RawExpr,
RawStmt,
ReturnStmt,
SliceExpr,
Stmt,
SubscriptExpr,
TernaryExpr,
TupleExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
)
class InvalidSyntaxError(Exception):
pass
class UnsupportedSyntaxError(Exception):
def __init__(self, expr: ast.expr) -> None:
super().__init__(
f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}"
)
class PythonParser:
"""A parser to convert raw Python `ast` nodes in custom IR nodes"""
CAST_FUNCTION = "cast"
UNSAFE_CAST_FUNCTION = "unsafe_cast"
def parse_module(self, node: ast.Module) -> list[Stmt]:
statements: list[Stmt] = []
for stmt in node.body:
try:
parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt)
if isinstance(parsed, Stmt):
statements.append(parsed)
elif parsed is not None:
statements.extend(parsed)
except UnsupportedSyntaxError as e:
print(f"{e}, skipping")
continue
return statements
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
location: Location = Location.from_ast(node)
match node:
case ast.AnnAssign():
return self.parse_annotation_assign(node)
case ast.Assign():
return self.parse_assign(node)
case ast.AugAssign():
return self.parse_aug_assign(node)
case ast.FunctionDef():
return self.parse_function(node)
case ast.Expr(value=expr):
return ExpressionStmt(
location=location,
expr=self.parse_expr(expr),
)
case ast.Return(value=value):
return ReturnStmt(
location=location,
value=self.parse_expr(value) if value is not None else None,
)
case ast.If():
return self.parse_if(node)
case ast.Pass():
return None
case ast.For(orelse=[]):
return self.parse_for(node)
case _:
print(f"Unsupported statement: {ast.unparse(node)}")
return RawStmt(location=location, stmt=node)
def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]:
statements: list[Stmt] = []
loc: Location = Location.from_ast(node)
match node:
case ast.AnnAssign(
target=ast.Name(id=target),
annotation=annotation,
value=value,
simple=1,
):
type = self._parse_type(annotation)
statements.append(
TypeAssign(
location=loc,
name=target,
type=type,
)
)
if value is not None:
statements.append(
AssignStmt(
location=loc,
targets=[
VariableExpr(
location=Location.from_ast(node.target), name=target
),
],
value=self.parse_expr(value),
),
)
case _:
print(f"Unsupported annotation: {ast.unparse(node)}")
return statements
def parse_assign(self, node: ast.Assign) -> AssignStmt:
targets: list[Expr] = []
for target in node.targets:
targets.append(self.parse_expr(target))
value: Expr = self.parse_expr(node.value)
return AssignStmt(
location=Location.from_ast(node),
targets=targets,
value=value,
)
def parse_aug_assign(self, node: ast.AugAssign) -> AssignStmt:
location: Location = Location.from_ast(node)
target: Expr = self.parse_expr(node.target)
value: Expr = self.parse_expr(node.value)
return AssignStmt(
location=location,
targets=[target],
value=BinaryExpr(
location=location,
left=target,
operator=node.op,
right=value,
),
)
def parse_if(self, node: ast.If) -> IfStmt:
body: list[Stmt] = []
for stmt in node.body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
orelse: list[Stmt] = []
for stmt in node.orelse:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
orelse.append(stmts)
elif stmts is not None:
orelse.extend(stmts)
return IfStmt(
location=Location.from_ast(node),
test=self.parse_expr(node.test),
body=body,
orelse=orelse,
)
def parse_for(self, node: ast.For) -> ForStmt:
body: list[Stmt] = []
for stmt in node.body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
return ForStmt(
location=Location.from_ast(node),
target=self.parse_expr(node.target),
iterator=self.parse_expr(node.iter),
body=body,
)
def parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node)
match node:
case ast.FunctionDef(
name=name,
args=args,
returns=returns,
body=raw_body,
):
body: list[Stmt] = []
for stmt in raw_body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
return Function(
location=loc,
name=name,
params=self._parse_param_spec(args),
returns=self._parse_type(returns) if returns is not None else None,
body=body,
)
case _:
print(f"Unsupported function definition: {ast.unparse(node)}")
def _parse_param_spec(self, args: ast.arguments) -> ParamSpec:
def parse_params(
args_list: list[ast.arg], defaults: list[Optional[Expr]]
) -> list[Function.Parameter]:
return [
self._parse_function_parameter(arg, default)
for arg, default in zip(args_list, defaults)
]
defaults: list[ast.expr] = args.defaults
parsed_defaults: list[Optional[Expr]] = [
self.parse_expr(default) for default in defaults
]
n_pos: int = len(args.posonlyargs)
n_mixed: int = len(args.args)
n_all_pos = n_pos + n_mixed
parsed_defaults = [
None,
] * (n_all_pos - len(defaults)) + parsed_defaults
pos_defaults: list[Optional[Expr]] = parsed_defaults[:n_pos]
mixed_defaults: list[Optional[Expr]] = parsed_defaults[n_pos:]
kw_defaults: list[Optional[Expr]] = [
self.parse_expr(default) if default is not None else None
for default in args.kw_defaults
]
return ParamSpec(
pos=parse_params(args.posonlyargs, pos_defaults),
mixed=parse_params(args.args, mixed_defaults),
kw=parse_params(args.kwonlyargs, kw_defaults),
)
def _parse_function_parameter(
self, arg: ast.arg, default: Optional[Expr]
) -> Function.Parameter:
loc: Location = Location.from_ast(arg)
name: str = arg.arg
type: Optional[MidasType] = None
if arg.annotation is not None:
type = self._parse_type(arg.annotation)
return Function.Parameter(
location=loc,
name=name,
type=type,
default=default,
)
def _parse_type(self, type_expr: ast.expr) -> MidasType:
loc: Location = Location.from_ast(type_expr)
match type_expr:
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
return self._parse_frame_type(schema)
case ast.Subscript(value=ast.Name(id=name), slice=arg):
args: tuple[MidasType, ...] = (
tuple(self._parse_type(a) for a in arg.elts)
if isinstance(arg, ast.Tuple)
else (self._parse_type(arg),)
)
return BaseType(
location=loc,
base=name,
args=args,
)
case ast.Name(id=name):
return BaseType(
location=loc,
base=name,
args=(),
)
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
left = self._parse_type(left_expr)
match left:
# If chained constraints, separate base type and rebuild constraint
case ConstraintType(type=left_type, constraint=left_constraint):
constraint = ast.BinOp(
left=left_constraint,
op=ast.Add(),
right=right_expr,
)
ast.copy_location(constraint, type_expr)
return ConstraintType(
location=loc,
type=left_type,
constraint=constraint,
)
case _:
return ConstraintType(
location=loc,
type=left,
constraint=right_expr,
)
case ast.Constant(value=None):
return BaseType(
location=loc,
base="None",
args=(),
)
case _:
raise UnsupportedSyntaxError(type_expr)
def _parse_frame_type(self, schema: ast.expr) -> FrameType:
loc: Location = Location.from_ast(schema)
columns: list[FrameColumn] = []
match schema:
case ast.Tuple(elts=cols):
for col in cols:
columns.append(self._parse_frame_column(col))
case ast.Slice() | ast.Name():
columns.append(self._parse_frame_column(schema))
case _:
raise UnsupportedSyntaxError(schema)
return FrameType(location=loc, columns=columns)
def _parse_frame_column(self, column: ast.expr) -> FrameColumn:
loc: Location = Location.from_ast(column)
match column:
case ast.Name():
return FrameColumn(
location=loc,
name=None,
type=self._parse_type(column),
)
case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
if name == "_":
name = None
type: Optional[MidasType] = None
match type_expr:
case None:
raise InvalidSyntaxError("Missing column type")
case ast.Name(id="_"):
type = None
case ast.expr():
type = self._parse_type(type_expr)
case _:
raise UnsupportedSyntaxError(type_expr)
return FrameColumn(location=loc, name=name, type=type)
case _:
raise UnsupportedSyntaxError(column)
def parse_expr(self, node: ast.expr) -> Expr:
location: Location = Location.from_ast(node)
match node:
case ast.BoolOp():
return self.parse_bool_op(node)
case ast.BinOp(left=left, op=op, right=right):
return BinaryExpr(
location=location,
left=self.parse_expr(left),
operator=op,
right=self.parse_expr(right),
)
case ast.UnaryOp(op=op, operand=right):
return UnaryExpr(
location=location,
operator=op,
right=self.parse_expr(right),
)
case ast.Compare():
return self.parse_compare(node)
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
return self.parse_cast(node)
case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)):
return self.parse_cast(node)
case ast.Call():
return self.parse_call(node)
case ast.IfExp():
return self.parse_ternary(node)
case ast.Constant(value=value):
return LiteralExpr(location=location, value=value)
case ast.Attribute(value=object, attr=name):
return GetExpr(
location=location,
object=self.parse_expr(object),
name=name,
)
case ast.Name(id=name):
return VariableExpr(location=location, name=name)
case ast.List(elts=items):
return ListExpr(
location=location,
items=[self.parse_expr(item) for item in items],
)
case ast.Dict(keys=keys, values=values):
return DictExpr(
location=location,
keys=[
self.parse_expr(key) if key is not None else None
for key in keys
],
values=[self.parse_expr(value) for value in values],
)
case ast.Subscript(value=value, slice=index):
return SubscriptExpr(
location=location,
object=self.parse_expr(value),
index=self.parse_expr(index),
)
case ast.Slice(lower=lower, upper=upper, step=step):
return SliceExpr(
location=location,
lower=self.parse_expr(lower) if lower is not None else None,
upper=self.parse_expr(upper) if upper is not None else None,
step=self.parse_expr(step) if step is not None else None,
)
case ast.Tuple(elts=items):
return TupleExpr(
location=location,
items=tuple(self.parse_expr(item) for item in items),
)
case _:
print(f"Unsupported expression: {ast.unparse(node)}")
return RawExpr(location=location, expr=node)
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
op: ast.boolop = node.op
rights: list[Expr] = [self.parse_expr(expr) for expr in node.values]
expr: LogicalExpr = LogicalExpr(
location=Location.span(
rights[0].location,
rights[1].location,
),
left=rights[0],
operator=op,
right=rights[1],
)
for right in rights[2:]:
expr = LogicalExpr(
location=Location.span(expr.location, right.location),
left=expr,
operator=op,
right=right,
)
return expr
def parse_compare(self, node: ast.Compare) -> Expr:
ops: list[ast.cmpop] = node.ops
left: Expr = self.parse_expr(node.left)
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
expr: Expr = CompareExpr(
location=Location.span(
left.location,
rights[0].location,
),
left=left,
operator=ops[0],
right=rights[0],
)
for i, right in enumerate(rights[1:]):
comparison = CompareExpr(
location=Location.span(rights[i].location, right.location),
left=rights[i],
operator=ops[i],
right=right,
)
expr = LogicalExpr(
location=Location.span(expr.location, comparison.location),
left=expr,
operator=ast.And(),
right=comparison,
)
return expr
def parse_cast(self, node: ast.Call) -> CastExpr:
assert isinstance(node.func, ast.Name)
func: str = node.func.id
match node:
case ast.Call(args=[type, expr], keywords=[]):
return CastExpr(
location=Location.from_ast(node),
type=self._parse_type(type),
expr=self.parse_expr(expr),
unsafe=func == self.UNSAFE_CAST_FUNCTION,
)
case _:
raise InvalidSyntaxError(
f"Invalid call to {func}, expected type and expression"
)
def parse_call(self, node: ast.Call) -> CallExpr:
return CallExpr(
location=Location.from_ast(node),
callee=self.parse_expr(node.func),
arguments=[self.parse_expr(arg) for arg in node.args],
keywords={
arg.arg: self.parse_expr(arg.value)
for arg in node.keywords
if arg.arg is not None # Should always be True, type checker happy
},
)
def parse_ternary(self, node: ast.IfExp) -> TernaryExpr:
return TernaryExpr(
location=Location.from_ast(node),
test=self.parse_expr(node.test),
if_true=self.parse_expr(node.body),
if_false=self.parse_expr(node.orelse),
)

52
midas/typing.py Normal file
View File

@@ -0,0 +1,52 @@
from typing import Generic, TypeVar
from typing import cast as typing_cast
cast = typing_cast
"""### Midas documentation
Cast a value to a type.
- **Compile-time**: tells the type checker that the return value has the designated type.
- **Run-time**: generates assertions to ensure the value can be interpreted as the given type.
---
<br>
<br>
<br>
_**Internal Python documentation**_
"""
unsafe_cast = typing_cast
"""### Midas documentation
Cast a value to a type.
- **Compile-time**: tells the type checker that the return value has the designated type.
- **Run-time**: -
This operation is unsound, use at your own risk!
---
<br>
<br>
<br>
_**Internal Python documentation**_
"""
T = TypeVar("T")
class Frame(Generic[T]):
"""A `Frame` is the abstract type implemented by `DataFrame`
A frame contains any number of named columns (see :class:`Column`)
"""
class Column(Generic[T]):
"""A `Column` is the abstract type implemented by `Series`
A column contains a any number of values of the same type
"""

67
midas/utils.py Normal file
View File

@@ -0,0 +1,67 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional
import midas.ast.python as p
from midas.checker.types import Type
from midas.generator.collector import AssertionCollector
AllowRepeat = Callable[[object], bool]
class UniversalJSONDumper:
@classmethod
def dump(
cls,
obj: Any,
include_keys: Optional[list[str | tuple[str, str]]] = None,
allow_repeat: Optional[AllowRepeat] = None,
) -> Any:
if include_keys is None:
include_keys = []
return cls._dump(obj, include_keys, allow_repeat, [])
@classmethod
def _dump(
cls,
obj: Any,
include_keys: list[str | tuple[str, str]],
allow_repeat: Optional[AllowRepeat],
visited: list[Any],
) -> Any:
if obj in visited:
return None
match obj:
case str() | int() | float() | None:
return obj
case list() | set() | tuple():
return [
cls._dump(child, include_keys, allow_repeat, visited)
for child in obj
]
case dict():
return {
str(k): cls._dump(v, include_keys, allow_repeat, visited)
for k, v in obj.items()
}
case object():
if allow_repeat is None or not allow_repeat(obj):
visited.append(obj)
return {
"_type": obj.__class__.__name__,
} | {
k: cls._dump(v, include_keys, allow_repeat, visited)
for k, v in obj.__dict__.items()
if not k.startswith("_")
or k in include_keys
or (obj.__class__.__name__, k) in include_keys
}
case _:
raise ValueError(f"Unsupported value: {obj}")
@dataclass(frozen=True, kw_only=True)
class TypedAST:
stmts: list[p.Stmt]
judgements: list[tuple[p.Expr, Type]]
evaluated_casts: list[p.CastExpr]
assertions: AssertionCollector

View File

@@ -1,152 +0,0 @@
from typing import Optional
from core.ast.annotations import (
AnnotationStmt,
ConstraintExpr,
Expr,
LiteralExpr,
SchemaElementExpr,
SchemaExpr,
Stmt,
TypeExpr,
WildcardExpr,
)
from lexer.token import Token, TokenType
from parser.base import Parser
from parser.errors import ParsingError
class AnnotationParser(Parser):
"""A simple parser for custom type annotations"""
SYNC_BOUNDARY: set[TokenType] = set()
def parse(self) -> Optional[Stmt]:
stmt: Optional[Stmt] = None
try:
stmt = self.annotation()
except ParsingError:
self.synchronize()
if not self.is_at_end():
self.error(self.peek(), "Extra tokens")
return stmt
def synchronize(self):
"""Skip tokens until a synchronization boundary is found
This method allows gracefully recovering from a parse error
to a safe place and continue parsing
"""
self.advance()
while not self.is_at_end():
if self.peek().type in self.SYNC_BOUNDARY:
return
self.advance()
def annotation(self) -> AnnotationStmt:
"""Parse an annotation
An annotation is written as `Type` or `Type[Schema]`
Returns:
AnnotationStmt: the parsed annotation statement
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type identifier")
schema: Optional[SchemaExpr] = None
if self.match(TokenType.LEFT_BRACKET):
schema = self.schema()
return AnnotationStmt(name=name, schema=schema)
def type_expr(self) -> TypeExpr:
"""Parse a type expression
Returns:
TypeExpr: the parsed type expression
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
constraints: list[ConstraintExpr] = []
while not self.is_at_end() and self.match(TokenType.PLUS):
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint")
constraints.append(self.constraint_expr())
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint")
return TypeExpr(name=name, constraints=constraints)
def constraint_expr(self) -> ConstraintExpr:
"""Parse a type constraint
Returns:
ConstraintExpr: the parsed type constraint expression
"""
left: Expr = self.constraint_value()
op: Token = self.constraint_operator()
right: Expr = self.constraint_value()
return ConstraintExpr(left=left, op=op, right=right)
def constraint_value(self) -> Expr:
if self.match(TokenType.UNDERSCORE):
return WildcardExpr(self.previous())
return self.literal()
def literal(self) -> LiteralExpr:
if self.match(TokenType.FALSE):
return LiteralExpr(False)
if self.match(TokenType.TRUE):
return LiteralExpr(True)
if self.match(TokenType.NONE):
return LiteralExpr(None)
if self.match(TokenType.NUMBER):
return LiteralExpr(self.previous().value)
raise self.error(self.peek(), "Expected literal")
def constraint_operator(self) -> Token:
if self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL, TokenType.EQUAL_EQUAL, TokenType.BANG_EQUAL):
return self.previous()
raise self.error(self.peek(), "Expected constraint operator")
def schema(self) -> SchemaExpr:
"""Parse a schema definition
A comma separated list of schema elements
Returns:
SchemaExpr: the parsed schema expression
"""
left: Token = self.previous()
elements: list[Expr] = []
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
elements.append(self.schema_element())
if not self.check(TokenType.RIGHT_BRACKET):
self.consume(TokenType.COMMA, "Expected ',' between schema elements")
right: Token = self.consume(TokenType.RIGHT_BRACKET, "Unclosed schema")
return SchemaExpr(left=left, elements=elements, right=right)
def schema_element(self) -> SchemaElementExpr:
"""Parse a schema element
An anonymous element (`_`), a type, an untyped named column (`name: _`),
or a named column (`name: Type`)
Returns:
SchemaElementExpr: the parsed schema element expression
"""
if self.match(TokenType.UNDERSCORE):
return SchemaElementExpr(name=None, type=None)
if not self.check(TokenType.IDENTIFIER):
raise self.error(self.peek(), "Expected schema element")
name: Optional[Token] = None
type: Optional[TypeExpr] = None
if self.check_next(TokenType.COLON):
name = self.advance()
self.advance()
if not self.match(TokenType.UNDERSCORE):
type = self.type_expr()
return SchemaElementExpr(name=name, type=type)

Some files were not shown because too many files have changed in this diff Show More