diff --git a/libpurple_to_markdown/__main__.py b/libpurple_to_markdown/__main__.py index ba8f893..cfd7f91 100644 --- a/libpurple_to_markdown/__main__.py +++ b/libpurple_to_markdown/__main__.py @@ -1,7 +1,9 @@ import argparse import dataclasses import logging -from collections.abc import Iterable, Iterator +import datetime +from collections.abc import Iterable, Iterator, Mapping +from typing import Callable from pathlib import Path from . import ( @@ -16,16 +18,16 @@ from .markdown import format_messages logger = logging.getLogger(__name__) -def group_messages(messages: Iterable[Message], key) -> dict[str, list[Message]]: +def group_messages(messages: Iterable[Message], key_fn: Callable[[Message], str]) -> dict[str, list[Message]]: by_key: dict[str, list[Message]] = {} for msg in messages: - by_key.setdefault(key(msg), []).append(msg) + by_key.setdefault(key_fn(msg), []).append(msg) del msg return by_key def group_messages_by_chat_id(messages: Iterable[Message]) -> dict[str, list[Message]]: - return group_messages(messages, key=lambda msg: msg.chat_id) + return group_messages(messages, key_fn=lambda msg: msg.chat_id) def year_and_month_period_key(msg: Message): @@ -43,7 +45,9 @@ def year_quarter_period_key(msg: Message): MAX_AVERAGE_MESSAGES_PER_PERIOD = 120 -PERIOD_KEYS_BY_NAME = { +TOO_FEW_MESSAGES_TO_CARE = 2 + +PERIOD_KEYS_BY_NAME: Mapping[str, Callable[[Message], str]] = { 'full': (lambda msg: 'full'), 'year': year_period_key, 'quarter': year_quarter_period_key, @@ -51,12 +55,16 @@ PERIOD_KEYS_BY_NAME = { } -def group_messages_by_period(messages: Iterable[Message], period_key: str | None = None) -> dict[str, list[Message]]: - possible_period_keys = PERIOD_KEYS_BY_NAME.values() +def group_messages_by_period(messages: Iterable[Message], period_key: str | None = None) -> tuple[dict[str, list[Message]], Callable[[Message], str]]: + # Determine key function + possible_period_keys: Iterable[Callable[[Message], str]] = PERIOD_KEYS_BY_NAME.values() if period_key is not None: possible_period_keys = [PERIOD_KEYS_BY_NAME[period_key]] + del period_key + + # Group by key for period_key_fn in possible_period_keys: - grouped = group_messages(messages, key=period_key) + grouped = group_messages(messages, key_fn=period_key_fn) average_num_messages = sum(len(grouped[k]) for k in grouped) / len(grouped) if average_num_messages <= MAX_AVERAGE_MESSAGES_PER_PERIOD: break @@ -81,12 +89,10 @@ def parse_args(): parser.add_argument('--output', type=Path) parser.add_argument('--myself', type=str, default='Myself') parser.add_argument('--overwrite', action='store_true', dest='overwrite_files') - parser.add_argument('--period', dest='period_key', values=list(PERIOD_KEYS_BY_NAME.keys())) + parser.add_argument('--period', dest='period_key', choices=list(PERIOD_KEYS_BY_NAME.keys())) parser.add_argument('--skip-this-period', action='store_true', dest='skip_this_period') return parser.parse_args() - - def main(): logging.basicConfig() logging.getLogger().setLevel('INFO') @@ -115,7 +121,7 @@ def main(): for chat_id, messages_in_chat_original in messages_by_chat_id.items(): messages_in_chat = merge_adjacent_messages(messages_in_chat_original) - if len(messages_in_chat) <= 2: + if len(messages_in_chat) <= TOO_FEW_MESSAGES_TO_CARE: logger.info(' "%s": Skipped due to too few messages', chat_id) continue @@ -128,7 +134,7 @@ def main(): len(messages_in_chat_original) / len(messages_by_period), ) - this_period_name = period_key_fn(datetime.datetime.now()) + this_period_name = period_key_fn(Message(datetime.datetime.now(), '','','')) for period_key_name, messages in messages_by_period.items(): file_escaped_chat_id = chat_id.replace(' ', '-') @@ -146,7 +152,7 @@ def main(): continue # Create folders and file - output_file.parent.mkdir(exist_ok=True) + output_file.parent.mkdir(exist_ok=True, parents=True) with open(output_file, 'w') as f: f.write( format_messages(