import dataclasses
import datetime
import json
import re
from decimal import Decimal
from logging import getLogger
from pathlib import Path
from typing import Any
from zoneinfo import ZoneInfo

import enforce_typing
import frontmatter
import marko
import marko.md_renderer

logger = getLogger(__name__)

StatisticKey = str


@enforce_typing.enforce_types
@dataclasses.dataclass(frozen=True, order=True)
class Event:
    start_time: datetime.datetime | None
    end_time: datetime.datetime | None
    verb: str | None
    subject: str | None
    comment: str

    def __post_init__(self):
        if self.subject:
            assert ':' not in self.subject
            assert '/' not in self.subject


@dataclasses.dataclass(frozen=True)
class FileContents:
    frontmatter: dict[str, Any]
    blocks_pre_events: list
    events: frozenset[Event]
    blocks_post_events: list
    timezone: ZoneInfo


@dataclasses.dataclass(frozen=False)
class CachedFile:
    data: bytes
    is_dirty: bool


MARKDOWN_PARSER = marko.Markdown()
MARKDOWN_RENDERER = marko.md_renderer.MarkdownRenderer()

FILE_FORMAT = """
{blocks_pre_events}
## Events
{block_events}
{blocks_post_events}
"""


class ObsidianVault:
    def __init__(
        self,
        vault_path: Path,
        read_only: bool = 'silent',
        allow_invalid_vault=False,
    ):
        self.vault_path = vault_path
        self.read_only = read_only
        self.internal_file_text_cache: dict[Path, CachedFile] = {}

        if not allow_invalid_vault:
            assert (self.vault_path / '.obsidian').exists(), 'Not an Obsidian Vault'

        try:
            with open(self.vault_path / '.obsidian' / 'daily-notes.json') as f:
                daily_notes_config = json.load(f)
            self.daily_folder = daily_notes_config['folder']
            self.path_format = daily_notes_config['format']
            self.template_file_path = daily_notes_config['template']
        except FileNotFoundError:
            if not allow_invalid_vault:
                assert False, 'Missing daily notes configuration!'

    def get_statistic(
        self,
        date: datetime.date,
        statistic_key: StatisticKey,
    ) -> Any | None:
        if contents := self._load_date_contents(date):
            return contents.frontmatter.get(statistic_key)
        return None

    def add_statistic(
        self,
        date: datetime.date,
        statistic_key: StatisticKey,
        amount: Any,
    ) -> bool:
        # Adjust arguments
        if isinstance(amount, Decimal):
            amount = float(amount)

        # Load contents
        contents = self._load_date_contents(date)

        # Update contents
        if contents.frontmatter.get(statistic_key) == amount:
            return False

        contents.frontmatter[statistic_key] = amount
        if amount is None:
            del contents.frontmatter[statistic_key]

        # Save contents
        self._save_date_contents(date, contents)
        return True

    def add_events(self, date: datetime.date, events: list[Event]) -> bool:
        contents = self._load_date_contents(date)
        if contents is None:
            return False

        # Exit without writing if there were no changes.
        updated_events: frozenset[Event] = contents.events | set(events)
        if contents.events == updated_events:
            return False

        contents = dataclasses.replace(contents, events=updated_events)
        self._save_date_contents(date, contents)
        return True

    def get_events(self, date: datetime.date) -> frozenset[Event]:
        contents = self._load_date_contents(date)
        if contents is None:
            return frozenset()
        return contents.events

    def _load_date_contents(self, date: datetime.date) -> FileContents | None:
        timezone = ZoneInfo(
            'Europe/Copenhagen',
        )  # TODO: Parameterize in an intelligent manner

        file_path = self._date_file_path(date)
        text = self._load_file_text(file_path) or self._load_file_text(
            self._daily_template_path(),
        )
        assert text is not None

        file_frontmatter = frontmatter.loads(text)

        ast = MARKDOWN_PARSER.parse(str(file_frontmatter))
        (pre_events, list_block_items, post_events) = find_events_list_block(ast)
        events = frozenset(
            parse_event_string(list_item, date, timezone)
            for list_item in list_block_items
        )
        return FileContents(
            file_frontmatter.metadata,
            pre_events,
            events,
            post_events,
            timezone,
        )

    def _save_date_contents(self, date: datetime.date, contents: FileContents) -> None:
        blocks_pre_events = ''.join(
            MARKDOWN_RENDERER.render(b) for b in contents.blocks_pre_events
        )
        blocks_post_events = ''.join(
            MARKDOWN_RENDERER.render(b) for b in contents.blocks_post_events
        )

        events = list(contents.events)
        events.sort(key=lambda x: x.comment or '')
        events.sort(key=lambda x: x.subject or '')
        events.sort(key=lambda x: x.verb or '')
        date_sentinel = datetime.datetime(1900, 1, 1, 1, 1, 1, tzinfo=contents.timezone)
        events.sort(key=lambda x: x.start_time or x.end_time or date_sentinel)

        formatted_events = [
            '- ' + format_event_string(e, tz=contents.timezone) for e in events
        ]
        formatted_events = list(dict.fromkeys(formatted_events))
        block_events = '\n'.join(formatted_events)

        post = frontmatter.Post(
            content=FILE_FORMAT.format(
                blocks_pre_events=blocks_pre_events,
                blocks_post_events=blocks_post_events,
                block_events=block_events,
            ).strip(),
            **contents.frontmatter,
        )

        self._save_file_text_to_cache(
            self._date_file_path(date),
            frontmatter.dumps(post).encode('utf8'),
        )

    def _save_file_text_to_cache(self, path: Path, text: bytes) -> None:
        if path not in self.internal_file_text_cache:
            self.internal_file_text_cache[path] = CachedFile(None, False)
        self.internal_file_text_cache[path].data = text
        self.internal_file_text_cache[path].is_dirty = True

    def _date_file_path(self, date: datetime.date) -> Path:
        path = (
            self.path_format.replace('YYYY', str(date.year))
            .replace('MM', f'{date.month:02d}')
            .replace('DD', f'{date.day:02d}')
        )
        return (self.vault_path / self.daily_folder / path).with_suffix('.md')

    def _daily_template_path(self) -> Path:
        return (self.vault_path / self.template_file_path).with_suffix('.md')

    def _load_file_text(self, path: Path) -> bytes | None:
        if path not in self.internal_file_text_cache:
            try:
                with open(path, 'rb') as f:
                    self.internal_file_text_cache[path] = CachedFile(f.read(), False)
            except FileNotFoundError:
                return None
        return self.internal_file_text_cache[path].data

    def flush_cache(self) -> None:
        if self.read_only:
            msg = 'Read-only ObsidianVault cannot be flushed'
            raise RuntimeError(msg)
        for path, cached_file in self.internal_file_text_cache.items():
            if cached_file.is_dirty:
                logger.info('Saving file "%s"', path)
                path.parent.mkdir(exist_ok=True, parents=True)
                with open(path, 'wb') as f:
                    f.write(cached_file.data)
            del path, cached_file


