diff --git a/assert-gooder.lua b/assert-gooder.lua index d060d0e..8af2c3d 100644 --- a/assert-gooder.lua +++ b/assert-gooder.lua @@ -154,7 +154,7 @@ local function get_module_filetext (module_filepath) -- Just attempt standard file open local filehandle = io.open(module_filepath, 'r') if filehandle then - filetext = filehandle:read '*all' + local filetext = filehandle:read '*all' filehandle:close() return filetext end @@ -167,6 +167,26 @@ local function get_module_filetext (module_filepath) return nil end +local function seperate_by_toplevel_commas (text) + assert(type(text) == 'string') + local section_start, index, sections = 1, 1, {} + while index < #text do + local next_comma = text:find(',', index) + local next_par_start, next_par_end = text:find('%b()', index) + if not next_comma then + break + elseif not next_par_start or next_comma < next_par_start then + sections[#sections+1] = text:sub(section_start, next_comma - 1) + index = next_comma + 1 + section_start = index + else + index = next_par_end + 1 + end + end + sections[#sections+1] = text:sub(section_start) + return sections +end + local function get_assert_body_text (call_info) if call_info.what == 'Lua' or call_info.what == 'main' then -- Find filetext @@ -190,7 +210,9 @@ local function get_assert_body_text (call_info) end end -- Find body exclusively. - return table.concat(lines_after, '\n'):match('assert%s*(%b())'):sub(2, -2) + local assert_arguments_text = table.concat(lines_after, '\n'):match('assert%s*(%b())'):sub(2, -2) + local assert_arguments = seperate_by_toplevel_commas(assert_arguments_text) + return assert_arguments[1] end error 'Not implemented yet!' @@ -365,7 +387,7 @@ local function determine_error_message (call_info, msg, condition) end end -return function (condition) +return function (condition, format, ...) if condition then return condition end -- local level = 2 @@ -383,15 +405,26 @@ return function (condition) if not success then io.stderr:write(('[assert-gooder/internal]: Internal error occured while determining error message for calling assert:\n %s\n'):format(internal_error_msg)) end - -- + + -- Format error message: assert(#msg_container <= 2 and type(msg_container[1]) == 'string') - local l = {'assertion failed! ', msg_container[1]} + local l = {} + if format ~= nil then + l[#l+1] = (type(format) == 'string') and format:format(...) or tostring(format) + l[#l+1] = ':' + else + l[#l+1] = 'assertion failed!' + end + l[#l+1] = ' ' + l[#l+1] = msg_container[1] if msg_container[2] then assert(type(msg_container[2]) == 'string') - l[3] = ' (' - l[4] = msg_container[2] - l[5] = ')' + l[#l+1] = ' (' + l[#l+1] = msg_container[2] + l[#l+1] = ')' end + + -- Throw error message error(table.concat(l, ''), 2) end diff --git a/test/test_assert-gooder.lua b/test/test_assert-gooder.lua index f20c62e..eb5ba4d 100644 --- a/test/test_assert-gooder.lua +++ b/test/test_assert-gooder.lua @@ -303,6 +303,27 @@ SUITE:addTest('Identify odd number', function () assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! bad local \'a\' (odd number expected, but got even number 5.21)', msg) end) +-------------------------------------------------------------------------------- +-- Custom error message + +SUITE:addTest('Custom error message', function () + local _, msg = pcall(function () + local a = 2 + assert(type(a) == 'string', 'expected string') + end) + assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'expected string: bad local \'a\' (string expected, but got number 2)', msg) +end) + +SUITE:addTest('Custom formatted message', function () + local _, msg = pcall(function () + local a = 2 + assert(type(a) == 'string', 'expected string not %s', type(a)) + end) + assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'expected string not number: bad local \'a\' (string expected, but got number 2)', msg) +end) + + + --------------------------------------------------------------------------------