Testing the use of langgraph for agentic behaviour
This commit is contained in:
commit
50d04ea850
50
main.py
Normal file
50
main.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain_ollama import ChatOllama
|
||||
from typing import Annotated
|
||||
from langgraph.prebuilt import ToolNode, tools_condition
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import tools
|
||||
|
||||
#MODEL = "gemma3:27b"
|
||||
MODEL = "qwen3:latest"
|
||||
|
||||
def create_model():
|
||||
available_tools = tools.get_tools()
|
||||
logger.info("Available tools:")
|
||||
for tool in available_tools:
|
||||
logger.info("- %s", tool.name)
|
||||
|
||||
llm = ChatOllama(model=MODEL)
|
||||
llm.bind_tools(tools=available_tools )
|
||||
return create_react_agent(llm, tools=available_tools )
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level='INFO')
|
||||
messages = [SystemMessage("You are a useful assistant with access to built in system tools.")]
|
||||
llm = create_model()
|
||||
prev_idx = 0
|
||||
while True:
|
||||
|
||||
user_input = input("User: ")
|
||||
messages.append(HumanMessage(user_input))
|
||||
|
||||
result = llm.invoke({
|
||||
'messages': messages,
|
||||
})
|
||||
for msg in result['messages']:
|
||||
print(f'{msg.type}: {msg.content}')
|
||||
messages.append(msg)
|
||||
del msg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
55
tools.py
Normal file
55
tools.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
from langchain_core.tools import tool
|
||||
|
||||
import requests_cache
|
||||
import clients
|
||||
from typing import get_type_hints, List, Dict, Any
|
||||
from collections.abc import Iterator
|
||||
import logging
|
||||
import secret_loader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool
|
||||
def search(query: str):
|
||||
"""Call to surf the web."""
|
||||
if "sf" in query.lower() or "san francisco" in query.lower():
|
||||
return "It's 60 degrees and foggy."
|
||||
return "It's 90 degrees and sunny."
|
||||
|
||||
def wrap(class_, method):
|
||||
logger.debug("Wrapping %s.%s", class_.__name__, method.__name__)
|
||||
is_iterator = str(method.__annotations__.get('return', '')).startswith('collections.abc.Iterator')
|
||||
def wrapper(input):
|
||||
logger.info("AI called %s.%s", class_.__name__, method.__name__)
|
||||
result = method(input)
|
||||
if is_iterator:
|
||||
result = list(result)
|
||||
return result
|
||||
|
||||
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
|
||||
wrapper.__doc__ = method.__doc__
|
||||
wrapper.__annotations__ = method.__annotations__
|
||||
return tool(wrapper)
|
||||
|
||||
def wrap_all_methods_on_client(obj):
|
||||
for field_name in dir(obj):
|
||||
if field_name.startswith('_'):
|
||||
continue
|
||||
method = getattr(obj, field_name)
|
||||
if not callable(method):
|
||||
continue
|
||||
if method.__doc__ is None:
|
||||
continue
|
||||
yield wrap(obj.__class__, method)
|
||||
|
||||
def get_tools():
|
||||
session = requests_cache.CachedSession('output/test.sqlite')
|
||||
session.headers['User-Agent'] = 'Test Test'
|
||||
|
||||
secrets = secret_loader.SecretLoader()
|
||||
|
||||
tools = []
|
||||
for client in clients.all_clients(session, secrets):
|
||||
tools.extend(wrap_all_methods_on_client(client))
|
||||
return tools
|
Loading…
Reference in New Issue
Block a user