local lexer = assert(require((... and select('1', ...):match('.+%.') or '')..'lua_lang'), '[assert-gooder]: Could not load vital library: lua_lang') -------------------------------------------------------------------------------- local EXPECTED_GLOBAL = { -- TODO: Expand with other functions. ['type'] = 'function', ['tonumber'] = 'function', ['math'] = 'table', } -------------------------------------------------------------------------------- local function get_value_of_string (string_str) if string_str:sub(1, 1) == '"' or string_str:sub(1, 1) == '\'' then return string_str:sub(2, -2) end assert(false) end local CONSTANT_VALUE_TOKEN = { NUMBER = tonumber, STRING = get_value_of_string, TRUE = function() return true end, FALSE = function() return false end, NIL = function() return nil end } local VALUE_TOKEN = { IDENTIFIER = true } for k in pairs(CONSTANT_VALUE_TOKEN) do VALUE_TOKEN[k] = true end local COMPARE_BINOP = { EQ = true, NEQ = true, LEQ = true, GEQ = true, LE = true, GT = true, } local function get_value_token (token) if CONSTANT_VALUE_TOKEN[token.token]then return { exp = token.token, value = CONSTANT_VALUE_TOKEN[token.token](token.text) } elseif token.token == 'IDENTIFIER' then return { exp = 'LVALUE', token.text } end assert(false) end local function get_variable (var_name, info) -- assert(type(var_name) == 'string') assert(type(info) == 'table') -- Local if info.locals[var_name] then local var_info = info.locals[var_name] return var_info[1], var_info[2] and ('argument #'..var_info[3]) or 'local', info.name or var_info[2] and '' end -- Up-value local index = 0 repeat index = index + 1 local name, val = debug.getupvalue(info.func, index) if name == var_name then return val, 'upvalue' end until not name -- Global local env = getfenv(info.func)[var_name] return env, 'global' end local function get_value_from_lvalue (lvalue, info) assert(type(lvalue) == 'table') assert(type(info) == 'table') -- Base value local value, var_scope, in_func = get_variable(lvalue[1], info) -- Sub value for i = 2, #lvalue do value = value[lvalue[i].value] end -- return value, var_scope, in_func end -------------------------------------------------------------------------------- -- Parsing local function parse (tokens) -- TODO: Make a more general parser assert(type(tokens) == 'table') -- if #tokens == 6 and tokens[1].text == 'type' and tokens[2].token == 'LPAR' and tokens[3].token == 'IDENTIFIER' and tokens[4].token == 'RPAR'and tokens[5].token == 'EQ'and tokens[6].token == 'STRING' then return { exp = 'COMPARE', binop = 'EQ', [1] = { exp = 'CALL', get_value_token(tokens[1]), get_value_token(tokens[3]) }, [2] = get_value_token(tokens[6]), } elseif #tokens == 3 and VALUE_TOKEN[tokens[1].token] and COMPARE_BINOP[tokens[2].token] and VALUE_TOKEN[tokens[3].token] then return { exp = 'COMPARE', binop = tokens[2].token, [1] = get_value_token(tokens[1]), [2] = get_value_token(tokens[3]) } elseif #tokens == 4 and tokens[1].token == 'HASHTAG' and VALUE_TOKEN[tokens[2].token] and COMPARE_BINOP[tokens[3].token] and VALUE_TOKEN[tokens[4].token] then return { exp = 'COMPARE', binop = tokens[3].token, [1] = { exp = 'UNOP', get_value_token(tokens[2]) } , [2] = get_value_token(tokens[4]) } elseif #tokens == 3 and tokens[1].token == 'IDENTIFIER' and tokens[2].token == 'DOT' and tokens[3].token == 'IDENTIFIER' then return { exp = 'LVALUE', tokens[1].text, { exp = 'STRING', value = tokens[3].text } } elseif #tokens == 4 and tokens[1].token == 'IDENTIFIER' and tokens[2].token == 'LBRACK' and VALUE_TOKEN[tokens[3].token] and tokens[4].token == 'RBRACK' then return { exp = 'LVALUE', tokens[1].text, get_value_token(tokens[3]) } elseif #tokens == 1 then return get_value_token(tokens[1]) else io.stderr:write '[assert-gooder/internal]: Unknown AST structure!\n' end end local function for_each_node_in_ast (ast, func) assert(type(ast) == 'table') assert(type(func) == 'function') -- for _, node in ipairs(ast) do if type(node) == 'table' then for_each_node_in_ast(node, func) end end -- return func(ast) end local function populate_ast_with_semantics (ast, info) assert(type(ast) == 'table') assert(type(info) == 'table') return for_each_node_in_ast(ast, function(node) if node.exp == 'LVALUE' then node.value = get_value_from_lvalue(node, info) end end) end -------------------------------------------------------------------------------- local function get_assert_body_text (call_info) if call_info.what == 'Lua' or call_info.what == 'main' then -- Find filetext local filetext = nil if call_info.source:find '^@' then local filehandle = io.open(call_info.short_src, 'r') filetext = filehandle:read '*all' filehandle:close() elseif call_info.short_src:find '^%[string' then filetext = call_info.source else error 'Not implemented yet!' end -- Get lines local filetext = filetext .. '\n' local lines_after, line_i = {}, 0 for line in filetext:gmatch '([^\r\n]*)[\r\n]' do line_i = line_i + 1 if call_info.currentline == line_i then lines_after[#lines_after+1] = line end end -- Find body exclusively. return table.concat(lines_after, '\n'):match('assert%s*(%b())'):sub(2, -2) end error 'Not implemented yet!' end local function get_function_name (call_info) -- if call_info.name then return string.format('\'%s\'', call_info.name) end -- local where = nil if call_info.source:find '^@' then where = 'at '..call_info.short_src..':'..call_info.linedefined elseif call_info.short_src:find '^%[string' then where = 'from loaded string' else error 'not yet implemented' end -- return string.format('the anonymous function %s', where) end local function fmt_val (val) if type(val) == 'string' then return string.format('%q', val) else return tostring(val) end end local function fmt_lvalue (lvalue, var_scope) assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE') -- local base_var = tostring(lvalue[1]) if var_scope then base_var = ('%s \'%s\''):format(var_scope, lvalue[1]) end -- if #lvalue == 1 then return base_var elseif #lvalue == 2 then return string.format('key %s in %s', fmt_val(lvalue[2].value), base_var) else error 'Not implemented yet!' end end local function get_variable_and_prefix (lvalue, call_info) assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE') -- local value, var_scope, in_func = get_value_from_lvalue(lvalue, call_info) -- local func_name = in_func and (' to '..get_function_name(call_info)) or '' return value, ('bad %s%s'):format(fmt_lvalue(lvalue, var_scope), func_name) end local PRIMITIVE_VALUES = { ['nil'] = true, ['boolean'] = true, } local COMPLEX_TYPES = { ['table'] = true, ['userdata'] = true, ['cdata'] = true, ['function'] = true, } local function fmt_val_with_type (val) -- Primitive values ARE their type, and don't need the annotation. if PRIMITIVE_VALUES[type(val)] then return tostring(val) end -- Complex types are already formatted with some type information. if COMPLEX_TYPES[type(val)] then return tostring(val) end -- Numbers and string should have their types with them. return type(val) .. ' ' .. fmt_val(val) end -------------------------------------------------------------------------------- local function get_assert_body (call_info) return lexer:lex(text), text end local function determine_error_message (call_info, msg, condition) -- Error checking. assert(type(call_info) == 'table') assert(type(msg) == 'table' and type(msg[1]) == 'string') assert(not condition) -- Get assert body. local body_text = get_assert_body_text(call_info) -- Simplest formatting. -- No analysis of the assert-body, just report that it failed, -- along with it's body. msg[1] = ('expression `%s` evaluated to %s'):format(body_text, condition) -- Lex text. local tokens = lexer:lex(body_text) assert(type(tokens) == 'table') -- Find identifiers and provide simple explanations of their -- values. do local l, seen_variables = {}, {} for i, token in ipairs(tokens) do local variable_name = token.text if token.token == 'IDENTIFIER' and (i == 1 or tokens[i-1].token ~= 'COLON' and tokens[i-1].token ~= 'DOT') and not seen_variables[variable_name] then seen_variables[variable_name] = true local value = get_variable(variable_name, call_info) if EXPECTED_GLOBAL[variable_name] == nil then l[#l+1] = ('%s was %s'):format(token.text, fmt_val_with_type(value)) elseif EXPECTED_GLOBAL[variable_name] ~= type(value) then l[#l+1] = ('standard-%s %s was %s'):format(EXPECTED_GLOBAL[variable_name], token.text, fmt_val_with_type(value)) end end end msg[2] = #l > 0 and table.concat(l,', ') or nil end local ast = parse(tokens) if not ast then return end assert(type(ast) == 'table') populate_ast_with_semantics(ast, call_info) -- Alternative more detailed formatting. -- Identical to last message, but now with values of each involved -- identifier. do local l = {} for_each_node_in_ast(ast, function(node) if node.exp == 'LVALUE' then l[#l+1] = fmt_lvalue(node) ..' was ' .. fmt_val_with_type(node.value) end end) msg[2] = #l > 0 and table.concat(l,', ') or nil end -- More specific types local var_prefix = function(token) return get_variable_and_prefix(token, call_info) end if not ast then return nil elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' and ast[1].exp == 'CALL' and ast[1][1].exp == 'LVALUE' and ast[1][1][1] == 'type' then local gotten_val, prefix = var_prefix(ast[1][2]) msg[1], msg[2] = prefix, ('%s expected, but got %s'):format(ast[2].value, fmt_val_with_type(gotten_val)) elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' then local gotten_value, prefix = var_prefix(ast[1]) local expected_value = ast[2].value local fmt_gotten = (type(expected_value) == type(gotten_value)) and fmt_val or fmt_val_with_type msg[1], msg[2] = prefix, ('%s expected, but got %s'):format(fmt_val_with_type(expected_value), fmt_gotten(gotten_value)) elseif ast.exp == 'COMPARE' and ast.binop == 'NEQ' then local gotten_val, prefix = var_prefix(ast[1]) local expected_value = ast[2].value msg[1], msg[2] = prefix, ('expected anything other than %s, but got %s'):format(fmt_val_with_type(expected_value), fmt_val(gotten_val)) elseif ast.exp == 'LVALUE' then local gotten_val, prefix = var_prefix(ast) msg[1], msg[2] = prefix, ('truthy expected, but got %s'):format(fmt_val(gotten_val)) elseif CONSTANT_VALUE_TOKEN[ast.exp] then local func_name = get_function_name(call_info) msg[1] = ('this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name) else error(('[assert-gooder/internal]: Unknown expression type %s'):format(ast.exp)) end end return function (condition) if condition then return condition end -- local level = 2 local call_info = debug.getinfo(level) call_info.locals = {} for i = 1, math.huge do local name, value = debug.getlocal(level, i) if not name then break end call_info.locals[name] = { value, i <= call_info.nparams, i } end -- local msg_container = {''} local success, internal_error_msg = pcall(determine_error_message, call_info, msg_container, condition) -- Handle internal errors if not success then io.stderr:write(('[assert-gooder/internal]: Internal error occured while determining error message for calling assert:\n %s\n'):format(internal_error_msg)) end -- assert(#msg_container <= 2 and type(msg_container[1]) == 'string') local l = {'assertion failed! ', msg_container[1]} if msg_container[2] then assert(type(msg_container[2]) == 'string') l[3] = ' (' l[4] = msg_container[2] l[5] = ')' end error(table.concat(l, ''), 2) end