Last active
          July 18, 2022 22:45 
        
      - 
      
- 
        Save dhilst/24bfd7904ccefb542abf7fa099e7e516 to your computer and use it in GitHub Desktop. 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | from typing import * | |
| import ast | |
| from dataclasses import dataclass | |
| @dataclass(frozen=True) | |
| class FuncSig: | |
| name : str | |
| args : list[str] | |
| ret: str | |
| def __repr__(self): | |
| args = " -> ".join(self.args + [self.ret]) | |
| return f"{self.name} : {args}" | |
| class Typechecker(ast.NodeVisitor): | |
| def __init__(self): | |
| super().__init__() | |
| self.typeenv = {} | |
| def visit_FunctionDef(self, node): | |
| oldenv = self.typeenv.copy() | |
| self.typeenv.update({arg.arg: arg.annotation.id for arg in node.args.args}) | |
| self.generic_visit(node) | |
| signature = FuncSig(node.name, [arg.annotation.id for arg in node.args.args], | |
| node.returns.id) | |
| self.typeenv = oldenv | |
| self.typeenv[node.name] = signature | |
| def visit_Call(self, node): | |
| if type(node.func) is ast.Name and node.func.id in self.typeenv: | |
| actual_args = [] | |
| for arg in node.args: | |
| if type(arg) is ast.Constant: | |
| actual_args.append(type(arg.value).__name__) | |
| elif type(arg) is ast.Name: | |
| if arg.id in self.typeenv: | |
| actual_args.append(self.typeenv[arg.id]) | |
| else: | |
| # Cannot typecheck, no type information | |
| return | |
| expected_args = self.typeenv[node.func.id].args | |
| # dumb typechecking | |
| if actual_args != expected_args: | |
| raise TypeError(f"Type error in call for {node.func.id}, " | |
| f"expected : {expected_args}, found : {actual_args}") | |
| self.generic_visit(node) | |
| def typecheck(text, typeenv={}): | |
| tree = ast.parse(text) | |
| Typechecker().visit(tree) | |
| try: | |
| typecheck(""" | |
| def inc(a: int, b: int) -> int: | |
| return a + 1 | |
| def foo(a: int) -> float: | |
| return inc(a, "a") # type error here | |
| """) | |
| except TypeError as e: | |
| print(e) # Type error in call for inc, | |
| # expected : ['int', 'int'], found : ['int', 'str']p | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment