Ruff
This commit is contained in:
parent
6ea0e48cc3
commit
e42bb3df0a
2
setup.py
2
setup.py
|
@ -14,6 +14,7 @@ PACKAGE_DESCRIPTION = """
|
||||||
PACKAGE_DESCRIPTION_SHORT = """
|
PACKAGE_DESCRIPTION_SHORT = """
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
def parse_version_file(text: str) -> str:
|
def parse_version_file(text: str) -> str:
|
||||||
match = re.match(r'^__version__\s*=\s*(["\'])([\d\.]+)\1$', text)
|
match = re.match(r'^__version__\s*=\s*(["\'])([\d\.]+)\1$', text)
|
||||||
if match is None:
|
if match is None:
|
||||||
|
@ -21,6 +22,7 @@ def parse_version_file(text: str) -> str:
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
return match.group(2)
|
return match.group(2)
|
||||||
|
|
||||||
|
|
||||||
with open(PACKAGE_NAME + '/_version.py') as f:
|
with open(PACKAGE_NAME + '/_version.py') as f:
|
||||||
version = parse_version_file(f.read())
|
version = parse_version_file(f.read())
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,11 @@
|
||||||
from langgraph.prebuilt import create_react_agent
|
|
||||||
from langchain_ollama import ChatOllama
|
|
||||||
from typing import Annotated
|
|
||||||
from langgraph.prebuilt import ToolNode, tools_condition
|
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
|
||||||
from langgraph.graph.message import add_messages
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import prompt_toolkit
|
import prompt_toolkit
|
||||||
import prompt_toolkit.history
|
|
||||||
import prompt_toolkit.auto_suggest
|
import prompt_toolkit.auto_suggest
|
||||||
|
import prompt_toolkit.history
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_ollama import ChatOllama
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -23,21 +17,24 @@ cli_history = prompt_toolkit.history.FileHistory('output/cli_history.txt')
|
||||||
# MODEL = "qwen3:latest"
|
# MODEL = "qwen3:latest"
|
||||||
MODEL = 'hf.co/unsloth/Qwen3-30B-A3B-GGUF:Q4_K_M'
|
MODEL = 'hf.co/unsloth/Qwen3-30B-A3B-GGUF:Q4_K_M'
|
||||||
|
|
||||||
|
|
||||||
def create_model():
|
def create_model():
|
||||||
available_tools = tools.get_tools()
|
available_tools = tools.get_tools()
|
||||||
logger.info("Available tools:")
|
logger.info('Available tools:')
|
||||||
for tool in available_tools:
|
for tool in available_tools:
|
||||||
logger.info("- %s", tool.name)
|
logger.info('- %s', tool.name)
|
||||||
|
|
||||||
llm = ChatOllama(model=MODEL)
|
llm = ChatOllama(model=MODEL)
|
||||||
llm.bind_tools(tools=available_tools)
|
llm.bind_tools(tools=available_tools)
|
||||||
return create_react_agent(llm, tools=available_tools)
|
return create_react_agent(llm, tools=available_tools)
|
||||||
|
|
||||||
SYSTEM_MESSAGE = '''
|
|
||||||
|
SYSTEM_MESSAGE = """
|
||||||
You are a useful assistant with access to built in system tools.
|
You are a useful assistant with access to built in system tools.
|
||||||
Format responses as markdown.
|
Format responses as markdown.
|
||||||
Provide links when available.
|
Provide links when available.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
logging.basicConfig(level='INFO')
|
logging.basicConfig(level='INFO')
|
||||||
|
@ -45,15 +42,18 @@ def main():
|
||||||
llm = create_model()
|
llm = create_model()
|
||||||
prev_idx = 0
|
prev_idx = 0
|
||||||
while True:
|
while True:
|
||||||
user_input = prompt_toolkit.prompt("Human: ",
|
user_input = prompt_toolkit.prompt(
|
||||||
|
'Human: ',
|
||||||
history=cli_history,
|
history=cli_history,
|
||||||
auto_suggest=prompt_toolkit.auto_suggest.AutoSuggestFromHistory(),
|
auto_suggest=prompt_toolkit.auto_suggest.AutoSuggestFromHistory(),
|
||||||
)
|
)
|
||||||
messages.append(HumanMessage(user_input))
|
messages.append(HumanMessage(user_input))
|
||||||
|
|
||||||
result = llm.invoke({
|
result = llm.invoke(
|
||||||
|
{
|
||||||
'messages': messages,
|
'messages': messages,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
messages = result['messages']
|
messages = result['messages']
|
||||||
for msg in messages[prev_idx:]:
|
for msg in messages[prev_idx:]:
|
||||||
print(msg.pretty_repr())
|
print(msg.pretty_repr())
|
||||||
|
|
|
@ -1,14 +1,12 @@
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
import requests_cache
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import clients
|
|
||||||
from typing import get_type_hints, List, Dict, Any
|
|
||||||
from collections.abc import Iterator
|
|
||||||
import logging
|
import logging
|
||||||
import secret_loader
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
|
import clients
|
||||||
|
import requests_cache
|
||||||
|
import secret_loader
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pycountry.db
|
import pycountry.db
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -20,10 +18,11 @@ logger = logging.getLogger(__name__)
|
||||||
@tool
|
@tool
|
||||||
def search(query: str):
|
def search(query: str):
|
||||||
"""Call to surf the web."""
|
"""Call to surf the web."""
|
||||||
if "sf" in query.lower() or "san francisco" in query.lower():
|
if 'sf' in query.lower() or 'san francisco' in query.lower():
|
||||||
return "It's 60 degrees and foggy."
|
return "It's 60 degrees and foggy."
|
||||||
return "It's 90 degrees and sunny."
|
return "It's 90 degrees and sunny."
|
||||||
|
|
||||||
|
|
||||||
def dataclasses_to_json(data):
|
def dataclasses_to_json(data):
|
||||||
if pycountry and isinstance(data, pycountry.db.Country):
|
if pycountry and isinstance(data, pycountry.db.Country):
|
||||||
return data.alpha_2
|
return data.alpha_2
|
||||||
|
@ -51,22 +50,32 @@ RETURN_FORMATS = {
|
||||||
|
|
||||||
RETURN_FORMAT = 'json'
|
RETURN_FORMAT = 'json'
|
||||||
|
|
||||||
|
|
||||||
def wrap_method(class_, method):
|
def wrap_method(class_, method):
|
||||||
logger.info("Wrapping %s.%s", class_.__name__, method.__name__)
|
logger.info('Wrapping %s.%s', class_.__name__, method.__name__)
|
||||||
is_iterator = str(method.__annotations__.get('return', '')).startswith('collections.abc.Iterator')
|
is_iterator = str(method.__annotations__.get('return', '')).startswith(
|
||||||
|
'collections.abc.Iterator',
|
||||||
|
)
|
||||||
|
|
||||||
def wrapper(input_value):
|
def wrapper(input_value):
|
||||||
if isinstance(input_value, dict):
|
if isinstance(input_value, dict):
|
||||||
logger.warning("Silently converting from dict to plain value!")
|
logger.warning('Silently converting from dict to plain value!')
|
||||||
input_value = next(input_value.values())
|
input_value = next(input_value.values())
|
||||||
logger.info("AI called %s.%s(%s)", class_.__name__, method.__name__, repr(input_value))
|
logger.info(
|
||||||
|
'AI called %s.%s(%s)', class_.__name__, method.__name__, repr(input_value),
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
result = method(input_value)
|
result = method(input_value)
|
||||||
if is_iterator:
|
if is_iterator:
|
||||||
result = list(result)
|
result = list(result)
|
||||||
return RETURN_FORMATS[RETURN_FORMAT](result)
|
return RETURN_FORMATS[RETURN_FORMAT](result)
|
||||||
except:
|
except:
|
||||||
logger.exception("AI invocation of %s.%s(%s) failed!", class_.__name__, method.__name__, repr(input_value))
|
logger.exception(
|
||||||
|
'AI invocation of %s.%s(%s) failed!',
|
||||||
|
class_.__name__,
|
||||||
|
method.__name__,
|
||||||
|
repr(input_value),
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
|
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
|
||||||
|
@ -74,6 +83,7 @@ def wrap_method(class_, method):
|
||||||
wrapper.__annotations__ = method.__annotations__
|
wrapper.__annotations__ = method.__annotations__
|
||||||
return tool(wrapper)
|
return tool(wrapper)
|
||||||
|
|
||||||
|
|
||||||
def wrap_all_methods_on_client(obj):
|
def wrap_all_methods_on_client(obj):
|
||||||
for field_name in dir(obj):
|
for field_name in dir(obj):
|
||||||
if field_name.startswith('_'):
|
if field_name.startswith('_'):
|
||||||
|
@ -85,6 +95,7 @@ def wrap_all_methods_on_client(obj):
|
||||||
continue
|
continue
|
||||||
yield wrap_method(obj.__class__, method)
|
yield wrap_method(obj.__class__, method)
|
||||||
|
|
||||||
|
|
||||||
def get_tools():
|
def get_tools():
|
||||||
session = requests_cache.CachedSession('output/test.sqlite')
|
session = requests_cache.CachedSession('output/test.sqlite')
|
||||||
session.headers['User-Agent'] = 'Test Test'
|
session.headers['User-Agent'] = 'Test Test'
|
||||||
|
|
Loading…
Reference in New Issue
Block a user