local lexer = require 'lua_lang' -------------------------------------------------------------------------------- local function get_value_of_string (string_str) if string_str:sub(1, 1) == '"' or string_str:sub(1, 1) == '\'' then return string_str:sub(2, -2) end assert(false) end local CONSTANT_VALUE_TOKEN = { NUMBER = tonumber, STRING = get_value_of_string, TRUE = function() return true end, FALSE = function() return false end, NIL = function() return nil end } local VALUE_TOKEN = { IDENTIFIER = true } for k in pairs(CONSTANT_VALUE_TOKEN) do VALUE_TOKEN[k] = true end local COMPARE_BINOP = { EQ = true, NEQ = true, LEQ = true, GEQ = true, LE = true, GT = true, } local function get_value_token (token) if CONSTANT_VALUE_TOKEN[token.token]then return { exp = token.token, value = CONSTANT_VALUE_TOKEN[token.token](token.text) } elseif token.token == 'IDENTIFIER' then return { exp = 'LVALUE', token.text } end assert(false) end local function parse (tokens) -- TODO: Make a more general parser assert(type(tokens) == 'table') -- if #tokens == 6 and tokens[1].text == 'type' and tokens[2].token == 'LPAR' and tokens[3].token == 'IDENTIFIER' and tokens[4].token == 'RPAR'and tokens[5].token == 'EQ'and tokens[6].token == 'STRING' then return { exp = 'COMPARE', binop = 'EQ', left = { exp = 'CALL', callee = get_value_token(tokens[1]), get_value_token(tokens[3]) }, right = get_value_token(tokens[6]), } elseif #tokens == 3 and VALUE_TOKEN[tokens[1].token] and COMPARE_BINOP[tokens[2].token] and VALUE_TOKEN[tokens[3].token] then return { exp = 'COMPARE', binop = tokens[2].token, left = get_value_token(tokens[1]), right = get_value_token(tokens[3]) } elseif #tokens == 1 then return get_value_token(tokens[1]) else print(require'pretty'(tokens)) --assert(false) return nil end end local function get_assert_body_text (call_info) if call_info.what == 'Lua' then -- Find filetext local filetext = nil if call_info.source:find '^@' then local f = io.open(call_info.short_src, 'r') filetext = f:read '*all' f:close() elseif call_info.short_src:find '^%[string' then filetext = call_info.source else error 'Not implemented yet!' end -- Get lines local filetext = filetext .. '\n' local lines_after, line_i = {}, 0 for line in filetext:gmatch '([^\r\n]*)[\r\n]' do line_i = line_i + 1 if call_info.currentline == line_i then lines_after[#lines_after+1] = line end end -- Find body exclusively. return table.concat(lines_after, '\n'):match('assert%s*(%b())'):sub(2, -2) end error 'Not implemented yet!' end local function get_assert_body (call_info) local text = get_assert_body_text(call_info) return lexer:lex(text), text end local function get_variable (var_name, level) -- Local local index = 0 repeat index = index + 1 local name, val = debug.getlocal(level + 1, index) if name == var_name then local info = debug.getinfo(level + 1) local is_par = index <= info.nparams return val, is_par and ('argument #'..index) or 'local', info.name or is_par and '' end until not name -- Up-value local index, func = 0, debug.getinfo(level + 1).func repeat index = index + 1 local name, val = debug.getupvalue(func, index) if name == var_name then return val, 'upvalue' end until not name -- Global return getfenv(level + 1)[var_name], 'global' end local function get_function_name (call_info) -- if call_info.name then return string.format('\'%s\'', call_info.name) end -- local where = nil if call_info.source:find '^@' then where = 'at '..call_info.short_src..':'..call_info.linedefined elseif call_info.short_src:find '^%[string' then where = 'from loaded string' else error 'not yet implemented' end -- return string.format('the anonymous function %s', where) end local function fmt_val (val) if type(val) == 'string' then return string.format('%q', val) else return tostring(val) end end local function fmt_lvalue (lvalue) assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE') return table.concat(lvalue, '.') end local function get_variable_and_prefix (lvalue, level) assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE') assert(type(level) == 'number') -- local base_value, var_scope, in_func = get_variable(lvalue[1], level + 1) local value = base_value -- TODO: Generalize to any lvalue local func_name = in_func and (' to '..get_function_name(debug.getinfo(level + 1))) or '' return value, ('assertion failed! bad %s \'%s\'%s'):format(var_scope, fmt_lvalue(lvalue), func_name) end local PRIMITIVE_VALUES = { ['nil'] = true, ['boolean'] = true, } local COMPLEX_TYPES = { ['table'] = true, ['userdata'] = true, ['cdata'] = true, } local function fmt_with_type (val) -- Primitive values ARE their type, and don't need the annotation. if PRIMITIVE_VALUES[type(val)] then return tostring(val) end -- Complex types are already formatted with some type information. if COMPLEX_TYPES[type(val)] then return tostring(val) end -- Numbers and string should have their types with them. return type(val) .. ' ' .. fmt_val(val) end -------------------------------------------------------------------------------- return function (condition) if condition then return condition end local call_info = debug.getinfo(2) local tokens, body_text = get_assert_body(call_info) local ast = parse(tokens) if ast == nil then error(('assertion failed! expression `%s` evaluated to %s'):format(body_text, condition), 2) elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' and ast.left.exp == 'CALL' and ast.left.callee.exp == 'LVALUE' and ast.left.callee[1] == 'type' then local gotten_val, prefix = get_variable_and_prefix(ast.left[1], 2) error(('%s (%s expected, but got %s: %s)'):format(prefix, ast.right.value, type(gotten_val), fmt_val(gotten_val)), 2) elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' then local gotten_value, prefix = get_variable_and_prefix(ast.left, 2) local expected_value = ast.right.value local type_annotation = (type(expected_value) == type(gotten_value)) and '' or (' '..type(gotten_value)) error(('%s (%s expected, but got%s: %s)'):format(prefix, fmt_with_type(expected_value), type_annotation, fmt_val(gotten_value)), 2) elseif ast.exp == 'COMPARE' and ast.binop == 'NEQ' then local gotten_val, prefix = get_variable_and_prefix(ast.left, 2) local expected_value = ast.right.value error(('%s (expected anything other than %s, but got %s)'):format(prefix, fmt_with_type(expected_value), fmt_val(gotten_val)), 2) elseif ast.exp == 'LVALUE' then local gotten_val, prefix = get_variable_and_prefix(ast, 2) error(('%s (truthy expected, but got %s)'):format(prefix, fmt_val(gotten_val)), 2) elseif CONSTANT_VALUE_TOKEN[ast.exp] then local func_name = get_function_name(call_info) error(('assertion failed! this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name), 2) else error(('[assert-gooder/internal]: Unknown expression type %s'):format(ast.exp)) end end