123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403 |
- from json import load
- from os import name
- import string
- from collections import defaultdict
- import llvmlite.ir as ir
- import llvmlite.binding as llvm
- from ctypes import CFUNCTYPE
- from dataclasses import dataclass, field
- from typing import Union
- from enum import Enum
- from llvmlite.ir.builder import IRBuilder
- from llvmlite.ir.values import Function
- llvm.initialize()
- llvm.initialize_native_target()
- llvm.initialize_native_asmprinter()
- ast = load(open("ast.json", encoding="utf8"))
- func_ty = ir.FunctionType(ir.VoidType(), [])
- f64_ty = ir.DoubleType()
- voidptr_ty = ir.IntType(8).as_pointer()
- i64_ty = ir.IntType(64)
- bool_ty = ir.IntType(1)
- printf_ty = ir.FunctionType(ir.IntType(64), [voidptr_ty], var_arg=True)
- m = ir.Module()
- func = ir.Function(m, func_ty, name="main")
- builder = ir.IRBuilder(func.append_basic_block('entry'))
- printf = ir.Function(m, printf_ty, name="printf")
- fmt = bytearray("> %.8f\n\0".encode('utf-8'))
- c_fmt = ir.Constant(ir.ArrayType(ir.IntType(8), len(fmt)), fmt)
- global_fmt = ir.GlobalVariable(m, c_fmt.type, name="fstr")
- global_fmt.linkage = 'internal'
- global_fmt.global_constant = True
- global_fmt.initializer = c_fmt
- fmt_arg = builder.bitcast(global_fmt, voidptr_ty)
- cond_fmt = bytearray("> cond: %d\n\0".encode('utf-8'))
- cond_c_fmt = ir.Constant(ir.ArrayType(ir.IntType(8), len(cond_fmt)), cond_fmt)
- cond_global_fmt = ir.GlobalVariable(m, cond_c_fmt.type, name="cond_global_fmt")
- cond_global_fmt.linkage = 'internal'
- cond_global_fmt.global_constant = True
- cond_global_fmt.initializer = cond_c_fmt
- def op_funcs(builder: ir.IRBuilder):
- return {
- "+": builder.fadd, "-": builder.fsub,
- "*": builder.fmul, "/": builder.fdiv,
- }
- builtin_func = {
- 'output': printf,
- }
- is_number_str = lambda x: all([i in string.digits + "." for i in x])
- get_dict_key = lambda x: list(x.keys())
- get_dict_value = lambda x: list(x.values())
- class ArgumentType(Enum):
- CONST = 1
- VAR = 2
- @dataclass
- class ArgumentData:
- name: str
- type: ArgumentType
- def __repr__(self) -> str:
- return self.__str__()
- def __str__(self) -> str:
- return f"{self.type.name} {self.name}"
- def eval_line(builder: IRBuilder, line: dict, vars_envs: dict, func_envs: dict[str, ir.Function]):
- if isinstance(line, dict):
- name: str = get_dict_key(line)[0]
- params: list = get_dict_value(line)[0]
- else:
- if isinstance(line, ArgumentData):
- line: ArgumentData
- if line.type == ArgumentType.CONST:
- return f64_ty(line.name)
- elif line.type == ArgumentType.VAR:
- return eval_line(builder, {line.name: []}, vars_envs, func_envs)
- else:
- print('error line', line, type(line))
- return line
- if name == 'root':
- for arg in params:
- eval_line(builder, arg, vars_envs, func_envs)
- elif name.startswith('#'):
- pass
- elif name == '括号':
- return eval_line(builder, params[0], vars_envs, func_envs)
- elif name == '=':
- pass
- elif is_number_str(name):
- assert not params
- var = f64_ty(float(name))
- return var
- elif name in op_funcs(builder):
- assert len(params) == 2
- l = eval_line(builder, params[0], vars_envs, func_envs)
- r = eval_line(builder, params[1], vars_envs, func_envs)
-
- res = op_funcs(builder)[name](l, r)
-
- return res
- elif name in builtin_func:
- ir_params = []
- for para in params:
- para = eval_line(builder, para, vars_envs, func_envs)
- para = builder.sitofp(para, f64_ty)
- ir_params.append(para)
-
- func = builtin_func[name]
- call_res = builder.call(func, [fmt_arg, *ir_params])
- return call_res
- elif name in func_envs and len(params) != 0:
- ir_params = []
- for para in params:
-
- para = eval_line(builder, para, vars_envs, func_envs)
-
- ir_params.append(para)
- func = func_envs[name]
- while len(ir_params) < len(func.args):
- ir_params.append(f64_ty(0))
- ir_params[0] = builder.fptosi(ir_params[0], i64_ty)
-
- return builder.call(func, ir_params)
- elif name in vars_envs:
- return vars_envs.get(name)
- else:
- print('cannot find', name)
- @dataclass
- class FuncData:
- name: str
- args: list[ArgumentData]
- body: dict
- arg_count: int = field(default=0, init=False)
- const_arg_count: int = field(default=0, init=False)
- var_arg_count: int = field(default=0, init=False)
- @staticmethod
- def from_ast(ast: dict):
- assert '=' in ast, f"{ast} is not a def"
- prototype, body = ast['=']
- func_name = get_dict_key(prototype)[0]
- args = [get_dict_key(i)[0] for i in prototype[func_name]]
- args = [
- ArgumentData(i, ArgumentType.CONST if is_number_str(i) else ArgumentType.VAR)
- for i in args
- ]
- return FuncData(func_name, args, body)
- def __post_init__(self):
- for arg in self.args:
- self.arg_count += 1
- if arg.type == ArgumentType.CONST:
- self.const_arg_count += 1
- elif arg.type == ArgumentType.VAR:
- self.var_arg_count += 1
- def func_type(self):
- return ir.FunctionType(
- f64_ty,
- [f64_ty for _ in range(self.arg_count)]
- )
- def _generate_ll_name(self):
- return f"{self.name}-{self.args}"
- @staticmethod
- def expend_ast_func_param(func_name: str, root: dict):
- name: str = get_dict_key(root)[0]
- params: list = get_dict_value(root)[0]
- if name == '=':
- return root
-
- for i, p in enumerate(params):
- if not isinstance(p, dict):
- continue
- params[i] = FuncData.expend_ast_func_param(func_name, p)
- if name == func_name and len(params):
- params = [i64_ty(len(params)), *params]
- return {name: params}
- def generate_ll_func(self, m: ir.Module, funcs_dict: dict[str, ir.Function]):
- if not hasattr(self, '_current_func'):
- current_func = ir.Function(
- m, self.func_type(),
- f"{self.name}{self.args}"
- )
- var_envs = {}
-
- for i, arg in enumerate(self.args):
- if arg.type == ArgumentType.VAR:
- var_envs[arg.name] = current_func.args[i]
- bb_entry = current_func.append_basic_block("entry")
- current_builder = ir.IRBuilder(bb_entry)
-
- res = eval_line(current_builder, self.body, var_envs, funcs_dict)
- current_builder.ret(res)
- self._current_func = current_func
-
-
- return self._current_func
- @dataclass
- class OverloadableFunc:
- name: str
- funcs: list[FuncData]
- arg_counts: list[int] = field(init=False)
- def __post_init__(self):
- self.arg_counts = [f.arg_count for f in self.funcs]
- def add_func(self, f: FuncData):
- self.funcs.append(f)
- self.arg_counts = [f.arg_count for f in self.funcs]
- def master_func_ty(self):
-
- return ir.FunctionType(
- f64_ty,
- [i64_ty, *[f64_ty for _ in range(max(self.arg_counts))]]
- )
- def get_master_func_declear(self):
- if not hasattr(self, "master_func_declear"):
- self.master_func_declear = ir.Function(
- m, self.master_func_ty(),
- f"{self.name}_master"
- )
- return self.master_func_declear
- def generate_master_func(self, m: ir.Module, all_ll_master_funcs: dict[str, ir.Function]):
- sub_funcs = [
- fd.generate_ll_func(m, all_ll_master_funcs)
- for fd in self.funcs
- ]
- master_func = self.get_master_func_declear()
- count_var, *master_params = master_func.args
-
- bb_entry = master_func.append_basic_block("entry")
- master_entry_builder = ir.IRBuilder(bb_entry)
- bb_exit = master_func.append_basic_block('exit')
- bb_exit_builder = ir.IRBuilder(bb_exit)
- bb_exit_builder.ret(f64_ty(-1))
- bb_switch_default = master_func.append_basic_block(name='switch_default')
- bb_switch_default_builder = ir.IRBuilder(bb_switch_default)
- bb_switch_default_builder.branch(bb_exit)
- sw = master_entry_builder.switch(count_var, bb_switch_default)
- for count in set(self.arg_counts):
- bb_curr_switch = master_func.append_basic_block(
- name=self.name + f"_arg_{count}"
- )
- bb_curr_switch_builder = ir.IRBuilder(bb_curr_switch)
- cond_fmt_arg = bb_curr_switch_builder.bitcast(cond_global_fmt, voidptr_ty)
- sw.add_case(count, bb_curr_switch)
- 同参函数 = [f for f in self.funcs if f.arg_count == count]
- 同参函数 = sorted(同参函数, key=lambda x: x.const_arg_count, reverse=True)
-
- for f in 同参函数:
- cond = bool_ty(1)
- for arg_index, arg in enumerate(f.args):
- if arg.type != ArgumentType.CONST:
- continue
- cond = bb_curr_switch_builder.and_(
- cond,
- bb_curr_switch_builder.fcmp_unordered(
- '==',
- master_func.args[1 + arg_index],
- ir.Constant(
- master_func.args[1 + arg_index].type,
- float(arg.name)
- ),
- name=f'cmp_{arg.name}'
- )
- )
-
-
- with bb_curr_switch_builder.if_then(cond):
- var_envs = {}
-
- for i, arg in enumerate(f.args):
- if arg.type == ArgumentType.VAR:
- var_envs[arg.name] = master_func.args[i + 1]
-
- params = [
- eval_line(
- bb_curr_switch_builder, arg,
- var_envs, all_ll_master_funcs
- )
- for arg in f.args
- ]
- while len(params) < f.arg_count:
- params.append(f64_ty(0))
-
- res = bb_curr_switch_builder.call(
- f.generate_ll_func(m, all_ll_master_funcs),
- params
- )
- bb_curr_switch_builder.ret(res)
- bb_curr_switch_builder.branch(bb_exit)
-
- return master_func
- '''
- 拿到所有函数后, 按名称分组, 每个名称有一个入口函数叫做master函数, master包含比最大参数多一个参数, 第一个参数用于描述参数数量.
- 之后生成master的声明, 然后生成用于重载的函数, 之后按名称过一遍ast, 把对master的调用前面加上一个参数量, 最后拿着这些重载函数补全master的定义.
- '''
- ast_root = ast
- overloadable_func_dict: dict[str, OverloadableFunc] = {}
- all_ll_master_funcs: dict[str, ir.Function] = {}
- for item in ast_root['root']:
- if '=' in item:
- f = FuncData.from_ast(item)
- f.body = FuncData.expend_ast_func_param(f.name, f.body)
- if f.name in overloadable_func_dict:
- overloadable_func_dict[f.name].add_func(f)
- else:
- overloadable_func_dict[f.name] = OverloadableFunc(f.name, [f])
- for name, f in overloadable_func_dict.items():
- ast_root = FuncData.expend_ast_func_param(name, ast_root)
- all_ll_master_funcs[name] = f.get_master_func_declear()
- for name in overloadable_func_dict:
-
-
- (overloadable_func_dict[name].generate_master_func(m, all_ll_master_funcs))
- eval_line(builder, ast, {}, all_ll_master_funcs)
- builder.ret_void()
- asm = (str(m).replace(
- "unknown-unknown-unknown",
- llvm.Target.from_default_triple().triple
- ))
- print(asm)
- llvm_module = llvm.parse_assembly(str(m))
- tm = llvm.Target.from_default_triple().create_target_machine()
- with llvm.create_mcjit_compiler(llvm_module, tm) as ee:
- ee.finalize_object()
- fptr = ee.get_function_address("main")
- py_func = CFUNCTYPE(None)(fptr)
- py_func()
|