import csv import datetime import decimal import io import logging import typing import urllib.parse from collections.abc import Callable, Iterable, Mapping, Sequence from decimal import Decimal from pathlib import Path from frozendict import frozendict from . import data logger = logging.getLogger(__name__) CSV_DIALECT = 'one_true_dialect' csv.register_dialect(CSV_DIALECT, lineterminator='\n', skipinitialspace=True) T = typing.TypeVar('T') def try_value(fn: Callable[[str], T], s: str) -> T | None: try: return fn(s) except (ValueError, decimal.InvalidOperation): return None def csv_str_to_value( s: str, ) -> ( str | Decimal | datetime.date | datetime.datetime | urllib.parse.ParseResult | bool | None ): s = s.strip() if len(s) == 0: return None if (v_decimal := try_value(Decimal, s)) is not None: return v_decimal if (v_date := try_value(datetime.date.fromisoformat, s)) is not None: return v_date if (v_datetime := try_value(datetime.datetime.fromisoformat, s)) is not None: return v_datetime if s.startswith(('http://', 'https://')): return urllib.parse.urlparse(s) if s.lower() == 'false': return False if s.lower() == 'true': return True if s.lower() == 'none': return None return s def csv_safe_value(v: object) -> str: if isinstance(v, urllib.parse.ParseResult): return v.geturl() if isinstance(v, datetime.datetime): assert v.tzinfo is not None, v return str(v) def equals_without_fields( a: Mapping[str, object], b: Mapping[str, object], fields: Iterable[str] = frozenset(), ) -> bool: a = dict(a) b = dict(b) for f in fields: del a[f], b[f] return frozendict(a) == frozendict(b) def deduplicate_by_ignoring_certain_fields( dicts: list[dict], deduplicate_ignore_columns: Iterable[str], ) -> list[dict]: """Removes duplicates that occur when ignoring certain columns. Output order is stable. """ to_remove = set() for idx1, first in enumerate(dicts): for idx2, second in enumerate(dicts[idx1 + 1 :], idx1 + 1): if equals_without_fields(first, second, deduplicate_ignore_columns): to_remove.add(idx2) del idx2, second del idx1, first to_remove = sorted(to_remove) while to_remove: del dicts[to_remove.pop()] return dicts def deduplicate_dicts( dicts: Sequence[dict[str,typing.Any] | frozendict[str,typing.Any]], deduplicate_mode: data.DeduplicateMode, deduplicate_ignore_columns: list[str], ) -> tuple[Sequence[dict[str,typing.Any]], list[str]]: assert isinstance(deduplicate_ignore_columns, list), deduplicate_ignore_columns fieldnames = [] for d in dicts: for k in d.keys(): if k not in fieldnames: fieldnames.append(k) del k del d if deduplicate_mode == data.DeduplicateMode.ONLY_LATEST: while len(dicts) >= 2 and equals_without_fields( dicts[-1], dicts[-2], deduplicate_ignore_columns, ): del dicts[-1] elif deduplicate_mode == data.DeduplicateMode.BY_ALL_COLUMNS: dicts = deduplicate_by_ignoring_certain_fields( dicts, deduplicate_ignore_columns, ) elif deduplicate_mode != data.DeduplicateMode.NONE: dicts = set(dicts) dicts = sorted(dicts, key=lambda d: tuple(str(d.get(fn, '')) for fn in fieldnames)) return dicts, fieldnames def normalize_dict(d: dict[str,typing.Any]) -> frozendict[str,typing.Any]: return frozendict( {k: csv_str_to_value(str(v)) for k, v in d.items() if csv_str_to_value(str(v)) is not None}, ) def load_csv_file(csv_file: Path) -> list[frozendict]: dicts: list[frozendict] = [] with open(csv_file) as csvfile: reader = csv.DictReader(csvfile, dialect=CSV_DIALECT) for row in reader: for k in list(row.keys()): orig = row[k] row[k] = csv_str_to_value(orig) if row[k] is None: del row[k] del k, orig dicts.append(frozendict(row)) del row del csvfile return dicts def extend_csv_file( csv_file: Path, new_dicts: list[dict], deduplicate_mode: data.DeduplicateMode, deduplicate_ignore_columns: list[str], ) -> dict: assert isinstance(deduplicate_ignore_columns, list), deduplicate_ignore_columns try: dicts = load_csv_file(csv_file) except FileNotFoundError as e: logger.info('Creating file: %s', csv_file) dicts = [] original_num_dicts = len(dicts) dicts += [normalize_dict(d) for d in new_dicts] del new_dicts dicts, fieldnames = deduplicate_dicts( dicts, deduplicate_mode, deduplicate_ignore_columns, ) csvfile_in_memory = io.StringIO() writer = csv.DictWriter( csvfile_in_memory, fieldnames=fieldnames, dialect=CSV_DIALECT, ) writer.writeheader() for d in dicts: writable_d = {k:csv_safe_value(v) for k,v in d.items()} writer.writerow(writable_d) del d, writable_d output_csv = csvfile_in_memory.getvalue() del writer, csvfile_in_memory csv_file.parent.mkdir(parents=True, exist_ok=True) with open(csv_file, 'w') as csvfile: csvfile.write(output_csv) del csvfile logger.info( 'Extended CSV "%s" from %d to %d lines', csv_file, original_num_dicts, len(dicts), ) return { 'extended': original_num_dicts != len(dicts), 'input_lines': original_num_dicts, 'output_lines': len(dicts), 'dicts': dicts, }