import csv
import dataclasses
import datetime
import decimal
import typing
import urllib.parse
from collections.abc import Callable
from decimal import Decimal
from pathlib import Path
from typing import Any

from frozendict import frozendict

CSV_DIALECT = 'one_true_dialect'
csv.register_dialect(CSV_DIALECT, lineterminator='\n', skipinitialspace=True)

T = typing.TypeVar('T')


def csv_safe_value(v: Any) -> str:
    if isinstance(v, urllib.parse.ParseResult):
        return v.geturl()
    if isinstance(v, datetime.datetime):
        if v.tzinfo is None or v.tzinfo != datetime.UTC:
            msg = f'Timezone must be UTC: {v}'
            raise ValueError(msg)
    if isinstance(v, urllib.parse.ParseResult):
        return v.geturl()
    return str(v)


def try_value(fn: Callable[[str], T], s: str) -> T | None:
    try:
        return fn(s)
    except (ValueError, decimal.InvalidOperation):
        return None


def parse_timedelta(text: str) -> datetime.timedelta:
    if t := try_value(lambda t: datetime.datetime.strptime(t, '%H:%M:%S.%f'), text):
        return datetime.timedelta(hours=t.hour, minutes=t.minute, seconds=t.second)
    elif t := try_value(lambda t: datetime.datetime.strptime(t, '%H:%M:%S'), text):
        return datetime.timedelta(hours=t.hour, minutes=t.minute, seconds=t.second)
    else:
        return None


def csv_str_to_value(
    s: str,
) -> (
    str
    | Decimal
    | datetime.date
    | datetime.datetime
    | urllib.parse.ParseResult
    | bool
    | None
):
    assert not isinstance(s, list)  # TODO?

    if s is None:
        return None
    s = s.strip()
    if len(s) == 0:
        return None
    if (v_decimal := try_value(Decimal, s)) is not None:
        return v_decimal
    if (v_date := try_value(datetime.date.fromisoformat, s)) is not None:
        return v_date
    if (v_datetime := try_value(datetime.datetime.fromisoformat, s)) is not None:
        return v_datetime
    if (v_timedelta := parse_timedelta(s)) is not None:
        return v_timedelta
    if s.startswith(('http://', 'https://')):
        return urllib.parse.urlparse(s)
    if s.lower() == 'false':
        return False
    if s.lower() == 'true':
        return True
    if s.lower() == 'none':
        return None
    return s


def load_csv_file(csv_file: Path, sniff=False) -> list[frozendict[str, typing.Any]]:
    dicts: list[frozendict] = []
    with open(csv_file) as csvfile:
        if sniff:
            dialect = csv.Sniffer().sniff(csvfile.read(1024))
            csvfile.seek(0)
        else:
            dialect = CSV_DIALECT
        reader = csv.DictReader(csvfile, dialect=dialect)
        for row in reader:
            for k in list(row.keys()):
                orig = row[k]
                row[k] = csv_str_to_value(orig)
                if row[k] is None:
                    del row[k]
                del k, orig
            dicts.append(frozendict(row))
            del row
        del csvfile
    return dicts


@dataclasses.dataclass
class PossibleKeys:
    time_start: list[str]
    time_end: list[str]
    duration: list[str]
    name: list[str]
    image: list[str]
    misc: list[str]


def is_duration_key(k,v):
    if isinstance(v, Decimal) and 'duration_seconds' in k:
        return True
    if isinstance(v, datetime.timedelta) and 'duration' in k:
        return True
    return False


def determine_possible_keys(event_data: dict[str, Any]) -> PossibleKeys:
    # Select data
    time_keys = [k for k, v in event_data.items() if isinstance(v, datetime.date)]
    duration_keys = [
        k
        for k, v in event_data.items()
        if is_duration_key(k,v)
    ]
    name_keys = [k for k, v in event_data.items() if isinstance(v, str)]
    image_keys = [
        k for k, v in event_data.items() if isinstance(v, urllib.parse.ParseResult)
    ]

    misc_keys = list(event_data.keys())
    for k in image_keys:
        if k in misc_keys:
            misc_keys.remove(k)
        del k
    for k in time_keys:
        if k in misc_keys:
            misc_keys.remove(k)
        del k

    time_start_keys = [k for k in time_keys if 'start' in k.lower()]
    time_end_keys = [
        k
        for k in time_keys
        if 'end' in k.lower() or 'stop' in k.lower() or 'last' in k.lower()
    ]

    return PossibleKeys(
        time_start=time_start_keys,
        time_end=time_end_keys,
        duration=duration_keys,
        name=name_keys,
        image=image_keys,
        misc=misc_keys,
    )


def start_end(
    sample: dict[str, Any],
    keys: PossibleKeys,
) -> tuple[datetime.datetime | None, datetime.datetime | None]:
    if keys.time_start and keys.time_end:
        return (sample[keys.time_start[0]], sample[keys.time_end[0]])

    if keys.time_start and keys.duration:
        start = sample[keys.time_start[0]]
        duration = sample[keys.duration[0]]
        if not isinstance(duration, datetime.timedelta):
            duration = datetime.timedelta(seconds=float(duration))
        return (start, start + duration)

    if keys.time_end and keys.duration:
        end = sample[keys.time_end[0]]
        duration = sample[keys.duration[0]]
        if not isinstance(duration, datetime.timedelta):
            duration = datetime.timedelta(seconds=float(duration))
        return (end - duration, end)

    if keys.time_start:
        start = sample[keys.time_start[0]]
        return (start, None)
    if keys.time_end:
        return (None, sample[keys.time_end[0]])
    return (None, None)