@@ -0,0 +1,403 @@
+# %%
+# python 3.9
+# pip install llvmlite==0.36.0
+from json import load
+from os import name
+import string
+from collections import defaultdict
+import llvmlite.ir as ir
+import llvmlite.binding as llvm
+from ctypes import CFUNCTYPE
+from dataclasses import dataclass, field
+from typing import Union
+from enum import Enum
+from llvmlite.ir.builder import IRBuilder
+from llvmlite.ir.values import Function
+ast = load(open("ast.json", encoding="utf8"))
+func_ty = ir.FunctionType(ir.VoidType(), [])
+f64_ty = ir.DoubleType()
+voidptr_ty = ir.IntType(8).as_pointer()
+i64_ty = ir.IntType(64)
+bool_ty = ir.IntType(1)
+printf_ty = ir.FunctionType(ir.IntType(64), [voidptr_ty], var_arg=True)
+m = ir.Module()
+func = ir.Function(m, func_ty, name="main")
+builder = ir.IRBuilder(func.append_basic_block('entry'))
+printf = ir.Function(m, printf_ty, name="printf")
+fmt = bytearray("> %.8f\n\0".encode('utf-8'))
+c_fmt = ir.Constant(ir.ArrayType(ir.IntType(8), len(fmt)), fmt)
+global_fmt = ir.GlobalVariable(m, c_fmt.type, name="fstr")
+global_fmt.linkage = 'internal'
+global_fmt.global_constant = True
+global_fmt.initializer = c_fmt
+fmt_arg = builder.bitcast(global_fmt, voidptr_ty)
+cond_fmt = bytearray("> cond: %d\n\0".encode('utf-8'))
+cond_c_fmt = ir.Constant(ir.ArrayType(ir.IntType(8), len(cond_fmt)), cond_fmt)
+cond_global_fmt = ir.GlobalVariable(m, cond_c_fmt.type, name="cond_global_fmt")
+cond_global_fmt.linkage = 'internal'
+cond_global_fmt.global_constant = True
+cond_global_fmt.initializer = cond_c_fmt
+# op_funcs = {
+# "+": builder.fadd, "-": builder.fsub,
+# "*": builder.fmul, "/": builder.fdiv,
+# }
+def op_funcs(builder: ir.IRBuilder):
+ return {
+ "+": builder.fadd, "-": builder.fsub,
+ "*": builder.fmul, "/": builder.fdiv,
+ }
+builtin_func = {
+ 'output': printf,
+is_number_str = lambda x: all([i in string.digits + "." for i in x])
+get_dict_key = lambda x: list(x.keys())
+get_dict_value = lambda x: list(x.values())
+class ArgumentType(Enum):
+ CONST = 1
+ VAR = 2
+class ArgumentData:
+ name: str
+ type: ArgumentType
+ def __repr__(self) -> str:
+ return self.__str__()
+ def __str__(self) -> str:
+ return f"{self.type.name} {self.name}"
+def eval_line(builder: IRBuilder, line: dict, vars_envs: dict, func_envs: dict[str, ir.Function]):
+ if isinstance(line, dict):
+ name: str = get_dict_key(line)[0]
+ params: list = get_dict_value(line)[0]
+ else:
+ if isinstance(line, ArgumentData):
+ line: ArgumentData
+ if line.type == ArgumentType.CONST:
+ return f64_ty(line.name)
+ elif line.type == ArgumentType.VAR:
+ return eval_line(builder, {line.name: []}, vars_envs, func_envs)
+ else:
+ print('error line', line, type(line))
+ return line
+ if name == 'root':
+ for arg in params:
+ eval_line(builder, arg, vars_envs, func_envs)
+ elif name.startswith('#'):
+ pass
+ elif name == '括号':
+ return eval_line(builder, params[0], vars_envs, func_envs)
+ elif name == '=':
+ pass
+ elif is_number_str(name):
+ assert not params
+ var = f64_ty(float(name))
+ return var
+ elif name in op_funcs(builder):
+ assert len(params) == 2
+ l = eval_line(builder, params[0], vars_envs, func_envs)
+ r = eval_line(builder, params[1], vars_envs, func_envs)
+ # print(f'{params=}', l, r, vars_envs, func_envs)
+ res = op_funcs(builder)[name](l, r)
+ # print(f"op {l=} {name} {r=} | {res=}")
+ return res
+ elif name in builtin_func:
+ ir_params = []
+ for para in params:
+ para = eval_line(builder, para, vars_envs, func_envs)
+ para = builder.sitofp(para, f64_ty)
+ ir_params.append(para)
+ func = builtin_func[name]
+ call_res = builder.call(func, [fmt_arg, *ir_params])
+ return call_res
+ elif name in func_envs and len(params) != 0:
+ ir_params = []
+ for para in params:
+ # print(f'{para=}')
+ para = eval_line(builder, para, vars_envs, func_envs)
+ # print(f'{para=}')
+ ir_params.append(para)
+ func = func_envs[name]
+ while len(ir_params) < len(func.args):
+ ir_params.append(f64_ty(0))
+ ir_params[0] = builder.fptosi(ir_params[0], i64_ty)
+ # print('call func', func, ir_params)
+ return builder.call(func, ir_params)
+ elif name in vars_envs:
+ return vars_envs.get(name)
+ else:
+ print('cannot find', name)
+class FuncData:
+ name: str
+ args: list[ArgumentData] # 未扩展的, 扩展只影响 expend_ast_func_param
+ body: dict
+ arg_count: int = field(default=0, init=False)
+ const_arg_count: int = field(default=0, init=False)
+ var_arg_count: int = field(default=0, init=False)
+ @staticmethod
+ def from_ast(ast: dict):
+ assert '=' in ast, f"{ast} is not a def"
+ prototype, body = ast['=']
+ func_name = get_dict_key(prototype)[0]
+ args = [get_dict_key(i)[0] for i in prototype[func_name]]
+ args = [
+ ArgumentData(i, ArgumentType.CONST if is_number_str(i) else ArgumentType.VAR)
+ for i in args
+ ]
+ return FuncData(func_name, args, body)
+ def __post_init__(self):
+ for arg in self.args:
+ self.arg_count += 1
+ if arg.type == ArgumentType.CONST:
+ self.const_arg_count += 1
+ elif arg.type == ArgumentType.VAR:
+ self.var_arg_count += 1
+ def func_type(self):
+ return ir.FunctionType(
+ f64_ty,
+ [f64_ty for _ in range(self.arg_count)]
+ )
+ def _generate_ll_name(self):
+ return f"{self.name}-{self.args}"
+ @staticmethod
+ def expend_ast_func_param(func_name: str, root: dict):
+ name: str = get_dict_key(root)[0]
+ params: list = get_dict_value(root)[0]
+ if name == '=':
+ return root
+ for i, p in enumerate(params):
+ if not isinstance(p, dict):
+ continue
+ params[i] = FuncData.expend_ast_func_param(func_name, p)
+ if name == func_name and len(params):
+ params = [i64_ty(len(params)), *params]
+ return {name: params}
+ def generate_ll_func(self, m: ir.Module, funcs_dict: dict[str, ir.Function]):
+ if not hasattr(self, '_current_func'):
+ current_func = ir.Function(
+ m, self.func_type(),
+ f"{self.name}{self.args}"
+ )
+ var_envs = {}
+ # print(f"{current_func=} {current_func.args=}")
+ for i, arg in enumerate(self.args):
+ if arg.type == ArgumentType.VAR:
+ var_envs[arg.name] = current_func.args[i]
+ bb_entry = current_func.append_basic_block("entry")
+ current_builder = ir.IRBuilder(bb_entry)
+ # print(f'{self.body=} {funcs_dict=}')
+ res = eval_line(current_builder, self.body, var_envs, funcs_dict)
+ current_builder.ret(res)
+ self._current_func = current_func
+ # print(current_func)
+ return self._current_func
+class OverloadableFunc:
+ name: str
+ funcs: list[FuncData]
+ arg_counts: list[int] = field(init=False)
+ def __post_init__(self):
+ self.arg_counts = [f.arg_count for f in self.funcs]
+ def add_func(self, f: FuncData):
+ self.funcs.append(f)
+ self.arg_counts = [f.arg_count for f in self.funcs]
+ def master_func_ty(self):
+ return ir.FunctionType(
+ f64_ty,
+ [i64_ty, *[f64_ty for _ in range(max(self.arg_counts))]]
+ )
+ def get_master_func_declear(self):
+ if not hasattr(self, "master_func_declear"):
+ self.master_func_declear = ir.Function(
+ m, self.master_func_ty(),
+ f"{self.name}_master"
+ )
+ return self.master_func_declear
+ def generate_master_func(self, m: ir.Module, all_ll_master_funcs: dict[str, ir.Function]):
+ sub_funcs = [
+ fd.generate_ll_func(m, all_ll_master_funcs)
+ for fd in self.funcs
+ ]
+ master_func = self.get_master_func_declear()
+ count_var, *master_params = master_func.args
+ bb_entry = master_func.append_basic_block("entry")
+ master_entry_builder = ir.IRBuilder(bb_entry)
+ bb_exit = master_func.append_basic_block('exit')
+ bb_exit_builder = ir.IRBuilder(bb_exit)
+ bb_exit_builder.ret(f64_ty(-1))
+ bb_switch_default = master_func.append_basic_block(name='switch_default')
+ bb_switch_default_builder = ir.IRBuilder(bb_switch_default)
+ bb_switch_default_builder.branch(bb_exit)
+ sw = master_entry_builder.switch(count_var, bb_switch_default)
+ for count in set(self.arg_counts):
+ bb_curr_switch = master_func.append_basic_block(
+ name=self.name + f"_arg_{count}"
+ )
+ bb_curr_switch_builder = ir.IRBuilder(bb_curr_switch)
+ cond_fmt_arg = bb_curr_switch_builder.bitcast(cond_global_fmt, voidptr_ty)
+ sw.add_case(count, bb_curr_switch)
+ 同参函数 = [f for f in self.funcs if f.arg_count == count]
+ 同参函数 = sorted(同参函数, key=lambda x: x.const_arg_count, reverse=True)
+ for f in 同参函数:
+ cond = bool_ty(1)
+ for arg_index, arg in enumerate(f.args):
+ if arg.type != ArgumentType.CONST:
+ continue
+ cond = bb_curr_switch_builder.and_(
+ cond,
+ bb_curr_switch_builder.fcmp_unordered(
+ '==',
+ master_func.args[1 + arg_index],
+ ir.Constant(
+ master_func.args[1 + arg_index].type,
+ float(arg.name)
+ ),
+ name=f'cmp_{arg.name}'
+ )
+ )
+ # print(cond)
+ # bb_curr_switch_builder.call(printf, [cond_fmt_arg, bb_curr_switch_builder.zext(cond, ir.IntType(32))])
+ with bb_curr_switch_builder.if_then(cond):
+ var_envs = {}
+ # print(f"{current_func=} {current_func.args=}")
+ for i, arg in enumerate(f.args):
+ if arg.type == ArgumentType.VAR:
+ var_envs[arg.name] = master_func.args[i + 1]
+ # print(f"{f.args=}")
+ params = [
+ eval_line(
+ bb_curr_switch_builder, arg,
+ var_envs, all_ll_master_funcs
+ )
+ for arg in f.args
+ ]
+ while len(params) < f.arg_count:
+ params.append(f64_ty(0))
+ # print(f"{f}, {params=}")
+ res = bb_curr_switch_builder.call(
+ f.generate_ll_func(m, all_ll_master_funcs),
+ params
+ )
+ bb_curr_switch_builder.ret(res)
+ bb_curr_switch_builder.branch(bb_exit)
+ # print(master_func)
+ return master_func
+拿到所有函数后, 按名称分组, 每个名称有一个入口函数叫做master函数, master包含比最大参数多一个参数, 第一个参数用于描述参数数量.
+之后生成master的声明, 然后生成用于重载的函数, 之后按名称过一遍ast, 把对master的调用前面加上一个参数量, 最后拿着这些重载函数补全master的定义.
+ast_root = ast
+# def create_func(ast_root: dict):
+overloadable_func_dict: dict[str, OverloadableFunc] = {}
+all_ll_master_funcs: dict[str, ir.Function] = {}
+for item in ast_root['root']:
+ if '=' in item:
+ f = FuncData.from_ast(item)
+ f.body = FuncData.expend_ast_func_param(f.name, f.body)
+ if f.name in overloadable_func_dict:
+ overloadable_func_dict[f.name].add_func(f)
+ else:
+ overloadable_func_dict[f.name] = OverloadableFunc(f.name, [f])
+for name, f in overloadable_func_dict.items():
+ ast_root = FuncData.expend_ast_func_param(name, ast_root)
+ all_ll_master_funcs[name] = f.get_master_func_declear()
+for name in overloadable_func_dict:
+ # for fd in overloadable_func_dict[name].funcs:
+ # fd.generate_ll_func(m, all_ll_master_funcs)
+ (overloadable_func_dict[name].generate_master_func(m, all_ll_master_funcs))
+# print(overloadable_func_dict)
+# print(ast_root)
+# %%
+eval_line(builder, ast, {}, all_ll_master_funcs)
+asm = (str(m).replace(
+ "unknown-unknown-unknown",
+ llvm.Target.from_default_triple().triple
+# %%
+llvm_module = llvm.parse_assembly(str(m))
+tm = llvm.Target.from_default_triple().create_target_machine()
+with llvm.create_mcjit_compiler(llvm_module, tm) as ee:
+ ee.finalize_object()
+ fptr = ee.get_function_address("main")
+ py_func = CFUNCTYPE(None)(fptr)
+ py_func()
+# %%