From f8447cdd3994b6ad4b0909e17dfac063fa6711f5 Mon Sep 17 00:00:00 2001 From: Jon Michael Aanes Date: Wed, 14 Mar 2018 12:21:51 +0100 Subject: [PATCH] Added actual parser, that allows parsing more complex expressions --- Parser.lua | 163 ++++++++++++++++++++++++ assert-gooder.lua | 313 +++++++++++++++++++++++++++++++++++++--------- lua_lang.lua | 43 +++---- 3 files changed, 442 insertions(+), 77 deletions(-) create mode 100644 Parser.lua diff --git a/Parser.lua b/Parser.lua new file mode 100644 index 0000000..aad3da8 --- /dev/null +++ b/Parser.lua @@ -0,0 +1,163 @@ + +require 'errors' 'shunt' . enable_strict_globals () + +---- Algorithm + +local DEFAULT_ASSOC = 'left' +local DEFAULT_ARITY = 2 + +local function greater_than (a, b) + assert(type(a) == 'number', a) + assert(type(b) == 'number', b) + + return a > b +end + +local function greater_than_or_equal (a, b) + assert(type(a) == 'number', a) + assert(type(b) == 'number', b) + + return a >= b +end + +local function shunting_yard (tokens, lang, num_arity_to_token) + -- Implementation of the shunting yard algorithm, for transforming + -- infix notation math to postfix notation math. + -- + -- Features extension for variadic first-class functions, which + -- results in slightly complicated output. + -- Whenever a function is called, the output will be formatted + -- like: + -- + -- + -- For example: f(1, 2, 3) + -- Output: "f" 1 2 3 3 $ + -- + -- The exact behavior can be customized. + -- + assert(type(tokens) == 'table' and #tokens > 0) + assert(type(lang) == 'function') + + local pop_i = 1 + local stack, output, call_stack = {}, {}, {} + local prev_info = nil + + for _, val in ipairs(tokens) do + local t_info = lang(val, prev_info) + assert(type(t_info) == 'table' and t_info.val ~= nil) + + if t_info.type == 'imm' then + output[#output+1] = t_info + + -- Binary operators + elseif t_info.type == 'op' then + -- Determine associativity + t_info.associativity = t_info.associativity or DEFAULT_ASSOC + t_info.arity = t_info.arity or DEFAULT_ARITY + assert(type(t_info.precedence) == 'number') + assert(t_info.associativity == 'left' or t_info.associativity == 'right') + assert(t_info.arity == 1 or t_info.arity == 2) + local test = t_info.associativity == 'left' and greater_than_or_equal or greater_than + + -- Pop operations with higher precedence from the stack + if t_info.arity == 1 and t_info.associativity == 'right' then + -- Do not pop operations from the stack + elseif t_info.arity == 2 then + while #stack > 0 and stack[#stack].type == 'op' and test(stack[#stack].precedence, t_info.precedence) do + output[#output+1], stack[#stack] = stack[#stack], nil + end + else + assert(false) + end + + -- Add this operation to the stack + stack[#stack+1] = t_info + + -- Enclosures + elseif t_info.type == 'bracket' then + -- Determine opening and closing brackets + assert(t_info.open ~= nil and t_info.close ~= nil and t_info.open ~= t_info.close) + + -- Are we opening or closing? + if t_info.val == t_info.open then + while #stack > 0 and stack[#stack].type == 'op' and stack[#stack].precedence >= t_info.precedence do + output[#output+1], stack[#stack] = stack[#stack], nil + end + if t_info.on_end_call_op and prev_info and prev_info.type ~= 'op' then + call_stack[#call_stack+1] = { arity = 1, bracket = t_info } + stack[#stack+1] = t_info.on_end_call_op + end + stack[#stack+1] = t_info + + elseif t_info.val == t_info.close then + + -- If closing, pop from stack into output + while #stack > 0 and stack[#stack].val ~= t_info.open do + output[#output+1], stack[#stack] = stack[#stack], nil + end + + -- Remove remaining bracket. + -- Brackets are not needed in postfix notation. + assert(#stack > 0 and stack[#stack].val == t_info.open) + stack[#stack] = nil + + -- If top of stack is a function, we should pop it + if #stack > 0 and stack[#stack].call_op then + assert(#call_stack > 0) + if prev_info.type == 'bracket' and prev_info.open == prev_info.val and t_info.open == prev_info.val then + -- No arguments for this function + call_stack[#call_stack].arity = 0 + end + + -- Pop from stacks and construct output + output[#output+1], call_stack[#call_stack] = num_arity_to_token(call_stack[#call_stack].arity), nil + output[#output+1], stack[#stack] = stack[#stack], nil + end + else + assert(false, require'pretty'(t_info)) + end + + elseif t_info.type == 'arg-sep' then + + -- Ensure that we are in a call + assert(#call_stack > 0 and call_stack[#call_stack].bracket.open == t_info.open) + + -- Pop from stack + while #stack > 0 and stack[#stack].val ~= t_info.open do + output[#output+1], stack[#stack] = stack[#stack], nil + end + + -- Ensure open is on the stack + assert(#stack > 0 and stack[#stack].val == t_info.open) + + -- Increment arity in call-stack + call_stack[#call_stack].arity = call_stack[#call_stack].arity + 1 + + else + assert(false, 'Unknown type: '..tostring(t_info.type)) + end + + prev_info = t_info + + end + + -- Pop remaining from stack onto output + while #stack > 0 do + output[#output+1], stack[#stack] = stack[#stack], nil + end + + -- Find value for all tokens on the stack + for i = 1, #output do + assert(output[i].val and output[i].type ~= 'bracket', require'pretty'(output[i])) + output[i] = output[i].val + end + + -- Post asserts + --assert(type(output) == 'table' and #output <= #tokens) + assert(#stack == 0) + -- + return output +end + +return shunting_yard + diff --git a/assert-gooder.lua b/assert-gooder.lua index c6daa4e..51b727c 100644 --- a/assert-gooder.lua +++ b/assert-gooder.lua @@ -26,27 +26,6 @@ local CONSTANT_VALUE_TOKEN = { NIL = function() return nil end } -local VALUE_TOKEN = { IDENTIFIER = true } -for k in pairs(CONSTANT_VALUE_TOKEN) do VALUE_TOKEN[k] = true end - -local COMPARE_BINOP = { - EQ = true, - NEQ = true, - LEQ = true, - GEQ = true, - LE = true, - GT = true, -} - -local function get_value_token (token) - if CONSTANT_VALUE_TOKEN[token.token]then - return { exp = token.token, value = CONSTANT_VALUE_TOKEN[token.token](token.text) } - elseif token.token == 'IDENTIFIER' then - return { exp = 'LVALUE', token.text } - end - assert(false) -end - local function get_variable (var_name, info) -- assert(type(var_name) == 'string') @@ -71,25 +50,209 @@ local function get_variable (var_name, info) return env, 'global' end +--[[ local function get_value_from_lvalue (lvalue, info) assert(type(lvalue) == 'table') assert(type(info) == 'table') -- Base value - local value, var_scope, in_func = get_variable(lvalue[1], info) + ilocal value, var_scope, in_func = get_variable(lvalue[1], info) -- Sub value for i = 2, #lvalue do value = value[lvalue[i].value] end -- return value, var_scope, in_func end +--]] -------------------------------------------------------------------------------- -- Parsing +local shunting_yard = require 'Parser' + +local NO_PARSE_TOKENS = { + FUNCTION = true, + SEMICOLON = true, + BREAK = true, + DO = true, + ELSE = true, + ELSEIF = true, + END = true, + FOR = true, + IF = true, + IN = true, + LOCAL = true, + REPEAT = true, + RETURN = true, + THEN = true, + UNTIL = true, + WHILE = true, + ASSIGN = true, + + -- TODO: These below should not be NO-PARSE-TOKENS: + VARARG = true, + LBRACE = true, + RBRACE = true, +} + + +local LUA_BINOP = { + ['OR'] = { precedence = 1 }, + ['AND'] = { precedence = 2 }, + ['LE'] = { precedence = 3 }, + ['LEQ'] = { precedence = 3 }, + ['EQ'] = { precedence = 3 }, + ['NEQ'] = { precedence = 3 }, + ['GEQ'] = { precedence = 3 }, + ['GT'] = { precedence = 3 }, + ['CONCAT'] = { precedence = 4, associativity = 'right' }, + ['PLUS'] = { precedence = 5 }, + ['MINUS'] = { precedence = 5 }, + ['TIMES'] = { precedence = 6 }, + ['DIVIDE'] = { precedence = 6 }, + ['MODULO'] = { precedence = 6 }, + ['CARET'] = { precedence = 8, associativity = 'right' }, + + ['DOT'] = { precedence = 10 }, + ['COLON'] = { precedence = 10 }, +} + +local LUA_PREOP = { + ['MINUS'] = { precedence = 7, val = 'UMINUS' }, + ['NOT'] = { precedence = 7 }, + ['HASHTAG'] = { precedence = 7 }, +} + +for keyword, binop in pairs (LUA_BINOP) do + binop.type = 'op' + binop.arity = 2 + binop.val = binop.val or keyword + binop.associativity = binop.associativity or 'left' +end + +for keyword, unop in pairs (LUA_PREOP) do + unop.type = 'op' + unop.arity = 1 + unop.val = unop.val or keyword + unop.associativity = 'right' +end + +local LUA_OP_ARITY = {} + +for _, op in pairs(LUA_BINOP) do LUA_OP_ARITY[op.val] = op.arity end +for _, op in pairs(LUA_PREOP) do LUA_OP_ARITY[op.val] = op.arity end + +local function lua_expression_lang (token, prev_token_info) + local ttype = token.token + local binary_disallowed = prev_token_info == nil or prev_token_info.type == 'bracket' and prev_token_info.open == prev_token_info.val or prev_token_info.type == 'op' + + if LUA_PREOP[ttype] and binary_disallowed then + -- TODO: Unary minus? + local op_info = LUA_PREOP[ttype] + return { type = 'op', arity = op_info.arity, associativity = op_info.associativity, precedence = op_info.precedence, val = token } + elseif LUA_BINOP[ttype] and not binary_disallowed then + local op_info = LUA_BINOP[ttype] + return { type = 'op', arity = op_info.arity, associativity = op_info.associativity, precedence = op_info.precedence, val = token } + elseif LUA_BINOP[ttype] or LUA_PREOP[ttype] then + error('Cannot use the operation '..ttype..' here') + + elseif ttype == 'LBRACK' or ttype == 'RBRACK' then + return { type = 'bracket', val = ttype, open = 'LBRACK', close = 'RBRACK', precedence = 10, on_end_call_op = { type = 'op', val = '?', call_op = true } } + elseif ttype == 'LPAR' or ttype == 'RPAR' then + return { type = 'bracket', val = ttype, open = 'LPAR', close = 'RPAR', precedence = 9, on_end_call_op = { type = 'op', val = '$', call_op = true } } + elseif ttype == 'COMMA' then + return { type = 'arg-sep', val = 'COMMA', open = 'LPAR' } + + elseif type (ttype) == 'number' or type (ttype) == 'string' then + return { type = 'imm', val = token } + end +end + +local function NUM_ARITY (num) + assert(type(num) == 'number') + return { type = 'imm', val = num } +end + local function parse (tokens) - -- TODO: Make a more general parser assert(type(tokens) == 'table') + -- + for i = 2, #tokens do + if tokens[i].token == 'IDENTIFIER' and tokens[i-1].token == 'DOT' then + tokens[i].token = 'STRING' + tokens[i].text = '\''..tokens[i].text..'\'' + end + end + + -- Postfix the tokens + local postfix_tokens = shunting_yard(tokens, lua_expression_lang, NUM_ARITY) + assert(type(postfix_tokens) == 'table') + print(require'pretty'(postfix_tokens)) + + -- Create AST from postfix tokens + local ast_stack = {} + for _, token in ipairs(postfix_tokens) do + if type(token) == 'table' and LUA_OP_ARITY[token.token] then + -- Operation + local node = { exp = 'OP', binop = token.token } + local start_left, end_right = math.huge, -math.huge + -- + for arg_index = LUA_OP_ARITY[token.token], 1, -1 do + local arg = table.remove(ast_stack) + assert(type(arg) == 'table') + start_left, end_right = math.min(start_left, arg.left), math.max(end_right, arg.right) + node[arg_index] = arg + end + -- + node.left, node.right = start_left, end_right + ast_stack[#ast_stack+1] = node + + elseif token == '$' then + -- Operation + local node = { exp = 'CALL' } + local start_left, end_right = math.huge, -math.huge + -- + local num_args = table.remove(ast_stack) + assert(type(num_args) == 'number') + -- + for arg_index = num_args + 1, 1, -1 do + local arg = table.remove(ast_stack) + assert(type(arg) == 'table') + start_left, end_right = math.min(start_left, arg.left), math.max(end_right, arg.right) + node[arg_index] = arg + end + -- + node.left, node.right = start_left, end_right + ast_stack[#ast_stack+1] = node + + elseif token == '?' then + -- Indexing + local node = { exp = 'OP', binop = 'DOT' } + local start_left, end_right = math.huge, -math.huge + -- + local num_args = table.remove(ast_stack) + assert(type(num_args) == 'number' and num_args == 1) + -- + for arg_index = num_args + 1, 1, -1 do + local arg = table.remove(ast_stack) + assert(type(arg) == 'table') + start_left, end_right = math.min(start_left, arg.left), math.max(end_right, arg.right) + node[arg_index] = arg + end + -- + node.left, node.right = start_left, end_right + ast_stack[#ast_stack+1] = node + + else + -- Immediate + ast_stack[#ast_stack+1] = token + + end + end + assert(#ast_stack == 1) + + return ast_stack[1] + + --[[ if #tokens == 6 and tokens[1].text == 'type' and tokens[2].token == 'LPAR' and tokens[3].token == 'IDENTIFIER' and tokens[4].token == 'RPAR'and tokens[5].token == 'EQ'and tokens[6].token == 'STRING' then return { @@ -121,6 +284,7 @@ local function parse (tokens) else io.stderr:write '[assert-gooder/internal]: Unknown AST structure!\n' end + --]] end local function for_each_node_in_ast (ast, func) @@ -140,8 +304,14 @@ local function populate_ast_with_semantics (ast, info) assert(type(ast) == 'table') assert(type(info) == 'table') return for_each_node_in_ast(ast, function(node) - if node.exp == 'LVALUE' then - node.value = get_value_from_lvalue(node, info) + if node.token == 'IDENTIFIER' then + -- TODO: Variable scope, and is it in a function? + node.value, node.scope, node.function_local = get_variable(node.text, info) + elseif CONSTANT_VALUE_TOKEN[node.token] then + node.value = CONSTANT_VALUE_TOKEN[node.token](node.text) + elseif node.exp == 'OP' and node.binop == 'DOT' then + assert(node[1].value and node[2].value) + node.value = node[1].value[ node[2].value ] --TODO end end) end @@ -243,28 +413,41 @@ local function fmt_val (val) end end -local function fmt_lvalue (lvalue, var_scope) - assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE') - -- - local base_var = tostring(lvalue[1]) - if var_scope then - base_var = ('%s \'%s\''):format(var_scope, lvalue[1]) - end - -- - if #lvalue == 1 then return base_var - elseif #lvalue == 2 then return string.format('key %s in %s', fmt_val(lvalue[2].value), base_var) - else error 'Not implemented yet!' +local function fmt_lvalue (node, with_scope) + assert(type(node) == 'table') + + if node.token == 'IDENTIFIER' then + local base = node.text + if with_scope then base = ('%s \'%s\''):format(node.scope, base) end + return base, node.function_local + + elseif node.exp == 'OP' and node.binop == 'DOT' then + local base, function_local = fmt_lvalue(node[1], with_scope) + return ('key %s in %s'):format(fmt_val(node[2].value), base), function_local end + error 'Not implemented yet!' end -local function get_variable_and_prefix (lvalue, call_info) - assert(type(lvalue) == 'table' and lvalue.exp == 'LVALUE') +--[[ +local function ast_to_formal_lvalue (ast) + if ast.token == 'IDENTIFIER' then + return { ast.text, exp = 'LVALUE' } + elseif ast.exp == 'OP' and ast.binop == 'DOT' then + local prev = ast_to_formal_lvalue(ast[1]) + prev[#prev+1] = get_value_token(ast[2]).value + return prev + end + assert(false) +end +--]] + +local function fmt_prefix (ast, call_info) + assert(type(ast) == 'table') -- - local value, var_scope, in_func = get_value_from_lvalue(lvalue, call_info) - -- - local func_name = in_func and (' to '..get_function_name(call_info)) or '' - return value, ('bad %s%s'):format(fmt_lvalue(lvalue, var_scope), func_name) + local name, is_function_local = fmt_lvalue(ast, true) + local func_name = is_function_local and (' to '..get_function_name(call_info)) or '' + return ('bad %s%s'):format(name, func_name) end @@ -373,43 +556,61 @@ local function determine_error_message (call_info, msg, condition) assert(type(ast) == 'table') populate_ast_with_semantics(ast, call_info) + local function is_l_value (ast) + print(ast.exp, ast.binop, ast.exp == 'OP' and ast.binop == 'DOT' ) + if ast.exp == 'OP' and ast.binop == 'DOT' then + print 'Derp?' + return true + elseif ast.token == 'IDENTIFIER' then + return true + end + return false + end + -- Alternative more detailed formatting. -- Identical to last message, but now with values of each involved -- identifier. do local l = {} for_each_node_in_ast(ast, function(node) - if node.exp == 'LVALUE' then - l[#l+1] = fmt_lvalue(node) ..' was ' .. fmt_val_with_type(node.value) + if is_l_value(node) then + local name = fmt_lvalue(node) + if not l[name] then + l[#l+1] = name..' was ' .. fmt_val_with_type(node.value) + l[name] = true + end end end) msg[2] = #l > 0 and table.concat(l,', ') or nil end -- More specific types - local var_prefix = function(token) return get_variable_and_prefix(token, call_info) end + local var_prefix = function(token) + return get_variable_and_prefix(token, call_info) + end if not ast then return nil - elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' and ast[1].exp == 'CALL' and ast[1][1].exp == 'LVALUE' and ast[1][1][1] == 'type' then - local gotten_val, prefix = var_prefix(ast[1][2]) - msg[1], msg[2] = prefix, ('%s expected, but got %s'):format(ast[2].value, fmt_val_with_type(gotten_val)) + elseif ast.exp == 'OP' and ast.binop == 'EQ' and ast[1].exp == 'CALL' and ast[1][1].value == type then + local prefix = fmt_prefix(ast[1][2], call_info) + msg[1], msg[2] = prefix, ('%s expected, but got %s'):format(ast[2].value, fmt_val_with_type(ast[1][2].value)) - elseif ast.exp == 'COMPARE' and ast.binop == 'EQ' then - local gotten_value, prefix = var_prefix(ast[1]) - local expected_value = ast[2].value + elseif ast.exp == 'OP' and ast.binop == 'EQ' then + local prefix = fmt_prefix(ast[1], call_info) + local gotten_value, expected_value = ast[1].value, ast[2].value local fmt_gotten = (type(expected_value) == type(gotten_value)) and fmt_val or fmt_val_with_type msg[1], msg[2] = prefix, ('%s expected, but got %s'):format(fmt_val_with_type(expected_value), fmt_gotten(gotten_value)) - elseif ast.exp == 'COMPARE' and ast.binop == 'NEQ' then - local gotten_val, prefix = var_prefix(ast[1]) - local expected_value = ast[2].value + elseif ast.exp == 'OP' and ast.binop == 'NEQ' then + local prefix = fmt_prefix(ast[1], call_info) + local gotten_value, expected_value = ast[1].value, ast[2].value msg[1], msg[2] = prefix, ('expected anything other than %s, but got %s'):format(fmt_val_with_type(expected_value), fmt_val(gotten_val)) elseif ast.exp == 'LVALUE' then - local gotten_val, prefix = var_prefix(ast) + local prefix = fmt_prefix(ast[1], call_info) + local gotten_value = ast[1].value msg[1], msg[2] = prefix, ('truthy expected, but got %s'):format(fmt_val(gotten_val)) - elseif CONSTANT_VALUE_TOKEN[ast.exp] then + elseif CONSTANT_VALUE_TOKEN[ast.token] then local func_name = get_function_name(call_info) msg[1] = ('this assert will always fail, as it\'s body is `%s`. assumingly this should be an unreachable part of %s'):format(body_text, func_name) diff --git a/lua_lang.lua b/lua_lang.lua index 58100aa..40c077b 100644 --- a/lua_lang.lua +++ b/lua_lang.lua @@ -20,28 +20,28 @@ return Lexer { { 'then', 'THEN' }, { 'until', 'UNTIL' }, { 'while', 'WHILE' }, - { '%+' , 'PLUS' }, - { '%-' , 'MINUS' }, - { '%*' , 'TIMES' }, - { '%/' , 'DIVIDE' }, - { '%%' , 'MODULO' }, - { '%^' , 'CARET' }, - { '%#' , 'HASHTAG' }, - { '%==' , 'EQ' }, - { '%~=' , 'NEQ' }, - { '%<=' , 'LEQ' }, - { '%>=' , 'GEQ' }, - { '%<' , 'LE' }, - { '%>' , 'GT' }, + { '%+', 'PLUS' }, + { '%-', 'MINUS' }, + { '%*', 'TIMES' }, + { '%/', 'DIVIDE' }, + { '%%', 'MODULO' }, + { '%^', 'CARET' }, + { '%#', 'HASHTAG' }, + { '%==', 'EQ' }, + { '%~=', 'NEQ' }, + { '%<=', 'LEQ' }, + { '%>=', 'GEQ' }, + { '%<', 'LE' }, + { '%>', 'GT' }, { '%=', 'ASSIGN' }, - { '%(' , 'LPAR' }, - { '%)' , 'RPAR' }, - { '%{' , 'LBRACE' }, - { '%}' , 'RBRACE' }, - { '%;' , 'SEMICOLON' }, - { '%,' , 'COMMA' }, - { '%.%.' , 'CONCAT' }, - { '%.%.%.' , 'VARARG' }, + { '%(', 'LPAR' }, + { '%)', 'RPAR' }, + { '%{', 'LBRACE' }, + { '%}', 'RBRACE' }, + { '%;', 'SEMICOLON' }, + { '%,', 'COMMA' }, + { '%.%.', 'CONCAT' }, + { '%.%.%.', 'VARARG' }, -- { 'false', 'FALSE' }, { 'true', 'TRUE' }, @@ -62,3 +62,4 @@ return Lexer { { '%-%-%[%[.-%]%]', Lexer.CONTINUE }, { '%s+', Lexer.CONTINUE }, } +