Types
This commit is contained in:
parent
a4df23e8ff
commit
6f22e8d239
|
@ -1,7 +1,9 @@
|
|||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections.abc import Iterable, Iterator
|
||||
import datetime
|
||||
from collections.abc import Iterable, Iterator, Mapping
|
||||
from typing import Callable
|
||||
from pathlib import Path
|
||||
|
||||
from . import (
|
||||
|
@ -16,16 +18,16 @@ from .markdown import format_messages
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def group_messages(messages: Iterable[Message], key) -> dict[str, list[Message]]:
|
||||
def group_messages(messages: Iterable[Message], key_fn: Callable[[Message], str]) -> dict[str, list[Message]]:
|
||||
by_key: dict[str, list[Message]] = {}
|
||||
for msg in messages:
|
||||
by_key.setdefault(key(msg), []).append(msg)
|
||||
by_key.setdefault(key_fn(msg), []).append(msg)
|
||||
del msg
|
||||
return by_key
|
||||
|
||||
|
||||
def group_messages_by_chat_id(messages: Iterable[Message]) -> dict[str, list[Message]]:
|
||||
return group_messages(messages, key=lambda msg: msg.chat_id)
|
||||
return group_messages(messages, key_fn=lambda msg: msg.chat_id)
|
||||
|
||||
|
||||
def year_and_month_period_key(msg: Message):
|
||||
|
@ -43,7 +45,9 @@ def year_quarter_period_key(msg: Message):
|
|||
|
||||
MAX_AVERAGE_MESSAGES_PER_PERIOD = 120
|
||||
|
||||
PERIOD_KEYS_BY_NAME = {
|
||||
TOO_FEW_MESSAGES_TO_CARE = 2
|
||||
|
||||
PERIOD_KEYS_BY_NAME: Mapping[str, Callable[[Message], str]] = {
|
||||
'full': (lambda msg: 'full'),
|
||||
'year': year_period_key,
|
||||
'quarter': year_quarter_period_key,
|
||||
|
@ -51,12 +55,16 @@ PERIOD_KEYS_BY_NAME = {
|
|||
}
|
||||
|
||||
|
||||
def group_messages_by_period(messages: Iterable[Message], period_key: str | None = None) -> dict[str, list[Message]]:
|
||||
possible_period_keys = PERIOD_KEYS_BY_NAME.values()
|
||||
def group_messages_by_period(messages: Iterable[Message], period_key: str | None = None) -> tuple[dict[str, list[Message]], Callable[[Message], str]]:
|
||||
# Determine key function
|
||||
possible_period_keys: Iterable[Callable[[Message], str]] = PERIOD_KEYS_BY_NAME.values()
|
||||
if period_key is not None:
|
||||
possible_period_keys = [PERIOD_KEYS_BY_NAME[period_key]]
|
||||
del period_key
|
||||
|
||||
# Group by key
|
||||
for period_key_fn in possible_period_keys:
|
||||
grouped = group_messages(messages, key=period_key)
|
||||
grouped = group_messages(messages, key_fn=period_key_fn)
|
||||
average_num_messages = sum(len(grouped[k]) for k in grouped) / len(grouped)
|
||||
if average_num_messages <= MAX_AVERAGE_MESSAGES_PER_PERIOD:
|
||||
break
|
||||
|
@ -81,12 +89,10 @@ def parse_args():
|
|||
parser.add_argument('--output', type=Path)
|
||||
parser.add_argument('--myself', type=str, default='Myself')
|
||||
parser.add_argument('--overwrite', action='store_true', dest='overwrite_files')
|
||||
parser.add_argument('--period', dest='period_key', values=list(PERIOD_KEYS_BY_NAME.keys()))
|
||||
parser.add_argument('--period', dest='period_key', choices=list(PERIOD_KEYS_BY_NAME.keys()))
|
||||
parser.add_argument('--skip-this-period', action='store_true', dest='skip_this_period')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel('INFO')
|
||||
|
@ -115,7 +121,7 @@ def main():
|
|||
|
||||
for chat_id, messages_in_chat_original in messages_by_chat_id.items():
|
||||
messages_in_chat = merge_adjacent_messages(messages_in_chat_original)
|
||||
if len(messages_in_chat) <= 2:
|
||||
if len(messages_in_chat) <= TOO_FEW_MESSAGES_TO_CARE:
|
||||
logger.info(' "%s": Skipped due to too few messages', chat_id)
|
||||
continue
|
||||
|
||||
|
@ -128,7 +134,7 @@ def main():
|
|||
len(messages_in_chat_original) / len(messages_by_period),
|
||||
)
|
||||
|
||||
this_period_name = period_key_fn(datetime.datetime.now())
|
||||
this_period_name = period_key_fn(Message(datetime.datetime.now(), '','',''))
|
||||
|
||||
for period_key_name, messages in messages_by_period.items():
|
||||
file_escaped_chat_id = chat_id.replace(' ', '-')
|
||||
|
@ -146,7 +152,7 @@ def main():
|
|||
continue
|
||||
|
||||
# Create folders and file
|
||||
output_file.parent.mkdir(exist_ok=True)
|
||||
output_file.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(
|
||||
format_messages(
|
||||
|
|
Loading…
Reference in New Issue
Block a user