Ruff
Some checks failed
Run Python tests (through Pytest) / Test (push) Failing after 22s
Verify Python project can be installed, loaded and have version checked / Test (push) Failing after 20s

This commit is contained in:
Jon Michael Aanes 2025-05-11 16:06:53 +02:00
parent 6ea0e48cc3
commit e42bb3df0a
3 changed files with 57 additions and 44 deletions

View File

@ -14,6 +14,7 @@ PACKAGE_DESCRIPTION = """
PACKAGE_DESCRIPTION_SHORT = """
""".strip()
def parse_version_file(text: str) -> str:
match = re.match(r'^__version__\s*=\s*(["\'])([\d\.]+)\1$', text)
if match is None:
@ -21,6 +22,7 @@ def parse_version_file(text: str) -> str:
raise Exception(msg)
return match.group(2)
with open(PACKAGE_NAME + '/_version.py') as f:
version = parse_version_file(f.read())

View File

@ -1,17 +1,11 @@
from langgraph.prebuilt import create_react_agent
from langchain_ollama import ChatOllama
from typing import Annotated
from langgraph.prebuilt import ToolNode, tools_condition
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
import logging
import prompt_toolkit
import prompt_toolkit.history
import prompt_toolkit.auto_suggest
import prompt_toolkit.history
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_ollama import ChatOllama
from langgraph.prebuilt import create_react_agent
logger = logging.getLogger(__name__)
@ -19,25 +13,28 @@ from . import tools
cli_history = prompt_toolkit.history.FileHistory('output/cli_history.txt')
#MODEL = "gemma3:27b"
#MODEL = "qwen3:latest"
# MODEL = "gemma3:27b"
# MODEL = "qwen3:latest"
MODEL = 'hf.co/unsloth/Qwen3-30B-A3B-GGUF:Q4_K_M'
def create_model():
available_tools = tools.get_tools()
logger.info("Available tools:")
logger.info('Available tools:')
for tool in available_tools:
logger.info("- %s", tool.name)
logger.info('- %s', tool.name)
llm = ChatOllama(model=MODEL)
llm.bind_tools(tools=available_tools )
return create_react_agent(llm, tools=available_tools )
llm.bind_tools(tools=available_tools)
return create_react_agent(llm, tools=available_tools)
SYSTEM_MESSAGE = '''
SYSTEM_MESSAGE = """
You are a useful assistant with access to built in system tools.
Format responses as markdown.
Provide links when available.
'''
"""
def main():
logging.basicConfig(level='INFO')
@ -45,15 +42,18 @@ def main():
llm = create_model()
prev_idx = 0
while True:
user_input = prompt_toolkit.prompt("Human: ",
user_input = prompt_toolkit.prompt(
'Human: ',
history=cli_history,
auto_suggest=prompt_toolkit.auto_suggest.AutoSuggestFromHistory(),
)
messages.append(HumanMessage(user_input))
result = llm.invoke({
result = llm.invoke(
{
'messages': messages,
})
},
)
messages = result['messages']
for msg in messages[prev_idx:]:
print(msg.pretty_repr())

View File

@ -1,14 +1,12 @@
from langchain_core.tools import tool
import requests_cache
import dataclasses
import clients
from typing import get_type_hints, List, Dict, Any
from collections.abc import Iterator
import logging
import secret_loader
from decimal import Decimal
import clients
import requests_cache
import secret_loader
from langchain_core.tools import tool
try:
import pycountry.db
except ImportError:
@ -20,28 +18,29 @@ logger = logging.getLogger(__name__)
@tool
def search(query: str):
"""Call to surf the web."""
if "sf" in query.lower() or "san francisco" in query.lower():
if 'sf' in query.lower() or 'san francisco' in query.lower():
return "It's 60 degrees and foggy."
return "It's 90 degrees and sunny."
def dataclasses_to_json(data):
if pycountry and isinstance(data, pycountry.db.Country):
return data.alpha_2
if isinstance(data, list | tuple):
return [dataclasses_to_json(d) for d in data]
if isinstance(data, dict):
return {k:dataclasses_to_json(v) for k,v in data.items() if v}
return {k: dataclasses_to_json(v) for k, v in data.items() if v}
if isinstance(data, Decimal):
return float(data)
if dataclasses.is_dataclass(data):
result = {}
for field in dataclasses.fields(data):
result[field.name] = getattr(data,field.name)
result[field.name] = getattr(data, field.name)
return dataclasses_to_json(result)
return data
assert dataclasses_to_json([1,2,3]) == [1,2,3]
assert dataclasses_to_json([1, 2, 3]) == [1, 2, 3]
RETURN_FORMATS = {
@ -51,22 +50,32 @@ RETURN_FORMATS = {
RETURN_FORMAT = 'json'
def wrap_method(class_, method):
logger.info("Wrapping %s.%s", class_.__name__, method.__name__)
is_iterator = str(method.__annotations__.get('return', '')).startswith('collections.abc.Iterator')
logger.info('Wrapping %s.%s', class_.__name__, method.__name__)
is_iterator = str(method.__annotations__.get('return', '')).startswith(
'collections.abc.Iterator',
)
def wrapper(input_value):
if isinstance(input_value, dict):
logger.warning("Silently converting from dict to plain value!")
logger.warning('Silently converting from dict to plain value!')
input_value = next(input_value.values())
logger.info("AI called %s.%s(%s)", class_.__name__, method.__name__, repr(input_value))
logger.info(
'AI called %s.%s(%s)', class_.__name__, method.__name__, repr(input_value),
)
try:
result = method(input_value)
if is_iterator:
result = list(result)
return RETURN_FORMATS[RETURN_FORMAT](result)
except:
logger.exception("AI invocation of %s.%s(%s) failed!", class_.__name__, method.__name__, repr(input_value))
logger.exception(
'AI invocation of %s.%s(%s) failed!',
class_.__name__,
method.__name__,
repr(input_value),
)
raise
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
@ -74,6 +83,7 @@ def wrap_method(class_, method):
wrapper.__annotations__ = method.__annotations__
return tool(wrapper)
def wrap_all_methods_on_client(obj):
for field_name in dir(obj):
if field_name.startswith('_'):
@ -85,6 +95,7 @@ def wrap_all_methods_on_client(obj):
continue
yield wrap_method(obj.__class__, method)
def get_tools():
session = requests_cache.CachedSession('output/test.sqlite')
session.headers['User-Agent'] = 'Test Test'