2017-10-28 10:25:10 +00:00
2017-11-04 12:45:28 +00:00
local lexer = assert ( require ( ( ... and select ( ' 1 ' , ... ) : match ( ' .+%. ' ) or ' ' ) .. ' lua_lang ' ) , ' [assert-gooder]: Could not load vital library: lua_lang ' )
2017-10-28 10:25:10 +00:00
--------------------------------------------------------------------------------
2017-11-02 10:21:00 +00:00
local function get_value_of_string ( string_str )
if string_str : sub ( 1 , 1 ) == ' " ' or string_str : sub ( 1 , 1 ) == ' \' ' then
return string_str : sub ( 2 , - 2 )
end
assert ( false )
end
local CONSTANT_VALUE_TOKEN = {
NUMBER = tonumber ,
STRING = get_value_of_string ,
TRUE = function ( ) return true end ,
FALSE = function ( ) return false end ,
NIL = function ( ) return nil end
}
local VALUE_TOKEN = { IDENTIFIER = true }
for k in pairs ( CONSTANT_VALUE_TOKEN ) do VALUE_TOKEN [ k ] = true end
local COMPARE_BINOP = {
EQ = true ,
NEQ = true ,
LEQ = true ,
GEQ = true ,
LE = true ,
GT = true ,
}
local function get_value_token ( token )
if CONSTANT_VALUE_TOKEN [ token.token ] then
return { exp = token.token , value = CONSTANT_VALUE_TOKEN [ token.token ] ( token.text ) }
elseif token.token == ' IDENTIFIER ' then
return { exp = ' LVALUE ' , token.text }
end
assert ( false )
end
local function parse ( tokens )
-- TODO: Make a more general parser
assert ( type ( tokens ) == ' table ' )
--
if # tokens == 6 and tokens [ 1 ] . text == ' type ' and tokens [ 2 ] . token == ' LPAR ' and tokens [ 3 ] . token == ' IDENTIFIER ' and tokens [ 4 ] . token == ' RPAR ' and tokens [ 5 ] . token == ' EQ ' and tokens [ 6 ] . token == ' STRING ' then
return {
exp = ' COMPARE ' ,
binop = ' EQ ' ,
left = { exp = ' CALL ' , callee = get_value_token ( tokens [ 1 ] ) , get_value_token ( tokens [ 3 ] ) } ,
right = get_value_token ( tokens [ 6 ] ) ,
}
elseif # tokens == 3 and VALUE_TOKEN [ tokens [ 1 ] . token ] and COMPARE_BINOP [ tokens [ 2 ] . token ] and VALUE_TOKEN [ tokens [ 3 ] . token ] then
return {
exp = ' COMPARE ' ,
binop = tokens [ 2 ] . token ,
left = get_value_token ( tokens [ 1 ] ) ,
right = get_value_token ( tokens [ 3 ] )
}
2017-11-02 10:35:49 +00:00
elseif # tokens == 3 and tokens [ 1 ] . token == ' IDENTIFIER ' and tokens [ 2 ] . token == ' DOT ' and tokens [ 3 ] . token == ' IDENTIFIER ' then
return { exp = ' LVALUE ' , tokens [ 1 ] . text , { exp = ' STRING ' , value = tokens [ 3 ] . text } }
elseif # tokens == 4 and tokens [ 1 ] . token == ' IDENTIFIER ' and tokens [ 2 ] . token == ' LBRACK ' and VALUE_TOKEN [ tokens [ 3 ] . token ] and tokens [ 4 ] . token == ' RBRACK ' then
return { exp = ' LVALUE ' , tokens [ 1 ] . text , get_value_token ( tokens [ 3 ] ) }
2017-11-02 10:21:00 +00:00
elseif # tokens == 1 then
return get_value_token ( tokens [ 1 ] )
else
print ( require ' pretty ' ( tokens ) )
--assert(false)
return nil
end
end
2017-10-28 10:25:10 +00:00
local function get_assert_body_text ( call_info )
if call_info.what == ' Lua ' then
-- Find filetext
local filetext = nil
if call_info.source : find ' ^@ ' then
local f = io.open ( call_info.short_src , ' r ' )
filetext = f : read ' *all '
f : close ( )
elseif call_info.short_src : find ' ^%[string ' then
filetext = call_info.source
else
error ' Not implemented yet! '
end
-- Get lines
local filetext = filetext .. ' \n '
local lines_after , line_i = { } , 0
for line in filetext : gmatch ' ([^ \r \n ]*)[ \r \n ] ' do
line_i = line_i + 1
if call_info.currentline == line_i then
lines_after [ # lines_after + 1 ] = line
end
end
-- Find body exclusively.
return table.concat ( lines_after , ' \n ' ) : match ( ' assert%s*(%b()) ' ) : sub ( 2 , - 2 )
end
error ' Not implemented yet! '
end
local function get_assert_body ( call_info )
local text = get_assert_body_text ( call_info )
return lexer : lex ( text ) , text
end
local function get_variable ( var_name , level )
-- Local
local index = 0
repeat
index = index + 1
local name , val = debug.getlocal ( level + 1 , index )
if name == var_name then
local info = debug.getinfo ( level + 1 )
local is_par = index <= info.nparams
return val , is_par and ( ' argument # ' .. index ) or ' local ' , info.name or is_par and ' '
end
until not name
-- Up-value
local index , func = 0 , debug.getinfo ( level + 1 ) . func
repeat
index = index + 1
local name , val = debug.getupvalue ( func , index )
if name == var_name then return val , ' upvalue ' end
until not name
-- Global
return getfenv ( level + 1 ) [ var_name ] , ' global '
end
2017-10-29 09:16:16 +00:00
local function get_function_name ( call_info )
--
if call_info.name then return string.format ( ' \' %s \' ' , call_info.name ) end
--
local where = nil
if call_info.source : find ' ^@ ' then
where = ' at ' .. call_info.short_src .. ' : ' .. call_info.linedefined
elseif call_info.short_src : find ' ^%[string ' then
where = ' from loaded string '
else
error ' not yet implemented '
end
--
return string.format ( ' the anonymous function %s ' , where )
end
2017-10-28 11:15:57 +00:00
local function fmt_val ( val )
if type ( val ) == ' string ' then
return string.format ( ' %q ' , val )
else
return tostring ( val )
end
end
2017-11-02 10:35:49 +00:00
local function fmt_lvalue ( lvalue , var_scope )
2017-11-02 10:21:00 +00:00
assert ( type ( lvalue ) == ' table ' and lvalue.exp == ' LVALUE ' )
2017-11-02 10:35:49 +00:00
if # lvalue == 1 then return string.format ( ' %s \' %s \' ' , var_scope , lvalue [ 1 ] )
elseif # lvalue == 2 then return string.format ( ' key %s in %s \' %s \' ' , fmt_val ( lvalue [ 2 ] . value ) , var_scope , lvalue [ 1 ] )
else error ' Not implemented yet! '
end
2017-11-02 10:21:00 +00:00
end
local function get_variable_and_prefix ( lvalue , level )
assert ( type ( lvalue ) == ' table ' and lvalue.exp == ' LVALUE ' )
2017-10-28 11:15:57 +00:00
assert ( type ( level ) == ' number ' )
--
2017-11-02 10:21:00 +00:00
local base_value , var_scope , in_func = get_variable ( lvalue [ 1 ] , level + 1 )
2017-11-02 10:35:49 +00:00
-- Determine value of variable
local value = base_value
for i = 2 , # lvalue do value = value [ lvalue [ i ] . value ] end
--
2017-10-29 09:16:16 +00:00
local func_name = in_func and ( ' to ' .. get_function_name ( debug.getinfo ( level + 1 ) ) ) or ' '
2017-11-02 10:35:49 +00:00
return value , ( ' assertion failed! bad %s%s ' ) : format ( fmt_lvalue ( lvalue , var_scope ) , func_name )
2017-10-28 11:15:57 +00:00
end
local PRIMITIVE_VALUES = {
[ ' nil ' ] = true ,
[ ' boolean ' ] = true ,
}
local COMPLEX_TYPES = {
[ ' table ' ] = true ,
[ ' userdata ' ] = true ,
[ ' cdata ' ] = true ,
}
local function fmt_with_type ( val )
-- Primitive values ARE their type, and don't need the annotation.
if PRIMITIVE_VALUES [ type ( val ) ] then return tostring ( val ) end
-- Complex types are already formatted with some type information.
if COMPLEX_TYPES [ type ( val ) ] then return tostring ( val ) end
-- Numbers and string should have their types with them.
return type ( val ) .. ' ' .. fmt_val ( val )
end
2017-10-28 10:25:10 +00:00
--------------------------------------------------------------------------------
return function ( condition )
if condition then return condition end
local call_info = debug.getinfo ( 2 )
local tokens , body_text = get_assert_body ( call_info )
2017-11-02 10:21:00 +00:00
local ast = parse ( tokens )
2017-10-28 10:25:10 +00:00
2017-11-02 10:21:00 +00:00
if ast == nil then
error ( ( ' assertion failed! expression `%s` evaluated to %s ' ) : format ( body_text , condition ) , 2 )
2017-10-28 11:15:57 +00:00
2017-11-02 10:21:00 +00:00
elseif ast.exp == ' COMPARE ' and ast.binop == ' EQ ' and ast.left . exp == ' CALL ' and ast.left . callee.exp == ' LVALUE ' and ast.left . callee [ 1 ] == ' type ' then
local gotten_val , prefix = get_variable_and_prefix ( ast.left [ 1 ] , 2 )
error ( ( ' %s (%s expected, but got %s: %s) ' ) : format ( prefix , ast.right . value , type ( gotten_val ) , fmt_val ( gotten_val ) ) , 2 )
2017-10-28 11:15:57 +00:00
2017-11-02 10:21:00 +00:00
elseif ast.exp == ' COMPARE ' and ast.binop == ' EQ ' then
local gotten_value , prefix = get_variable_and_prefix ( ast.left , 2 )
local expected_value = ast.right . value
local type_annotation = ( type ( expected_value ) == type ( gotten_value ) ) and ' ' or ( ' ' .. type ( gotten_value ) )
error ( ( ' %s (%s expected, but got%s: %s) ' ) : format ( prefix , fmt_with_type ( expected_value ) , type_annotation , fmt_val ( gotten_value ) ) , 2 )
2017-10-28 11:15:57 +00:00
2017-11-02 10:21:00 +00:00
elseif ast.exp == ' COMPARE ' and ast.binop == ' NEQ ' then
local gotten_val , prefix = get_variable_and_prefix ( ast.left , 2 )
local expected_value = ast.right . value
2017-10-28 11:15:57 +00:00
error ( ( ' %s (expected anything other than %s, but got %s) ' ) : format ( prefix , fmt_with_type ( expected_value ) , fmt_val ( gotten_val ) ) , 2 )
2017-11-02 10:21:00 +00:00
elseif ast.exp == ' LVALUE ' then
local gotten_val , prefix = get_variable_and_prefix ( ast , 2 )
2017-10-28 11:15:57 +00:00
error ( ( ' %s (truthy expected, but got %s) ' ) : format ( prefix , fmt_val ( gotten_val ) ) , 2 )
2017-11-02 10:21:00 +00:00
elseif CONSTANT_VALUE_TOKEN [ ast.exp ] then
2017-10-29 09:16:16 +00:00
local func_name = get_function_name ( call_info )
error ( ( ' assertion failed! this assert will always fail, as it \' s body is `%s`. assumingly this should be an unreachable part of %s ' ) : format ( body_text , func_name ) , 2 )
2017-11-02 10:21:00 +00:00
else
error ( ( ' [assert-gooder/internal]: Unknown expression type %s ' ) : format ( ast.exp ) )
end
2017-10-28 10:25:10 +00:00
end