56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
from langchain_core.tools import tool
|
|
|
|
import requests_cache
|
|
import clients
|
|
from typing import get_type_hints, List, Dict, Any
|
|
from collections.abc import Iterator
|
|
import logging
|
|
import secret_loader
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@tool
|
|
def search(query: str):
|
|
"""Call to surf the web."""
|
|
if "sf" in query.lower() or "san francisco" in query.lower():
|
|
return "It's 60 degrees and foggy."
|
|
return "It's 90 degrees and sunny."
|
|
|
|
def wrap(class_, method):
|
|
logger.debug("Wrapping %s.%s", class_.__name__, method.__name__)
|
|
is_iterator = str(method.__annotations__.get('return', '')).startswith('collections.abc.Iterator')
|
|
def wrapper(input):
|
|
logger.info("AI called %s.%s", class_.__name__, method.__name__)
|
|
result = method(input)
|
|
if is_iterator:
|
|
result = list(result)
|
|
return result
|
|
|
|
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
|
|
wrapper.__doc__ = method.__doc__
|
|
wrapper.__annotations__ = method.__annotations__
|
|
return tool(wrapper)
|
|
|
|
def wrap_all_methods_on_client(obj):
|
|
for field_name in dir(obj):
|
|
if field_name.startswith('_'):
|
|
continue
|
|
method = getattr(obj, field_name)
|
|
if not callable(method):
|
|
continue
|
|
if method.__doc__ is None:
|
|
continue
|
|
yield wrap(obj.__class__, method)
|
|
|
|
def get_tools():
|
|
session = requests_cache.CachedSession('output/test.sqlite')
|
|
session.headers['User-Agent'] = 'Test Test'
|
|
|
|
secrets = secret_loader.SecretLoader()
|
|
|
|
tools = []
|
|
for client in clients.all_clients(session, secrets):
|
|
tools.extend(wrap_all_methods_on_client(client))
|
|
return tools
|