Browse Source

ast to ir

myuan 2 years ago
parent
commit
e47df49076
1 changed files with 403 additions and 0 deletions
  1. 403 0
      4. 抽象语法树编译到 llvm/AST2IR.py

+ 403 - 0
4. 抽象语法树编译到 llvm/AST2IR.py

@@ -0,0 +1,403 @@
+# %%
+# python 3.9
+# pip install llvmlite==0.36.0
+
+from json import load
+from os import name
+import string
+from collections import defaultdict
+import llvmlite.ir as ir
+import llvmlite.binding as llvm
+from ctypes import CFUNCTYPE
+from dataclasses import dataclass, field
+from typing import Union
+from enum import Enum
+
+from llvmlite.ir.builder import IRBuilder
+from llvmlite.ir.values import Function
+
+
+llvm.initialize()
+llvm.initialize_native_target()
+llvm.initialize_native_asmprinter()
+
+
+ast = load(open("ast.json", encoding="utf8"))
+
+
+
+func_ty = ir.FunctionType(ir.VoidType(), [])
+f64_ty = ir.DoubleType()
+voidptr_ty = ir.IntType(8).as_pointer()
+i64_ty = ir.IntType(64)
+bool_ty = ir.IntType(1)
+
+printf_ty = ir.FunctionType(ir.IntType(64), [voidptr_ty], var_arg=True)
+
+m = ir.Module()
+
+func = ir.Function(m, func_ty, name="main")
+builder = ir.IRBuilder(func.append_basic_block('entry'))
+printf = ir.Function(m, printf_ty, name="printf")
+
+fmt = bytearray("> %.8f\n\0".encode('utf-8'))
+c_fmt = ir.Constant(ir.ArrayType(ir.IntType(8), len(fmt)), fmt)
+
+global_fmt = ir.GlobalVariable(m, c_fmt.type, name="fstr")
+global_fmt.linkage = 'internal'
+global_fmt.global_constant = True
+global_fmt.initializer = c_fmt
+
+fmt_arg = builder.bitcast(global_fmt, voidptr_ty)
+
+cond_fmt = bytearray("> cond: %d\n\0".encode('utf-8'))
+cond_c_fmt = ir.Constant(ir.ArrayType(ir.IntType(8), len(cond_fmt)), cond_fmt)
+
+cond_global_fmt = ir.GlobalVariable(m, cond_c_fmt.type, name="cond_global_fmt")
+cond_global_fmt.linkage = 'internal'
+cond_global_fmt.global_constant = True
+cond_global_fmt.initializer = cond_c_fmt
+
+
+
+# op_funcs = {
+#     "+": builder.fadd, "-": builder.fsub,
+#     "*": builder.fmul, "/": builder.fdiv,
+# }
+
+def op_funcs(builder: ir.IRBuilder):
+    return {
+        "+": builder.fadd, "-": builder.fsub,
+        "*": builder.fmul, "/": builder.fdiv,
+    }
+
+builtin_func = {
+    'output': printf,
+}
+
+is_number_str = lambda x: all([i in string.digits + "." for i in x])
+
+get_dict_key = lambda x: list(x.keys())
+get_dict_value = lambda x: list(x.values())
+
+class ArgumentType(Enum):
+    CONST = 1
+    VAR = 2
+
+@dataclass
+class ArgumentData:
+    name: str
+    type: ArgumentType
+
+    def __repr__(self) -> str:
+        return self.__str__()
+
+    def __str__(self) -> str:
+        return f"{self.type.name} {self.name}"
+
+def eval_line(builder: IRBuilder, line: dict, vars_envs: dict, func_envs: dict[str, ir.Function]):
+    if isinstance(line, dict):
+        name: str = get_dict_key(line)[0]
+        params: list = get_dict_value(line)[0]
+    else:
+        if isinstance(line, ArgumentData):
+            line: ArgumentData
+            if line.type == ArgumentType.CONST:
+                return f64_ty(line.name)
+            elif line.type == ArgumentType.VAR:
+                return eval_line(builder, {line.name: []}, vars_envs, func_envs)
+        else:
+            print('error line', line, type(line))
+            return line
+    if name == 'root':
+        for arg in params:
+            eval_line(builder, arg, vars_envs, func_envs)
+    elif name.startswith('#'):
+        pass
+    elif name == '括号':
+        return eval_line(builder, params[0], vars_envs, func_envs)
+    elif name == '=':
+        pass
+    elif is_number_str(name):
+        assert not params
+        var = f64_ty(float(name))
+        return var
+    elif name in op_funcs(builder):
+        assert len(params) == 2
+        l = eval_line(builder, params[0], vars_envs, func_envs)
+        r = eval_line(builder, params[1], vars_envs, func_envs)
+        # print(f'{params=}', l, r, vars_envs, func_envs)
+        res = op_funcs(builder)[name](l, r)
+        # print(f"op {l=} {name} {r=} | {res=}")
+        return res
+    elif name in builtin_func:
+        ir_params = []
+        for para in params:
+            para = eval_line(builder, para, vars_envs, func_envs)
+            para = builder.sitofp(para, f64_ty)
+            ir_params.append(para)
+            
+        func = builtin_func[name]
+        call_res = builder.call(func, [fmt_arg, *ir_params])
+        return call_res
+
+    elif name in func_envs and len(params) != 0:
+        ir_params = []
+        for para in params:
+            # print(f'{para=}')
+            para = eval_line(builder, para, vars_envs, func_envs)
+            # print(f'{para=}')
+            ir_params.append(para)
+        func = func_envs[name]
+        while len(ir_params) < len(func.args):
+            ir_params.append(f64_ty(0))
+        ir_params[0] = builder.fptosi(ir_params[0], i64_ty)
+        # print('call func', func, ir_params)
+
+        return builder.call(func, ir_params)
+    elif name in vars_envs:
+        return vars_envs.get(name)
+    else:
+        print('cannot find', name)
+
+@dataclass
+class FuncData:
+    name: str
+    args: list[ArgumentData] # 未扩展的, 扩展只影响 expend_ast_func_param
+    body: dict
+
+    arg_count: int = field(default=0, init=False)
+    const_arg_count: int = field(default=0, init=False)
+    var_arg_count: int = field(default=0, init=False)
+    @staticmethod
+    def from_ast(ast: dict):
+        assert '=' in ast, f"{ast} is not a def"
+        prototype, body = ast['=']
+        func_name = get_dict_key(prototype)[0]
+        args = [get_dict_key(i)[0] for i in prototype[func_name]]
+        args = [
+            ArgumentData(i, ArgumentType.CONST if is_number_str(i) else ArgumentType.VAR) 
+            for i in args
+        ]
+
+        return FuncData(func_name, args, body)
+
+    def __post_init__(self):
+        for arg in self.args:
+            self.arg_count += 1
+            if arg.type == ArgumentType.CONST:
+                self.const_arg_count += 1
+            elif arg.type == ArgumentType.VAR:
+                self.var_arg_count += 1
+
+    def func_type(self):
+        return ir.FunctionType(
+            f64_ty, 
+            [f64_ty for _ in range(self.arg_count)]
+        )
+
+    def _generate_ll_name(self):
+        return f"{self.name}-{self.args}"
+
+    @staticmethod
+    def expend_ast_func_param(func_name: str, root: dict):
+        name: str = get_dict_key(root)[0]
+        params: list = get_dict_value(root)[0]
+        if name == '=':
+            return root
+        
+        for i, p in enumerate(params):
+            if not isinstance(p, dict):
+                continue
+            params[i] = FuncData.expend_ast_func_param(func_name, p)
+        if name == func_name and len(params):
+            params = [i64_ty(len(params)), *params]
+        return {name: params}
+
+    def generate_ll_func(self, m: ir.Module, funcs_dict: dict[str, ir.Function]):
+        if not hasattr(self, '_current_func'):
+            current_func = ir.Function(
+                m, self.func_type(), 
+                f"{self.name}{self.args}"
+            )
+            var_envs = {}
+            # print(f"{current_func=} {current_func.args=}")
+
+            for i, arg in enumerate(self.args):
+                if arg.type == ArgumentType.VAR:
+                    var_envs[arg.name] = current_func.args[i]
+
+            bb_entry = current_func.append_basic_block("entry")
+            current_builder = ir.IRBuilder(bb_entry)
+            # print(f'{self.body=} {funcs_dict=}')
+            res = eval_line(current_builder, self.body, var_envs, funcs_dict)
+            current_builder.ret(res)
+            self._current_func = current_func
+            # print(current_func)
+        
+        return self._current_func
+
+@dataclass
+class OverloadableFunc:
+    name: str
+    funcs: list[FuncData]
+
+    arg_counts: list[int] = field(init=False)
+
+    def __post_init__(self):
+        self.arg_counts = [f.arg_count for f in self.funcs]
+
+    def add_func(self, f: FuncData):
+        self.funcs.append(f)
+        self.arg_counts = [f.arg_count for f in self.funcs]
+
+    def master_func_ty(self):
+        
+        return ir.FunctionType(
+            f64_ty, 
+            [i64_ty, *[f64_ty for _ in range(max(self.arg_counts))]]
+        )
+    def get_master_func_declear(self):
+        if not hasattr(self, "master_func_declear"):
+            self.master_func_declear = ir.Function(
+                m, self.master_func_ty(), 
+                f"{self.name}_master"
+            )
+        return self.master_func_declear
+
+    def generate_master_func(self, m: ir.Module, all_ll_master_funcs: dict[str, ir.Function]):
+        sub_funcs = [
+            fd.generate_ll_func(m, all_ll_master_funcs) 
+            for fd in self.funcs
+        ]
+        master_func = self.get_master_func_declear()
+        count_var, *master_params = master_func.args
+        
+        bb_entry = master_func.append_basic_block("entry")
+        master_entry_builder = ir.IRBuilder(bb_entry)
+
+        bb_exit = master_func.append_basic_block('exit')
+        bb_exit_builder = ir.IRBuilder(bb_exit)
+        bb_exit_builder.ret(f64_ty(-1))
+
+        bb_switch_default = master_func.append_basic_block(name='switch_default')
+        bb_switch_default_builder = ir.IRBuilder(bb_switch_default)
+        bb_switch_default_builder.branch(bb_exit)
+
+        sw = master_entry_builder.switch(count_var, bb_switch_default)
+
+        for count in set(self.arg_counts):
+            bb_curr_switch = master_func.append_basic_block(
+                name=self.name + f"_arg_{count}"
+            )
+            bb_curr_switch_builder = ir.IRBuilder(bb_curr_switch)
+            cond_fmt_arg = bb_curr_switch_builder.bitcast(cond_global_fmt, voidptr_ty)
+
+            sw.add_case(count, bb_curr_switch)
+            同参函数 = [f for f in self.funcs if f.arg_count == count]
+            同参函数 = sorted(同参函数, key=lambda x: x.const_arg_count, reverse=True)
+            
+            for f in 同参函数:
+                cond = bool_ty(1)
+
+                for arg_index, arg in enumerate(f.args):
+                    if arg.type != ArgumentType.CONST:
+                        continue
+
+                    cond = bb_curr_switch_builder.and_(
+                        cond, 
+                        bb_curr_switch_builder.fcmp_unordered(
+                            '==', 
+                            master_func.args[1 + arg_index], 
+                            ir.Constant(
+                                master_func.args[1 + arg_index].type, 
+                                float(arg.name)
+                            ), 
+                            name=f'cmp_{arg.name}'
+                        )
+                    )
+                # print(cond)
+                # bb_curr_switch_builder.call(printf, [cond_fmt_arg, bb_curr_switch_builder.zext(cond, ir.IntType(32))])
+                with bb_curr_switch_builder.if_then(cond):
+                    var_envs = {}
+                    # print(f"{current_func=} {current_func.args=}")
+
+                    for i, arg in enumerate(f.args):
+                        if arg.type == ArgumentType.VAR:
+                            var_envs[arg.name] = master_func.args[i + 1]
+                    # print(f"{f.args=}")
+                    params = [
+                        eval_line(
+                            bb_curr_switch_builder, arg, 
+                            var_envs, all_ll_master_funcs
+                        ) 
+                        for arg in f.args
+                    ]
+                    while len(params) < f.arg_count:
+                        params.append(f64_ty(0))
+                    # print(f"{f}, {params=}")
+                    res = bb_curr_switch_builder.call(
+                        f.generate_ll_func(m, all_ll_master_funcs), 
+                        params
+                    )
+                    bb_curr_switch_builder.ret(res)
+
+            bb_curr_switch_builder.branch(bb_exit)
+
+        # print(master_func)
+        return master_func
+
+
+'''
+拿到所有函数后, 按名称分组, 每个名称有一个入口函数叫做master函数, master包含比最大参数多一个参数, 第一个参数用于描述参数数量. 
+之后生成master的声明, 然后生成用于重载的函数, 之后按名称过一遍ast, 把对master的调用前面加上一个参数量, 最后拿着这些重载函数补全master的定义. 
+'''
+ast_root = ast
+# def create_func(ast_root: dict):
+
+overloadable_func_dict: dict[str, OverloadableFunc] = {}
+all_ll_master_funcs: dict[str, ir.Function] = {}
+
+for item in ast_root['root']:
+    if '=' in item:
+        f = FuncData.from_ast(item)
+        f.body = FuncData.expend_ast_func_param(f.name, f.body)
+
+        if f.name in overloadable_func_dict:
+            overloadable_func_dict[f.name].add_func(f)
+        else:
+            overloadable_func_dict[f.name] = OverloadableFunc(f.name, [f])
+for name, f in overloadable_func_dict.items():
+    ast_root = FuncData.expend_ast_func_param(name, ast_root)
+    all_ll_master_funcs[name] = f.get_master_func_declear()
+
+
+for name in overloadable_func_dict:
+    # for fd in overloadable_func_dict[name].funcs:
+    #     fd.generate_ll_func(m, all_ll_master_funcs)
+    (overloadable_func_dict[name].generate_master_func(m, all_ll_master_funcs))
+# print(overloadable_func_dict)
+# print(ast_root)
+# %%
+eval_line(builder, ast, {}, all_ll_master_funcs)
+builder.ret_void()
+asm = (str(m).replace(
+    "unknown-unknown-unknown", 
+    llvm.Target.from_default_triple().triple
+))
+print(asm)
+
+# %%
+
+
+llvm_module = llvm.parse_assembly(str(m))
+tm = llvm.Target.from_default_triple().create_target_machine()
+
+with llvm.create_mcjit_compiler(llvm_module, tm) as ee:
+    ee.finalize_object()
+    fptr = ee.get_function_address("main")
+    py_func = CFUNCTYPE(None)(fptr)
+    py_func()
+
+# %%
+