def find_events_list_block(ast) -> tuple[list, list[str], list]:
    blocks = ast.children
    for block_i, block in enumerate(blocks):
        if (
            isinstance(block, marko.block.Heading)
            and block.children[0].children.lower() == 'events'
        ):
            events_block = (
                ast.children[block_i + 1] if block_i + 1 < len(ast.children) else None
            )
            if isinstance(events_block, marko.block.List):
                offset = 2
                event_texts = [
                    MARKDOWN_RENDERER.render_children(li).strip()
                    for li in events_block.children
                ]
            else:
                offset = 1
                event_texts = []

            return (blocks[:block_i], event_texts, blocks[block_i + offset :])
    return (blocks, [], [])


def format_event_string(event: Event, tz: ZoneInfo) -> str:
    assert event is not None
    if (
        event.start_time is None
        and event.end_time is None
        and event.subject is None
        and event.verb is None
    ):
        return event.comment

    buf = []
    buf.append(f'{event.start_time.astimezone(tz):%H:%M}')
    if event.end_time and event.end_time != event.start_time:
        buf.append(f'-{event.end_time.astimezone(tz):%H:%M}')
    buf.append(' | ')
    buf.append(event.verb)
    buf.append(' [[')
    buf.append(event.subject)
    buf.append((']]. ' + event.comment).strip())

    return ''.join(buf)


RE_TIME = r'(\d\d:\d\d(?::\d\d(?:\.\d+?))?)'
RE_VERB = r'(\w+(?:ed|te))'
RE_LINK_MD = r'\[([^\]:/]*)\]\(?:[^)]*\)'
RE_LINK_WIKI = r'\[\[(?:[^\]:]*\/)?([^\]:/]*)\]\]'

RE_TIME_FORMAT = RE_TIME + r'(?:\s*\-\s*' + RE_TIME + r')?'


def parse_event_string(
    event_str: str,
    date: datetime.date,
    timezone: ZoneInfo,
) -> Event:
    """Parses event string for the given date."""
    if m := re.match(
        r'^\s*'
        + RE_TIME_FORMAT
        + r'[ :\|-]*'
        + RE_VERB
        + r'\s+'
        + RE_LINK_MD
        + r'\.?\s*(.*)$',
        event_str,
    ):
        start_time = datetime.time.fromisoformat(m.group(1))
        end_time = datetime.time.fromisoformat(m.group(2)) if m.group(2) else start_time
    elif m := re.match(
        r'^\s*'
        + RE_TIME_FORMAT
        + r'[ :\|-]*'
        + RE_VERB
        + r'\s+'
        + RE_LINK_WIKI
        + r'\.?\s*(.*)$',
        event_str,
    ):
        start_time = datetime.time.fromisoformat(m.group(1))
        end_time = datetime.time.fromisoformat(m.group(2)) if m.group(2) else start_time
    else:
        logger.debug('Could not parse format: %s', event_str)
        return Event(None, None, None, None, event_str)

    start = datetime.datetime.combine(date, start_time, timezone).astimezone(
        datetime.UTC,
    )
    end = datetime.datetime.combine(date, end_time, timezone).astimezone(datetime.UTC)

    return Event(start, end, m.group(3), m.group(4), m.group(5))