# %% # python 3.9 # pip install llvmlite==0.36.0 from json import load from os import name import string from collections import defaultdict import llvmlite.ir as ir import llvmlite.binding as llvm from ctypes import CFUNCTYPE from dataclasses import dataclass, field from typing import Union from enum import Enum from llvmlite.ir.builder import IRBuilder from llvmlite.ir.values import Function llvm.initialize() llvm.initialize_native_target() llvm.initialize_native_asmprinter() ast = load(open("ast.json", encoding="utf8")) func_ty = ir.FunctionType(ir.VoidType(), []) f64_ty = ir.DoubleType() voidptr_ty = ir.IntType(8).as_pointer() i64_ty = ir.IntType(64) bool_ty = ir.IntType(1) printf_ty = ir.FunctionType(ir.IntType(64), [voidptr_ty], var_arg=True) m = ir.Module() func = ir.Function(m, func_ty, name="main") builder = ir.IRBuilder(func.append_basic_block('entry')) printf = ir.Function(m, printf_ty, name="printf") fmt = bytearray("> %.8f\n\0".encode('utf-8')) c_fmt = ir.Constant(ir.ArrayType(ir.IntType(8), len(fmt)), fmt) global_fmt = ir.GlobalVariable(m, c_fmt.type, name="fstr") global_fmt.linkage = 'internal' global_fmt.global_constant = True global_fmt.initializer = c_fmt fmt_arg = builder.bitcast(global_fmt, voidptr_ty) cond_fmt = bytearray("> cond: %d\n\0".encode('utf-8')) cond_c_fmt = ir.Constant(ir.ArrayType(ir.IntType(8), len(cond_fmt)), cond_fmt) cond_global_fmt = ir.GlobalVariable(m, cond_c_fmt.type, name="cond_global_fmt") cond_global_fmt.linkage = 'internal' cond_global_fmt.global_constant = True cond_global_fmt.initializer = cond_c_fmt # op_funcs = { # "+": builder.fadd, "-": builder.fsub, # "*": builder.fmul, "/": builder.fdiv, # } def op_funcs(builder: ir.IRBuilder): return { "+": builder.fadd, "-": builder.fsub, "*": builder.fmul, "/": builder.fdiv, } builtin_func = { 'output': printf, } is_number_str = lambda x: all([i in string.digits + "." for i in x]) get_dict_key = lambda x: list(x.keys()) get_dict_value = lambda x: list(x.values()) class ArgumentType(Enum): CONST = 1 VAR = 2 @dataclass class ArgumentData: name: str type: ArgumentType def __repr__(self) -> str: return self.__str__() def __str__(self) -> str: return f"{self.type.name} {self.name}" def eval_line(builder: IRBuilder, line: dict, vars_envs: dict, func_envs: dict[str, ir.Function]): if isinstance(line, dict): name: str = get_dict_key(line)[0] params: list = get_dict_value(line)[0] else: if isinstance(line, ArgumentData): line: ArgumentData if line.type == ArgumentType.CONST: return f64_ty(line.name) elif line.type == ArgumentType.VAR: return eval_line(builder, {line.name: []}, vars_envs, func_envs) else: print('error line', line, type(line)) return line if name == 'root': for arg in params: eval_line(builder, arg, vars_envs, func_envs) elif name.startswith('#'): pass elif name == '括号': return eval_line(builder, params[0], vars_envs, func_envs) elif name == '=': pass elif is_number_str(name): assert not params var = f64_ty(float(name)) return var elif name in op_funcs(builder): assert len(params) == 2 l = eval_line(builder, params[0], vars_envs, func_envs) r = eval_line(builder, params[1], vars_envs, func_envs) # print(f'{params=}', l, r, vars_envs, func_envs) res = op_funcs(builder)[name](l, r) # print(f"op {l=} {name} {r=} | {res=}") return res elif name in builtin_func: ir_params = [] for para in params: para = eval_line(builder, para, vars_envs, func_envs) para = builder.sitofp(para, f64_ty) ir_params.append(para) func = builtin_func[name] call_res = builder.call(func, [fmt_arg, *ir_params]) return call_res elif name in func_envs and len(params) != 0: ir_params = [] for para in params: # print(f'{para=}') para = eval_line(builder, para, vars_envs, func_envs) # print(f'{para=}') ir_params.append(para) func = func_envs[name] while len(ir_params) < len(func.args): ir_params.append(f64_ty(0)) ir_params[0] = builder.fptosi(ir_params[0], i64_ty) # print('call func', func, ir_params) return builder.call(func, ir_params) elif name in vars_envs: return vars_envs.get(name) else: print('cannot find', name) @dataclass class FuncData: name: str args: list[ArgumentData] # 未扩展的, 扩展只影响 expend_ast_func_param body: dict arg_count: int = field(default=0, init=False) const_arg_count: int = field(default=0, init=False) var_arg_count: int = field(default=0, init=False) @staticmethod def from_ast(ast: dict): assert '=' in ast, f"{ast} is not a def" prototype, body = ast['='] func_name = get_dict_key(prototype)[0] args = [get_dict_key(i)[0] for i in prototype[func_name]] args = [ ArgumentData(i, ArgumentType.CONST if is_number_str(i) else ArgumentType.VAR) for i in args ] return FuncData(func_name, args, body) def __post_init__(self): for arg in self.args: self.arg_count += 1 if arg.type == ArgumentType.CONST: self.const_arg_count += 1 elif arg.type == ArgumentType.VAR: self.var_arg_count += 1 def func_type(self): return ir.FunctionType( f64_ty, [f64_ty for _ in range(self.arg_count)] ) def _generate_ll_name(self): return f"{self.name}-{self.args}" @staticmethod def expend_ast_func_param(func_name: str, root: dict): name: str = get_dict_key(root)[0] params: list = get_dict_value(root)[0] if name == '=': return root for i, p in enumerate(params): if not isinstance(p, dict): continue params[i] = FuncData.expend_ast_func_param(func_name, p) if name == func_name and len(params): params = [i64_ty(len(params)), *params] return {name: params} def generate_ll_func(self, m: ir.Module, funcs_dict: dict[str, ir.Function]): if not hasattr(self, '_current_func'): current_func = ir.Function( m, self.func_type(), f"{self.name}{self.args}" ) var_envs = {} # print(f"{current_func=} {current_func.args=}") for i, arg in enumerate(self.args): if arg.type == ArgumentType.VAR: var_envs[arg.name] = current_func.args[i] bb_entry = current_func.append_basic_block("entry") current_builder = ir.IRBuilder(bb_entry) # print(f'{self.body=} {funcs_dict=}') res = eval_line(current_builder, self.body, var_envs, funcs_dict) current_builder.ret(res) self._current_func = current_func # print(current_func) return self._current_func @dataclass class OverloadableFunc: name: str funcs: list[FuncData] arg_counts: list[int] = field(init=False) def __post_init__(self): self.arg_counts = [f.arg_count for f in self.funcs] def add_func(self, f: FuncData): self.funcs.append(f) self.arg_counts = [f.arg_count for f in self.funcs] def master_func_ty(self): return ir.FunctionType( f64_ty, [i64_ty, *[f64_ty for _ in range(max(self.arg_counts))]] ) def get_master_func_declear(self): if not hasattr(self, "master_func_declear"): self.master_func_declear = ir.Function( m, self.master_func_ty(), f"{self.name}_master" ) return self.master_func_declear def generate_master_func(self, m: ir.Module, all_ll_master_funcs: dict[str, ir.Function]): sub_funcs = [ fd.generate_ll_func(m, all_ll_master_funcs) for fd in self.funcs ] master_func = self.get_master_func_declear() count_var, *master_params = master_func.args bb_entry = master_func.append_basic_block("entry") master_entry_builder = ir.IRBuilder(bb_entry) bb_exit = master_func.append_basic_block('exit') bb_exit_builder = ir.IRBuilder(bb_exit) bb_exit_builder.ret(f64_ty(-1)) bb_switch_default = master_func.append_basic_block(name='switch_default') bb_switch_default_builder = ir.IRBuilder(bb_switch_default) bb_switch_default_builder.branch(bb_exit) sw = master_entry_builder.switch(count_var, bb_switch_default) for count in set(self.arg_counts): bb_curr_switch = master_func.append_basic_block( name=self.name + f"_arg_{count}" ) bb_curr_switch_builder = ir.IRBuilder(bb_curr_switch) cond_fmt_arg = bb_curr_switch_builder.bitcast(cond_global_fmt, voidptr_ty) sw.add_case(count, bb_curr_switch) 同参函数 = [f for f in self.funcs if f.arg_count == count] 同参函数 = sorted(同参函数, key=lambda x: x.const_arg_count, reverse=True) for f in 同参函数: cond = bool_ty(1) for arg_index, arg in enumerate(f.args): if arg.type != ArgumentType.CONST: continue cond = bb_curr_switch_builder.and_( cond, bb_curr_switch_builder.fcmp_unordered( '==', master_func.args[1 + arg_index], ir.Constant( master_func.args[1 + arg_index].type, float(arg.name) ), name=f'cmp_{arg.name}' ) ) # print(cond) # bb_curr_switch_builder.call(printf, [cond_fmt_arg, bb_curr_switch_builder.zext(cond, ir.IntType(32))]) with bb_curr_switch_builder.if_then(cond): var_envs = {} # print(f"{current_func=} {current_func.args=}") for i, arg in enumerate(f.args): if arg.type == ArgumentType.VAR: var_envs[arg.name] = master_func.args[i + 1] # print(f"{f.args=}") params = [ eval_line( bb_curr_switch_builder, arg, var_envs, all_ll_master_funcs ) for arg in f.args ] while len(params) < f.arg_count: params.append(f64_ty(0)) # print(f"{f}, {params=}") res = bb_curr_switch_builder.call( f.generate_ll_func(m, all_ll_master_funcs), params ) bb_curr_switch_builder.ret(res) bb_curr_switch_builder.branch(bb_exit) # print(master_func) return master_func ''' 拿到所有函数后, 按名称分组, 每个名称有一个入口函数叫做master函数, master包含比最大参数多一个参数, 第一个参数用于描述参数数量. 之后生成master的声明, 然后生成用于重载的函数, 之后按名称过一遍ast, 把对master的调用前面加上一个参数量, 最后拿着这些重载函数补全master的定义. ''' ast_root = ast # def create_func(ast_root: dict): overloadable_func_dict: dict[str, OverloadableFunc] = {} all_ll_master_funcs: dict[str, ir.Function] = {} for item in ast_root['root']: if '=' in item: f = FuncData.from_ast(item) f.body = FuncData.expend_ast_func_param(f.name, f.body) if f.name in overloadable_func_dict: overloadable_func_dict[f.name].add_func(f) else: overloadable_func_dict[f.name] = OverloadableFunc(f.name, [f]) for name, f in overloadable_func_dict.items(): ast_root = FuncData.expend_ast_func_param(name, ast_root) all_ll_master_funcs[name] = f.get_master_func_declear() for name in overloadable_func_dict: # for fd in overloadable_func_dict[name].funcs: # fd.generate_ll_func(m, all_ll_master_funcs) (overloadable_func_dict[name].generate_master_func(m, all_ll_master_funcs)) # print(overloadable_func_dict) # print(ast_root) # %% eval_line(builder, ast, {}, all_ll_master_funcs) builder.ret_void() asm = (str(m).replace( "unknown-unknown-unknown", llvm.Target.from_default_triple().triple )) print(asm) # %% llvm_module = llvm.parse_assembly(str(m)) tm = llvm.Target.from_default_triple().create_target_machine() with llvm.create_mcjit_compiler(llvm_module, tm) as ee: ee.finalize_object() fptr = ee.get_function_address("main") py_func = CFUNCTYPE(None)(fptr) py_func() # %%