1
0

Improved robustness.

This commit is contained in:
Jon Michael Aanes 2017-11-05 11:31:46 +01:00
parent 05930b8a2b
commit b1c6693f1f
5 changed files with 60 additions and 32 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
# Ignore editor files
.*.swp

View File

@ -65,14 +65,12 @@ local function parse (tokens)
elseif #tokens == 1 then elseif #tokens == 1 then
return get_value_token(tokens[1]) return get_value_token(tokens[1])
else else
print(require'pretty'(tokens)) error 'Unknown AST structure!'
--assert(false)
return nil
end end
end end
local function get_assert_body_text (call_info) local function get_assert_body_text (call_info)
if call_info.what == 'Lua' then if call_info.what == 'Lua' or call_info.what == 'main' then
-- Find filetext -- Find filetext
local filetext = nil local filetext = nil
if call_info.source:find '^@' then if call_info.source:find '^@' then
@ -105,21 +103,25 @@ local function get_assert_body (call_info)
return lexer:lex(text), text return lexer:lex(text), text
end end
local function get_variable (var_name, level) local function get_variable (var_name, info, level)
--
assert(type(var_name) == 'string')
assert(type(info) == 'table')
assert(type(level) == 'number')
-- Local -- Local
local index = 0 local index, func = 0, info.func
repeat repeat
index = index + 1 index = index + 1
local name, val = debug.getlocal(level + 1, index) local name, val = debug.getlocal(level + 1, index)
if name == var_name then if name == var_name then
local info = debug.getinfo(level + 1)
local is_par = index <= info.nparams local is_par = index <= info.nparams
return val, is_par and ('argument #'..index) or 'local', info.name or is_par and '' return val, is_par and ('argument #'..index) or 'local', info.name or is_par and ''
end end
until not name until not name
-- Up-value -- Up-value
local index, func = 0, debug.getinfo(level + 1).func local index = 0
repeat repeat
index = index + 1 index = index + 1
local name, val = debug.getupvalue(func, index) local name, val = debug.getupvalue(func, index)
@ -127,7 +129,7 @@ local function get_variable (var_name, level)
until not name until not name
-- Global -- Global
return getfenv(level + 1)[var_name], 'global' return getfenv(func)[var_name], 'global'
end end
local function get_function_name (call_info) local function get_function_name (call_info)
@ -162,17 +164,17 @@ local function fmt_lvalue (lvalue, var_scope)
end end
end end
local function get_variable_and_prefix (lvalue, level) local function get_variable_and_prefix (lvalue, call_info, level)
assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE') assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE')
assert(type(level) == 'number') assert(type(level) == 'number')
-- --
local base_value, var_scope, in_func = get_variable(lvalue[1], level + 1) local base_value, var_scope, in_func = get_variable(lvalue[1], call_info, level + 1)
-- Determine value of variable -- Determine value of variable
local value = base_value local value = base_value
for i = 2, #lvalue do value = value[lvalue[i].value] end for i = 2, #lvalue do value = value[lvalue[i].value] end
-- --
local func_name = in_func and (' to '..get_function_name(debug.getinfo(level + 1))) or '' local func_name = in_func and (' to '..get_function_name(call_info)) or ''
return value, ('assertion failed! bad %s%s'):format(fmt_lvalue(lvalue, var_scope), func_name) return value, ('bad %s%s'):format(fmt_lvalue(lvalue, var_scope), func_name)
end end
@ -198,39 +200,53 @@ end
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
return function (condition) local function determine_error_message (call_info, msg, level, condition)
if condition then return condition end
local call_info = debug.getinfo(2)
local tokens, body_text = get_assert_body(call_info) local tokens, body_text = get_assert_body(call_info)
local ast = parse(tokens) local ast = parse(tokens)
if ast == nil then msg[1] = ('expression `%s` evaluated to %s'):format(body_text, condition)
error(('assertion failed! expression `%s` evaluated to %s'):format(body_text, condition), 2) local var_prefix = function(token) return get_variable_and_prefix(token, call_info, level + 2) end
if not ast then return nil
elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' and ast.left.exp == 'CALL' and ast.left.callee.exp == 'LVALUE' and ast.left.callee[1] == 'type' then elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' and ast.left.exp == 'CALL' and ast.left.callee.exp == 'LVALUE' and ast.left.callee[1] == 'type' then
local gotten_val, prefix = get_variable_and_prefix(ast.left[1], 2) local gotten_val, prefix = var_prefix(ast.left[1])
error(('%s (%s expected, but got %s: %s)'):format(prefix, ast.right.value, type(gotten_val), fmt_val(gotten_val)), 2) msg[1] = ('%s (%s expected, but got %s: %s)'):format(prefix, ast.right.value, type(gotten_val), fmt_val(gotten_val))
elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' then elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' then
local gotten_value, prefix = get_variable_and_prefix(ast.left, 2) local gotten_value, prefix = var_prefix(ast.left)
local expected_value = ast.right.value local expected_value = ast.right.value
local type_annotation = (type(expected_value) == type(gotten_value)) and '' or (' '..type(gotten_value)) local type_annotation = (type(expected_value) == type(gotten_value)) and '' or (' '..type(gotten_value))
error(('%s (%s expected, but got%s: %s)'):format(prefix, fmt_with_type(expected_value), type_annotation, fmt_val(gotten_value)), 2) msg[1] = ('%s (%s expected, but got%s: %s)'):format(prefix, fmt_with_type(expected_value), type_annotation, fmt_val(gotten_value))
elseif ast.exp == 'COMPARE' and ast.binop == 'NEQ' then elseif ast.exp == 'COMPARE' and ast.binop == 'NEQ' then
local gotten_val, prefix = get_variable_and_prefix(ast.left, 2) local gotten_val, prefix = var_prefix(ast.left)
local expected_value = ast.right.value local expected_value = ast.right.value
error(('%s (expected anything other than %s, but got %s)'):format(prefix, fmt_with_type(expected_value), fmt_val(gotten_val)), 2) msg[1] = ('%s (expected anything other than %s, but got %s)'):format(prefix, fmt_with_type(expected_value), fmt_val(gotten_val))
elseif ast.exp == 'LVALUE' then elseif ast.exp == 'LVALUE' then
local gotten_val, prefix = get_variable_and_prefix(ast, 2) local gotten_val, prefix = var_prefix(ast)
error(('%s (truthy expected, but got %s)'):format(prefix, fmt_val(gotten_val)), 2) msg[1] = ('%s (truthy expected, but got %s)'):format(prefix, fmt_val(gotten_val))
elseif CONSTANT_VALUE_TOKEN[ast.exp] then elseif CONSTANT_VALUE_TOKEN[ast.exp] then
local func_name = get_function_name(call_info) local func_name = get_function_name(call_info)
error(('assertion failed! this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name), 2) msg[1] = ('this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name)
else else
error(('[assert-gooder/internal]: Unknown expression type %s'):format(ast.exp)) error(('[assert-gooder/internal]: Unknown expression type %s'):format(ast.exp))
end end
end end
return function (condition)
if condition then return condition end
local level = 2
local call_info = debug.getinfo(level)
local msg_container = {''}
local success, internal_error_msg = pcall(determine_error_message, call_info, msg_container, level, condition)
-- Handle internal errors
if not success then
io.stderr:write(('[assert-gooder/internal]: Internal error occured while determining error message for calling assert:\n %s'):format(internal_error_msg))
end
--
error(('assertion failed! %s'):format(msg_container[1]), 2)
end

View File

@ -1,2 +1,2 @@
return assert(require((... and select('1', ...):match('.+%')..'.' or '')..'assert-gooder'), '[assert-gooder]: Could not load vital library: assert-gooder') return assert(require((... and select('1', ...):match('.+')..'.' or '')..'assert-gooder'), '[assert-gooder]: Could not load vital library: assert-gooder')

View File

@ -174,6 +174,15 @@ SUITE:addTest('constant nil', function ()
assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! this assert will always fail, as it\'s body is `nil`. assumingly this should be an unreachable part of the anonymous function at ./test/test_assert-gooder.lua:'..curline(-3), msg) assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! this assert will always fail, as it\'s body is `nil`. assumingly this should be an unreachable part of the anonymous function at ./test/test_assert-gooder.lua:'..curline(-3), msg)
end) end)
SUITE:addTest('function as type argument', function ()
local _, msg = pcall(function ()
local f = function() end
assert(type(f) == 'string')
end)
-- TODO: How do we test this?
assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! bad local \'f\' (expected string, but got function)', msg)
end)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
return SUITE return SUITE

View File

@ -3,4 +3,5 @@
local TEST_SUITE = require "TestSuite" 'assert-gooder' local TEST_SUITE = require "TestSuite" 'assert-gooder'
TEST_SUITE:addModules 'test/test_*.lua' TEST_SUITE:addModules 'test/test_*.lua'
TEST_SUITE:enableStrictGlobal()
TEST_SUITE:runTests() TEST_SUITE:runTests()