248 lines
6.3 KiB
Lua
248 lines
6.3 KiB
Lua
|
|
||
|
--[[
|
||
|
Markov Chain module of the luaFortune library.
|
||
|
|
||
|
Documentation and License can be found here:
|
||
|
https://bitbucket.org/Jmaa/luafortune
|
||
|
--]]
|
||
|
|
||
|
local markov = {
|
||
|
version = 0.5
|
||
|
}
|
||
|
|
||
|
--------------------------------------------------------------------------------
|
||
|
-- Constants
|
||
|
|
||
|
local ENGLISH_ALPHABET = {"a","b","c","d","e","f","g","h","i","j","k","l","m",
|
||
|
"n","o","p","q","r","s","t","u","v","w","x","y","z"}
|
||
|
|
||
|
--------------------------------------------------------------------------------
|
||
|
-- Misc functios
|
||
|
|
||
|
local function reverse_alphabet (alphabet)
|
||
|
local reverse = {}
|
||
|
for index, letter in ipairs(alphabet) do
|
||
|
reverse[letter] = index
|
||
|
end
|
||
|
reverse[""] = 0
|
||
|
return reverse
|
||
|
end
|
||
|
|
||
|
--------------------------------------------------------------------------------
|
||
|
|
||
|
local function new_chain (alphabet, order)
|
||
|
local alphabet = alphabet or ENGLISH_ALPHABET
|
||
|
return {
|
||
|
alphabet = alphabet,
|
||
|
reverse_alphabet = reverse_alphabet(alphabet),
|
||
|
nr_symbols = #alphabet,
|
||
|
order = order,
|
||
|
}
|
||
|
end
|
||
|
|
||
|
--------------------------------------------------------------------------------
|
||
|
-- first-order markov chains
|
||
|
|
||
|
local MarkovChain = {}
|
||
|
MarkovChain.__index = MarkovChain
|
||
|
markov.MarkovChain = MarkovChain
|
||
|
|
||
|
function MarkovChain.new (alphabet)
|
||
|
return setmetatable(new_chain(alphabet), MarkovChain)
|
||
|
end
|
||
|
|
||
|
function MarkovChain:trainMany (list_of_words)
|
||
|
for _, word in ipairs(list_of_words) do
|
||
|
self:trainSingle(word)
|
||
|
end
|
||
|
end
|
||
|
|
||
|
function MarkovChain:getTotalWeight (loc)
|
||
|
local total_weight = 0
|
||
|
for i=1, self.nr_symbols do
|
||
|
total_weight = total_weight + (self[loc+i] or 0)
|
||
|
end
|
||
|
return total_weight
|
||
|
end
|
||
|
|
||
|
function MarkovChain:translateToFinal (id_word)
|
||
|
local word = {}
|
||
|
for i=#id_word, 1, -1 do
|
||
|
word[i] = self.alphabet[id_word[i]]
|
||
|
end
|
||
|
return word
|
||
|
end
|
||
|
|
||
|
function MarkovChain:translateToRaw (str_word)
|
||
|
local id_word = {}
|
||
|
for i=#str_word, 1, -1 do
|
||
|
id_word[i] = self.reverse_alphabet[str_word[i]]
|
||
|
end
|
||
|
return id_word
|
||
|
end
|
||
|
|
||
|
function MarkovChain:trainSingle (word)
|
||
|
for i=1, #word do
|
||
|
local cur_id =
|
||
|
self.reverse_alphabet[word[i-1] or ""] * self.nr_symbols +
|
||
|
self.reverse_alphabet[word[i ]]
|
||
|
self[cur_id] = (self[cur_id] or 0) + 1
|
||
|
end
|
||
|
end
|
||
|
|
||
|
function MarkovChain:getNextRaw (word, index)
|
||
|
local index = index or #word
|
||
|
local loc = index==0 and 0 or word[index]*self.nr_symbols
|
||
|
local weight_random = math.random(self:getTotalWeight(loc))
|
||
|
for j=1, self.nr_symbols do
|
||
|
weight_random = weight_random - (self[loc+j] or 0)
|
||
|
if weight_random <= 0 then
|
||
|
return j
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
function MarkovChain:getNext (word, index)
|
||
|
local id_word = {}
|
||
|
for i=1, #word do
|
||
|
id_word[i] = self.reverse_alphabet[word[i]]
|
||
|
end
|
||
|
return self.alphabet[self:getNextRaw(id_word, index)]
|
||
|
end
|
||
|
|
||
|
--------------------------------------------------------------------------------
|
||
|
-- nth-order markov chains
|
||
|
|
||
|
local NoMarkovChain = {
|
||
|
trainMany = MarkovChain.trainMany,
|
||
|
getTotalWeight = MarkovChain.getTotalWeight,
|
||
|
translateToFinal = MarkovChain.translateToFinal,
|
||
|
translateToRaw = MarkovChain.translateToRaw
|
||
|
}
|
||
|
NoMarkovChain.__index = NoMarkovChain
|
||
|
markov.NoMarkovChain = NoMarkovChain
|
||
|
|
||
|
function MarkovChain.new (alphabet, order)
|
||
|
return setmetatable(new_chain(alphabet, order), NoMarkovChain)
|
||
|
end
|
||
|
|
||
|
function NoMarkovChain:trainSingle (word)
|
||
|
local word_len = #word
|
||
|
local markov_order = self.order
|
||
|
local nr_symbols = self.nr_symbols
|
||
|
for i = markov_order + 1, markov_order + word_len do
|
||
|
local cur_id = 0
|
||
|
for j=i-markov_order, math.min(word_len,i+markov_order) do
|
||
|
self[cur_id] = (self[cur_id] or 0) + 1
|
||
|
cur_id = cur_id*nr_symbols+self.reverse_alphabet[word[j]]
|
||
|
end
|
||
|
self[cur_id] = (self[cur_id] or 0) + 1
|
||
|
end
|
||
|
end
|
||
|
|
||
|
function NoMarkovChain:getLoc (word, this_i)
|
||
|
local loc = 0
|
||
|
for i = math.max(1, this_i-self.order+1), this_i do
|
||
|
loc = loc*self.nr_symbols + word[i]
|
||
|
end
|
||
|
return loc
|
||
|
end
|
||
|
|
||
|
function NoMarkovChain:getNextRaw (word, index)
|
||
|
local index = index or #word
|
||
|
local loc = self:getLoc(word, index) * self.nr_symbols
|
||
|
local total_weight = self:getTotalWeight(loc)
|
||
|
if total_weight == 0 then return nil end
|
||
|
local weight_random = math.random(total_weight)
|
||
|
for j=1, self.nr_symbols do
|
||
|
weight_random = weight_random - (self[loc+j] or 0)
|
||
|
if weight_random <= 0 then
|
||
|
return j
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
--------------------------------------------------------------------------------
|
||
|
-- Variable-order markov chains
|
||
|
|
||
|
local VoMarkovChain = {
|
||
|
trainSingle = NoMarkovChain.trainSingle,
|
||
|
trainMany = NoMarkovChain.trainMany,
|
||
|
getLoc = NoMarkovChain.getLoc,
|
||
|
getTotalWeight = NoMarkovChain.getTotalWeight,
|
||
|
translateToFinal = NoMarkovChain.translateToFinal,
|
||
|
translateToRaw = NoMarkovChain.translateToRaw,
|
||
|
getNext = NoMarkovChain.getNext,
|
||
|
}
|
||
|
VoMarkovChain.__index = VoMarkovChain
|
||
|
markov.VoMarkovChain = VoMarkovChain
|
||
|
|
||
|
function VoMarkovChain.new (alphabet, order)
|
||
|
return setmetatable(new_chain(alphabet, order), VoMarkovChain)
|
||
|
end
|
||
|
|
||
|
function VoMarkovChain:getLoc (word, this_i, min_i)
|
||
|
local loc = 0
|
||
|
for i = math.max(1, min_i+this_i-self.order+1), this_i do
|
||
|
loc = loc*self.nr_symbols + word[i]
|
||
|
end
|
||
|
return loc
|
||
|
end
|
||
|
|
||
|
function VoMarkovChain:getNextRaw (word, index)
|
||
|
local index = index or #word
|
||
|
for i=0, index do
|
||
|
local loc = self:getLoc(word, index, i)
|
||
|
local total_weight = self[loc]
|
||
|
if total_weight then
|
||
|
local weight_random = math.random(total_weight)
|
||
|
loc = loc * self.nr_symbols
|
||
|
for j=1, self.nr_symbols do
|
||
|
weight_random = weight_random - (self[loc+j] or 0)
|
||
|
if weight_random <= 0 then
|
||
|
return j
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
--------------------------------------------------------------------------------
|
||
|
-- String-based markov wrapper
|
||
|
|
||
|
local StringWrapper = {}
|
||
|
StringWrapper.__index = StringWrapper
|
||
|
markov.StringWrapper = StringWrapper
|
||
|
|
||
|
function StringWrapper.new (chain)
|
||
|
return setmetatable({chain = chain}, StringWrapper)
|
||
|
end
|
||
|
|
||
|
function StringWrapper:train (list_of_words, is_iterator)
|
||
|
local char_match = "[\01-\127\192-\255][\128-\191]*"
|
||
|
for _, word in is_iterator and list_of_words or ipairs(list_of_words) do
|
||
|
word = word or _
|
||
|
local word_table = {}
|
||
|
for char in word:gmatch(char_match) do
|
||
|
table.insert(word_table, char)
|
||
|
end
|
||
|
self.chain:trainSingle(word_table)
|
||
|
end
|
||
|
end
|
||
|
|
||
|
function StringWrapper:generate (word_len)
|
||
|
local word = {}
|
||
|
for i=1, word_len do
|
||
|
local next_char = self.chain:getNextRaw(word)
|
||
|
word[i] = next_char
|
||
|
if next_char == nil then
|
||
|
break
|
||
|
end
|
||
|
end
|
||
|
return table.concat(self.chain:translateToFinal(word))
|
||
|
end
|
||
|
|
||
|
--------------------------------------------------------------------------------
|
||
|
|
||
|
return markov
|