Added json return_format
This commit is contained in:
parent
d030a98ba2
commit
20688ea3b4
|
@ -32,9 +32,15 @@ def create_model():
|
||||||
llm.bind_tools(tools=available_tools )
|
llm.bind_tools(tools=available_tools )
|
||||||
return create_react_agent(llm, tools=available_tools )
|
return create_react_agent(llm, tools=available_tools )
|
||||||
|
|
||||||
|
SYSTEM_MESSAGE = '''
|
||||||
|
You are a useful assistant with access to built in system tools.
|
||||||
|
Format responses as markdown.
|
||||||
|
Provide links when available.
|
||||||
|
'''
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
logging.basicConfig(level='INFO')
|
logging.basicConfig(level='INFO')
|
||||||
messages = [SystemMessage("You are a useful assistant with access to built in system tools.")]
|
messages = [SystemMessage(SYSTEM_MESSAGE)]
|
||||||
llm = create_model()
|
llm = create_model()
|
||||||
prev_idx = 0
|
prev_idx = 0
|
||||||
while True:
|
while True:
|
||||||
|
@ -49,7 +55,7 @@ def main():
|
||||||
})
|
})
|
||||||
messages = result['messages']
|
messages = result['messages']
|
||||||
for msg in messages[prev_idx:]:
|
for msg in messages[prev_idx:]:
|
||||||
print(f'{msg.type}: {msg.content}')
|
print(msg.pretty_repr())
|
||||||
del msg
|
del msg
|
||||||
prev_idx = len(messages)
|
prev_idx = len(messages)
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,18 @@
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
import requests_cache
|
import requests_cache
|
||||||
|
import dataclasses
|
||||||
import clients
|
import clients
|
||||||
from typing import get_type_hints, List, Dict, Any
|
from typing import get_type_hints, List, Dict, Any
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
import logging
|
import logging
|
||||||
import secret_loader
|
import secret_loader
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pycountry.db
|
||||||
|
except ImportError:
|
||||||
|
pycountry = None
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +23,29 @@ def search(query: str):
|
||||||
return "It's 60 degrees and foggy."
|
return "It's 60 degrees and foggy."
|
||||||
return "It's 90 degrees and sunny."
|
return "It's 90 degrees and sunny."
|
||||||
|
|
||||||
|
def dataclasses_to_json(data):
|
||||||
|
if pycountry and isinstance(data, pycountry.db.Country):
|
||||||
|
return data.alpha_3
|
||||||
|
if isinstance(data, list | tuple):
|
||||||
|
return [dataclasses_to_json(d) for d in data]
|
||||||
|
if dataclasses.is_dataclass(data):
|
||||||
|
result = {}
|
||||||
|
for field in dataclasses.fields(data):
|
||||||
|
result[field.name] = dataclasses_to_json(getattr(data,field.name))
|
||||||
|
return result
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
assert dataclasses_to_json([1,2,3]) == [1,2,3]
|
||||||
|
|
||||||
|
|
||||||
|
RETURN_FORMATS = {
|
||||||
|
'raw_python': lambda x: x,
|
||||||
|
'json': dataclasses_to_json,
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_FORMAT = 'json'
|
||||||
|
|
||||||
def wrap_method(class_, method):
|
def wrap_method(class_, method):
|
||||||
logger.info("Wrapping %s.%s", class_.__name__, method.__name__)
|
logger.info("Wrapping %s.%s", class_.__name__, method.__name__)
|
||||||
is_iterator = str(method.__annotations__.get('return', '')).startswith('collections.abc.Iterator')
|
is_iterator = str(method.__annotations__.get('return', '')).startswith('collections.abc.Iterator')
|
||||||
|
@ -26,10 +55,14 @@ def wrap_method(class_, method):
|
||||||
logger.warning("Silently converting from dict to plain value!")
|
logger.warning("Silently converting from dict to plain value!")
|
||||||
input_value = next(input_value.values())
|
input_value = next(input_value.values())
|
||||||
logger.info("AI called %s.%s(%s)", class_.__name__, method.__name__, repr(input_value))
|
logger.info("AI called %s.%s(%s)", class_.__name__, method.__name__, repr(input_value))
|
||||||
|
try:
|
||||||
result = method(input_value)
|
result = method(input_value)
|
||||||
if is_iterator:
|
if is_iterator:
|
||||||
result = list(result)
|
result = list(result)
|
||||||
return result
|
return RETURN_FORMATS[RETURN_FORMAT](result)
|
||||||
|
except:
|
||||||
|
logger.exception("AI invocation of %s.%s(%s) failed!", class_.__name__, method.__name__, repr(input_value))
|
||||||
|
raise
|
||||||
|
|
||||||
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
|
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
|
||||||
wrapper.__doc__ = method.__doc__
|
wrapper.__doc__ = method.__doc__
|
||||||
|
|
Loading…
Reference in New Issue
Block a user