"""Obsidian Import.

Sub-module for importing time-based data into Obsidian.
"""

import dataclasses
import datetime
from collections.abc import Iterable, Iterator
from logging import getLogger
from pathlib import Path
from typing import Any
from zoneinfo import ZoneInfo

from personal_data.activity import (
    ActivitySample,
    Label,
    RealizedActivitySample,
    heuristically_realize_samples,
    merge_adjacent_samples,
)
from personal_data.csv_import import determine_possible_keys, load_csv_file, start_end

from .obsidian import Event, ObsidianVault

logger = getLogger(__name__)

Row = dict[str, Any]
Rows = list[Row]


def iterate_samples_from_rows(rows: Rows) -> Iterator[ActivitySample]:
    assert len(rows) > 0

    if True:
        event_data = rows[len(rows) // 2]  # Hopefully select a useful representative.
        possible_keys = determine_possible_keys(event_data)
        logger.info('Found possible keys: %s', possible_keys)
        del event_data

    assert len(possible_keys.time_start) + len(possible_keys.time_end) >= 1
    assert len(possible_keys.image) >= 0

    for event_data in rows:
        (start_at, end_at) = start_end(event_data, possible_keys)
        labels = [
            Label(k, event_data.get(k)) for k in possible_keys.misc if k in event_data
        ]

        # Create event
        yield ActivitySample(
            labels=tuple(labels),
            start_at=start_at,
            end_at=end_at,
        )

        del event_data


def import_workout_csv(vault: ObsidianVault, rows: Rows) -> int:
    num_updated = 0
    for row in rows:
        date = row['Date']
        was_updated = False
        mapping = {
            'Cycling (mins)': ('Cycling (Duration)', 'minutes'),
            'Cycling (kcals)': ('Cycling (kcals)', ''),
            'Weight (Kg)': ('Weight (Kg)', ''),
        }

        for input_key, (output_key, unit) in mapping.items():
            v = row.get(input_key)
            if v is not None:
                if unit:
                    v = str(v) + ' ' + unit
            was_updated |= vault.add_statistic(date, output_key, v)
            if input_key != output_key:
                was_updated |= vault.add_statistic(date, input_key, None)
            del input_key, output_key, unit, v

        if was_updated:
            num_updated += 1
        del row, date
    return num_updated


def import_step_counts_csv(vault: ObsidianVault, rows: Rows) -> int:
    MINIMUM = 300

    num_updated = 0

    rows_per_date = {}
    for row in rows:
        date = row['Start'].date()
        rows_per_date.setdefault(date, [])
        rows_per_date[date].append(row)
        del date, row

    steps_per_date = {
        date: sum(row['Steps'] for row in rows) for date, rows in rows_per_date.items()
    }

    for date, steps in steps_per_date.items():
        if steps < MINIMUM:
            continue
        was_updated = vault.add_statistic(date, 'Steps', steps)
        if was_updated:
            num_updated += 1
        del date, steps, was_updated

    return num_updated


def escape_for_obsidian_link(link: str) -> str:
    return link.replace(':', ' ').replace('/', ' ').replace('  ', ' ')


@dataclasses.dataclass(frozen=True)
class EventContent:
    verb: str
    subject: str
    comment: str


def import_activity_sample_csv(
    vault: ObsidianVault,
    rows: Rows,
    content_mapper,
    group_category: str | None = None,
) -> int:
    samples = heuristically_realize_samples(list(iterate_samples_from_rows(rows)))

    if group_category is not None:
        samples = merge_adjacent_samples(list(samples), group_category)

    timezone = ZoneInfo(
        'Europe/Copenhagen',
    )  # TODO: Parameterize in an intelligent manner

    samples_per_date: dict[datetime.date, list[RealizedActivitySample]] = {}
    for sample in samples:
        date: datetime.date = sample.start_at.astimezone(timezone).date()
        samples_per_date.setdefault(date, [])
        samples_per_date[date].append(sample)
        del date, sample
    del rows

    def map_to_event(sample: RealizedActivitySample) -> Event:
        content = content_mapper(sample)
        return Event(
            sample.start_at,
            sample.end_at,
            verb=content.verb,
            subject=escape_for_obsidian_link(content.subject),
            comment=content.comment,
        )

    num_updated = 0

    for date, samples in list(samples_per_date.items()):
        events = [map_to_event(sample) for sample in samples]
        was_updated = vault.add_events(date, events)

        if was_updated:
            num_updated += 1
        del date, was_updated

    return num_updated


def import_activity_sample_csv_from_file(
    vault: ObsidianVault,
    data_path: Path,
    content_mapper,
    **kwargs,
) -> int:
    rows = load_csv_file(data_path)
    logger.info('Loaded CSV with %d lines (%s)', len(rows), data_path)
    num_updated = import_activity_sample_csv(vault, rows, content_mapper, **kwargs)
    logger.info('Updated %d files', num_updated)


def map_watched_series_content(sample: RealizedActivitySample) -> EventContent:
    subject = sample.single_label_with_category('series.name')
    comment = '{} Episode {}: *{}*'.format(
        sample.single_label_with_category('season.name'),
        sample.single_label_with_category('episode.index'),
        sample.single_label_with_category('episode.name'),
    )
    return EventContent(
        verb='Watched',
        subject=subject,
        comment=comment,
    )


def map_games_played_content(sample: RealizedActivitySample) -> EventContent:
    subject = sample.single_label_with_category('game.name')
    comment = ''
    return EventContent(
        verb='Played',
        subject=subject,
        comment=comment,
    )


def import_watched_series_csv_from_file(vault: ObsidianVault) -> int:
    data_path = Path('output/show_episodes_watched.csv')
    return import_activity_sample_csv_from_file(
        vault,
        data_path,
        map_watched_series_content,
    )


def import_played_games_csv_from_file(vault: ObsidianVault) -> int:
    data_path = Path('output/games_played.csv')
    if not data_path.exists():
        logger.warning('Skipping import of played games: %s is missing', data_path)
        return 0
    return import_activity_sample_csv_from_file(
        vault,
        data_path,
        map_games_played_content,
        group_category='game.name',
    )


def import_data(obsidian_path: Path, dry_run=True):
    vault = ObsidianVault(obsidian_path, read_only=dry_run and 'silent' or None)

    if False:
        data_path = Path('/home/jmaa/Notes/workout.csv')
        rows = load_csv_file(data_path)
        logger.info('Loaded CSV with %d lines', len(rows))
        num_updated = import_workout_csv(vault, rows)
        logger.info('Updated %d files', num_updated)

    if False:
        data_path = Path(
            '/home/jmaa/personal-archive/misc-data/step_counts_2023-07-26_to_2024-09-21.csv',
        )
        rows = load_csv_file(data_path)
        logger.info('Loaded CSV with %d lines', len(rows))
        num_updated = import_step_counts_csv(vault, rows)
        logger.info('Updated %d files', num_updated)

    import_played_games_csv_from_file(vault)
    import_watched_series_csv_from_file(vault)

    num_dirty = len([f for f in vault.internal_file_text_cache.values() if f.is_dirty])
    logger.info('dirty files in cache: %d', num_dirty)
    logger.info(
        'clean files in cache: %d',
        len(vault.internal_file_text_cache) - num_dirty,
    )
    if not dry_run:
        vault.flush_cache()