test-langgraph/test_langgraph/tools.py
Jon Michael Aanes 6ea0e48cc3
Some checks failed
Run Python tests (through Pytest) / Test (push) Failing after 23s
Verify Python project can be installed, loaded and have version checked / Test (push) Failing after 21s
Trying some stuff, including a new model
2025-05-11 02:21:09 +02:00

98 lines
2.9 KiB
Python

from langchain_core.tools import tool
import requests_cache
import dataclasses
import clients
from typing import get_type_hints, List, Dict, Any
from collections.abc import Iterator
import logging
import secret_loader
from decimal import Decimal
try:
import pycountry.db
except ImportError:
pycountry = None
logger = logging.getLogger(__name__)
@tool
def search(query: str):
"""Call to surf the web."""
if "sf" in query.lower() or "san francisco" in query.lower():
return "It's 60 degrees and foggy."
return "It's 90 degrees and sunny."
def dataclasses_to_json(data):
if pycountry and isinstance(data, pycountry.db.Country):
return data.alpha_2
if isinstance(data, list | tuple):
return [dataclasses_to_json(d) for d in data]
if isinstance(data, dict):
return {k:dataclasses_to_json(v) for k,v in data.items() if v}
if isinstance(data, Decimal):
return float(data)
if dataclasses.is_dataclass(data):
result = {}
for field in dataclasses.fields(data):
result[field.name] = getattr(data,field.name)
return dataclasses_to_json(result)
return data
assert dataclasses_to_json([1,2,3]) == [1,2,3]
RETURN_FORMATS = {
'raw_python': lambda x: x,
'json': dataclasses_to_json,
}
RETURN_FORMAT = 'json'
def wrap_method(class_, method):
logger.info("Wrapping %s.%s", class_.__name__, method.__name__)
is_iterator = str(method.__annotations__.get('return', '')).startswith('collections.abc.Iterator')
def wrapper(input_value):
if isinstance(input_value, dict):
logger.warning("Silently converting from dict to plain value!")
input_value = next(input_value.values())
logger.info("AI called %s.%s(%s)", class_.__name__, method.__name__, repr(input_value))
try:
result = method(input_value)
if is_iterator:
result = list(result)
return RETURN_FORMATS[RETURN_FORMAT](result)
except:
logger.exception("AI invocation of %s.%s(%s) failed!", class_.__name__, method.__name__, repr(input_value))
raise
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
wrapper.__doc__ = method.__doc__
wrapper.__annotations__ = method.__annotations__
return tool(wrapper)
def wrap_all_methods_on_client(obj):
for field_name in dir(obj):
if field_name.startswith('_'):
continue
method = getattr(obj, field_name)
if not callable(method):
continue
if method.__doc__ is None:
continue
yield wrap_method(obj.__class__, method)
def get_tools():
session = requests_cache.CachedSession('output/test.sqlite')
session.headers['User-Agent'] = 'Test Test'
secrets = secret_loader.SecretLoader()
tools = []
for client in clients.all_clients(session, secrets):
tools.extend(wrap_all_methods_on_client(client))
return tools