import abc import dataclasses import datetime from collections.abc import Iterable, Mapping from decimal import Decimal import enforce_typing from fin_defs import Asset @enforce_typing.enforce_types @dataclasses.dataclass class Depo(abc.ABC): name: str updated_time: datetime.datetime @abc.abstractmethod def assets(self) -> Iterable[Asset]: """Returns the different assets managed by this depo.""" @abc.abstractmethod def get_amount_of_asset(self, asset: Asset) -> Decimal: """Returns the amount of owned assets for all nested depos. Must return 0 if depo does not contain the given asset. """ @enforce_typing.enforce_types @dataclasses.dataclass class DepoSingle(Depo): _assets: Mapping[Asset, Decimal] def assets(self) -> Iterable[Asset]: return self._assets def get_amount_of_asset(self, asset: Asset) -> Decimal: return self._assets.get(asset, Decimal(0)) @enforce_typing.enforce_types @dataclasses.dataclass class DepoGroup(Depo): nested: list[Depo] def assets(self) -> Iterable[Asset]: assets: list[Asset] = [] for nested_depo in self.nested: assets.extend(nested_depo.assets()) return assets def get_amount_of_asset(self, asset: Asset) -> Decimal: summed: Decimal = Decimal(0) for nested_depo in self.nested: summed += nested_depo.get_amount_of_asset(asset) del nested_depo return summed