1
0
Fork 0

Improved robustness.

This commit is contained in:
Jon Michael Aanes 2017-11-05 11:31:46 +01:00
parent 05930b8a2b
commit b1c6693f1f
5 changed files with 60 additions and 32 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
# Ignore editor files
.*.swp

View File

@ -65,14 +65,12 @@ local function parse (tokens)
elseif #tokens == 1 then
return get_value_token(tokens[1])
else
print(require'pretty'(tokens))
--assert(false)
return nil
error 'Unknown AST structure!'
end
end
local function get_assert_body_text (call_info)
if call_info.what == 'Lua' then
if call_info.what == 'Lua' or call_info.what == 'main' then
-- Find filetext
local filetext = nil
if call_info.source:find '^@' then
@ -105,21 +103,25 @@ local function get_assert_body (call_info)
return lexer:lex(text), text
end
local function get_variable (var_name, level)
local function get_variable (var_name, info, level)
--
assert(type(var_name) == 'string')
assert(type(info) == 'table')
assert(type(level) == 'number')
-- Local
local index = 0
local index, func = 0, info.func
repeat
index = index + 1
local name, val = debug.getlocal(level + 1, index)
if name == var_name then
local info = debug.getinfo(level + 1)
local is_par = index <= info.nparams
return val, is_par and ('argument #'..index) or 'local', info.name or is_par and ''
end
until not name
-- Up-value
local index, func = 0, debug.getinfo(level + 1).func
local index = 0
repeat
index = index + 1
local name, val = debug.getupvalue(func, index)
@ -127,7 +129,7 @@ local function get_variable (var_name, level)
until not name
-- Global
return getfenv(level + 1)[var_name], 'global'
return getfenv(func)[var_name], 'global'
end
local function get_function_name (call_info)
@ -162,17 +164,17 @@ local function fmt_lvalue (lvalue, var_scope)
end
end
local function get_variable_and_prefix (lvalue, level)
local function get_variable_and_prefix (lvalue, call_info, level)
assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE')
assert(type(level) == 'number')
--
local base_value, var_scope, in_func = get_variable(lvalue[1], level + 1)
-- Determine value of variable
local value = base_value
for i = 2, #lvalue do value = value[lvalue[i].value] end
--
local func_name = in_func and (' to '..get_function_name(debug.getinfo(level + 1))) or ''
return value, ('assertion failed! bad %s%s'):format(fmt_lvalue(lvalue, var_scope), func_name)
local base_value, var_scope, in_func = get_variable(lvalue[1], call_info, level + 1)
-- Determine value of variable
local value = base_value
for i = 2, #lvalue do value = value[lvalue[i].value] end
--
local func_name = in_func and (' to '..get_function_name(call_info)) or ''
return value, ('bad %s%s'):format(fmt_lvalue(lvalue, var_scope), func_name)
end
@ -198,39 +200,53 @@ end
--------------------------------------------------------------------------------
return function (condition)
if condition then return condition end
local call_info = debug.getinfo(2)
local function determine_error_message (call_info, msg, level, condition)
local tokens, body_text = get_assert_body(call_info)
local ast = parse(tokens)
if ast == nil then
error(('assertion failed! expression `%s` evaluated to %s'):format(body_text, condition), 2)
msg[1] = ('expression `%s` evaluated to %s'):format(body_text, condition)
local var_prefix = function(token) return get_variable_and_prefix(token, call_info, level + 2) end
if not ast then return nil
elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' and ast.left.exp == 'CALL' and ast.left.callee.exp == 'LVALUE' and ast.left.callee[1] == 'type' then
local gotten_val, prefix = get_variable_and_prefix(ast.left[1], 2)
error(('%s (%s expected, but got %s: %s)'):format(prefix, ast.right.value, type(gotten_val), fmt_val(gotten_val)), 2)
local gotten_val, prefix = var_prefix(ast.left[1])
msg[1] = ('%s (%s expected, but got %s: %s)'):format(prefix, ast.right.value, type(gotten_val), fmt_val(gotten_val))
elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' then
local gotten_value, prefix = get_variable_and_prefix(ast.left, 2)
local gotten_value, prefix = var_prefix(ast.left)
local expected_value = ast.right.value
local type_annotation = (type(expected_value) == type(gotten_value)) and '' or (' '..type(gotten_value))
error(('%s (%s expected, but got%s: %s)'):format(prefix, fmt_with_type(expected_value), type_annotation, fmt_val(gotten_value)), 2)
msg[1] = ('%s (%s expected, but got%s: %s)'):format(prefix, fmt_with_type(expected_value), type_annotation, fmt_val(gotten_value))
elseif ast.exp == 'COMPARE' and ast.binop == 'NEQ' then
local gotten_val, prefix = get_variable_and_prefix(ast.left, 2)
local gotten_val, prefix = var_prefix(ast.left)
local expected_value = ast.right.value
error(('%s (expected anything other than %s, but got %s)'):format(prefix, fmt_with_type(expected_value), fmt_val(gotten_val)), 2)
msg[1] = ('%s (expected anything other than %s, but got %s)'):format(prefix, fmt_with_type(expected_value), fmt_val(gotten_val))
elseif ast.exp == 'LVALUE' then
local gotten_val, prefix = get_variable_and_prefix(ast, 2)
error(('%s (truthy expected, but got %s)'):format(prefix, fmt_val(gotten_val)), 2)
local gotten_val, prefix = var_prefix(ast)
msg[1] = ('%s (truthy expected, but got %s)'):format(prefix, fmt_val(gotten_val))
elseif CONSTANT_VALUE_TOKEN[ast.exp] then
local func_name = get_function_name(call_info)
error(('assertion failed! this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name), 2)
msg[1] = ('this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name)
else
error(('[assert-gooder/internal]: Unknown expression type %s'):format(ast.exp))
end
end
return function (condition)
if condition then return condition end
local level = 2
local call_info = debug.getinfo(level)
local msg_container = {''}
local success, internal_error_msg = pcall(determine_error_message, call_info, msg_container, level, condition)
-- Handle internal errors
if not success then
io.stderr:write(('[assert-gooder/internal]: Internal error occured while determining error message for calling assert:\n %s'):format(internal_error_msg))
end
--
error(('assertion failed! %s'):format(msg_container[1]), 2)
end

View File

@ -1,2 +1,2 @@
return assert(require((... and select('1', ...):match('.+%')..'.' or '')..'assert-gooder'), '[assert-gooder]: Could not load vital library: assert-gooder')
return assert(require((... and select('1', ...):match('.+')..'.' or '')..'assert-gooder'), '[assert-gooder]: Could not load vital library: assert-gooder')

View File

@ -174,6 +174,15 @@ SUITE:addTest('constant nil', function ()
assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! this assert will always fail, as it\'s body is `nil`. assumingly this should be an unreachable part of the anonymous function at ./test/test_assert-gooder.lua:'..curline(-3), msg)
end)
SUITE:addTest('function as type argument', function ()
local _, msg = pcall(function ()
local f = function() end
assert(type(f) == 'string')
end)
-- TODO: How do we test this?
assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! bad local \'f\' (expected string, but got function)', msg)
end)
--------------------------------------------------------------------------------
return SUITE

View File

@ -3,4 +3,5 @@
local TEST_SUITE = require "TestSuite" 'assert-gooder'
TEST_SUITE:addModules 'test/test_*.lua'
TEST_SUITE:enableStrictGlobal()
TEST_SUITE:runTests()