diff --git a/test_langgraph/__main__.py b/test_langgraph/__main__.py index 504ae04..2f1d7f2 100644 --- a/test_langgraph/__main__.py +++ b/test_langgraph/__main__.py @@ -32,9 +32,15 @@ def create_model(): llm.bind_tools(tools=available_tools ) return create_react_agent(llm, tools=available_tools ) +SYSTEM_MESSAGE = ''' +You are a useful assistant with access to built in system tools. +Format responses as markdown. +Provide links when available. +''' + def main(): logging.basicConfig(level='INFO') - messages = [SystemMessage("You are a useful assistant with access to built in system tools.")] + messages = [SystemMessage(SYSTEM_MESSAGE)] llm = create_model() prev_idx = 0 while True: @@ -49,7 +55,7 @@ def main(): }) messages = result['messages'] for msg in messages[prev_idx:]: - print(f'{msg.type}: {msg.content}') + print(msg.pretty_repr()) del msg prev_idx = len(messages) diff --git a/test_langgraph/tools.py b/test_langgraph/tools.py index a0837de..9c22aa1 100644 --- a/test_langgraph/tools.py +++ b/test_langgraph/tools.py @@ -1,12 +1,18 @@ from langchain_core.tools import tool import requests_cache +import dataclasses import clients from typing import get_type_hints, List, Dict, Any from collections.abc import Iterator import logging import secret_loader +try: + import pycountry.db +except ImportError: + pycountry = None + logger = logging.getLogger(__name__) @@ -17,6 +23,29 @@ def search(query: str): return "It's 60 degrees and foggy." return "It's 90 degrees and sunny." +def dataclasses_to_json(data): + if pycountry and isinstance(data, pycountry.db.Country): + return data.alpha_3 + if isinstance(data, list | tuple): + return [dataclasses_to_json(d) for d in data] + if dataclasses.is_dataclass(data): + result = {} + for field in dataclasses.fields(data): + result[field.name] = dataclasses_to_json(getattr(data,field.name)) + return result + return data + + +assert dataclasses_to_json([1,2,3]) == [1,2,3] + + +RETURN_FORMATS = { + 'raw_python': lambda x: x, + 'json': dataclasses_to_json, +} + +RETURN_FORMAT = 'json' + def wrap_method(class_, method): logger.info("Wrapping %s.%s", class_.__name__, method.__name__) is_iterator = str(method.__annotations__.get('return', '')).startswith('collections.abc.Iterator') @@ -26,10 +55,14 @@ def wrap_method(class_, method): logger.warning("Silently converting from dict to plain value!") input_value = next(input_value.values()) logger.info("AI called %s.%s(%s)", class_.__name__, method.__name__, repr(input_value)) - result = method(input_value) - if is_iterator: - result = list(result) - return result + try: + result = method(input_value) + if is_iterator: + result = list(result) + return RETURN_FORMATS[RETURN_FORMAT](result) + except: + logger.exception("AI invocation of %s.%s(%s) failed!", class_.__name__, method.__name__, repr(input_value)) + raise wrapper.__name__ = f'{class_.__name__}.{method.__name__}' wrapper.__doc__ = method.__doc__