import dataclasses
import datetime
from collections.abc import Iterator, Sequence

HIDDEN_LABEL_CATEGORY = '__'
DEFAULT_ESTIMATED_DURATION = datetime.timedelta(hours=1)


@dataclasses.dataclass(frozen=True, order=True)
class Label:
    category: str
    label: str

    def __post_init__(self):
        assert self.category is not None
        assert ':' not in self.category
        assert self.label is not None


@dataclasses.dataclass(frozen=True, order=True)
class ActivitySample:
    labels: Sequence[Label]
    start_at: datetime.datetime | None
    end_at: datetime.datetime | None

    def __post_init__(self):
        if self.start_at and self.end_at:
            assert self.start_at <= self.end_at

    def single_label_with_category(self, category: str) -> str:
        for label in self.labels:
            if label.category == category:
                return label.label
        return None


@dataclasses.dataclass(frozen=True, order=True)
class RealizedActivitySample(ActivitySample):
    start_at: datetime.datetime
    end_at: datetime.datetime

    def __post_init__(self):
        assert self.start_at is not None
        assert self.end_at is not None
        assert self.start_at <= self.end_at


def heuristically_realize_samples(
    samples: list[ActivitySample],
) -> Iterator[RealizedActivitySample]:
    """Secret sauce.

    Guarentees that:
    * No samples overlap.
    """

    samples.sort(key=lambda x: x.end_at)

    previous_sample_end = None
    for sample in samples:
        end_at = sample.end_at

        if previous_sample_end is None:
            if end_at.tzinfo:
                previous_sample_end = datetime.datetime.fromtimestamp(0, datetime.UTC)
            else:
                previous_sample_end = datetime.datetime.fromtimestamp(0)

        assert previous_sample_end <= end_at, 'Iterating in incorrect order'

        # TODO: Allow end_at is None

        start_at = sample.start_at
        if start_at is None:
            estimated_duration: datetime.timedelta = DEFAULT_ESTIMATED_DURATION
            start_at = max(previous_sample_end, end_at - estimated_duration)
            del estimated_duration

        yield RealizedActivitySample(
            labels=sample.labels,
            end_at=end_at,
            start_at=start_at,
        )

        previous_sample_end = sample.end_at
        del sample


def mergable_labels(a: Sequence[Label], b: Sequence[Label]) -> Sequence[Label]:
    return list(set(a).intersection(set(b)))


def merge_adjacent_samples(
    samples: list[RealizedActivitySample],
    group_category: str,
) -> list[RealizedActivitySample]:
    max_interval_between_samples = datetime.timedelta(minutes=5)

    def can_merge(
        before: RealizedActivitySample,
        after: RealizedActivitySample,
    ) -> bool:
        if before.single_label_with_category(
            group_category,
        ) != after.single_label_with_category(group_category):
            return False
        return (after.start_at - before.end_at) < max_interval_between_samples

    samples.sort(key=lambda s: s.start_at)

    new: list[RealizedActivitySample] = []

    for s in samples:
        if len(new) > 0 and can_merge(new[-1], s):
            # TODO: Merge/strip attributes?
            new[-1] = RealizedActivitySample(
                labels=mergable_labels(new[-1].labels, s.labels),
                start_at=new[-1].start_at,
                end_at=s.end_at,
            )
        else:
            new.append(s)

    return new