import logging import traceback import pytest from infernal_interpreter.Emulator import CodeParseException, Emulator from infernal_interpreter.Junk import JunkComparisonException logger = logging.getLogger(__name__) ################################################################################ code_tests = [] def add_test(name, result, register, code): code_tests.append((name, result, register, code)) ################################################################################ # Arithmetic Operations add_test( 'constant $255', 255, '%rsi', """ movq $255, %rsi """, ) add_test( 'static addition 10+$20', 30, '%rsi', """ movq $10, %rsi addq $20, %rsi """, ) add_test( 'register addition 10+20', 30, '%rsi', """ movq $10, %rsi movq $20, %rax addq %rax, %rsi """, ) add_test( 'register subtraction 10-20', -10, '%rsi', """ movq $10, %rsi movq $20, %rax subq %rax, %rsi """, ) ################################################################################ # Branching branch_tests = [ ('jg', 10, '>', 5, 0), ('jg', 5, '>', 5, 1), ('jg', 5, '>', 10, 1), ('jl', 10, '<', 5, 1), ('jl', 5, '<', 5, 1), ('jl', 5, '<', 10, 0), ('je', 10, '==', 5, 1), ('je', 5, '==', 5, 0), ('je', 5, '==', 10, 1), ('jge', 10, '>=', 5, 0), ('jge', 5, '>=', 5, 0), ('jge', 5, '>=', 10, 1), ('jle', 10, '<=', 5, 1), ('jle', 5, '<=', 5, 0), ('jle', 5, '<=', 10, 0), ('jne', 10, '!=', 5, 0), ('jne', 5, '!=', 5, 1), ('jne', 5, '!=', 10, 0), ] for jump_instruct, a, comp, b, result in branch_tests: add_test( f'branch {a} {comp} {b}={not result}', result, '%rsi', f""" start: cmpq ${b}, ${a} {jump_instruct} true movq $1, %rsi jmp return true: movq $0, %rsi return: ret """, ) ################################################################################ # Junk Comparisons add_test( 'invalid comparison 1', JunkComparisonException, 'None', """ start: movq $100, %rsp # Set stack pointer to a random position. popq %rsi # Move a Junk value into %rsi cmpq $10, %rsi # Do a Junk comparison, which triggers the `i` comp # virtual register. jg start # Attempt to do a jump to the start, if (Junk-10)>0, # which makes no sense, and thus throws an error. """, ) add_test( 'invalid addition 1', JunkComparisonException, 'None', """ start: movq $100, %rsp # Set stack pointer to a random position. popq %rsi # Move a Junk value into %rsi addq %rsi, %rsp # Adds Junk to 101, which produces Junk. cmpq $10, %rsp # Do a Junk comparison, which triggers the `i` comp # virtual register. jg start # Attempt to do a jump to the start, if (Junk-10)>0, # which makes no sense, and thus throws an error. """, ) add_test( 'invalid addition 2', JunkComparisonException, 'None', """ start: movq $100, %rsp # Set stack pointer to a random position. popq %rsi # Move a Junk value into %rsi addq %rsp, %rsi # Adds 101 to Junk, which produces Junk. cmpq $10, %rsi # Do a Junk comparison, which triggers the `i` comp # virtual register. jg start # Attempt to do a jump to the start, if (Junk-10)>0, # which makes no sense, and thus throws an error. """, ) add_test( 'invalid subtraction 1', JunkComparisonException, 'None', """ start: movq $100, %rsp # Set stack pointer to a random position. popq %rsi # Move a Junk value into %rsi subq %rsi, %rsp # Adds Junk to 101, which produces Junk. cmpq $10, %rsp # Do a Junk comparison, which triggers the `i` comp # virtual register. jg start # Attempt to do a jump to the start, if (Junk-10)>0, # which makes no sense, and thus throws an error. """, ) add_test( 'invalid subtraction 2', JunkComparisonException, 'None', """ start: movq $100, %rsp # Set stack pointer to a random position. popq %rsi # Move a Junk value into %rsi subq %rsp, %rsi # Adds 101 to Junk, which produces Junk. cmpq $10, %rsi # Do a Junk comparison, which triggers the `i` comp # virtual register. jg start # Attempt to do a jump to the start, if (Junk-10)>0, # which makes no sense, and thus throws an error. """, ) ################################################################################ @pytest.mark.parametrize('name,result,register,code', code_tests) def test_execution(name, result, register, code): line_nr = None try: emu = Emulator(code) emu.setStack('junk...', 'calling eip') emu.setRegs(rip=0, rbp='old bp') except CodeParseException as e: logger.exception( 'Encountered error when parsing %s, at line %s: %s', name, e.line_nr, e.str, ) raise emu = emu try: for line_nr in emu: pass if isinstance(result, BaseException): logger.error('Error should have happened in %s, but did not', name) output = emu.getVal(register) if output != result: logger.error( 'Failed in %s. %s was %s, should be %s', name, register, output, result, ) except BaseException as e: logger.exception('Failed') if not isinstance(result, BaseException) or not isinstance(e, result): logger.exception( 'Encountered error in %s, at operation %s', name, emu.getVal('%rip'), ) traceback.print_exc() raise