Added json return_format
This commit is contained in:
parent
d030a98ba2
commit
20688ea3b4
|
@ -32,9 +32,15 @@ def create_model():
|
|||
llm.bind_tools(tools=available_tools )
|
||||
return create_react_agent(llm, tools=available_tools )
|
||||
|
||||
SYSTEM_MESSAGE = '''
|
||||
You are a useful assistant with access to built in system tools.
|
||||
Format responses as markdown.
|
||||
Provide links when available.
|
||||
'''
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level='INFO')
|
||||
messages = [SystemMessage("You are a useful assistant with access to built in system tools.")]
|
||||
messages = [SystemMessage(SYSTEM_MESSAGE)]
|
||||
llm = create_model()
|
||||
prev_idx = 0
|
||||
while True:
|
||||
|
@ -49,7 +55,7 @@ def main():
|
|||
})
|
||||
messages = result['messages']
|
||||
for msg in messages[prev_idx:]:
|
||||
print(f'{msg.type}: {msg.content}')
|
||||
print(msg.pretty_repr())
|
||||
del msg
|
||||
prev_idx = len(messages)
|
||||
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
from langchain_core.tools import tool
|
||||
|
||||
import requests_cache
|
||||
import dataclasses
|
||||
import clients
|
||||
from typing import get_type_hints, List, Dict, Any
|
||||
from collections.abc import Iterator
|
||||
import logging
|
||||
import secret_loader
|
||||
|
||||
try:
|
||||
import pycountry.db
|
||||
except ImportError:
|
||||
pycountry = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -17,6 +23,29 @@ def search(query: str):
|
|||
return "It's 60 degrees and foggy."
|
||||
return "It's 90 degrees and sunny."
|
||||
|
||||
def dataclasses_to_json(data):
|
||||
if pycountry and isinstance(data, pycountry.db.Country):
|
||||
return data.alpha_3
|
||||
if isinstance(data, list | tuple):
|
||||
return [dataclasses_to_json(d) for d in data]
|
||||
if dataclasses.is_dataclass(data):
|
||||
result = {}
|
||||
for field in dataclasses.fields(data):
|
||||
result[field.name] = dataclasses_to_json(getattr(data,field.name))
|
||||
return result
|
||||
return data
|
||||
|
||||
|
||||
assert dataclasses_to_json([1,2,3]) == [1,2,3]
|
||||
|
||||
|
||||
RETURN_FORMATS = {
|
||||
'raw_python': lambda x: x,
|
||||
'json': dataclasses_to_json,
|
||||
}
|
||||
|
||||
RETURN_FORMAT = 'json'
|
||||
|
||||
def wrap_method(class_, method):
|
||||
logger.info("Wrapping %s.%s", class_.__name__, method.__name__)
|
||||
is_iterator = str(method.__annotations__.get('return', '')).startswith('collections.abc.Iterator')
|
||||
|
@ -26,10 +55,14 @@ def wrap_method(class_, method):
|
|||
logger.warning("Silently converting from dict to plain value!")
|
||||
input_value = next(input_value.values())
|
||||
logger.info("AI called %s.%s(%s)", class_.__name__, method.__name__, repr(input_value))
|
||||
result = method(input_value)
|
||||
if is_iterator:
|
||||
result = list(result)
|
||||
return result
|
||||
try:
|
||||
result = method(input_value)
|
||||
if is_iterator:
|
||||
result = list(result)
|
||||
return RETURN_FORMATS[RETURN_FORMAT](result)
|
||||
except:
|
||||
logger.exception("AI invocation of %s.%s(%s) failed!", class_.__name__, method.__name__, repr(input_value))
|
||||
raise
|
||||
|
||||
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
|
||||
wrapper.__doc__ = method.__doc__
|
||||
|
|
Loading…
Reference in New Issue
Block a user