From b1c6693f1fd2cb532d5963deb702eddc8837369e Mon Sep 17 00:00:00 2001 From: Jon Michael Aanes Date: Sun, 5 Nov 2017 11:31:46 +0100 Subject: [PATCH] Improved robustness. --- .gitignore | 2 + assert-gooder.lua | 78 ++++++++++++++++++++++--------------- init.lua | 2 +- test/test_assert-gooder.lua | 9 +++++ test/tests.lua | 1 + 5 files changed, 60 insertions(+), 32 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2f0cada --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +# Ignore editor files +.*.swp diff --git a/assert-gooder.lua b/assert-gooder.lua index fb8ffa3..a153539 100644 --- a/assert-gooder.lua +++ b/assert-gooder.lua @@ -65,14 +65,12 @@ local function parse (tokens) elseif #tokens == 1 then return get_value_token(tokens[1]) else - print(require'pretty'(tokens)) - --assert(false) - return nil + error 'Unknown AST structure!' end end local function get_assert_body_text (call_info) - if call_info.what == 'Lua' then + if call_info.what == 'Lua' or call_info.what == 'main' then -- Find filetext local filetext = nil if call_info.source:find '^@' then @@ -105,21 +103,25 @@ local function get_assert_body (call_info) return lexer:lex(text), text end -local function get_variable (var_name, level) +local function get_variable (var_name, info, level) + -- + assert(type(var_name) == 'string') + assert(type(info) == 'table') + assert(type(level) == 'number') + -- Local - local index = 0 + local index, func = 0, info.func repeat index = index + 1 local name, val = debug.getlocal(level + 1, index) if name == var_name then - local info = debug.getinfo(level + 1) local is_par = index <= info.nparams return val, is_par and ('argument #'..index) or 'local', info.name or is_par and '' end until not name -- Up-value - local index, func = 0, debug.getinfo(level + 1).func + local index = 0 repeat index = index + 1 local name, val = debug.getupvalue(func, index) @@ -127,7 +129,7 @@ local function get_variable (var_name, level) until not name -- Global - return getfenv(level + 1)[var_name], 'global' + return getfenv(func)[var_name], 'global' end local function get_function_name (call_info) @@ -162,17 +164,17 @@ local function fmt_lvalue (lvalue, var_scope) end end -local function get_variable_and_prefix (lvalue, level) +local function get_variable_and_prefix (lvalue, call_info, level) assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE') assert(type(level) == 'number') -- - local base_value, var_scope, in_func = get_variable(lvalue[1], level + 1) - -- Determine value of variable - local value = base_value - for i = 2, #lvalue do value = value[lvalue[i].value] end - -- - local func_name = in_func and (' to '..get_function_name(debug.getinfo(level + 1))) or '' - return value, ('assertion failed! bad %s%s'):format(fmt_lvalue(lvalue, var_scope), func_name) + local base_value, var_scope, in_func = get_variable(lvalue[1], call_info, level + 1) + -- Determine value of variable + local value = base_value + for i = 2, #lvalue do value = value[lvalue[i].value] end + -- + local func_name = in_func and (' to '..get_function_name(call_info)) or '' + return value, ('bad %s%s'):format(fmt_lvalue(lvalue, var_scope), func_name) end @@ -198,39 +200,53 @@ end -------------------------------------------------------------------------------- -return function (condition) - if condition then return condition end - local call_info = debug.getinfo(2) +local function determine_error_message (call_info, msg, level, condition) local tokens, body_text = get_assert_body(call_info) local ast = parse(tokens) - if ast == nil then - error(('assertion failed! expression `%s` evaluated to %s'):format(body_text, condition), 2) + msg[1] = ('expression `%s` evaluated to %s'):format(body_text, condition) + local var_prefix = function(token) return get_variable_and_prefix(token, call_info, level + 2) end + if not ast then return nil elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' and ast.left.exp == 'CALL' and ast.left.callee.exp == 'LVALUE' and ast.left.callee[1] == 'type' then - local gotten_val, prefix = get_variable_and_prefix(ast.left[1], 2) - error(('%s (%s expected, but got %s: %s)'):format(prefix, ast.right.value, type(gotten_val), fmt_val(gotten_val)), 2) + local gotten_val, prefix = var_prefix(ast.left[1]) + msg[1] = ('%s (%s expected, but got %s: %s)'):format(prefix, ast.right.value, type(gotten_val), fmt_val(gotten_val)) elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' then - local gotten_value, prefix = get_variable_and_prefix(ast.left, 2) + local gotten_value, prefix = var_prefix(ast.left) local expected_value = ast.right.value local type_annotation = (type(expected_value) == type(gotten_value)) and '' or (' '..type(gotten_value)) - error(('%s (%s expected, but got%s: %s)'):format(prefix, fmt_with_type(expected_value), type_annotation, fmt_val(gotten_value)), 2) + msg[1] = ('%s (%s expected, but got%s: %s)'):format(prefix, fmt_with_type(expected_value), type_annotation, fmt_val(gotten_value)) elseif ast.exp == 'COMPARE' and ast.binop == 'NEQ' then - local gotten_val, prefix = get_variable_and_prefix(ast.left, 2) + local gotten_val, prefix = var_prefix(ast.left) local expected_value = ast.right.value - error(('%s (expected anything other than %s, but got %s)'):format(prefix, fmt_with_type(expected_value), fmt_val(gotten_val)), 2) + msg[1] = ('%s (expected anything other than %s, but got %s)'):format(prefix, fmt_with_type(expected_value), fmt_val(gotten_val)) elseif ast.exp == 'LVALUE' then - local gotten_val, prefix = get_variable_and_prefix(ast, 2) - error(('%s (truthy expected, but got %s)'):format(prefix, fmt_val(gotten_val)), 2) + local gotten_val, prefix = var_prefix(ast) + msg[1] = ('%s (truthy expected, but got %s)'):format(prefix, fmt_val(gotten_val)) elseif CONSTANT_VALUE_TOKEN[ast.exp] then local func_name = get_function_name(call_info) - error(('assertion failed! this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name), 2) + msg[1] = ('this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name) else error(('[assert-gooder/internal]: Unknown expression type %s'):format(ast.exp)) end end + +return function (condition) + if condition then return condition end + local level = 2 + local call_info = debug.getinfo(level) + local msg_container = {''} + local success, internal_error_msg = pcall(determine_error_message, call_info, msg_container, level, condition) + -- Handle internal errors + if not success then + io.stderr:write(('[assert-gooder/internal]: Internal error occured while determining error message for calling assert:\n %s'):format(internal_error_msg)) + end + -- + error(('assertion failed! %s'):format(msg_container[1]), 2) +end + diff --git a/init.lua b/init.lua index c4bab98..a124f7d 100644 --- a/init.lua +++ b/init.lua @@ -1,2 +1,2 @@ -return assert(require((... and select('1', ...):match('.+%')..'.' or '')..'assert-gooder'), '[assert-gooder]: Could not load vital library: assert-gooder') +return assert(require((... and select('1', ...):match('.+')..'.' or '')..'assert-gooder'), '[assert-gooder]: Could not load vital library: assert-gooder') diff --git a/test/test_assert-gooder.lua b/test/test_assert-gooder.lua index 4138677..b97d772 100644 --- a/test/test_assert-gooder.lua +++ b/test/test_assert-gooder.lua @@ -174,6 +174,15 @@ SUITE:addTest('constant nil', function () assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! this assert will always fail, as it\'s body is `nil`. assumingly this should be an unreachable part of the anonymous function at ./test/test_assert-gooder.lua:'..curline(-3), msg) end) +SUITE:addTest('function as type argument', function () + local _, msg = pcall(function () + local f = function() end + assert(type(f) == 'string') + end) + -- TODO: How do we test this? + assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! bad local \'f\' (expected string, but got function)', msg) +end) + -------------------------------------------------------------------------------- return SUITE diff --git a/test/tests.lua b/test/tests.lua index ea4acd9..17bce2c 100644 --- a/test/tests.lua +++ b/test/tests.lua @@ -3,4 +3,5 @@ local TEST_SUITE = require "TestSuite" 'assert-gooder' TEST_SUITE:addModules 'test/test_*.lua' + TEST_SUITE:enableStrictGlobal() TEST_SUITE:runTests()