import dataclasses import logging from decimal import Decimal import typing import clients import requests_cache import secret_loader from langchain_core.tools import tool try: import pycountry.db except ImportError: pycountry = None try: import fin_defs except ImportError: fin_defs = None logger = logging.getLogger(__name__) @tool def search(query: str): """Call to surf the web.""" if 'sf' in query.lower() or 'san francisco' in query.lower(): return "It's 60 degrees and foggy." return "It's 90 degrees and sunny." def dataclasses_to_json(data): if pycountry and isinstance(data, pycountry.db.Country): return data.alpha_2 if fin_defs and isinstance(data, fin_defs.AssetAmount): return str(data) if fin_defs and isinstance(data, fin_defs.Asset): return data.raw_short_name() if isinstance(data, list | tuple): return [dataclasses_to_json(d) for d in data] if isinstance(data, dict): return {k: dataclasses_to_json(v) for k, v in data.items() if v} if isinstance(data, Decimal): return float(data) if dataclasses.is_dataclass(data): result = {} for field in dataclasses.fields(data): result[field.name] = getattr(data, field.name) return dataclasses_to_json(result) return data assert dataclasses_to_json([1, 2, 3]) == [1, 2, 3] RETURN_FORMATS = { 'raw_python': lambda x: x, 'json': dataclasses_to_json, } RETURN_FORMAT = 'json' MAX_TOOL_RESULT_LEN = 1000 APPEND_RESULT_TYPE_DOCS = True def wrap_method(class_, method): logger.info('Wrapping %s.%s', class_.__name__, method.__name__) return_type = method.__annotations__.get('return', '') is_iterator = str(return_type).startswith( 'collections.abc.Iterator', ) def wrapper(input_value): logger.info( 'AI called %s.%s(%s)', class_.__name__, method.__name__, repr(input_value), ) try: if isinstance(input_value, dict): logger.warning('Silently converting from dict to plain value!') input_value = next(input_value.values()) result = method(input_value) if is_iterator: result = list(result) result_str: str = str(RETURN_FORMATS[RETURN_FORMAT](result)) del result except: logger.exception( 'AI invocation of %s.%s(%s) failed!', class_.__name__, method.__name__, repr(input_value), ) raise if len(result_str) > MAX_TOOL_RESULT_LEN: result_str = result_str[:MAX_TOOL_RESULT_LEN] + ' (remaining tool result elicited...)' if APPEND_RESULT_TYPE_DOCS and (return_docs := getattr(return_type, '__doc__', None)): result_str = result_str+'\n'+return_docs return result_str wrapper.__name__ = f'{class_.__name__}.{method.__name__}' wrapper.__doc__ = method.__doc__ wrapper.__annotations__ = method.__annotations__ return wrapper def wrap_all_methods_on_client(obj): for field_name in dir(obj): if field_name.startswith('_'): continue method = getattr(obj, field_name) if not callable(method): continue if not hasattr(method, '__name__'): continue if method.__doc__ is None: continue yield wrap_method(obj.__class__, method) def get_tools(): session = requests_cache.CachedSession('output/test.sqlite') session.headers['User-Agent'] = 'Test Test' secrets = secret_loader.SecretLoader() tools = [] for client in clients.all_clients(session, secrets): tools.extend(wrap_all_methods_on_client(client)) return tools