From 6f3da373f934639174d5dee725895b606b4b90cf Mon Sep 17 00:00:00 2001 From: Jon Michael Aanes Date: Thu, 2 Nov 2017 11:21:00 +0100 Subject: [PATCH] Moved to using an AST based approach. The AST construction is not very advanced yet, but no regressions have occured. --- assert-gooder.lua | 146 ++++++++++++++++++++++++------------ test/test_assert-gooder.lua | 36 ++++++++- 2 files changed, 134 insertions(+), 48 deletions(-) diff --git a/assert-gooder.lua b/assert-gooder.lua index 7405e37..7712233 100644 --- a/assert-gooder.lua +++ b/assert-gooder.lua @@ -3,6 +3,71 @@ local lexer = require 'lua_lang' -------------------------------------------------------------------------------- + +local function get_value_of_string (string_str) + if string_str:sub(1, 1) == '"' or string_str:sub(1, 1) == '\'' then + return string_str:sub(2, -2) + end + assert(false) +end + +local CONSTANT_VALUE_TOKEN = { + NUMBER = tonumber, + STRING = get_value_of_string, + TRUE = function() return true end, + FALSE = function() return false end, + NIL = function() return nil end +} + +local VALUE_TOKEN = { IDENTIFIER = true } +for k in pairs(CONSTANT_VALUE_TOKEN) do VALUE_TOKEN[k] = true end + +local COMPARE_BINOP = { + EQ = true, + NEQ = true, + LEQ = true, + GEQ = true, + LE = true, + GT = true, +} + +local function get_value_token (token) + if CONSTANT_VALUE_TOKEN[token.token]then + return { exp = token.token, value = CONSTANT_VALUE_TOKEN[token.token](token.text) } + elseif token.token == 'IDENTIFIER' then + return { exp = 'LVALUE', token.text } + end + assert(false) +end + +local function parse (tokens) + -- TODO: Make a more general parser + assert(type(tokens) == 'table') + -- + + if #tokens == 6 and tokens[1].text == 'type' and tokens[2].token == 'LPAR' and tokens[3].token == 'IDENTIFIER' and tokens[4].token == 'RPAR'and tokens[5].token == 'EQ'and tokens[6].token == 'STRING' then + return { + exp = 'COMPARE', + binop = 'EQ', + left = { exp = 'CALL', callee = get_value_token(tokens[1]), get_value_token(tokens[3]) }, + right = get_value_token(tokens[6]), + } + elseif #tokens == 3 and VALUE_TOKEN[tokens[1].token] and COMPARE_BINOP[tokens[2].token] and VALUE_TOKEN[tokens[3].token] then + return { + exp = 'COMPARE', + binop = tokens[2].token, + left = get_value_token(tokens[1]), + right = get_value_token(tokens[3]) + } + elseif #tokens == 1 then + return get_value_token(tokens[1]) + else + print(require'pretty'(tokens)) + --assert(false) + return nil + end +end + local function get_assert_body_text (call_info) if call_info.what == 'Lua' then -- Find filetext @@ -78,13 +143,6 @@ local function get_function_name (call_info) return string.format('the anonymous function %s', where) end -local function get_value_of_string (string_str) - if string_str:sub(1, 1) == '"' or string_str:sub(1, 1) == '\'' then - return string_str:sub(2, -2) - end - assert(false) -end - local function fmt_val (val) if type(val) == 'string' then return string.format('%q', val) @@ -93,23 +151,21 @@ local function fmt_val (val) end end -local function get_variable_and_prefix (gotten_name, level) - assert(type(gotten_name) == 'string') - assert(type(level) == 'number') - -- - local gotten_val, var_scope, in_func = get_variable(gotten_name, level + 1) - local func_name = in_func and (' to '..get_function_name(debug.getinfo(level + 1))) or '' - return gotten_val, ('assertion failed! bad %s \'%s\'%s'):format(var_scope, gotten_name, func_name) +local function fmt_lvalue (lvalue) + assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE') + return table.concat(lvalue, '.') end +local function get_variable_and_prefix (lvalue, level) + assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE') + assert(type(level) == 'number') + -- + local base_value, var_scope, in_func = get_variable(lvalue[1], level + 1) + local value = base_value -- TODO: Generalize to any lvalue + local func_name = in_func and (' to '..get_function_name(debug.getinfo(level + 1))) or '' + return value, ('assertion failed! bad %s \'%s\'%s'):format(var_scope, fmt_lvalue(lvalue), func_name) +end -local CONSTANT_VALUE_TOKEN = { - NUMBER = tonumber, - STRING = get_value_of_string, - TRUE = function() return true end, - FALSE = function() return false end, - NIL = function() return nil end -} local PRIMITIVE_VALUES = { ['nil'] = true, @@ -122,11 +178,6 @@ local COMPLEX_TYPES = { ['cdata'] = true, } -local function get_value_of_const_token (token) - assert(CONSTANT_VALUE_TOKEN[token.token]) - return CONSTANT_VALUE_TOKEN[token.token](token.text) -end - local function fmt_with_type (val) -- Primitive values ARE their type, and don't need the annotation. if PRIMITIVE_VALUES[type(val)] then return tostring(val) end @@ -142,32 +193,35 @@ return function (condition) if condition then return condition end local call_info = debug.getinfo(2) local tokens, body_text = get_assert_body(call_info) + local ast = parse(tokens) - if #tokens == 6 and tokens[1].text == 'type' and tokens[2].token == 'LPAR' and tokens[3].token == 'IDENTIFIER' and tokens[4].token == 'RPAR'and tokens[5].token == 'EQ'and tokens[6].token == 'STRING' then - local gotten_val, prefix = get_variable_and_prefix(tokens[3].text, 2) - error(('%s (%s expected, but got %s: %s)'):format(prefix, get_value_of_string(tokens[6].text), type(gotten_val), fmt_val(gotten_val)), 2) - elseif #tokens == 3 and tokens[1].token == 'IDENTIFIER' and tokens[2].token == 'EQ' and CONSTANT_VALUE_TOKEN[tokens[3].token] then - local gotten_val, prefix = get_variable_and_prefix(tokens[1].text, 2) + if ast == nil then + error(('assertion failed! expression `%s` evaluated to %s'):format(body_text, condition), 2) - local expected_value = get_value_of_const_token(tokens[3]) - local type_annotation = (type(expected_value) == type(gotten_val)) and '' or (' '..type(gotten_val)) + elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' and ast.left.exp == 'CALL' and ast.left.callee.exp == 'LVALUE' and ast.left.callee[1] == 'type' then + local gotten_val, prefix = get_variable_and_prefix(ast.left[1], 2) + error(('%s (%s expected, but got %s: %s)'):format(prefix, ast.right.value, type(gotten_val), fmt_val(gotten_val)), 2) - error(('%s (%s expected, but got%s: %s)'):format(prefix, fmt_with_type(expected_value), type_annotation, fmt_val(gotten_val)), 2) - elseif #tokens == 3 and tokens[1].token == 'IDENTIFIER' and tokens[2].token == 'NEQ' and CONSTANT_VALUE_TOKEN[tokens[3].token] then - local gotten_val, prefix = get_variable_and_prefix(tokens[1].text, 2) - - local expected_value = get_value_of_const_token(tokens[3]) + elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' then + local gotten_value, prefix = get_variable_and_prefix(ast.left, 2) + local expected_value = ast.right.value + local type_annotation = (type(expected_value) == type(gotten_value)) and '' or (' '..type(gotten_value)) + error(('%s (%s expected, but got%s: %s)'):format(prefix, fmt_with_type(expected_value), type_annotation, fmt_val(gotten_value)), 2) + elseif ast.exp == 'COMPARE' and ast.binop == 'NEQ' then + local gotten_val, prefix = get_variable_and_prefix(ast.left, 2) + local expected_value = ast.right.value error(('%s (expected anything other than %s, but got %s)'):format(prefix, fmt_with_type(expected_value), fmt_val(gotten_val)), 2) - elseif #tokens == 1 and tokens[1].token == 'IDENTIFIER' then - local gotten_val, prefix = get_variable_and_prefix(tokens[1].text, 2) + + elseif ast.exp == 'LVALUE' then + local gotten_val, prefix = get_variable_and_prefix(ast, 2) error(('%s (truthy expected, but got %s)'):format(prefix, fmt_val(gotten_val)), 2) - elseif #tokens == 1 and (tokens[1].token == 'NIL' or tokens[1].token == 'FALSE') then + + elseif CONSTANT_VALUE_TOKEN[ast.exp] then local func_name = get_function_name(call_info) error(('assertion failed! this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name), 2) - elseif #tokens == 1 then - error 'should be unreachable!' - else - error(('assertion failed! expression `%s` evaluated to %s'):format(body_text, condition), 2) - end + + else + error(('[assert-gooder/internal]: Unknown expression type %s'):format(ast.exp)) + end end diff --git a/test/test_assert-gooder.lua b/test/test_assert-gooder.lua index 002b61b..8652f6c 100644 --- a/test/test_assert-gooder.lua +++ b/test/test_assert-gooder.lua @@ -47,6 +47,31 @@ SUITE:addTest('argument to named function', function () assert_equal('./test/test_assert-gooder.lua:'..curline(-3)..': '..'assertion failed! bad argument #1 \'a\' to \'f\' (string expected, but got number: 2)', msg) end) + +SUITE:addTest('indexing', function () + local _, msg = pcall(function () + local a = { b = 39 } + assert(type(a.b) == 'string') + end) + assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! bad key "b" in local \'a\' (string expected, but got number: 39)', msg) +end) + +SUITE:addTest('subscript constant', function () + local _, msg = pcall(function () + local a = { 4, 2, 3, 6 } + assert(type(a.b) == 'string') + end) + assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! bad key 2 in local \'a\' (string expected, but got number: 2)', msg) +end) + +SUITE:addTest('subscript variable', function () + local _, msg = pcall(function () + local a, i = { 4, 2, 3, 6 }, 2 + assert(type(a.b) == 'string') + end) + assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! bad key 2 in local \'a\' (string expected, but got number: 2)', msg) +end) + -------------------------------------------------------------------------------- SUITE:addTest('can improve asserts in loaded strings too', function () @@ -98,6 +123,15 @@ SUITE:addTest('truthy', function () assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! bad local \'a\' (truthy expected, but got false)', msg) end) +SUITE:addTest('truthy indexing', function () + local _, msg = pcall(function () + local a = { b = false } + assert(a.b) + end) + assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! bad key "b" in local \'a\' (truthy expected, but got false)', msg) +end) + + SUITE:addTest('not equal', function (constant_value, msg_in_pars) local func = loadstring (("return function() local a = %s; assert(a ~= %s) end"):format(constant_value, constant_value)) () local _, msg = pcall(setfenv(func, getfenv())) @@ -124,8 +158,6 @@ SUITE:addTest('constant nil', function () assert_equal('./test/test_assert-gooder.lua:'..curline(-2)..': '..'assertion failed! this assert will always fail, as it\'s body is `nil`. assumingly this should be an unreachable part of the anonymous function at ./test/test_assert-gooder.lua:'..curline(-3), msg) end) - - -------------------------------------------------------------------------------- return SUITE