|
@@ -0,0 +1,403 @@
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+from json import load
|
|
|
+from os import name
|
|
|
+import string
|
|
|
+from collections import defaultdict
|
|
|
+import llvmlite.ir as ir
|
|
|
+import llvmlite.binding as llvm
|
|
|
+from ctypes import CFUNCTYPE
|
|
|
+from dataclasses import dataclass, field
|
|
|
+from typing import Union
|
|
|
+from enum import Enum
|
|
|
+
|
|
|
+from llvmlite.ir.builder import IRBuilder
|
|
|
+from llvmlite.ir.values import Function
|
|
|
+
|
|
|
+
|
|
|
+llvm.initialize()
|
|
|
+llvm.initialize_native_target()
|
|
|
+llvm.initialize_native_asmprinter()
|
|
|
+
|
|
|
+
|
|
|
+ast = load(open("ast.json", encoding="utf8"))
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+func_ty = ir.FunctionType(ir.VoidType(), [])
|
|
|
+f64_ty = ir.DoubleType()
|
|
|
+voidptr_ty = ir.IntType(8).as_pointer()
|
|
|
+i64_ty = ir.IntType(64)
|
|
|
+bool_ty = ir.IntType(1)
|
|
|
+
|
|
|
+printf_ty = ir.FunctionType(ir.IntType(64), [voidptr_ty], var_arg=True)
|
|
|
+
|
|
|
+m = ir.Module()
|
|
|
+
|
|
|
+func = ir.Function(m, func_ty, name="main")
|
|
|
+builder = ir.IRBuilder(func.append_basic_block('entry'))
|
|
|
+printf = ir.Function(m, printf_ty, name="printf")
|
|
|
+
|
|
|
+fmt = bytearray("> %.8f\n\0".encode('utf-8'))
|
|
|
+c_fmt = ir.Constant(ir.ArrayType(ir.IntType(8), len(fmt)), fmt)
|
|
|
+
|
|
|
+global_fmt = ir.GlobalVariable(m, c_fmt.type, name="fstr")
|
|
|
+global_fmt.linkage = 'internal'
|
|
|
+global_fmt.global_constant = True
|
|
|
+global_fmt.initializer = c_fmt
|
|
|
+
|
|
|
+fmt_arg = builder.bitcast(global_fmt, voidptr_ty)
|
|
|
+
|
|
|
+cond_fmt = bytearray("> cond: %d\n\0".encode('utf-8'))
|
|
|
+cond_c_fmt = ir.Constant(ir.ArrayType(ir.IntType(8), len(cond_fmt)), cond_fmt)
|
|
|
+
|
|
|
+cond_global_fmt = ir.GlobalVariable(m, cond_c_fmt.type, name="cond_global_fmt")
|
|
|
+cond_global_fmt.linkage = 'internal'
|
|
|
+cond_global_fmt.global_constant = True
|
|
|
+cond_global_fmt.initializer = cond_c_fmt
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def op_funcs(builder: ir.IRBuilder):
|
|
|
+ return {
|
|
|
+ "+": builder.fadd, "-": builder.fsub,
|
|
|
+ "*": builder.fmul, "/": builder.fdiv,
|
|
|
+ }
|
|
|
+
|
|
|
+builtin_func = {
|
|
|
+ 'output': printf,
|
|
|
+}
|
|
|
+
|
|
|
+is_number_str = lambda x: all([i in string.digits + "." for i in x])
|
|
|
+
|
|
|
+get_dict_key = lambda x: list(x.keys())
|
|
|
+get_dict_value = lambda x: list(x.values())
|
|
|
+
|
|
|
+class ArgumentType(Enum):
|
|
|
+ CONST = 1
|
|
|
+ VAR = 2
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class ArgumentData:
|
|
|
+ name: str
|
|
|
+ type: ArgumentType
|
|
|
+
|
|
|
+ def __repr__(self) -> str:
|
|
|
+ return self.__str__()
|
|
|
+
|
|
|
+ def __str__(self) -> str:
|
|
|
+ return f"{self.type.name} {self.name}"
|
|
|
+
|
|
|
+def eval_line(builder: IRBuilder, line: dict, vars_envs: dict, func_envs: dict[str, ir.Function]):
|
|
|
+ if isinstance(line, dict):
|
|
|
+ name: str = get_dict_key(line)[0]
|
|
|
+ params: list = get_dict_value(line)[0]
|
|
|
+ else:
|
|
|
+ if isinstance(line, ArgumentData):
|
|
|
+ line: ArgumentData
|
|
|
+ if line.type == ArgumentType.CONST:
|
|
|
+ return f64_ty(line.name)
|
|
|
+ elif line.type == ArgumentType.VAR:
|
|
|
+ return eval_line(builder, {line.name: []}, vars_envs, func_envs)
|
|
|
+ else:
|
|
|
+ print('error line', line, type(line))
|
|
|
+ return line
|
|
|
+ if name == 'root':
|
|
|
+ for arg in params:
|
|
|
+ eval_line(builder, arg, vars_envs, func_envs)
|
|
|
+ elif name.startswith('#'):
|
|
|
+ pass
|
|
|
+ elif name == '括号':
|
|
|
+ return eval_line(builder, params[0], vars_envs, func_envs)
|
|
|
+ elif name == '=':
|
|
|
+ pass
|
|
|
+ elif is_number_str(name):
|
|
|
+ assert not params
|
|
|
+ var = f64_ty(float(name))
|
|
|
+ return var
|
|
|
+ elif name in op_funcs(builder):
|
|
|
+ assert len(params) == 2
|
|
|
+ l = eval_line(builder, params[0], vars_envs, func_envs)
|
|
|
+ r = eval_line(builder, params[1], vars_envs, func_envs)
|
|
|
+
|
|
|
+ res = op_funcs(builder)[name](l, r)
|
|
|
+
|
|
|
+ return res
|
|
|
+ elif name in builtin_func:
|
|
|
+ ir_params = []
|
|
|
+ for para in params:
|
|
|
+ para = eval_line(builder, para, vars_envs, func_envs)
|
|
|
+ para = builder.sitofp(para, f64_ty)
|
|
|
+ ir_params.append(para)
|
|
|
+
|
|
|
+ func = builtin_func[name]
|
|
|
+ call_res = builder.call(func, [fmt_arg, *ir_params])
|
|
|
+ return call_res
|
|
|
+
|
|
|
+ elif name in func_envs and len(params) != 0:
|
|
|
+ ir_params = []
|
|
|
+ for para in params:
|
|
|
+
|
|
|
+ para = eval_line(builder, para, vars_envs, func_envs)
|
|
|
+
|
|
|
+ ir_params.append(para)
|
|
|
+ func = func_envs[name]
|
|
|
+ while len(ir_params) < len(func.args):
|
|
|
+ ir_params.append(f64_ty(0))
|
|
|
+ ir_params[0] = builder.fptosi(ir_params[0], i64_ty)
|
|
|
+
|
|
|
+
|
|
|
+ return builder.call(func, ir_params)
|
|
|
+ elif name in vars_envs:
|
|
|
+ return vars_envs.get(name)
|
|
|
+ else:
|
|
|
+ print('cannot find', name)
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class FuncData:
|
|
|
+ name: str
|
|
|
+ args: list[ArgumentData]
|
|
|
+ body: dict
|
|
|
+
|
|
|
+ arg_count: int = field(default=0, init=False)
|
|
|
+ const_arg_count: int = field(default=0, init=False)
|
|
|
+ var_arg_count: int = field(default=0, init=False)
|
|
|
+ @staticmethod
|
|
|
+ def from_ast(ast: dict):
|
|
|
+ assert '=' in ast, f"{ast} is not a def"
|
|
|
+ prototype, body = ast['=']
|
|
|
+ func_name = get_dict_key(prototype)[0]
|
|
|
+ args = [get_dict_key(i)[0] for i in prototype[func_name]]
|
|
|
+ args = [
|
|
|
+ ArgumentData(i, ArgumentType.CONST if is_number_str(i) else ArgumentType.VAR)
|
|
|
+ for i in args
|
|
|
+ ]
|
|
|
+
|
|
|
+ return FuncData(func_name, args, body)
|
|
|
+
|
|
|
+ def __post_init__(self):
|
|
|
+ for arg in self.args:
|
|
|
+ self.arg_count += 1
|
|
|
+ if arg.type == ArgumentType.CONST:
|
|
|
+ self.const_arg_count += 1
|
|
|
+ elif arg.type == ArgumentType.VAR:
|
|
|
+ self.var_arg_count += 1
|
|
|
+
|
|
|
+ def func_type(self):
|
|
|
+ return ir.FunctionType(
|
|
|
+ f64_ty,
|
|
|
+ [f64_ty for _ in range(self.arg_count)]
|
|
|
+ )
|
|
|
+
|
|
|
+ def _generate_ll_name(self):
|
|
|
+ return f"{self.name}-{self.args}"
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def expend_ast_func_param(func_name: str, root: dict):
|
|
|
+ name: str = get_dict_key(root)[0]
|
|
|
+ params: list = get_dict_value(root)[0]
|
|
|
+ if name == '=':
|
|
|
+ return root
|
|
|
+
|
|
|
+ for i, p in enumerate(params):
|
|
|
+ if not isinstance(p, dict):
|
|
|
+ continue
|
|
|
+ params[i] = FuncData.expend_ast_func_param(func_name, p)
|
|
|
+ if name == func_name and len(params):
|
|
|
+ params = [i64_ty(len(params)), *params]
|
|
|
+ return {name: params}
|
|
|
+
|
|
|
+ def generate_ll_func(self, m: ir.Module, funcs_dict: dict[str, ir.Function]):
|
|
|
+ if not hasattr(self, '_current_func'):
|
|
|
+ current_func = ir.Function(
|
|
|
+ m, self.func_type(),
|
|
|
+ f"{self.name}{self.args}"
|
|
|
+ )
|
|
|
+ var_envs = {}
|
|
|
+
|
|
|
+
|
|
|
+ for i, arg in enumerate(self.args):
|
|
|
+ if arg.type == ArgumentType.VAR:
|
|
|
+ var_envs[arg.name] = current_func.args[i]
|
|
|
+
|
|
|
+ bb_entry = current_func.append_basic_block("entry")
|
|
|
+ current_builder = ir.IRBuilder(bb_entry)
|
|
|
+
|
|
|
+ res = eval_line(current_builder, self.body, var_envs, funcs_dict)
|
|
|
+ current_builder.ret(res)
|
|
|
+ self._current_func = current_func
|
|
|
+
|
|
|
+
|
|
|
+ return self._current_func
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class OverloadableFunc:
|
|
|
+ name: str
|
|
|
+ funcs: list[FuncData]
|
|
|
+
|
|
|
+ arg_counts: list[int] = field(init=False)
|
|
|
+
|
|
|
+ def __post_init__(self):
|
|
|
+ self.arg_counts = [f.arg_count for f in self.funcs]
|
|
|
+
|
|
|
+ def add_func(self, f: FuncData):
|
|
|
+ self.funcs.append(f)
|
|
|
+ self.arg_counts = [f.arg_count for f in self.funcs]
|
|
|
+
|
|
|
+ def master_func_ty(self):
|
|
|
+
|
|
|
+ return ir.FunctionType(
|
|
|
+ f64_ty,
|
|
|
+ [i64_ty, *[f64_ty for _ in range(max(self.arg_counts))]]
|
|
|
+ )
|
|
|
+ def get_master_func_declear(self):
|
|
|
+ if not hasattr(self, "master_func_declear"):
|
|
|
+ self.master_func_declear = ir.Function(
|
|
|
+ m, self.master_func_ty(),
|
|
|
+ f"{self.name}_master"
|
|
|
+ )
|
|
|
+ return self.master_func_declear
|
|
|
+
|
|
|
+ def generate_master_func(self, m: ir.Module, all_ll_master_funcs: dict[str, ir.Function]):
|
|
|
+ sub_funcs = [
|
|
|
+ fd.generate_ll_func(m, all_ll_master_funcs)
|
|
|
+ for fd in self.funcs
|
|
|
+ ]
|
|
|
+ master_func = self.get_master_func_declear()
|
|
|
+ count_var, *master_params = master_func.args
|
|
|
+
|
|
|
+ bb_entry = master_func.append_basic_block("entry")
|
|
|
+ master_entry_builder = ir.IRBuilder(bb_entry)
|
|
|
+
|
|
|
+ bb_exit = master_func.append_basic_block('exit')
|
|
|
+ bb_exit_builder = ir.IRBuilder(bb_exit)
|
|
|
+ bb_exit_builder.ret(f64_ty(-1))
|
|
|
+
|
|
|
+ bb_switch_default = master_func.append_basic_block(name='switch_default')
|
|
|
+ bb_switch_default_builder = ir.IRBuilder(bb_switch_default)
|
|
|
+ bb_switch_default_builder.branch(bb_exit)
|
|
|
+
|
|
|
+ sw = master_entry_builder.switch(count_var, bb_switch_default)
|
|
|
+
|
|
|
+ for count in set(self.arg_counts):
|
|
|
+ bb_curr_switch = master_func.append_basic_block(
|
|
|
+ name=self.name + f"_arg_{count}"
|
|
|
+ )
|
|
|
+ bb_curr_switch_builder = ir.IRBuilder(bb_curr_switch)
|
|
|
+ cond_fmt_arg = bb_curr_switch_builder.bitcast(cond_global_fmt, voidptr_ty)
|
|
|
+
|
|
|
+ sw.add_case(count, bb_curr_switch)
|
|
|
+ 同参函数 = [f for f in self.funcs if f.arg_count == count]
|
|
|
+ 同参函数 = sorted(同参函数, key=lambda x: x.const_arg_count, reverse=True)
|
|
|
+
|
|
|
+ for f in 同参函数:
|
|
|
+ cond = bool_ty(1)
|
|
|
+
|
|
|
+ for arg_index, arg in enumerate(f.args):
|
|
|
+ if arg.type != ArgumentType.CONST:
|
|
|
+ continue
|
|
|
+
|
|
|
+ cond = bb_curr_switch_builder.and_(
|
|
|
+ cond,
|
|
|
+ bb_curr_switch_builder.fcmp_unordered(
|
|
|
+ '==',
|
|
|
+ master_func.args[1 + arg_index],
|
|
|
+ ir.Constant(
|
|
|
+ master_func.args[1 + arg_index].type,
|
|
|
+ float(arg.name)
|
|
|
+ ),
|
|
|
+ name=f'cmp_{arg.name}'
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+ with bb_curr_switch_builder.if_then(cond):
|
|
|
+ var_envs = {}
|
|
|
+
|
|
|
+
|
|
|
+ for i, arg in enumerate(f.args):
|
|
|
+ if arg.type == ArgumentType.VAR:
|
|
|
+ var_envs[arg.name] = master_func.args[i + 1]
|
|
|
+
|
|
|
+ params = [
|
|
|
+ eval_line(
|
|
|
+ bb_curr_switch_builder, arg,
|
|
|
+ var_envs, all_ll_master_funcs
|
|
|
+ )
|
|
|
+ for arg in f.args
|
|
|
+ ]
|
|
|
+ while len(params) < f.arg_count:
|
|
|
+ params.append(f64_ty(0))
|
|
|
+
|
|
|
+ res = bb_curr_switch_builder.call(
|
|
|
+ f.generate_ll_func(m, all_ll_master_funcs),
|
|
|
+ params
|
|
|
+ )
|
|
|
+ bb_curr_switch_builder.ret(res)
|
|
|
+
|
|
|
+ bb_curr_switch_builder.branch(bb_exit)
|
|
|
+
|
|
|
+
|
|
|
+ return master_func
|
|
|
+
|
|
|
+
|
|
|
+'''
|
|
|
+拿到所有函数后, 按名称分组, 每个名称有一个入口函数叫做master函数, master包含比最大参数多一个参数, 第一个参数用于描述参数数量.
|
|
|
+之后生成master的声明, 然后生成用于重载的函数, 之后按名称过一遍ast, 把对master的调用前面加上一个参数量, 最后拿着这些重载函数补全master的定义.
|
|
|
+'''
|
|
|
+ast_root = ast
|
|
|
+
|
|
|
+
|
|
|
+overloadable_func_dict: dict[str, OverloadableFunc] = {}
|
|
|
+all_ll_master_funcs: dict[str, ir.Function] = {}
|
|
|
+
|
|
|
+for item in ast_root['root']:
|
|
|
+ if '=' in item:
|
|
|
+ f = FuncData.from_ast(item)
|
|
|
+ f.body = FuncData.expend_ast_func_param(f.name, f.body)
|
|
|
+
|
|
|
+ if f.name in overloadable_func_dict:
|
|
|
+ overloadable_func_dict[f.name].add_func(f)
|
|
|
+ else:
|
|
|
+ overloadable_func_dict[f.name] = OverloadableFunc(f.name, [f])
|
|
|
+for name, f in overloadable_func_dict.items():
|
|
|
+ ast_root = FuncData.expend_ast_func_param(name, ast_root)
|
|
|
+ all_ll_master_funcs[name] = f.get_master_func_declear()
|
|
|
+
|
|
|
+
|
|
|
+for name in overloadable_func_dict:
|
|
|
+
|
|
|
+
|
|
|
+ (overloadable_func_dict[name].generate_master_func(m, all_ll_master_funcs))
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+eval_line(builder, ast, {}, all_ll_master_funcs)
|
|
|
+builder.ret_void()
|
|
|
+asm = (str(m).replace(
|
|
|
+ "unknown-unknown-unknown",
|
|
|
+ llvm.Target.from_default_triple().triple
|
|
|
+))
|
|
|
+print(asm)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+llvm_module = llvm.parse_assembly(str(m))
|
|
|
+tm = llvm.Target.from_default_triple().create_target_machine()
|
|
|
+
|
|
|
+with llvm.create_mcjit_compiler(llvm_module, tm) as ee:
|
|
|
+ ee.finalize_object()
|
|
|
+ fptr = ee.get_function_address("main")
|
|
|
+ py_func = CFUNCTYPE(None)(fptr)
|
|
|
+ py_func()
|
|
|
+
|
|
|
+
|
|
|
+
|