Improved robustness.
This commit is contained in:
parent
05930b8a2b
commit
b1c6693f1f
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
# Ignore editor files
|
||||
.*.swp
|
|
@ -65,14 +65,12 @@ local function parse (tokens)
|
|||
elseif #tokens == 1 then
|
||||
return get_value_token(tokens[1])
|
||||
else
|
||||
print(require'pretty'(tokens))
|
||||
--assert(false)
|
||||
return nil
|
||||
error 'Unknown AST structure!'
|
||||
end
|
||||
end
|
||||
|
||||
local function get_assert_body_text (call_info)
|
||||
if call_info.what == 'Lua' then
|
||||
if call_info.what == 'Lua' or call_info.what == 'main' then
|
||||
-- Find filetext
|
||||
local filetext = nil
|
||||
if call_info.source:find '^@' then
|
||||
|
@ -105,21 +103,25 @@ local function get_assert_body (call_info)
|
|||
return lexer:lex(text), text
|
||||
end
|
||||
|
||||
local function get_variable (var_name, level)
|
||||
local function get_variable (var_name, info, level)
|
||||
--
|
||||
assert(type(var_name) == 'string')
|
||||
assert(type(info) == 'table')
|
||||
assert(type(level) == 'number')
|
||||
|
||||
-- Local
|
||||
local index = 0
|
||||
local index, func = 0, info.func
|
||||
repeat
|
||||
index = index + 1
|
||||
local name, val = debug.getlocal(level + 1, index)
|
||||
if name == var_name then
|
||||
local info = debug.getinfo(level + 1)
|
||||
local is_par = index <= info.nparams
|
||||
return val, is_par and ('argument #'..index) or 'local', info.name or is_par and ''
|
||||
end
|
||||
until not name
|
||||
|
||||
-- Up-value
|
||||
local index, func = 0, debug.getinfo(level + 1).func
|
||||
local index = 0
|
||||
repeat
|
||||
index = index + 1
|
||||
local name, val = debug.getupvalue(func, index)
|
||||
|
@ -127,7 +129,7 @@ local function get_variable (var_name, level)
|
|||
until not name
|
||||
|
||||
-- Global
|
||||
return getfenv(level + 1)[var_name], 'global'
|
||||
return getfenv(func)[var_name], 'global'
|
||||
end
|
||||
|
||||
local function get_function_name (call_info)
|
||||
|
@ -162,17 +164,17 @@ local function fmt_lvalue (lvalue, var_scope)
|
|||
end
|
||||
end
|
||||
|
||||
local function get_variable_and_prefix (lvalue, level)
|
||||
local function get_variable_and_prefix (lvalue, call_info, level)
|
||||
assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE')
|
||||
assert(type(level) == 'number')
|
||||
--
|
||||
local base_value, var_scope, in_func = get_variable(lvalue[1], level + 1)
|
||||
-- Determine value of variable
|
||||
local value = base_value
|
||||
for i = 2, #lvalue do value = value[lvalue[i].value] end
|
||||
--
|
||||
local func_name = in_func and (' to '..get_function_name(debug.getinfo(level + 1))) or ''
|
||||
return value, ('assertion failed! bad %s%s'):format(fmt_lvalue(lvalue, var_scope), func_name)
|
||||
local base_value, var_scope, in_func = get_variable(lvalue[1], call_info, level + 1)
|
||||
-- Determine value of variable
|
||||
local value = base_value
|
||||
for i = 2, #lvalue do value = value[lvalue[i].value] end
|
||||
--
|
||||
local func_name = in_func and (' to '..get_function_name(call_info)) or ''
|
||||
return value, ('bad %s%s'):format(fmt_lvalue(lvalue, var_scope), func_name)
|
||||
end
|
||||
|
||||
|
||||
|
@ -198,39 +200,53 @@ end
|
|||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
return function (condition)
|
||||
if condition then return condition end
|
||||
local call_info = debug.getinfo(2)
|
||||
local function determine_error_message (call_info, msg, level, condition)
|
||||
local tokens, body_text = get_assert_body(call_info)
|
||||
local ast = parse(tokens)
|
||||
|
||||
if ast == nil then
|
||||
error(('assertion failed! expression `%s` evaluated to %s'):format(body_text, condition), 2)
|
||||
msg[1] = ('expression `%s` evaluated to %s'):format(body_text, condition)
|
||||
local var_prefix = function(token) return get_variable_and_prefix(token, call_info, level + 2) end
|
||||
|
||||
if not ast then return nil
|
||||
elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' and ast.left.exp == 'CALL' and ast.left.callee.exp == 'LVALUE' and ast.left.callee[1] == 'type' then
|
||||
local gotten_val, prefix = get_variable_and_prefix(ast.left[1], 2)
|
||||
error(('%s (%s expected, but got %s: %s)'):format(prefix, ast.right.value, type(gotten_val), fmt_val(gotten_val)), 2)
|
||||
local gotten_val, prefix = var_prefix(ast.left[1])
|
||||
msg[1] = ('%s (%s expected, but got %s: %s)'):format(prefix, ast.right.value, type(gotten_val), fmt_val(gotten_val))
|
||||
|
||||
elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' then
|
||||
local gotten_value, prefix = get_variable_and_prefix(ast.left, 2)
|
||||
local gotten_value, prefix = var_prefix(ast.left)
|
||||
local expected_value = ast.right.value
|
||||
local type_annotation = (type(expected_value) == type(gotten_value)) and '' or (' '..type(gotten_value))
|
||||
error(('%s (%s expected, but got%s: %s)'):format(prefix, fmt_with_type(expected_value), type_annotation, fmt_val(gotten_value)), 2)
|
||||
msg[1] = ('%s (%s expected, but got%s: %s)'):format(prefix, fmt_with_type(expected_value), type_annotation, fmt_val(gotten_value))
|
||||
|
||||
elseif ast.exp == 'COMPARE' and ast.binop == 'NEQ' then
|
||||
local gotten_val, prefix = get_variable_and_prefix(ast.left, 2)
|
||||
local gotten_val, prefix = var_prefix(ast.left)
|
||||
local expected_value = ast.right.value
|
||||
error(('%s (expected anything other than %s, but got %s)'):format(prefix, fmt_with_type(expected_value), fmt_val(gotten_val)), 2)
|
||||
msg[1] = ('%s (expected anything other than %s, but got %s)'):format(prefix, fmt_with_type(expected_value), fmt_val(gotten_val))
|
||||
|
||||
elseif ast.exp == 'LVALUE' then
|
||||
local gotten_val, prefix = get_variable_and_prefix(ast, 2)
|
||||
error(('%s (truthy expected, but got %s)'):format(prefix, fmt_val(gotten_val)), 2)
|
||||
local gotten_val, prefix = var_prefix(ast)
|
||||
msg[1] = ('%s (truthy expected, but got %s)'):format(prefix, fmt_val(gotten_val))
|
||||
|
||||
elseif CONSTANT_VALUE_TOKEN[ast.exp] then
|
||||
local func_name = get_function_name(call_info)
|
||||
error(('assertion failed! this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name), 2)
|
||||
msg[1] = ('this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name)
|
||||
|
||||
else
|
||||
error(('[assert-gooder/internal]: Unknown expression type %s'):format(ast.exp))
|
||||
end
|
||||
end
|
||||
|
||||
return function (condition)
|
||||
if condition then return condition end
|
||||
local level = 2
|
||||
local call_info = debug.getinfo(level)
|
||||
local msg_container = {''}
|
||||
local success, internal_error_msg = pcall(determine_error_message, call_info, msg_container, level, condition)
|
||||
-- Handle internal errors
|
||||
if not success then
|
||||
io.stderr:write(('[assert-gooder/internal]: Internal error occured while determining error message for calling assert:\n %s'):format(internal_error_msg))
|
||||
end
|
||||
--
|
||||
error(('assertion failed! %s'):format(msg_container[1]), 2)
|
||||
end
|
||||
|
||||
|
|
2
init.lua
2
init.lua
|
@ -1,2 +1,2 @@
|
|||
|
||||
return assert(require((... and select('1', ...):match('.+%')..'.' or '')..'assert-gooder'), '[assert-gooder]: Could not load vital library: assert-gooder')
|
||||
return assert(require((... and select('1', ...):match('.+')..'.' or '')..'assert-gooder'), '[assert-gooder]: Could not load vital library: assert-gooder')
|
||||
|
|
|
@ -174,6 +174,15 @@ SUITE:addTest('constant nil', function ()
|
|||
assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! this assert will always fail, as it\'s body is `nil`. assumingly this should be an unreachable part of the anonymous function at ./test/test_assert-gooder.lua:'..curline(-3), msg)
|
||||
end)
|
||||
|
||||
SUITE:addTest('function as type argument', function ()
|
||||
local _, msg = pcall(function ()
|
||||
local f = function() end
|
||||
assert(type(f) == 'string')
|
||||
end)
|
||||
-- TODO: How do we test this?
|
||||
assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! bad local \'f\' (expected string, but got function)', msg)
|
||||
end)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
return SUITE
|
||||
|
|
|
@ -3,4 +3,5 @@
|
|||
|
||||
local TEST_SUITE = require "TestSuite" 'assert-gooder'
|
||||
TEST_SUITE:addModules 'test/test_*.lua'
|
||||
TEST_SUITE:enableStrictGlobal()
|
||||
TEST_SUITE:runTests()
|
||||
|
|
Loading…
Reference in New Issue
Block a user