From 0aa427ac022d477e00eddee4d54e04c8d09642fd Mon Sep 17 00:00:00 2001 From: cfreksen Date: Sun, 29 Oct 2017 20:39:58 +0100 Subject: [PATCH] Add function calls. --- ll.py | 16 ++++++++++++ stepper.py | 75 ++++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 75 insertions(+), 16 deletions(-) diff --git a/ll.py b/ll.py index ffe3077..4e4bd88 100644 --- a/ll.py +++ b/ll.py @@ -60,6 +60,8 @@ Gep = namedtuple('Gep', ['base_ty', 'oper_ty', 'oper', 'steps']) Zext = namedtuple('Zext', ['from_ty', 'oper', 'to_ty']) Ptrtoint = namedtuple('Ptrtoint', ['pointer_ty', 'oper', 'to_ty']) +CallResult = namedtuple('CallResult', ['val']) + Ret = namedtuple('Ret', ['ty', 'oper']) Br = namedtuple('Br', ['label']) Cbr = namedtuple('Cbr', ['ty', 'oper', 'then_label', 'else_label']) @@ -87,12 +89,19 @@ def oper2s(operand): return str(operand.val) elif isinstance(operand, Id): return '%' + operand.val + elif isinstance(operand, Gid): + return '@' + operand.val else: # TODO print('oper2s: Unknown operand: {}' .format(operand)) +def tyopers2s(ty_oper_list): + return ', '.join('{} {}'.format(ty2s(ty), oper2s(oper)) + for ty, oper in ty_oper_list) + + def insn2s(insn): if isinstance(insn, Binop): return ('{} {} {}, {}' @@ -102,6 +111,10 @@ def insn2s(insn): return ('icmp {} {} {}, {}' .format(insn.cnd, ty2s(insn.ty), oper2s(insn.left), oper2s(insn.right))) + elif isinstance(insn, Call): + return ('call {} {} ({})' + .format(ty2s(insn.return_ty), oper2s(insn.callee), + tyopers2s(insn.arguments))) elif isinstance(insn, Bitcast): return ('bitcast {} {} to {}' .format(ty2s(insn.from_ty), oper2s(insn.oper), @@ -114,6 +127,9 @@ def insn2s(insn): return ('ptrtoint {}* {} to {}' .format(ty2s(insn.pointer_ty), oper2s(insn.oper), ty2s(insn.to_ty))) + elif isinstance(insn, CallResult): + return ('<>: function return {}' + .format(insn.val)) else: # TODO print('insn2s: Unknown insn: {}' diff --git a/stepper.py b/stepper.py index 0dc4519..e74f166 100644 --- a/stepper.py +++ b/stepper.py @@ -20,7 +20,8 @@ def warn(msg): def step(insns, terminator, blocks, stack_frames, ssa_env, global_env, memory, tdecs, fdecs, call_res): if len(insns) == 0: - return terminate(terminator, blocks, stack_frames, ssa_env, global_env, memory) + return terminate(terminator, blocks, stack_frames, ssa_env, global_env, memory, + call_res) ssa_target, next_insn = insns[0] insns_rest = insns[1:] @@ -51,6 +52,40 @@ def step(insns, terminator, blocks, stack_frames, ssa_env, global_env, memory, # TODO print('icmp {} {}, {}' .format(cnd, left_v, right_v)) + elif isinstance(next_insn, ll.Call): + callee = next_insn.callee + arguments = next_insn.arguments + + if not isinstance(callee, ll.Gid): + err('Cannot call anything but global identifiers: {}' + .format(ll.oper2s(callee))) + return insns_rest, terminator, blocks, stack_frames, ssa_env, memory, call_res + + arguments_v = [eval_oper(oper, ssa_env, global_env) + for ty, oper in arguments] + + try: + function = fdecs[callee.val] + except KeyError: + err('Could not find function {} in environment:\n{}' + .format(callee.val, fdecs.keys())) + return insns_rest, terminator, blocks, stack_frames, ssa_env, memory, call_res + + parameters = function.parameters + print('call @{} ({})' + .format(callee.val, + ', '.join('%{} <- {}'.format(par[1], arg) + for par, arg in zip(parameters, arguments_v)))) + child_insns = function.body.first_block.insns + child_terminator = function.body.first_block.terminator + child_blocks = function.body.named_blocks + child_stack_frames = [(insns_rest, terminator, blocks, ssa_env)] + stack_frames + child_ssa_env = {par[1]: arg for par, arg in zip(parameters, arguments_v)} + child_memory = memory + child_call_res = [ssa_target] + call_res + return (child_insns, child_terminator, child_blocks, child_stack_frames, + child_ssa_env, child_memory, child_call_res) + elif isinstance(next_insn, ll.Bitcast): oper = next_insn.oper from_ty = next_insn.from_ty @@ -81,6 +116,8 @@ def step(insns, terminator, blocks, stack_frames, ssa_env, global_env, memory, # TODO print('ptrtoint {}* {} to {}' .format(ll.ty2s(pointer_ty), oper_v, ll.ty2s(to_ty))) + elif isinstance(next_insn, ll.CallResult): + res = next_insn.val else: err('Unknown LLVM instruction: {}' .format(next_insn)) @@ -95,10 +132,10 @@ def step(insns, terminator, blocks, stack_frames, ssa_env, global_env, memory, .format(ssa_target, res)) ssa_env[ssa_target] = res - return insns_rest, terminator, blocks, stack_frames, ssa_env, memory, None + return insns_rest, terminator, blocks, stack_frames, ssa_env, memory, call_res -def terminate(terminator, blocks, stack_frames, ssa_env, global_env, memory): +def terminate(terminator, blocks, stack_frames, ssa_env, global_env, memory, call_res): def clear_block_from_ssa_env(insns, ssa_env): for (id, insn) in insns: if id is not None and id in ssa_env: @@ -117,17 +154,21 @@ def terminate(terminator, blocks, stack_frames, ssa_env, global_env, memory): # TODO print('Returning {}' .format(oper_v)) + if len(stack_frames) == 0: - new_insns = [] + new_insns = [(None, ll.CallResult(oper_v))] new_terminator = None new_blocks = {} new_ssa_env = ssa_env new_stack_frames = [] + new_call_res = [] else: new_insns, new_terminator, new_blocks, new_ssa_env = stack_frames[0] + new_insns = [(call_res[0], ll.CallResult(oper_v))] + new_insns new_stack_frames = stack_frames[1:] + new_call_res = call_res[1:] return (new_insns, new_terminator, new_blocks, new_stack_frames, - new_ssa_env, memory, oper_v) + new_ssa_env, memory, new_call_res) elif isinstance(terminator, ll.Br): label = terminator.label next_block = blocks[label] @@ -144,7 +185,7 @@ def terminate(terminator, blocks, stack_frames, ssa_env, global_env, memory): .format(label)) return (new_insns, new_terminator, blocks, stack_frames, - ssa_env, memory, None) + ssa_env, memory, call_res) elif isinstance(terminator, ll.Cbr): ty = terminator.ty if ty != ll.SimpleType.I1: @@ -168,7 +209,7 @@ def terminate(terminator, blocks, stack_frames, ssa_env, global_env, memory): .format(operand_v, label)) return (new_insns, new_terminator, blocks, stack_frames, - ssa_env, memory, None) + ssa_env, memory, call_res) else: err('Unknown LLVM terminator: {}' .format(terminator)) @@ -241,15 +282,17 @@ def gogo(): data = r''' define i64 @tigermain (i64 %U_mainSL_8, i64 %U_mainDummy_9) { %a = add i64 3, 5 ; please be 8 - %c = icmp eq i64 %a, 9 - br i1 %c, label %L1, label %L2 + %c = icmp eq i64 %a, 8 + br i1 %c, label %L1, label %L1 L1: - %b = add i64 %a, %a + %b = call i64 @f (i64 7, i64 %a) ret i1 %c -L2: - %d = add i64 %a, 1 - %e = add i64 10, %d - ret i64 %e +} + +define i64 @f (i64 %x, i64 %y) { + %a = mul i64 2, %y + %b = mul i64 %x, %a + ret i64 %b } ''' @@ -269,7 +312,7 @@ L2: ssa_env = {} # TODO: memory structure has not been decided yet memory = [None] - call_res = None + call_res = [] while True: (insns, terminator, blocks, @@ -282,7 +325,7 @@ L2: print('Stepping done! Final ssa_env:\n{}' .format(ssa_env)) print('Program resulted in {}'. - format(call_res)) + format(insns[0][1].val)) break