import dataclasses
import logging
from decimal import Decimal

import clients
import requests_cache
import secret_loader
from langchain_core.tools import tool

try:
    import pycountry.db
except ImportError:
    pycountry = None

try:
    import fin_defs
except ImportError:
    fin_defs = None

logger = logging.getLogger(__name__)


@tool
def search(query: str):
    """Call to surf the web."""
    if 'sf' in query.lower() or 'san francisco' in query.lower():
        return "It's 60 degrees and foggy."
    return "It's 90 degrees and sunny."


def dataclasses_to_json(data):
    if pycountry and isinstance(data, pycountry.db.Country):
        return data.alpha_2
    if fin_defs and isinstance(data, fin_defs.AssetAmount):
        return str(data)
    if fin_defs and isinstance(data, fin_defs.Asset):
        return data.raw_short_name()
    if isinstance(data, list | tuple):
        return [dataclasses_to_json(d) for d in data]
    if isinstance(data, dict):
        return {k: dataclasses_to_json(v) for k, v in data.items() if v}
    if isinstance(data, Decimal):
        return float(data)
    if dataclasses.is_dataclass(data):
        result = {}
        for field in dataclasses.fields(data):
            result[field.name] = getattr(data, field.name)
        return dataclasses_to_json(result)
    return data


assert dataclasses_to_json([1, 2, 3]) == [1, 2, 3]


RETURN_FORMATS = {
    'raw_python': lambda x: x,
    'json': dataclasses_to_json,
}

RETURN_FORMAT = 'json'

MAX_TOOL_RESULT_LEN = 1000
APPEND_RESULT_TYPE_DOCS = True

def wrap_method(class_, method):
    logger.info('Wrapping %s.%s', class_.__name__, method.__name__)
    return_type = method.__annotations__.get('return', '')
    is_iterator = str(return_type).startswith(
        'collections.abc.Iterator',
    )

    def wrapper(input_value):
        logger.info(
            'AI called %s.%s(%s)', class_.__name__, method.__name__, repr(input_value),
        )
        try:
            if isinstance(input_value, dict):
                logger.warning('Silently converting from dict to plain value!')
                input_value = next(input_value.values())
            result = method(input_value)
            if is_iterator:
                result = list(result)
            result_str: str = str(RETURN_FORMATS[RETURN_FORMAT](result))
            del result
        except:
            logger.exception(
                'AI invocation of %s.%s(%s) failed!',
                class_.__name__,
                method.__name__,
                repr(input_value),
            )
            raise
        if len(result_str) > MAX_TOOL_RESULT_LEN:
            result_str = result_str[:MAX_TOOL_RESULT_LEN] + ' (remaining tool result elicited...)'
        if APPEND_RESULT_TYPE_DOCS and (return_docs := getattr(return_type, '__doc__', None)):
            result_str = result_str+'\n'+return_docs
        return result_str

    wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
    wrapper.__doc__ = method.__doc__
    wrapper.__annotations__ = method.__annotations__
    return wrapper


def wrap_all_methods_on_client(obj):
    for field_name in dir(obj):
        if field_name.startswith('_'):
            continue
        method = getattr(obj, field_name)
        if not callable(method):
            continue
        if method.__doc__ is None:
            continue
        yield wrap_method(obj.__class__, method)


def get_tools():
    session = requests_cache.CachedSession('output/test.sqlite')
    session.headers['User-Agent'] = 'Test Test'

    secrets = secret_loader.SecretLoader()

    tools = []
    for client in clients.all_clients(session, secrets):
        tools.extend(wrap_all_methods_on_client(client))
    return tools