local lexer = assert(require((... and select('1', ...):match('.+%.') or '')..'lua_lang'), '[assert-gooder]: Could not load vital library: lua_lang') local shunting_yard = assert(require((... and select('1', ...):match('.+%.') or '')..'Parser'), '[assert-gooder]: Could not load vital library: Parser') -------------------------------------------------------------------------------- local EXPECTED_GLOBAL = { -- TODO: Expand with other functions. ['type'] = 'function', ['tonumber'] = 'function', ['math'] = 'table', } -------------------------------------------------------------------------------- local function get_value_of_string (string_str) if string_str:sub(1, 1) == '"' or string_str:sub(1, 1) == '\'' then return string_str:sub(2, -2) end assert(false) end local CONSTANT_VALUE_TOKEN = { NUMBER = tonumber, STRING = get_value_of_string, TRUE = function() return true end, FALSE = function() return false end, NIL = function() return nil end } local function table_indexing (t, k) return t[k] end local function safe_index (t, index) assert(type(t) == 'table') -- Attempt rawget first local value = rawget(t, index) if value ~= nil then return value end local mt = debug.getmetatable(t) if mt.__index then -- If weird indexing, use pcall local success, value = pcall(table_indexing, t, index) if success then return value end return nil else -- Then attempt normal indexing, if no weirdness return t[index] end end local function get_variable (var_name, info) -- Assertions assert(type(var_name) == 'string') assert(type(info) == 'table') assert(type(info.func) == 'function') assert(type(info.locals) == 'table') -- Local if info.locals[var_name] then local var_info = info.locals[var_name] return var_info[1], var_info[2] and ('argument #'..var_info[3]) or 'local', info.name or var_info[2] and '' end -- Up-value local index = 0 repeat index = index + 1 local name, val = debug.getupvalue(info.func, index) if name == var_name then return val, 'upvalue' end until not name -- Global local env = safe_index(getfenv(info.func), var_name) return env, 'global' end -------------------------------------------------------------------------------- -- Parsing local NO_PARSE_TOKENS = { FUNCTION = true, SEMICOLON = true, BREAK = true, DO = true, ELSE = true, ELSEIF = true, END = true, FOR = true, IF = true, IN = true, LOCAL = true, REPEAT = true, RETURN = true, THEN = true, UNTIL = true, WHILE = true, ASSIGN = true, -- TODO: These below should not be NO-PARSE-TOKENS: VARARG = true, LBRACE = true, RBRACE = true, } local LUA_BINOP = { ['OR'] = { precedence = 1 }, ['AND'] = { precedence = 2 }, ['LE'] = { precedence = 3 }, ['LEQ'] = { precedence = 3 }, ['EQ'] = { precedence = 3 }, ['NEQ'] = { precedence = 3 }, ['GEQ'] = { precedence = 3 }, ['GT'] = { precedence = 3 }, ['CONCAT'] = { precedence = 4, associativity = 'right' }, ['PLUS'] = { precedence = 5 }, ['MINUS'] = { precedence = 5 }, ['TIMES'] = { precedence = 6 }, ['DIVIDE'] = { precedence = 6 }, ['MODULO'] = { precedence = 6 }, ['CARET'] = { precedence = 8, associativity = 'right' }, ['DOT'] = { precedence = 10 }, ['COLON'] = { precedence = 10 }, } local LUA_PREOP = { ['MINUS'] = { precedence = 7, val = 'UMINUS' }, ['NOT'] = { precedence = 7 }, ['HASHTAG'] = { precedence = 7 }, } for keyword, binop in pairs (LUA_BINOP) do binop.type = 'op' binop.arity = 2 binop.val = binop.val or keyword binop.associativity = binop.associativity or 'left' end for keyword, unop in pairs (LUA_PREOP) do unop.type = 'op' unop.arity = 1 unop.val = unop.val or keyword unop.associativity = 'right' end local LUA_OP_ARITY = {} for _, op in pairs(LUA_BINOP) do LUA_OP_ARITY[op.val] = op.arity end for _, op in pairs(LUA_PREOP) do LUA_OP_ARITY[op.val] = op.arity end local function lua_expression_lang (token, prev_token_info) local ttype = token.token local binary_disallowed = prev_token_info == nil or prev_token_info.type == 'bracket' and prev_token_info.open == prev_token_info.val or prev_token_info.type == 'op' if LUA_PREOP[ttype] and binary_disallowed then -- TODO: Unary minus? local op_info = LUA_PREOP[ttype] return { type = 'op', arity = op_info.arity, associativity = op_info.associativity, precedence = op_info.precedence, val = token } elseif LUA_BINOP[ttype] and not binary_disallowed then local op_info = LUA_BINOP[ttype] return { type = 'op', arity = op_info.arity, associativity = op_info.associativity, precedence = op_info.precedence, val = token } elseif LUA_BINOP[ttype] or LUA_PREOP[ttype] then error('Cannot use the operation '..ttype..' here') elseif ttype == 'LBRACK' or ttype == 'RBRACK' then return { type = 'bracket', val = ttype, open = 'LBRACK', close = 'RBRACK', precedence = 10, on_end_call_op = { type = 'op', val = '?', call_op = true } } elseif ttype == 'LPAR' or ttype == 'RPAR' then return { type = 'bracket', val = ttype, open = 'LPAR', close = 'RPAR', precedence = 9, on_end_call_op = { type = 'op', val = '$', call_op = true } } elseif ttype == 'COMMA' then return { type = 'arg-sep', val = 'COMMA', open = 'LPAR' } elseif type (ttype) == 'number' or type (ttype) == 'string' then return { type = 'imm', val = token } end end local function NUM_ARITY (num) assert(type(num) == 'number') return { type = 'imm', val = num } end local function parse (tokens) assert(type(tokens) == 'table') -- for i = 2, #tokens do if tokens[i].token == 'IDENTIFIER' and tokens[i-1].token == 'DOT' then tokens[i].token = 'STRING' tokens[i].text = '\''..tokens[i].text..'\'' end end -- Postfix the tokens local postfix_tokens = shunting_yard(tokens, lua_expression_lang, NUM_ARITY) assert(type(postfix_tokens) == 'table') -- Create AST from postfix tokens local ast_stack = {} for _, token in ipairs(postfix_tokens) do if type(token) == 'table' and LUA_OP_ARITY[token.token] then -- Operation local node = { exp = 'OP', binop = token.token } local start_left, end_right = math.huge, -math.huge -- for arg_index = LUA_OP_ARITY[token.token], 1, -1 do local arg = table.remove(ast_stack) assert(type(arg) == 'table') start_left, end_right = math.min(start_left, arg.left), math.max(end_right, arg.right) node[arg_index] = arg end -- node.left, node.right = start_left, end_right ast_stack[#ast_stack+1] = node elseif token == '$' then -- Operation local node = { exp = 'CALL' } local start_left, end_right = math.huge, -math.huge -- local num_args = table.remove(ast_stack) assert(type(num_args) == 'number') -- for arg_index = num_args + 1, 1, -1 do local arg = table.remove(ast_stack) assert(type(arg) == 'table') start_left, end_right = math.min(start_left, arg.left), math.max(end_right, arg.right) node[arg_index] = arg end -- node.left, node.right = start_left, end_right ast_stack[#ast_stack+1] = node elseif token == '?' then -- Indexing local node = { exp = 'OP', binop = 'DOT' } local start_left, end_right = math.huge, -math.huge -- local num_args = table.remove(ast_stack) assert(type(num_args) == 'number' and num_args == 1) -- for arg_index = num_args + 1, 1, -1 do local arg = table.remove(ast_stack) assert(type(arg) == 'table') start_left, end_right = math.min(start_left, arg.left), math.max(end_right, arg.right) node[arg_index] = arg end -- node.left, node.right = start_left, end_right ast_stack[#ast_stack+1] = node else -- Immediate ast_stack[#ast_stack+1] = token end end assert(#ast_stack == 1) return ast_stack[1] end local function for_each_node_in_ast (ast, func) assert(type(ast) == 'table') assert(type(func) == 'function') -- for _, node in ipairs(ast) do if type(node) == 'table' then for_each_node_in_ast(node, func) end end -- return func(ast) end local CONSTANT_BINOP = {} local CONSTANT_UNOP = {} function CONSTANT_BINOP.DOT (node) return node[1].value[ node[2].value ] end -- TODO function CONSTANT_BINOP.AND (node) return node[1].value and node[2].value end function CONSTANT_BINOP.OR (node) return node[1].value or node[2].value end function CONSTANT_BINOP.PLUS (node) return node[1].value + node[2].value end function CONSTANT_BINOP.MINUS (node) return node[1].value - node[2].value end function CONSTANT_BINOP.TIMES (node) return node[1].value * node[2].value end function CONSTANT_BINOP.DIVIDE (node) return node[1].value / node[2].value end function CONSTANT_BINOP.MODULO (node) return node[1].value % node[2].value end function CONSTANT_BINOP.CARET (node) return node[1].value ^ node[2].value end function CONSTANT_BINOP.EQ (node) return node[1].value == node[2].value end function CONSTANT_BINOP.NEQ (node) return node[1].value ~= node[2].value end function CONSTANT_BINOP.LEQ (node) return node[1].value <= node[2].value end function CONSTANT_BINOP.GEQ (node) return node[1].value >= node[2].value end function CONSTANT_BINOP.LE (node) return node[1].value < node[2].value end function CONSTANT_BINOP.GT (node) return node[1].value > node[2].value end function CONSTANT_BINOP.CONCAT (node) return node[1].value .. node[2].value end function CONSTANT_UNOP.HASHTAG (node) return #node[1].value end local function populate_ast_with_semantics (ast, info) assert(type(ast) == 'table') assert(type(info) == 'table') for_each_node_in_ast(ast, function(node) if node.token then assert(not node.ast) node.exp, node.token = node.token, nil end end) --print 'Semantics!' return for_each_node_in_ast(ast, function(node) --print(require'pretty'(node)) if node.exp == 'IDENTIFIER' then node.value, node.scope, node.function_local = get_variable(node.text, info) elseif CONSTANT_VALUE_TOKEN[node.exp] then node.value = CONSTANT_VALUE_TOKEN[node.exp](node.text) node.is_constant = true elseif node.exp == 'OP' and CONSTANT_UNOP[node.binop] and node[1].value then assert(node[1].value) node.value = CONSTANT_UNOP[node.binop](node) node.is_constant = node[1].is_constant elseif node.exp == 'OP' and CONSTANT_BINOP[node.binop] and node[1].value and node[2] and node[2].value then assert(node[1].value and (not node[2] or node[2].value)) node.value = CONSTANT_BINOP[node.binop](node) node.is_constant = node[1].is_constant and (not node[2] or node[2].is_constant) end end) end -------------------------------------------------------------------------------- local function get_module_filetext (module_filepath) assert(type(module_filepath) == 'string') -- Just attempt standard file open local filehandle = io.open(module_filepath, 'r') if filehandle then local filetext = filehandle:read '*all' filehandle:close() return filetext end -- What about LÖVE? local filetext = love and love.filesystem and love.filesystem.read(module_filepath) or nil if filetext then return filetext end -- I give up... return nil end local function seperate_by_toplevel_commas (text) assert(type(text) == 'string') local section_start, index, sections = 1, 1, {} while index < #text do local next_comma = text:find(',', index) local next_par_start, next_par_end = text:find('%b()', index) if not next_comma then break elseif not next_par_start or next_comma < next_par_start then sections[#sections+1] = text:sub(section_start, next_comma - 1) index = next_comma + 1 section_start = index else index = next_par_end + 1 end end sections[#sections+1] = text:sub(section_start) return sections end local function get_assert_body_text (call_info) if call_info.what == 'Lua' or call_info.what == 'main' then -- Find filetext local filetext = nil if call_info.source:find '^@' then filetext = get_module_filetext(call_info.short_src) elseif call_info.short_src:find '^%[string' then filetext = call_info.source else error 'Not implemented yet!' end -- If cannot find if not filetext then return nil end -- Get lines local filetext = filetext .. '\n' local lines_after, line_i = {}, 0 for line in filetext:gmatch '([^\r\n]*)[\r\n]' do line_i = line_i + 1 if call_info.currentline == line_i then lines_after[#lines_after+1] = line end end -- Find body exclusively. local assert_arguments_text = table.concat(lines_after, '\n'):match('assert%s*(%b())'):sub(2, -2) local assert_arguments = seperate_by_toplevel_commas(assert_arguments_text) return assert_arguments[1] end error 'Not implemented yet!' end local function get_function_name (call_info) -- if call_info.name then return string.format('\'%s\'', call_info.name) end -- local where = nil if call_info.source:find '^@' then where = 'at '..call_info.short_src..':'..call_info.linedefined elseif call_info.short_src:find '^%[string' then where = 'from loaded string' else error 'not yet implemented' end -- return string.format('the anonymous function %s', where) end local function fmt_val (val) if type(val) == 'string' then return string.format('%q', val) else return tostring(val) end end local function fmt_lvalue (node, with_scope) assert(type(node) == 'table') if node.exp == 'IDENTIFIER' then local base = node.text if with_scope then base = ('%s \'%s\''):format(node.scope, base) end return base, node.function_local elseif node.exp == 'OP' and node.binop == 'DOT' then local base, function_local = fmt_lvalue(node[1], with_scope) return ('key %s in %s'):format(fmt_val(node[2].value), base), function_local end -- if node.exp == 'OP' and node.binop == 'HASHTAG' and #node == 1 then local base, is_local = fmt_lvalue(node[1], with_scope) return ('length of %s'):format(base), is_local end --print(require'pretty'(node)) error 'Not implemented yet!' end local function fmt_prefix (ast, call_info) assert(type(ast) == 'table') -- local name, is_function_local = fmt_lvalue(ast, true) local binder = ast.node == 'argument' and 'to' or 'in' local func_name = is_function_local and (' '..binder..' '..get_function_name(call_info)) or '' return ('bad %s%s'):format(name, func_name) end local PRIMITIVE_VALUES = { ['nil'] = true, ['boolean'] = true, } local COMPLEX_TYPES = { ['table'] = true, ['userdata'] = true, ['cdata'] = true, ['function'] = true, } local function fmt_table_with_type (val) assert(type(val) == 'table') local subtype = 'table' -- Find "last key" do local last_key, num_visited = nil, 0 repeat last_key, num_visited = next(val, last_key), num_visited + 1 until last_key == nil or type(last_key) ~= 'number' or last_key <= 0 or last_key > #val -- Conclude: if last_key == nil then subtype = (num_visited == 1) and 'empty table' or 'sequence of length '..(num_visited - 1) end end local addr = tostring(val):match '^table: (.*)$' local attr = debug.getmetatable(val) ~= nil and ' with metatable' or '' return ('%s%s: %s'):format(subtype, attr, addr) end local function fmt_val_with_type (val) -- Primitive values ARE their type, and don't need the annotation. if PRIMITIVE_VALUES[type(val)] then return tostring(val) end -- Tables can be of many different styles if type(val) == 'table' then return fmt_table_with_type(val) end -- Complex types are already formatted with some type information. if COMPLEX_TYPES[type(val)] then return tostring(val) end -- Numbers and string should have their types with them. return type(val) .. ' ' .. fmt_val(val) end local function is_l_value (ast) assert(type(ast) == 'table') if ast.exp == 'OP' and ast.binop == 'DOT' then return true elseif ast.exp == 'IDENTIFIER' then return true end return false end local function fancy_fmt_seq (seq, ends_with) ends_with = ends_with or ', and ' assert(type(seq) == 'table') assert(type(ends_with) == 'string') local sep = ', ' local l = {} for i = 1, #seq do l[#l+1] = fmt_val(seq[i]) l[#l+1] = sep end if #seq > 0 then l[#l] = nil end if #seq > 1 then l[#l-1] = ends_with end return table.concat(l, '') end local function similar_keys_in_table (t, key) assert(type(t) == 'table') assert(key ~= nil) local keys, key = {}, nil repeat key = next(t, key) keys[#keys+1] = key until #keys >= 3 or key == nil return keys end -------------------------------------------------------------------------------- local function get_assert_body (call_info) return lexer:lex(text), text end local function determine_error_message (call_info, msg, condition) -- Error checking. assert(type(call_info) == 'table') assert(type(msg) == 'table' and type(msg[1]) == 'string') assert(not condition) -- Get assert body. local body_text = get_assert_body_text(call_info) -- If we couldn't find the body text, we give up. if not body_text then return end -- Simplest formatting. -- No analysis of the assert-body, just report that it failed, -- along with it's body. msg[1] = ('expression `%s` evaluated to %s'):format(body_text, condition) -- Lex text. local tokens = lexer:lex(body_text) assert(type(tokens) == 'table') -- Find identifiers and provide simple explanations of their -- values. do local l, seen_variables = {}, {} for i, token in ipairs(tokens) do local variable_name = token.text if token.token == 'IDENTIFIER' and (i == 1 or tokens[i-1].token ~= 'COLON' and tokens[i-1].token ~= 'DOT') and not seen_variables[variable_name] then seen_variables[variable_name] = true local value = get_variable(variable_name, call_info) if EXPECTED_GLOBAL[variable_name] == nil then l[#l+1] = ('%s was %s'):format(token.text, fmt_val_with_type(value)) elseif EXPECTED_GLOBAL[variable_name] ~= type(value) then l[#l+1] = ('standard-%s %s was %s'):format(EXPECTED_GLOBAL[variable_name], token.text, fmt_val_with_type(value)) end end end msg[2] = #l > 0 and table.concat(l,', ') or nil end local ast = parse(tokens) if not ast then return end assert(type(ast) == 'table') populate_ast_with_semantics(ast, call_info) -- Alternative more detailed formatting. -- Identical to last message, but now with values of each involved -- identifier. do local l = {} for_each_node_in_ast(ast, function(node) if is_l_value(node) then local name = fmt_lvalue(node) if not l[name] then l[#l+1] = name..' was ' .. fmt_val_with_type(node.value) l[name] = true end end end) msg[2] = #l > 0 and table.concat(l,', ') or nil end -- More specific types local var_prefix = function(token) return get_variable_and_prefix(token, call_info) end local function fmt_number_value (value, relevant) local relevant = relevant or {} assert(type(value) == 'number') assert(type(relevant) == 'table') local l = { 'number', tonumber(value), base = 1 } if value % 1 ~= 0 then l[1] = 'decimal number' else l[1] = 'integer' end if relevant.sign then l.base, l[0] = 0, (value > 0) and 'positive' or 'negative' end if relevant.remainder then assert(type(relevant.remainder) == 'number') if relevant.remainder == 1 then -- Do nothing. -- The remainder of a decimal number is obvious. elseif relevant.remainder == 2 then l[1] = (value % 2 == 0) and 'even number' or 'odd number' else l[3], l[4] = 'with remainder', value % relevant.remainder end end return table.concat(l, ' ', l.base) end if not ast then return nil elseif ast.exp == 'OP' and ast.binop == 'EQ' and ast[1].exp == 'CALL' and ast[1][1].value == type then local prefix = fmt_prefix(ast[1][2], call_info) msg[1], msg[2] = prefix, ('%s expected, but got %s'):format(ast[2].value, fmt_val_with_type(ast[1][2].value)) elseif ast.exp == 'OP' and ast.binop == 'EQ' and ast[1].exp == 'OP' and ast[1].binop == 'MODULO' and ast[2].exp == 'NUMBER' then -- a % b == c local a = ast[1][1].value local b = ast[1][2].value local expect_remainder = ast[2].value assert(type(a) == 'number') assert(type(b) == 'number') assert(type(expect_remainder) == 'number' and expect_remainder >= 0 and expect_remainder < b, 'Nonsensical desired remainder') local expected_desc, relevant_attr = '???', { remainder = b } if b == 2 and expect_remainder == 0 then expected_desc = 'even number' elseif b == 2 and expect_remainder == 1 then expected_desc = 'odd number' elseif b == 1 and expect_remainder == 0 then expected_desc = 'integer' elseif expect_remainder == 0 then expected_desc = 'integer divisible by '..tostring(b) else expected_desc = 'integer with remainder '..expect_remainder..' when divided by '..tostring(b) end msg[1] = fmt_prefix(ast[1][1], call_info) msg[2] = ('%s expected, but got %s'):format(expected_desc, fmt_number_value(ast[1][1].value, relevant_attr)) elseif ast.exp == 'OP' and ast.binop == 'EQ' then local prefix = fmt_prefix(ast[1], call_info) local gotten_value, expected_value = ast[1].value, ast[2].value local fmt_gotten = (type(expected_value) == type(gotten_value)) and fmt_val or fmt_val_with_type msg[1], msg[2] = prefix, ('%s expected, but got %s'):format(fmt_val_with_type(expected_value), fmt_gotten(gotten_value)) elseif ast.exp == 'OP' and ast.binop == 'NEQ' then local prefix = fmt_prefix(ast[1], call_info) local gotten_value, expected_value = ast[1].value, ast[2].value msg[1], msg[2] = prefix, ('expected anything other than %s, but got %s'):format(fmt_val_with_type(expected_value), fmt_val(gotten_value)) elseif ast.exp == 'OP' and ast.binop == 'DOT' and is_l_value(ast[2]) then local prefix = fmt_prefix(ast[2], call_info) local gotten_value = ast[2].value local similar_keys, explain = similar_keys_in_table(ast[1].value, gotten_value) if #similar_keys > 0 then explain = (' close keys in %s include %s'):format(fmt_lvalue(ast[1]), fancy_fmt_seq(similar_keys)) else explain = (' value of %s was %s'):format(fmt_lvalue(ast[1]), fmt_val(ast[1].value)) end msg[1], msg[2] = prefix, ('value should occur as key in %s, but was %s.%s'):format(fmt_lvalue(ast[1], true), fmt_val(gotten_value), explain) elseif is_l_value(ast) then local prefix = fmt_prefix(ast, call_info) local gotten_value = ast.value msg[1], msg[2] = prefix, ('truthy expected, but got %s'):format(fmt_val(gotten_value)) elseif CONSTANT_VALUE_TOKEN[ast.exp] then local func_name = get_function_name(call_info) msg[1] = ('this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name) elseif not ast.exp then error(('[assert-gooder/internal]: Root node did not have expression type.')) else --print(require'pretty'(ast)) error(('[assert-gooder/internal]: Unknown expression type %s'):format(ast.exp)) end end return function (condition, format, ...) if condition then return condition end -- local level = 2 local call_info = debug.getinfo(level) call_info.locals = {} for i = 1, math.huge do local name, value = debug.getlocal(level, i) if not name then break end call_info.locals[name] = { value, i <= call_info.nparams, i } end -- local msg_container = {''} local success, internal_error_msg = pcall(determine_error_message, call_info, msg_container, condition) -- Handle internal errors if not success then io.stderr:write(('[assert-gooder/internal]: Internal error occured while determining error message for calling assert:\n %s\n'):format(internal_error_msg)) end -- Format error message: assert(#msg_container <= 2 and type(msg_container[1]) == 'string') local l = {} if format ~= nil then l[#l+1] = (type(format) == 'string') and format:format(...) or tostring(format) l[#l+1] = ':' else l[#l+1] = 'assertion failed!' end l[#l+1] = ' ' l[#l+1] = msg_container[1] if msg_container[2] then assert(type(msg_container[2]) == 'string') l[#l+1] = ' (' l[#l+1] = msg_container[2] l[#l+1] = ')' end -- Throw error message error(table.concat(l, ''), 2) end