import _csv import csv import datetime import io import logging import urllib.parse from collections.abc import Iterable, Mapping from pathlib import Path from typing import Any from frozendict import frozendict from . import csv_import, data logger = logging.getLogger(__name__) def csv_safe_value(v: Any) -> str: if isinstance(v, urllib.parse.ParseResult): return v.geturl() if isinstance(v, datetime.datetime): if v.tzinfo is None or v.tzinfo != datetime.UTC: msg = f'Timezone must be UTC: {v}' raise ValueError(msg) return str(v) def equals_without_fields( a: Mapping[str, Any], b: Mapping[str, Any], fields: Iterable[str] = frozenset(), ) -> bool: a = dict(a) b = dict(b) for f in fields: del a[f], b[f] return frozendict(a) == frozendict(b) def deduplicate_by_ignoring_certain_fields( dicts: list[frozendict[str, Any]], deduplicate_ignore_columns: Iterable[str], ) -> list[frozendict[str, Any]]: """Removes duplicates that occur when ignoring certain columns. Output order is stable. """ to_remove = set() for idx1, first in enumerate(dicts): for idx2, second in enumerate(dicts[idx1 + 1 :], idx1 + 1): if equals_without_fields(first, second, deduplicate_ignore_columns): to_remove.add(idx2) del idx2, second del idx1, first to_remove_ls = sorted(to_remove) del to_remove while to_remove_ls: del dicts[to_remove_ls.pop()] return dicts def deduplicate_dicts( dicts: list[frozendict[str, Any]], deduplicate_mode: data.DeduplicateMode, deduplicate_ignore_columns: list[str], ) -> tuple[list[frozendict[str, Any]], list[str]]: if not isinstance(deduplicate_ignore_columns, list): raise TypeError(deduplicate_ignore_columns) fieldnames = [] for d in dicts: for k in d: if k not in fieldnames: fieldnames.append(k) del k del d if deduplicate_mode == data.DeduplicateMode.ONLY_LATEST: while len(dicts) > 1 and equals_without_fields( dicts[-1], dicts[-2], deduplicate_ignore_columns, ): del dicts[-1] elif deduplicate_mode == data.DeduplicateMode.BY_ALL_COLUMNS: dicts = deduplicate_by_ignoring_certain_fields( dicts, deduplicate_ignore_columns, ) elif deduplicate_mode != data.DeduplicateMode.NONE: dicts = list(set(dicts)) dicts = sorted(dicts, key=lambda d: tuple(str(d.get(fn, '')) for fn in fieldnames)) return dicts, fieldnames def normalize_dict(d: dict[str, Any] | frozendict[str, Any]) -> frozendict[str, Any]: return frozendict( { k: csv_import.csv_str_to_value(str(v)) for k, v in d.items() if csv_import.csv_str_to_value(str(v)) is not None }, ) def extend_csv_file( csv_file: Path, new_dicts: list[dict[str, Any] | frozendict[str, Any]], deduplicate_mode: data.DeduplicateMode, deduplicate_ignore_columns: list[str], ) -> dict: if not isinstance(deduplicate_ignore_columns, list): raise TypeError(deduplicate_ignore_columns) try: original_dicts = csv_import.load_csv_file(csv_file) except (FileNotFoundError, _csv.Error): logger.info('Creating file: %s', csv_file) original_dicts = [] original_num_dicts = len(original_dicts) dicts = [normalize_dict(d) for d in original_dicts] + [ normalize_dict(d) for d in new_dicts ] del new_dicts dicts, fieldnames = deduplicate_dicts( dicts, deduplicate_mode, deduplicate_ignore_columns, ) csvfile_in_memory = io.StringIO() writer = csv.DictWriter( csvfile_in_memory, fieldnames=fieldnames, dialect=csv_import.CSV_DIALECT, ) writer.writeheader() for d in dicts: writable_d = {k: csv_safe_value(v) for k, v in d.items()} writer.writerow(writable_d) del d, writable_d output_csv = csvfile_in_memory.getvalue() del writer, csvfile_in_memory csv_file.parent.mkdir(parents=True, exist_ok=True) with open(csv_file, 'w') as csvfile: csvfile.write(output_csv) del csvfile logger.info( 'Extended CSV "%s" from %d to %d lines', csv_file, original_num_dicts, len(dicts), ) return { 'extended': original_num_dicts != len(dicts), 'input_lines': original_num_dicts, 'output_lines': len(dicts), 'dicts': dicts, }