Ruff
This commit is contained in:
parent
6ea0e48cc3
commit
e42bb3df0a
2
setup.py
2
setup.py
|
@ -14,6 +14,7 @@ PACKAGE_DESCRIPTION = """
|
|||
PACKAGE_DESCRIPTION_SHORT = """
|
||||
""".strip()
|
||||
|
||||
|
||||
def parse_version_file(text: str) -> str:
|
||||
match = re.match(r'^__version__\s*=\s*(["\'])([\d\.]+)\1$', text)
|
||||
if match is None:
|
||||
|
@ -21,6 +22,7 @@ def parse_version_file(text: str) -> str:
|
|||
raise Exception(msg)
|
||||
return match.group(2)
|
||||
|
||||
|
||||
with open(PACKAGE_NAME + '/_version.py') as f:
|
||||
version = parse_version_file(f.read())
|
||||
|
||||
|
|
|
@ -1,17 +1,11 @@
|
|||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain_ollama import ChatOllama
|
||||
from typing import Annotated
|
||||
from langgraph.prebuilt import ToolNode, tools_condition
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
import logging
|
||||
|
||||
import prompt_toolkit
|
||||
import prompt_toolkit.history
|
||||
import prompt_toolkit.auto_suggest
|
||||
import prompt_toolkit.history
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_ollama import ChatOllama
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -23,21 +17,24 @@ cli_history = prompt_toolkit.history.FileHistory('output/cli_history.txt')
|
|||
# MODEL = "qwen3:latest"
|
||||
MODEL = 'hf.co/unsloth/Qwen3-30B-A3B-GGUF:Q4_K_M'
|
||||
|
||||
|
||||
def create_model():
|
||||
available_tools = tools.get_tools()
|
||||
logger.info("Available tools:")
|
||||
logger.info('Available tools:')
|
||||
for tool in available_tools:
|
||||
logger.info("- %s", tool.name)
|
||||
logger.info('- %s', tool.name)
|
||||
|
||||
llm = ChatOllama(model=MODEL)
|
||||
llm.bind_tools(tools=available_tools)
|
||||
return create_react_agent(llm, tools=available_tools)
|
||||
|
||||
SYSTEM_MESSAGE = '''
|
||||
|
||||
SYSTEM_MESSAGE = """
|
||||
You are a useful assistant with access to built in system tools.
|
||||
Format responses as markdown.
|
||||
Provide links when available.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level='INFO')
|
||||
|
@ -45,15 +42,18 @@ def main():
|
|||
llm = create_model()
|
||||
prev_idx = 0
|
||||
while True:
|
||||
user_input = prompt_toolkit.prompt("Human: ",
|
||||
user_input = prompt_toolkit.prompt(
|
||||
'Human: ',
|
||||
history=cli_history,
|
||||
auto_suggest=prompt_toolkit.auto_suggest.AutoSuggestFromHistory(),
|
||||
)
|
||||
messages.append(HumanMessage(user_input))
|
||||
|
||||
result = llm.invoke({
|
||||
result = llm.invoke(
|
||||
{
|
||||
'messages': messages,
|
||||
})
|
||||
},
|
||||
)
|
||||
messages = result['messages']
|
||||
for msg in messages[prev_idx:]:
|
||||
print(msg.pretty_repr())
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
from langchain_core.tools import tool
|
||||
|
||||
import requests_cache
|
||||
import dataclasses
|
||||
import clients
|
||||
from typing import get_type_hints, List, Dict, Any
|
||||
from collections.abc import Iterator
|
||||
import logging
|
||||
import secret_loader
|
||||
from decimal import Decimal
|
||||
|
||||
import clients
|
||||
import requests_cache
|
||||
import secret_loader
|
||||
from langchain_core.tools import tool
|
||||
|
||||
try:
|
||||
import pycountry.db
|
||||
except ImportError:
|
||||
|
@ -20,10 +18,11 @@ logger = logging.getLogger(__name__)
|
|||
@tool
|
||||
def search(query: str):
|
||||
"""Call to surf the web."""
|
||||
if "sf" in query.lower() or "san francisco" in query.lower():
|
||||
if 'sf' in query.lower() or 'san francisco' in query.lower():
|
||||
return "It's 60 degrees and foggy."
|
||||
return "It's 90 degrees and sunny."
|
||||
|
||||
|
||||
def dataclasses_to_json(data):
|
||||
if pycountry and isinstance(data, pycountry.db.Country):
|
||||
return data.alpha_2
|
||||
|
@ -51,22 +50,32 @@ RETURN_FORMATS = {
|
|||
|
||||
RETURN_FORMAT = 'json'
|
||||
|
||||
|
||||
def wrap_method(class_, method):
|
||||
logger.info("Wrapping %s.%s", class_.__name__, method.__name__)
|
||||
is_iterator = str(method.__annotations__.get('return', '')).startswith('collections.abc.Iterator')
|
||||
logger.info('Wrapping %s.%s', class_.__name__, method.__name__)
|
||||
is_iterator = str(method.__annotations__.get('return', '')).startswith(
|
||||
'collections.abc.Iterator',
|
||||
)
|
||||
|
||||
def wrapper(input_value):
|
||||
if isinstance(input_value, dict):
|
||||
logger.warning("Silently converting from dict to plain value!")
|
||||
logger.warning('Silently converting from dict to plain value!')
|
||||
input_value = next(input_value.values())
|
||||
logger.info("AI called %s.%s(%s)", class_.__name__, method.__name__, repr(input_value))
|
||||
logger.info(
|
||||
'AI called %s.%s(%s)', class_.__name__, method.__name__, repr(input_value),
|
||||
)
|
||||
try:
|
||||
result = method(input_value)
|
||||
if is_iterator:
|
||||
result = list(result)
|
||||
return RETURN_FORMATS[RETURN_FORMAT](result)
|
||||
except:
|
||||
logger.exception("AI invocation of %s.%s(%s) failed!", class_.__name__, method.__name__, repr(input_value))
|
||||
logger.exception(
|
||||
'AI invocation of %s.%s(%s) failed!',
|
||||
class_.__name__,
|
||||
method.__name__,
|
||||
repr(input_value),
|
||||
)
|
||||
raise
|
||||
|
||||
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
|
||||
|
@ -74,6 +83,7 @@ def wrap_method(class_, method):
|
|||
wrapper.__annotations__ = method.__annotations__
|
||||
return tool(wrapper)
|
||||
|
||||
|
||||
def wrap_all_methods_on_client(obj):
|
||||
for field_name in dir(obj):
|
||||
if field_name.startswith('_'):
|
||||
|
@ -85,6 +95,7 @@ def wrap_all_methods_on_client(obj):
|
|||
continue
|
||||
yield wrap_method(obj.__class__, method)
|
||||
|
||||
|
||||
def get_tools():
|
||||
session = requests_cache.CachedSession('output/test.sqlite')
|
||||
session.headers['User-Agent'] = 'Test Test'
|
||||
|
|
Loading…
Reference in New Issue
Block a user