diff --git a/datagraph/schemeld.py b/datagraph/schemeld.py index daa9474..c51a6cb 100644 --- a/datagraph/schemeld.py +++ b/datagraph/schemeld.py @@ -1,14 +1,19 @@ import urllib.parse +from collections.abc import Iterator +from typing import Any STRICT_VALIDATION = True +Key = int | str | urllib.parse.ParseResult +Context = str # TODO -def canonical_keys(base_key, context): + +def canonical_keys(base_key: Key, context: Context | None) -> list[Any]: if isinstance(base_key, urllib.parse.ParseResult): return [base_key] if not isinstance(base_key, str): return [base_key] - elif base_key.startswith('@'): + if base_key.startswith('@'): return [base_key] if context is None: return [base_key] @@ -16,7 +21,7 @@ def canonical_keys(base_key, context): class Concept: - def __init__(self, context, pairs): + def __init__(self, context: Context | None, pairs: dict[Key, str]) -> None: self.pairs = [] for k, v in pairs.items(): keys = canonical_keys(k, context) @@ -25,10 +30,10 @@ class Concept: ) self.regenerate_by_keys() - def regenerate_by_keys(self): + def regenerate_by_keys(self) -> None: self.by_keys = {k: pair for pair in self.pairs for k in pair['keys']} - def __copy__(self): + def __copy__(self) -> 'Concept': new = Concept(None, {}) for p in self.pairs: new.pairs.append( @@ -41,41 +46,49 @@ class Concept: new.regenerate_by_keys() return new - def get(self, key, default=None): + def get(self, key: Key, default=None): pairs = self.by_keys.get(key, None) return pairs['values'] if pairs is not None else default - def getlist(self, key): + def getlist(self, key: Key) -> list[Any]: result = self.get(key) if result is None: return [] - assert isinstance(result, list), 'Not a list: ' + str(result) + if not isinstance(result, list): + msg = f'Not a list: {result}' + raise TypeError(msg) return [r['value'] for r in result] - def keys(self): + def keys(self) -> Iterator[Key]: for pair in self.pairs: yield pair['canonical_key'] - def setdefault(self, key, value): + def setdefault(self, key: Key, value): if key not in self.by_keys: self[key] = value return self.by_keys[key]['values'] - def to_dict(self): + def to_dict(self) -> dict[Key, Any]: return {p['canonical_key']: p['values'] for p in self.pairs} - def __getitem__(self, key): + def __getitem__(self, key: Key): return self.by_keys[key]['values'] - def __setitem__(self, key, value): + def __setitem__(self, key: Key, value) -> None: if STRICT_VALIDATION: if not isinstance(key, str) or key != '@id': - assert isinstance(value, list), value + if not isinstance(value, list): + raise TypeError(value) for v in value: - assert isinstance(v, dict), value - assert 'value' in v, value + if not isinstance(value, dict): + raise TypeError(value) + if 'value' not in v: + raise TypeError(value) for subk in v: - assert not isinstance(v[subk], list), value + if isinstance(v[subk], list): + raise TypeError(value) + del subk + del v if key in self.by_keys: self.by_keys[key]['values'] = value @@ -84,23 +97,23 @@ class Concept: self.pairs.append(pair) self.by_keys[key] = pair - def __contains__(self, key): + def __contains__(self, key: Key) -> bool: return key in self.by_keys - def __delitem__(self, key): + def __delitem__(self, key: Key) -> None: self.pairs.remove(self.by_keys[key]) del self.by_keys[key] - def __repr__(self): - if id := self.by_keys.get('@id'): - return 'Concept {{ @id = {} }}'.format(id['values']) + def __repr__(self) -> str: + if object_id := self.by_keys.get('@id'): + return 'Concept {{ @id = {} }}'.format(object_id['values']) return 'Concept ' + str({p['canonical_key']: p['values'] for p in self.pairs}) - def __str__(self): + def __str__(self) -> str: return repr(self) - def set_canonical_key(self, new_canonical_key, key=None): + def set_canonical_key(self, new_canonical_key: Key, key: Key | None = None): if key is None: key = new_canonical_key self.by_keys[key]['canonical_key'] = new_canonical_key diff --git a/datagraph/wikidata_ext.py b/datagraph/wikidata_ext.py index 213a11c..865afe4 100644 --- a/datagraph/wikidata_ext.py +++ b/datagraph/wikidata_ext.py @@ -20,16 +20,20 @@ def concept_uri(obj: wikidata.entity.Entity) -> urllib.parse.ParseResult: raise ValueError(msg) -def fmt_triple_value(obj: Any, prefer_obj=False) -> str: +def format_value_for_triple_request(obj: Any, prefer_obj=False) -> str: if obj is None: return '' if isinstance(obj, str): return f'"{obj}"' if isinstance(obj, urllib.parse.ParseResult): - return obj.geturl() if prefer_obj else fmt_triple_value(obj.geturl()) + return ( + obj.geturl() + if prefer_obj + else format_value_for_triple_request(obj.geturl()) + ) if isinstance(obj, wikidata.entity.Entity): uri = concept_uri(obj) - return fmt_triple_value(uri, True) + return format_value_for_triple_request(uri, True) msg = f'Type cannot be formatted: {type(obj)}' raise TypeError(msg) @@ -42,7 +46,7 @@ def fetch_by_url(url: str, headers: dict[str, str]): msg = 'REQUEST_SESSION must be set, before calling fetch_by_url' raise RuntimeError(msg) response = REQUEST_SESSION.get(url, headers=headers) - if response.status_code != 200: + if not response.status_code.ok: logging.error('Got %s error message: %s', response.status_code, response.text) return None return response @@ -58,9 +62,9 @@ def fmt_params(subject: Any, predicate: Any, object_: Any) -> dict[str, str | in msg = 'There are no entities for this query!' raise RuntimeError(msg) return { - 'subject': fmt_triple_value(subject, prefer_obj=True), - 'predicate': fmt_triple_value(predicate, prefer_obj=True), - 'object': fmt_triple_value(object_, prefer_obj=True), + 'subject': format_value_for_triple_request(subject, prefer_obj=True), + 'predicate': format_value_for_triple_request(predicate, prefer_obj=True), + 'object': format_value_for_triple_request(object_, prefer_obj=True), 'page': 1, }