From 96a2e2bed96e922a87612385fe52bba97799be98 Mon Sep 17 00:00:00 2001 From: Jon Michael Aanes Date: Fri, 25 Oct 2024 22:09:27 +0200 Subject: [PATCH] Code quality and typing improvements --- personal_data/parse_util.py | 17 +++++++----- personal_data/util.py | 52 ++++++++++++++++++++----------------- test/test_deduplicate.py | 2 +- test/test_parsing.py | 4 +-- test/test_psnprofiles.py | 6 ++--- test/test_util.py | 1 - 6 files changed, 45 insertions(+), 37 deletions(-) diff --git a/personal_data/parse_util.py b/personal_data/parse_util.py index 0e0bf76..9a89557 100644 --- a/personal_data/parse_util.py +++ b/personal_data/parse_util.py @@ -21,9 +21,9 @@ FORMAT_DATE_HEADER = '%a, %d %b %Y %H:%M:%S GMT' def parse_duration(text: str) -> datetime.timedelta: - (num, unit) = text.split(' ') - num = int(num) - unit = DATETIME_UNITS[unit] + (num_str, unit_str) = text.split(' ') + num = int(num_str) + unit = DATETIME_UNITS[unit_str] return unit * num @@ -57,12 +57,16 @@ def parse_time(text: str) -> datetime.datetime: if time is None and (m := try_parse(text, '%d %b @ %I:%M%p')): time = m.replace(year=NOW.year) - assert time is not None, 'Could not parse format' + if time is None: + msg = 'Unknown format: ' + text + raise RuntimeError(msg) if time.tzinfo is None: time = time.replace(tzinfo=LOCAL_TIMEZONE) - assert time.tzinfo is not None, time + if time.tzinfo is None: + msg = 'Could not parse timezone: ' + text + raise RuntimeError(msg) return time @@ -74,4 +78,5 @@ def parse_date(text: str) -> datetime.date: return dt.date() if dt := try_parse(text, '%B %d, %Y'): return dt.date() - assert False, text + msg = 'Unknown format: ' + text + raise RuntimeError(msg) diff --git a/personal_data/util.py b/personal_data/util.py index b8eb018..e43289e 100644 --- a/personal_data/util.py +++ b/personal_data/util.py @@ -3,9 +3,8 @@ import csv import datetime import io import logging -import typing import urllib.parse -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Iterable, Mapping from pathlib import Path from typing import Any @@ -20,13 +19,14 @@ def csv_safe_value(v: Any) -> str: if isinstance(v, urllib.parse.ParseResult): return v.geturl() if isinstance(v, datetime.datetime): - assert v.tzinfo is not None, v + if v.tzinfo is None: + raise RuntimeError(v) return str(v) def equals_without_fields( - a: Mapping[str, object], - b: Mapping[str, object], + a: Mapping[str, Any], + b: Mapping[str, Any], fields: Iterable[str] = frozenset(), ) -> bool: a = dict(a) @@ -39,14 +39,13 @@ def equals_without_fields( def deduplicate_by_ignoring_certain_fields( - dicts: list[dict], + dicts: list[frozendict[str, Any]], deduplicate_ignore_columns: Iterable[str], -) -> list[dict]: +) -> list[frozendict[str, Any]]: """Removes duplicates that occur when ignoring certain columns. Output order is stable. """ - to_remove = set() for idx1, first in enumerate(dicts): for idx2, second in enumerate(dicts[idx1 + 1 :], idx1 + 1): @@ -55,19 +54,21 @@ def deduplicate_by_ignoring_certain_fields( del idx2, second del idx1, first - to_remove = sorted(to_remove) - while to_remove: - del dicts[to_remove.pop()] + to_remove_ls = sorted(to_remove) + del to_remove + while to_remove_ls: + del dicts[to_remove_ls.pop()] return dicts def deduplicate_dicts( - dicts: Sequence[dict[str, typing.Any] | frozendict[str, typing.Any]], + dicts: list[frozendict[str, Any]], deduplicate_mode: data.DeduplicateMode, deduplicate_ignore_columns: list[str], -) -> tuple[Sequence[dict[str, typing.Any]], list[str]]: - assert isinstance(deduplicate_ignore_columns, list), deduplicate_ignore_columns +) -> tuple[list[frozendict[str, Any]], list[str]]: + if not isinstance(deduplicate_ignore_columns, list): + raise TypeError(deduplicate_ignore_columns) fieldnames = [] for d in dicts: @@ -78,7 +79,7 @@ def deduplicate_dicts( del d if deduplicate_mode == data.DeduplicateMode.ONLY_LATEST: - while len(dicts) >= 2 and equals_without_fields( + while len(dicts) > 1 and equals_without_fields( dicts[-1], dicts[-2], deduplicate_ignore_columns, @@ -90,13 +91,13 @@ def deduplicate_dicts( deduplicate_ignore_columns, ) elif deduplicate_mode != data.DeduplicateMode.NONE: - dicts = set(dicts) + dicts = list(set(dicts)) dicts = sorted(dicts, key=lambda d: tuple(str(d.get(fn, '')) for fn in fieldnames)) return dicts, fieldnames -def normalize_dict(d: dict[str, typing.Any]) -> frozendict[str, typing.Any]: +def normalize_dict(d: dict[str, Any] | frozendict[str, Any]) -> frozendict[str, Any]: return frozendict( { k: csv_import.csv_str_to_value(str(v)) @@ -108,20 +109,23 @@ def normalize_dict(d: dict[str, typing.Any]) -> frozendict[str, typing.Any]: def extend_csv_file( csv_file: Path, - new_dicts: list[dict[str, typing.Any]], + new_dicts: list[dict[str, Any] | frozendict[str, Any]], deduplicate_mode: data.DeduplicateMode, deduplicate_ignore_columns: list[str], ) -> dict: - assert isinstance(deduplicate_ignore_columns, list), deduplicate_ignore_columns + if not isinstance(deduplicate_ignore_columns, list): + raise TypeError(deduplicate_ignore_columns) try: - dicts = csv_import.load_csv_file(csv_file) - except (FileNotFoundError, _csv.Error) as e: + original_dicts = csv_import.load_csv_file(csv_file) + except (FileNotFoundError, _csv.Error): logger.info('Creating file: %s', csv_file) - dicts = [] + original_dicts = [] - original_num_dicts = len(dicts) - dicts += [normalize_dict(d) for d in new_dicts] + original_num_dicts = len(original_dicts) + dicts = [normalize_dict(d) for d in original_dicts] + [ + normalize_dict(d) for d in new_dicts + ] del new_dicts dicts, fieldnames = deduplicate_dicts( diff --git a/test/test_deduplicate.py b/test/test_deduplicate.py index 10250f9..4ca4506 100644 --- a/test/test_deduplicate.py +++ b/test/test_deduplicate.py @@ -39,7 +39,7 @@ def test_all_fields(): ] -def test_all_fields(): +def test_all_fields_for_duplicated_list(): ls, fields = deduplicate_dicts(LIST + LIST, DeduplicateMode.BY_ALL_COLUMNS, ['t']) assert fields == ['a', 'b', 't'] print(ls) diff --git a/test/test_parsing.py b/test/test_parsing.py index e9b08ad..747f9b7 100644 --- a/test/test_parsing.py +++ b/test/test_parsing.py @@ -54,6 +54,6 @@ PARSE_MAPPINGS = [ ] -@pytest.mark.parametrize('text,parsed', PARSE_MAPPINGS) -def test_csv_str_to_value(text, parsed): +@pytest.mark.parametrize(('text', 'parsed'), PARSE_MAPPINGS) +def test_csv_str_to_value(text: str, parsed: object): assert csv_str_to_value(text) == parsed, text diff --git a/test/test_psnprofiles.py b/test/test_psnprofiles.py index bc90da3..c818cae 100644 --- a/test/test_psnprofiles.py +++ b/test/test_psnprofiles.py @@ -10,6 +10,6 @@ URLS_AND_IDS = [ ] -@pytest.mark.parametrize('id, url', URLS_AND_IDS) -def test_game_psnprofiles_id_from_url(id, url): - assert psnprofiles.game_psnprofiles_id_from_url(url) == id +@pytest.mark.parametrize(('psnprofiles_id', 'url'), URLS_AND_IDS) +def test_game_psnprofiles_id_from_url(psnprofiles_id: int, url: str): + assert psnprofiles.game_psnprofiles_id_from_url(url) == psnprofiles_id diff --git a/test/test_util.py b/test/test_util.py index d083a8a..7d51335 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -2,5 +2,4 @@ import personal_data def test_version(): - assert personal_data._version.__version__ is not None assert personal_data.__version__ is not None