1
0
This commit is contained in:
Jon Michael Aanes 2024-11-17 12:20:52 +01:00
parent a4df23e8ff
commit 6f22e8d239
Signed by: Jmaa
SSH Key Fingerprint: SHA256:Ab0GfHGCblESJx7JRE4fj4bFy/KRpeLhi41y4pF3sNA

View File

@ -1,7 +1,9 @@
import argparse
import dataclasses
import logging
from collections.abc import Iterable, Iterator
import datetime
from collections.abc import Iterable, Iterator, Mapping
from typing import Callable
from pathlib import Path
from . import (
@ -16,16 +18,16 @@ from .markdown import format_messages
logger = logging.getLogger(__name__)
def group_messages(messages: Iterable[Message], key) -> dict[str, list[Message]]:
def group_messages(messages: Iterable[Message], key_fn: Callable[[Message], str]) -> dict[str, list[Message]]:
by_key: dict[str, list[Message]] = {}
for msg in messages:
by_key.setdefault(key(msg), []).append(msg)
by_key.setdefault(key_fn(msg), []).append(msg)
del msg
return by_key
def group_messages_by_chat_id(messages: Iterable[Message]) -> dict[str, list[Message]]:
return group_messages(messages, key=lambda msg: msg.chat_id)
return group_messages(messages, key_fn=lambda msg: msg.chat_id)
def year_and_month_period_key(msg: Message):
@ -43,7 +45,9 @@ def year_quarter_period_key(msg: Message):
MAX_AVERAGE_MESSAGES_PER_PERIOD = 120
PERIOD_KEYS_BY_NAME = {
TOO_FEW_MESSAGES_TO_CARE = 2
PERIOD_KEYS_BY_NAME: Mapping[str, Callable[[Message], str]] = {
'full': (lambda msg: 'full'),
'year': year_period_key,
'quarter': year_quarter_period_key,
@ -51,12 +55,16 @@ PERIOD_KEYS_BY_NAME = {
}
def group_messages_by_period(messages: Iterable[Message], period_key: str | None = None) -> dict[str, list[Message]]:
possible_period_keys = PERIOD_KEYS_BY_NAME.values()
def group_messages_by_period(messages: Iterable[Message], period_key: str | None = None) -> tuple[dict[str, list[Message]], Callable[[Message], str]]:
# Determine key function
possible_period_keys: Iterable[Callable[[Message], str]] = PERIOD_KEYS_BY_NAME.values()
if period_key is not None:
possible_period_keys = [PERIOD_KEYS_BY_NAME[period_key]]
del period_key
# Group by key
for period_key_fn in possible_period_keys:
grouped = group_messages(messages, key=period_key)
grouped = group_messages(messages, key_fn=period_key_fn)
average_num_messages = sum(len(grouped[k]) for k in grouped) / len(grouped)
if average_num_messages <= MAX_AVERAGE_MESSAGES_PER_PERIOD:
break
@ -81,12 +89,10 @@ def parse_args():
parser.add_argument('--output', type=Path)
parser.add_argument('--myself', type=str, default='Myself')
parser.add_argument('--overwrite', action='store_true', dest='overwrite_files')
parser.add_argument('--period', dest='period_key', values=list(PERIOD_KEYS_BY_NAME.keys()))
parser.add_argument('--period', dest='period_key', choices=list(PERIOD_KEYS_BY_NAME.keys()))
parser.add_argument('--skip-this-period', action='store_true', dest='skip_this_period')
return parser.parse_args()
def main():
logging.basicConfig()
logging.getLogger().setLevel('INFO')
@ -115,7 +121,7 @@ def main():
for chat_id, messages_in_chat_original in messages_by_chat_id.items():
messages_in_chat = merge_adjacent_messages(messages_in_chat_original)
if len(messages_in_chat) <= 2:
if len(messages_in_chat) <= TOO_FEW_MESSAGES_TO_CARE:
logger.info(' "%s": Skipped due to too few messages', chat_id)
continue
@ -128,7 +134,7 @@ def main():
len(messages_in_chat_original) / len(messages_by_period),
)
this_period_name = period_key_fn(datetime.datetime.now())
this_period_name = period_key_fn(Message(datetime.datetime.now(), '','',''))
for period_key_name, messages in messages_by_period.items():
file_escaped_chat_id = chat_id.replace(' ', '-')
@ -146,7 +152,7 @@ def main():
continue
# Create folders and file
output_file.parent.mkdir(exist_ok=True)
output_file.parent.mkdir(exist_ok=True, parents=True)
with open(output_file, 'w') as f:
f.write(
format_messages(