1
0
assert-gooder/assert-gooder.lua

280 lines
9.9 KiB
Lua

local lexer = assert(require((... and select('1', ...):match('.+%.') or '')..'lua_lang'), '[assert-gooder]: Could not load vital library: lua_lang')
--------------------------------------------------------------------------------
local function get_value_of_string (string_str)
if string_str:sub(1, 1) == '"' or string_str:sub(1, 1) == '\'' then
return string_str:sub(2, -2)
end
assert(false)
end
local CONSTANT_VALUE_TOKEN = {
NUMBER = tonumber,
STRING = get_value_of_string,
TRUE = function() return true end,
FALSE = function() return false end,
NIL = function() return nil end
}
local VALUE_TOKEN = { IDENTIFIER = true }
for k in pairs(CONSTANT_VALUE_TOKEN) do VALUE_TOKEN[k] = true end
local COMPARE_BINOP = {
EQ = true,
NEQ = true,
LEQ = true,
GEQ = true,
LE = true,
GT = true,
}
local function get_value_token (token)
if CONSTANT_VALUE_TOKEN[token.token]then
return { exp = token.token, value = CONSTANT_VALUE_TOKEN[token.token](token.text) }
elseif token.token == 'IDENTIFIER' then
return { exp = 'LVALUE', token.text }
end
assert(false)
end
local function get_variable (var_name, info)
--
assert(type(var_name) == 'string')
assert(type(info) == 'table')
-- Local
if info.locals[var_name] then
local var_info = info.locals[var_name]
return var_info[1], var_info[2] and ('argument #'..var_info[3]) or 'local', info.name or var_info[2] and ''
end
-- Up-value
local index = 0
repeat
index = index + 1
local name, val = debug.getupvalue(info.func, index)
if name == var_name then return val, 'upvalue' end
until not name
-- Global
return getfenv(info.func)[var_name], 'global'
end
local function get_value_from_lvalue (lvalue, info)
assert(type(lvalue) == 'table')
assert(type(info) == 'table')
-- Base value
local value, var_scope, in_func = get_variable(lvalue[1], info)
-- Sub value
for i = 2, #lvalue do value = value[lvalue[i].value] end
--
return value, var_scope, in_func
end
--------------------------------------------------------------------------------
-- Parsing
local function parse (tokens)
-- TODO: Make a more general parser
assert(type(tokens) == 'table')
--
if #tokens == 6 and tokens[1].text == 'type' and tokens[2].token == 'LPAR' and tokens[3].token == 'IDENTIFIER' and tokens[4].token == 'RPAR'and tokens[5].token == 'EQ'and tokens[6].token == 'STRING' then
return {
exp = 'COMPARE',
binop = 'EQ',
[1] = { exp = 'CALL', get_value_token(tokens[1]), get_value_token(tokens[3]) },
[2] = get_value_token(tokens[6]),
}
elseif #tokens == 3 and VALUE_TOKEN[tokens[1].token] and COMPARE_BINOP[tokens[2].token] and VALUE_TOKEN[tokens[3].token] then
return {
exp = 'COMPARE',
binop = tokens[2].token,
[1] = get_value_token(tokens[1]),
[2] = get_value_token(tokens[3])
}
elseif #tokens == 3 and tokens[1].token == 'IDENTIFIER' and tokens[2].token == 'DOT' and tokens[3].token == 'IDENTIFIER' then
return { exp = 'LVALUE', tokens[1].text, { exp = 'STRING', value = tokens[3].text } }
elseif #tokens == 4 and tokens[1].token == 'IDENTIFIER' and tokens[2].token == 'LBRACK' and VALUE_TOKEN[tokens[3].token] and tokens[4].token == 'RBRACK' then
return { exp = 'LVALUE', tokens[1].text, get_value_token(tokens[3]) }
elseif #tokens == 1 then
return get_value_token(tokens[1])
else
io.stderr:write '[assert-gooder/internal]: Unknown AST structure!'
end
end
local function populate_ast_with_semantics (node, info)
if type(node) ~= 'table' then return end
for i = 1, #node do populate_ast_with_semantics(node[i], info) end
--
if node.exp == 'LVALUE' then
node.value = get_value_from_lvalue(node, info)
end
end
--------------------------------------------------------------------------------
local function get_assert_body_text (call_info)
if call_info.what == 'Lua' or call_info.what == 'main' then
-- Find filetext
local filetext = nil
if call_info.source:find '^@' then
local f = io.open(call_info.short_src, 'r')
filetext = f:read '*all'
f:close()
elseif call_info.short_src:find '^%[string' then
filetext = call_info.source
else
error 'Not implemented yet!'
end
-- Get lines
local filetext = filetext .. '\n'
local lines_after, line_i = {}, 0
for line in filetext:gmatch '([^\r\n]*)[\r\n]' do
line_i = line_i + 1
if call_info.currentline == line_i then
lines_after[#lines_after+1] = line
end
end
-- Find body exclusively.
return table.concat(lines_after, '\n'):match('assert%s*(%b())'):sub(2, -2)
end
error 'Not implemented yet!'
end
local function get_assert_body (call_info)
local text = get_assert_body_text(call_info)
return lexer:lex(text), text
end
local function get_function_name (call_info)
--
if call_info.name then return string.format('\'%s\'', call_info.name) end
--
local where = nil
if call_info.source:find '^@' then
where = 'at '..call_info.short_src..':'..call_info.linedefined
elseif call_info.short_src:find '^%[string' then
where = 'from loaded string'
else
error 'not yet implemented'
end
--
return string.format('the anonymous function %s', where)
end
local function fmt_val (val)
if type(val) == 'string' then
return string.format('%q', val)
else
return tostring(val)
end
end
local function fmt_lvalue (lvalue, var_scope)
assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE')
if #lvalue == 1 then return string.format('%s \'%s\'', var_scope, lvalue[1])
elseif #lvalue == 2 then return string.format('key %s in %s \'%s\'', fmt_val(lvalue[2].value), var_scope, lvalue[1])
else error 'Not implemented yet!'
end
end
local function get_variable_and_prefix (lvalue, call_info)
assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE')
--
local value, var_scope, in_func = get_value_from_lvalue(lvalue, call_info)
--
local func_name = in_func and (' to '..get_function_name(call_info)) or ''
return value, ('bad %s%s'):format(fmt_lvalue(lvalue, var_scope), func_name)
end
local PRIMITIVE_VALUES = {
['nil'] = true,
['boolean'] = true,
}
local COMPLEX_TYPES = {
['table'] = true,
['userdata'] = true,
['cdata'] = true,
['function'] = true,
}
local function fmt_val_with_type (val)
-- Primitive values ARE their type, and don't need the annotation.
if PRIMITIVE_VALUES[type(val)] then return tostring(val) end
-- Complex types are already formatted with some type information.
if COMPLEX_TYPES[type(val)] then return tostring(val) end
-- Numbers and string should have their types with them.
return type(val) .. ' ' .. fmt_val(val)
end
--------------------------------------------------------------------------------
local function determine_error_message (call_info, msg, condition)
local tokens, body_text = get_assert_body(call_info)
local ast = parse(tokens)
populate_ast_with_semantics(ast, call_info)
msg[1] = ('expression `%s` evaluated to %s'):format(body_text, condition)
local var_prefix = function(token) return get_variable_and_prefix(token, call_info) end
if not ast then return nil
elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' and ast[1].exp == 'CALL' and ast[1][1].exp == 'LVALUE' and ast[1][1][1] == 'type' then
local gotten_val, prefix = var_prefix(ast[1][2])
msg[1] = ('%s (%s expected, but got %s)'):format(prefix, ast[2].value, fmt_val_with_type(gotten_val))
elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' then
local gotten_value, prefix = var_prefix(ast[1])
local expected_value = ast[2].value
local fmt_gotten = (type(expected_value) == type(gotten_value)) and fmt_val or fmt_val_with_type
msg[1] = ('%s (%s expected, but got %s)'):format(prefix, fmt_val_with_type(expected_value), fmt_gotten(gotten_value))
elseif ast.exp == 'COMPARE' and ast.binop == 'NEQ' then
local gotten_val, prefix = var_prefix(ast[1])
local expected_value = ast[2].value
msg[1] = ('%s (expected anything other than %s, but got %s)'):format(prefix, fmt_val_with_type(expected_value), fmt_val(gotten_val))
elseif ast.exp == 'LVALUE' then
local gotten_val, prefix = var_prefix(ast)
msg[1] = ('%s (truthy expected, but got %s)'):format(prefix, fmt_val(gotten_val))
elseif CONSTANT_VALUE_TOKEN[ast.exp] then
local func_name = get_function_name(call_info)
msg[1] = ('this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name)
else
error(('[assert-gooder/internal]: Unknown expression type %s'):format(ast.exp))
end
end
return function (condition)
if condition then return condition end
--
local level = 2
local call_info = debug.getinfo(level)
call_info.locals = {}
for i = 1, math.huge do
local name, value = debug.getlocal(level, i)
if not name then break end
print(value, i <= call_info.nparams)
call_info.locals[name] = { value, i <= call_info.nparams, i }
end
--
local msg_container = {''}
local success, internal_error_msg = pcall(determine_error_message, call_info, msg_container, condition)
-- Handle internal errors
if not success then
io.stderr:write(('[assert-gooder/internal]: Internal error occured while determining error message for calling assert:\n %s'):format(internal_error_msg))
end
--
error(('assertion failed! %s'):format(msg_container[1]), 2)
end