127 lines
3.6 KiB
Python
127 lines
3.6 KiB
Python
import dataclasses
|
|
import logging
|
|
from decimal import Decimal
|
|
|
|
import clients
|
|
import requests_cache
|
|
import secret_loader
|
|
from langchain_core.tools import tool
|
|
|
|
try:
|
|
import pycountry.db
|
|
except ImportError:
|
|
pycountry = None
|
|
|
|
try:
|
|
import fin_defs
|
|
except ImportError:
|
|
fin_defs = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@tool
|
|
def search(query: str):
|
|
"""Call to surf the web."""
|
|
if 'sf' in query.lower() or 'san francisco' in query.lower():
|
|
return "It's 60 degrees and foggy."
|
|
return "It's 90 degrees and sunny."
|
|
|
|
|
|
def dataclasses_to_json(data):
|
|
if pycountry and isinstance(data, pycountry.db.Country):
|
|
return data.alpha_2
|
|
if fin_defs and isinstance(data, fin_defs.AssetAmount):
|
|
return str(data)
|
|
if fin_defs and isinstance(data, fin_defs.Asset):
|
|
return data.raw_short_name()
|
|
if isinstance(data, list | tuple):
|
|
return [dataclasses_to_json(d) for d in data]
|
|
if isinstance(data, dict):
|
|
return {k: dataclasses_to_json(v) for k, v in data.items() if v}
|
|
if isinstance(data, Decimal):
|
|
return float(data)
|
|
if dataclasses.is_dataclass(data):
|
|
result = {}
|
|
for field in dataclasses.fields(data):
|
|
result[field.name] = getattr(data, field.name)
|
|
return dataclasses_to_json(result)
|
|
return data
|
|
|
|
|
|
assert dataclasses_to_json([1, 2, 3]) == [1, 2, 3]
|
|
|
|
|
|
RETURN_FORMATS = {
|
|
'raw_python': lambda x: x,
|
|
'json': dataclasses_to_json,
|
|
}
|
|
|
|
RETURN_FORMAT = 'json'
|
|
|
|
MAX_TOOL_RESULT_LEN = 1000
|
|
APPEND_RESULT_TYPE_DOCS = True
|
|
|
|
def wrap_method(class_, method):
|
|
logger.info('Wrapping %s.%s', class_.__name__, method.__name__)
|
|
return_type = method.__annotations__.get('return', '')
|
|
is_iterator = str(return_type).startswith(
|
|
'collections.abc.Iterator',
|
|
)
|
|
|
|
def wrapper(input_value):
|
|
logger.info(
|
|
'AI called %s.%s(%s)', class_.__name__, method.__name__, repr(input_value),
|
|
)
|
|
try:
|
|
if isinstance(input_value, dict):
|
|
logger.warning('Silently converting from dict to plain value!')
|
|
input_value = next(input_value.values())
|
|
result = method(input_value)
|
|
if is_iterator:
|
|
result = list(result)
|
|
result_str: str = str(RETURN_FORMATS[RETURN_FORMAT](result))
|
|
del result
|
|
except:
|
|
logger.exception(
|
|
'AI invocation of %s.%s(%s) failed!',
|
|
class_.__name__,
|
|
method.__name__,
|
|
repr(input_value),
|
|
)
|
|
raise
|
|
if len(result_str) > MAX_TOOL_RESULT_LEN:
|
|
result_str = result_str[:MAX_TOOL_RESULT_LEN] + ' (remaining tool result elicited...)'
|
|
if APPEND_RESULT_TYPE_DOCS and (return_docs := getattr(return_type, '__doc__', None)):
|
|
result_str = result_str+'\n'+return_docs
|
|
return result_str
|
|
|
|
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
|
|
wrapper.__doc__ = method.__doc__
|
|
wrapper.__annotations__ = method.__annotations__
|
|
return wrapper
|
|
|
|
|
|
def wrap_all_methods_on_client(obj):
|
|
for field_name in dir(obj):
|
|
if field_name.startswith('_'):
|
|
continue
|
|
method = getattr(obj, field_name)
|
|
if not callable(method):
|
|
continue
|
|
if method.__doc__ is None:
|
|
continue
|
|
yield wrap_method(obj.__class__, method)
|
|
|
|
|
|
def get_tools():
|
|
session = requests_cache.CachedSession('output/test.sqlite')
|
|
session.headers['User-Agent'] = 'Test Test'
|
|
|
|
secrets = secret_loader.SecretLoader()
|
|
|
|
tools = []
|
|
for client in clients.all_clients(session, secrets):
|
|
tools.extend(wrap_all_methods_on_client(client))
|
|
return tools
|