Code quality and typing improvements
This commit is contained in:
parent
5853249b7f
commit
96a2e2bed9
|
@ -21,9 +21,9 @@ FORMAT_DATE_HEADER = '%a, %d %b %Y %H:%M:%S GMT'
|
|||
|
||||
|
||||
def parse_duration(text: str) -> datetime.timedelta:
|
||||
(num, unit) = text.split(' ')
|
||||
num = int(num)
|
||||
unit = DATETIME_UNITS[unit]
|
||||
(num_str, unit_str) = text.split(' ')
|
||||
num = int(num_str)
|
||||
unit = DATETIME_UNITS[unit_str]
|
||||
return unit * num
|
||||
|
||||
|
||||
|
@ -57,12 +57,16 @@ def parse_time(text: str) -> datetime.datetime:
|
|||
if time is None and (m := try_parse(text, '%d %b @ %I:%M%p')):
|
||||
time = m.replace(year=NOW.year)
|
||||
|
||||
assert time is not None, 'Could not parse format'
|
||||
if time is None:
|
||||
msg = 'Unknown format: ' + text
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if time.tzinfo is None:
|
||||
time = time.replace(tzinfo=LOCAL_TIMEZONE)
|
||||
|
||||
assert time.tzinfo is not None, time
|
||||
if time.tzinfo is None:
|
||||
msg = 'Could not parse timezone: ' + text
|
||||
raise RuntimeError(msg)
|
||||
return time
|
||||
|
||||
|
||||
|
@ -74,4 +78,5 @@ def parse_date(text: str) -> datetime.date:
|
|||
return dt.date()
|
||||
if dt := try_parse(text, '%B %d, %Y'):
|
||||
return dt.date()
|
||||
assert False, text
|
||||
msg = 'Unknown format: ' + text
|
||||
raise RuntimeError(msg)
|
||||
|
|
|
@ -3,9 +3,8 @@ import csv
|
|||
import datetime
|
||||
import io
|
||||
import logging
|
||||
import typing
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from collections.abc import Iterable, Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
@ -20,13 +19,14 @@ def csv_safe_value(v: Any) -> str:
|
|||
if isinstance(v, urllib.parse.ParseResult):
|
||||
return v.geturl()
|
||||
if isinstance(v, datetime.datetime):
|
||||
assert v.tzinfo is not None, v
|
||||
if v.tzinfo is None:
|
||||
raise RuntimeError(v)
|
||||
return str(v)
|
||||
|
||||
|
||||
def equals_without_fields(
|
||||
a: Mapping[str, object],
|
||||
b: Mapping[str, object],
|
||||
a: Mapping[str, Any],
|
||||
b: Mapping[str, Any],
|
||||
fields: Iterable[str] = frozenset(),
|
||||
) -> bool:
|
||||
a = dict(a)
|
||||
|
@ -39,14 +39,13 @@ def equals_without_fields(
|
|||
|
||||
|
||||
def deduplicate_by_ignoring_certain_fields(
|
||||
dicts: list[dict],
|
||||
dicts: list[frozendict[str, Any]],
|
||||
deduplicate_ignore_columns: Iterable[str],
|
||||
) -> list[dict]:
|
||||
) -> list[frozendict[str, Any]]:
|
||||
"""Removes duplicates that occur when ignoring certain columns.
|
||||
|
||||
Output order is stable.
|
||||
"""
|
||||
|
||||
to_remove = set()
|
||||
for idx1, first in enumerate(dicts):
|
||||
for idx2, second in enumerate(dicts[idx1 + 1 :], idx1 + 1):
|
||||
|
@ -55,19 +54,21 @@ def deduplicate_by_ignoring_certain_fields(
|
|||
del idx2, second
|
||||
del idx1, first
|
||||
|
||||
to_remove = sorted(to_remove)
|
||||
while to_remove:
|
||||
del dicts[to_remove.pop()]
|
||||
to_remove_ls = sorted(to_remove)
|
||||
del to_remove
|
||||
while to_remove_ls:
|
||||
del dicts[to_remove_ls.pop()]
|
||||
|
||||
return dicts
|
||||
|
||||
|
||||
def deduplicate_dicts(
|
||||
dicts: Sequence[dict[str, typing.Any] | frozendict[str, typing.Any]],
|
||||
dicts: list[frozendict[str, Any]],
|
||||
deduplicate_mode: data.DeduplicateMode,
|
||||
deduplicate_ignore_columns: list[str],
|
||||
) -> tuple[Sequence[dict[str, typing.Any]], list[str]]:
|
||||
assert isinstance(deduplicate_ignore_columns, list), deduplicate_ignore_columns
|
||||
) -> tuple[list[frozendict[str, Any]], list[str]]:
|
||||
if not isinstance(deduplicate_ignore_columns, list):
|
||||
raise TypeError(deduplicate_ignore_columns)
|
||||
|
||||
fieldnames = []
|
||||
for d in dicts:
|
||||
|
@ -78,7 +79,7 @@ def deduplicate_dicts(
|
|||
del d
|
||||
|
||||
if deduplicate_mode == data.DeduplicateMode.ONLY_LATEST:
|
||||
while len(dicts) >= 2 and equals_without_fields(
|
||||
while len(dicts) > 1 and equals_without_fields(
|
||||
dicts[-1],
|
||||
dicts[-2],
|
||||
deduplicate_ignore_columns,
|
||||
|
@ -90,13 +91,13 @@ def deduplicate_dicts(
|
|||
deduplicate_ignore_columns,
|
||||
)
|
||||
elif deduplicate_mode != data.DeduplicateMode.NONE:
|
||||
dicts = set(dicts)
|
||||
dicts = list(set(dicts))
|
||||
|
||||
dicts = sorted(dicts, key=lambda d: tuple(str(d.get(fn, '')) for fn in fieldnames))
|
||||
return dicts, fieldnames
|
||||
|
||||
|
||||
def normalize_dict(d: dict[str, typing.Any]) -> frozendict[str, typing.Any]:
|
||||
def normalize_dict(d: dict[str, Any] | frozendict[str, Any]) -> frozendict[str, Any]:
|
||||
return frozendict(
|
||||
{
|
||||
k: csv_import.csv_str_to_value(str(v))
|
||||
|
@ -108,20 +109,23 @@ def normalize_dict(d: dict[str, typing.Any]) -> frozendict[str, typing.Any]:
|
|||
|
||||
def extend_csv_file(
|
||||
csv_file: Path,
|
||||
new_dicts: list[dict[str, typing.Any]],
|
||||
new_dicts: list[dict[str, Any] | frozendict[str, Any]],
|
||||
deduplicate_mode: data.DeduplicateMode,
|
||||
deduplicate_ignore_columns: list[str],
|
||||
) -> dict:
|
||||
assert isinstance(deduplicate_ignore_columns, list), deduplicate_ignore_columns
|
||||
if not isinstance(deduplicate_ignore_columns, list):
|
||||
raise TypeError(deduplicate_ignore_columns)
|
||||
|
||||
try:
|
||||
dicts = csv_import.load_csv_file(csv_file)
|
||||
except (FileNotFoundError, _csv.Error) as e:
|
||||
original_dicts = csv_import.load_csv_file(csv_file)
|
||||
except (FileNotFoundError, _csv.Error):
|
||||
logger.info('Creating file: %s', csv_file)
|
||||
dicts = []
|
||||
original_dicts = []
|
||||
|
||||
original_num_dicts = len(dicts)
|
||||
dicts += [normalize_dict(d) for d in new_dicts]
|
||||
original_num_dicts = len(original_dicts)
|
||||
dicts = [normalize_dict(d) for d in original_dicts] + [
|
||||
normalize_dict(d) for d in new_dicts
|
||||
]
|
||||
del new_dicts
|
||||
|
||||
dicts, fieldnames = deduplicate_dicts(
|
||||
|
|
|
@ -39,7 +39,7 @@ def test_all_fields():
|
|||
]
|
||||
|
||||
|
||||
def test_all_fields():
|
||||
def test_all_fields_for_duplicated_list():
|
||||
ls, fields = deduplicate_dicts(LIST + LIST, DeduplicateMode.BY_ALL_COLUMNS, ['t'])
|
||||
assert fields == ['a', 'b', 't']
|
||||
print(ls)
|
||||
|
|
|
@ -54,6 +54,6 @@ PARSE_MAPPINGS = [
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('text,parsed', PARSE_MAPPINGS)
|
||||
def test_csv_str_to_value(text, parsed):
|
||||
@pytest.mark.parametrize(('text', 'parsed'), PARSE_MAPPINGS)
|
||||
def test_csv_str_to_value(text: str, parsed: object):
|
||||
assert csv_str_to_value(text) == parsed, text
|
||||
|
|
|
@ -10,6 +10,6 @@ URLS_AND_IDS = [
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('id, url', URLS_AND_IDS)
|
||||
def test_game_psnprofiles_id_from_url(id, url):
|
||||
assert psnprofiles.game_psnprofiles_id_from_url(url) == id
|
||||
@pytest.mark.parametrize(('psnprofiles_id', 'url'), URLS_AND_IDS)
|
||||
def test_game_psnprofiles_id_from_url(psnprofiles_id: int, url: str):
|
||||
assert psnprofiles.game_psnprofiles_id_from_url(url) == psnprofiles_id
|
||||
|
|
|
@ -2,5 +2,4 @@ import personal_data
|
|||
|
||||
|
||||
def test_version():
|
||||
assert personal_data._version.__version__ is not None
|
||||
assert personal_data.__version__ is not None
|
||||
|
|
Loading…
Reference in New Issue
Block a user