--- # Spritesheet
--
-- Library for managing sprite sheets of textures and animations.
--
-- Has support for both individual images in spritesheets and animations. This
-- can be specified from a lua file placed beside the spritesheet image file.
--
-- ## Notes
--
-- - When drawing an image or animation when a shader is defined the library may
--   send certain useful constants along, notably `spritesheet_inverse_width` and
--   `spritesheet_inverse_height`.
-- - Checks that LÖVE is defined; if not, run in information loading
--   mode only.

local _VERSION = '0.1.3'

local error_original = error
local error, error_internal do
    error, error_internal  =  error_original, error_original
    local success, errorlib = pcall(require, 'errors')
    if success then
        error = errorlib 'spritesheet'
        error_internal  =  error.internal
    end
end

--------------------------------------------------------------------------------

local define_love = true
if type(love) ~= 'table' then
    io.stderr:write '[Spritesheet]: Loaded in non-LÖVE environment. Can still load spritesheet data,\n               but image drawing methods will not be defined.\n'
    define_love = false
end

--------------------------------------------------------------------------------
-- Util

local function calculate_animation_duration (self, frame_i)
    frame_i = frame_i or math.huge
    assert(type(self) == 'table')
    assert(type(frame_i) == 'number')

    -- If time_total is provided, return that.
    if self.time_total then
        assert(self.time == nil)
        self.time = self.time_total/#self
        return self.time_total
    end

    -- Easy if number
    if type(self.time) == 'number' then
        return math.min(#self, frame_i) * self.time
    end

    -- Sum if table
    local sum = 0
    for i = 1, math.min(#self, frame_i) do  sum = sum + self.time[i]  end
    return sum
end

local function get_quad_based_on_time (l, t)
    -- TODO: Reimplement as binary search. (Maybe only use the binary search version for very long animations?)
    --
    assert(type(l) == 'table')
    assert(type(t) == 'number')
    --
    local time = l.time
    for i = 1, #time do
        if t <= time[i] then
            return l[i]
        end
    end
    --
    return l[#l]
end

--------------------------------------------------------------------------------
---- Sprite

local Sprite = {}
      Sprite.__index = Sprite

function Sprite.new (quad, imagesheet)
    return setmetatable({ quad = quad, imagesheet = imagesheet, is_sprite = true }, Sprite)
end

local function set_shader_texture_size(texture)
    local currently_active_shader = love.graphics.getShader()
    if currently_active_shader ~= nil and currently_active_shader:hasUniform('spritesheet_inverse_width') then
        local width, height = texture:getDimensions()
        currently_active_shader:send('spritesheet_inverse_width', {1/width,        0})
        currently_active_shader:send('spritesheet_inverse_height', {      0, 1/height})
    end
end

if define_love then

function Sprite:generateImage ()
    local imagesheet, quad  =  self.imagesheet, self.quad
    self.func = function (x, y)
        set_shader_texture_size(imagesheet.image)
        love.graphics.draw(imagesheet.image, quad, math.floor(x), math.floor(y), 0, 1, 1, imagesheet.origin_x, imagesheet.origin_y)
    end
end

function Sprite:getImage ()
    if not self.func then  self:generateImage()  end
    return self.func
end

function Sprite:getQuad ()
    return self.quad
end

function Sprite:draw(...)
    return self:getImage()(...)
end

function Sprite:getSheetDimensions()
    return self.imagesheet.image:getDimensions()
end

end

setmetatable(Sprite, {__call = function(_, ...) return Sprite.new(...) end})

--------------------------------------------------------------------------------
---- Animation

local Animation = {}
      Animation.__index = Animation

function Animation.new (self)
    assert(type(self) == 'table' and self ~= Animation)
    assert(type(self.time) == 'table' or type(self.time) == 'number' and self.time > 0 or type(self.time_total) == 'number')
    assert(#self > 0)
    assert(type(self.time) == 'number' or type(self.time_total) == 'number' or #self == #self.time)
    if self.time_total then  assert(self.time == nil)  end
    assert(self.wrap == nil or self.wrap == true or self.wrap == false)

    setmetatable(self, Animation)
    self.duration      =  calculate_animation_duration(self)
    self.is_animation  =  true

    -- Contact frame?
    if self.contact_frame then
        self.contact_time  =  calculate_animation_duration(self, self.contact_frame)
    end
    --
    return self
end

if define_love then

function Animation:generateImage ()
    self.func = function (x, y, t)
        t  =  t or 0
        assert(type(t) == 'number')

        if self.wrap then  t = t % self.duration  end
        local quad = get_quad_based_on_time(self, t)
        if not quad then  error_internal('Could not determine quad when drawing animation. Time was %f.', t)  end
        set_shader_texture_size(self.imagesheet.image)
        love.graphics.draw(self.imagesheet.image, quad, x, y, 0, 1, 1, self.imagesheet.origin_x, self.imagesheet.origin_y)
    end
end

function Animation:getImage ()
    if not self.func then  self:generateImage()  end
    return self.func
end

function Animation:getQuad (i)
    assert(i == nil or type(i) == 'number' and self[i])
    return self[i or 1]
end

function Animation:getDuration ()
    return self.duration
end

function Animation:draw(...)
    return self:getImage()(...)
end

function Animation:getSheetDimensions()
    return self.imagesheet.image:getDimensions()
end

end

setmetatable(Animation, {__call = function(_, ...) return Animation.new(...) end})

--------------------------------------------------------------------------------
-- Constants

local SPRITESHEET_ENV = { Anim = Animation }
local SPRITESHEET_DATA_FILETYPES = { ['.lua'] = true, ['.raw'] = true }

--------------------------------------------------------------------------------

local function calculate_frame_times (time_list, nr_frames)
    assert(type(time_list) == 'table' or type(time_list) == 'number')
    assert(type(nr_frames) == 'number')

    local frame_times = { [0] = -math.huge, orig = not define_love and time_list or nil }

    if type(time_list) == 'number' then
        for i = 1, nr_frames do
            frame_times[i] = i * time_list
        end
    else
        frame_times[1] = time_list[1]
        for i = 2, nr_frames do
            frame_times[i] = frame_times[i-1] + time_list[i]
        end
    end

    return frame_times
end

local function load_quads (_, _, quad_data, imagesheet)
    assert(type(imagesheet) == 'table')
    assert(type(imagesheet.tiles_per_row)    == 'number')
    assert(type(imagesheet.tiles_per_column) == 'number')

    local  tile_width,  tile_height  = imagesheet.tile_width, imagesheet.tile_height
    local tiles_per_row  =  imagesheet.tiles_per_row
    local max_quad_id    =  tiles_per_row * imagesheet.tiles_per_column - 1
    local quad_cache = {}

    local function quad_from_id (id)
        -- Error checking
        if type(id) ~= 'number' then  error('All quad ids must be natural numbers, but one was %s (%s)', id, type(id))  end
        if id % 1 ~= 0          then  error('All quad ids must be natural numbers, but one was %03.03f (floating point number)', id)  end
        if not (0 <= id and id <= max_quad_id) then  error('All quad ids must - for this spritesheet ("%s") - be natural numbers equal/below %i, but one was %s = 0x%X', imagesheet.filename, max_quad_id, id, id)  end

        -- Calculate
        local quad = quad_cache[id]
        if quad == nil then
            if define_love then
                quad = love.graphics.newQuad((id%tiles_per_row)*tile_width, math.floor(id/tiles_per_row)*tile_height, tile_width, tile_height, imagesheet.width, imagesheet.height)
            else
                quad = { id = id, (id%tiles_per_row)*tile_width, math.floor(id/tiles_per_row)*tile_height, tile_width, tile_height }
            end
            quad_cache[id] = quad
        end
        return quad
    end

    local function visit_animation (t)
        assert(type(t) == 'table' and t.is_animation)
        for i = 1, #t do  t[i] = quad_from_id(t[i])  end
        t.time = calculate_frame_times(t.time, #t)
        t.imagesheet = imagesheet
    end

    local function visit_quad (n)
        assert(type(n) == 'number')
        return Sprite(quad_from_id(n), imagesheet)
    end

    local function visit_node (t, already_seen)
        assert(type(t) == 'table')

        for key, val in pairs(t) do
            local val_type = type(val)
            if val_type == 'number' then
                t[key] = visit_quad(val)
            elseif val_type == 'table' then
                local  visit_func = rawget(val,'is_animation') and visit_animation or visit_node
                       visit_func(val, already_seen)
            end
        end

        if type(error) == 'table' and define_love then
            error.strict_table(t)
        end
    end

    assert(type(quad_data) == 'table')
    assert(type(quad_data.tile_names) == 'table')

    local  visit_func = quad_data.tile_names.is_animation and visit_animation or visit_node
           visit_func(quad_data.tile_names)
    return quad_data.tile_names
end

local function load_quad_data (filename)
    if type(filename) ~= 'string' then  error('Bad argument #1, expected string, got %s (%s)', filename, type(filename))  end

    -- Attempt to load file
    for filetype in pairs(SPRITESHEET_DATA_FILETYPES) do

        local chunk, error_msg
        if define_love then
            chunk, error_msg  =  love.filesystem.load(filename..filetype)
        else
            chunk, error_msg  =  loadfile(filename..filetype)
        end

        if chunk then
            local data  =  setfenv(chunk, SPRITESHEET_ENV)()

            -- Error check
            if type(data) ~= 'table' then
                error('Bad spritesheet "%s". Must return a table, but returned %s (%s)', filename, data, type(data))
            end
            local l = {'Bad spritesheet "'.. filename.. '"'}
            if data.tiles_per_row ~= nil and data.tile_width ~= nil then
                l[#l+1] = 'Root table must not contain both keys "tiles_per_row" and "tile_width"'
            elseif data.tiles_per_row == nil and data.tile_width == nil then
                l[#l+1] = 'Root table must contain either keys "tiles_per_row" or "tile_width"'
            end
            if data.tiles_per_column ~= nil and data.tile_height ~= nil then
                l[#l+1] = 'Root table must not contain both keys "tiles_per_column" and "tile_height"'
            elseif data.tiles_per_column == nil and data.tile_height == nil then
                l[#l+1] = 'Root table must contain either keys "tiles_per_column" or "tile_height"'
            end
            local INTEGER_TILESET_KEYS = {'tile_width', 'tile_height', 'tiles_per_row', 'tiles_per_column'}
            for _, integer_key in ipairs(INTEGER_TILESET_KEYS) do
                local v = data[integer_key]
                if v and (type(v)  ~= 'number' or v % 1 ~= 0) then
                    l[#l+1] = string.format('Key "%s" in root table must map to integer value, but it was %s (%s)', integer_key, v, type(v))
                end
            end
            if not (type(data.tile_names) == 'table' or type(data.tile_names) == 'number') then
                l[#l+1] = string.format('Root table must contain key "tile_names", with either a table or a number value, but it was %s (%s)', data.tile_names, type(data.tile_names))
            end
            if data.tile_origin and type(data.tile_origin) ~= 'table' then
                l[#l+1] = string.format('Key "%s" in root table must map to a table value, but it was %s (%s)', 'tile_origin', data.tile_origin, type(data.tile_origin))
            end

            -- Throw error or return
            if #l > 1 then
                error(table.concat(l, '\n    '))
            end
            return data
        end
        print(error_msg)
    end

    -- Else, give up.
    error('Could not find file "%s.lua" or "%s.raw".', filename, filename)
end

--------------------------------------------------------------------------------

local SpriteSheet = {}
      SpriteSheet.__index = SpriteSheet
      SpriteSheet.is_spritesheet = true
      SpriteSheet._VERSION = _VERSION

function SpriteSheet.new (filename)
    local quad_data = load_quad_data(filename)

    -- NOTE: `force_uneven_tile_size` in quad_data can be used to
    -- ignore the image size-tile size divisibility check. Edit the
    -- spriteimage itself if you can, as it will silently ignore
    -- several errors.

    local img, width, height
    if define_love then
        img  =  love.graphics.newImage(filename..'.png')
        width, height  =  img:getDimensions()
    else
        img     =  require 'imlib2'.image.load(filename..'.png')
        width   =  img:get_width()
        height  =  img:get_height()
    end

    -- Set info
    local  self = setmetatable({}, SpriteSheet)
           self.filename     = filename
           self.image        = img
           self.width        = width
           self.height       = height
           self.origin_x     = 0
           self.origin_y     = 0

    -- TODO: Give warning/error due to rounding down.
    self.tiles_per_row    = quad_data.tiles_per_row     or math.floor(width  / quad_data.tile_width)
    self.tiles_per_column = quad_data.tiles_per_column  or math.floor(height / quad_data.tile_height)

    self.tile_width    =  quad_data.tile_width  or math.floor(width  / self.tiles_per_row)
    self.tile_height   =  quad_data.tile_height or math.floor(height / self.tiles_per_column)

    -- Error checking
    do
        local rem_width  = width  % self.tile_width
        local rem_height = height % self.tile_height
        if not quad_data.force_uneven_tile_size and (rem_width ~= 0 or rem_height ~= 0) then
            local s = ('Bad spritesheet "%s". Image size (%i, %i) must be dividable by tile size (%i, %i)')
                       :format(filename, self.image:getWidth(), self.image:getHeight(), self.tile_width, self.tile_height)
            if rem_width  ~= 0 then  s = s..('\n    Width  leaves a remainder of %i.'):format(rem_width)   end
            if rem_height ~= 0 then  s = s..('\n    Height leaves a remainder of %i.'):format(rem_height)  end
            error(s)
        end
    end

    -- Set origin
    if quad_data.tile_origin then
        self:setOrigin(unpack(quad_data.tile_origin))
    end

    -- Import quads into SpriteSheet
    self.quads = load_quads(width, height, quad_data, self)
    if rawget(self.quads, 'is_sprite') or rawget(self.quads, 'is_animation') then
        self.only_quads = self.quads
    else
        for key, value in pairs(self.quads) do
            assert(not self[key])
            self[key] = value
        end
    end

    -- Return
    return self
end

function SpriteSheet:setOrigin (ox, oy, mode)
    assert(type(ox) == 'number')
    assert(type(oy) == 'number')

   self.origin_orig  = { ox, oy, mode }

    if mode == 'absolute' then
        self.origin_x = ox
        self.origin_y = oy
    elseif mode == 'relative' or mode == nil then
        self.origin_x = math.floor(self.tile_width  * ox)
        self.origin_y = math.floor(self.tile_height * oy)
    else
        error('Unknown origin mode %s (%s)', mode, type(mode))
    end
    assert(self.origin_x % 1 == 0)
    assert(self.origin_y % 1 == 0)
end

function SpriteSheet:getQuadKeys ()
    local keys = {}
    for key in pairs(self.quads) do  keys[#keys+1] = key  end
    return keys
end

local function assert_self (self_v, is_field)
    if type(self_v) ~= 'table' then             error('Bad self value, must be indexable, but was %s (%s).', self_v, type(self_v))  end
    if is_field and not self_v[is_field] then   error('Bad self value, is indexable, but does not possess "%s" field.', is_field)  end
    return true
end

-- Create redirections
for _, method_name in pairs{ 'getImage', 'getQuad', 'draw', 'getDuration' } do
    SpriteSheet[method_name] = function (self, ...)
        assert_self (self, 'is_spritesheet')
        if not self.only_quads then  error('Attempting to call method "%s" on a SpriteSheet ("%s"). This is only allowed when the Spritesheet contains a single sprite or animation.', method_name, self.filename)  end
        return self.only_quads[method_name](self.only_quads, ...)
    end
end

setmetatable(SpriteSheet, {__call = function(_, ...) return SpriteSheet.new(...) end})

--------------------------------------------------------------------------------
-- Return

return SpriteSheet