From e42bb3df0af0a227d3aff94d502db5714504ca90 Mon Sep 17 00:00:00 2001 From: Jon Michael Aanes Date: Sun, 11 May 2025 16:06:53 +0200 Subject: [PATCH] Ruff --- setup.py | 2 ++ test_langgraph/__main__.py | 52 +++++++++++++++++++------------------- test_langgraph/tools.py | 47 +++++++++++++++++++++------------- 3 files changed, 57 insertions(+), 44 deletions(-) diff --git a/setup.py b/setup.py index 302fe61..f35b7e9 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ PACKAGE_DESCRIPTION = """ PACKAGE_DESCRIPTION_SHORT = """ """.strip() + def parse_version_file(text: str) -> str: match = re.match(r'^__version__\s*=\s*(["\'])([\d\.]+)\1$', text) if match is None: @@ -21,6 +22,7 @@ def parse_version_file(text: str) -> str: raise Exception(msg) return match.group(2) + with open(PACKAGE_NAME + '/_version.py') as f: version = parse_version_file(f.read()) diff --git a/test_langgraph/__main__.py b/test_langgraph/__main__.py index 615bbb1..9b58a74 100644 --- a/test_langgraph/__main__.py +++ b/test_langgraph/__main__.py @@ -1,17 +1,11 @@ -from langgraph.prebuilt import create_react_agent -from langchain_ollama import ChatOllama -from typing import Annotated -from langgraph.prebuilt import ToolNode, tools_condition - -from typing_extensions import TypedDict - -from langchain_core.messages import HumanMessage, SystemMessage -from langgraph.graph import StateGraph, START, END -from langgraph.graph.message import add_messages import logging + import prompt_toolkit -import prompt_toolkit.history import prompt_toolkit.auto_suggest +import prompt_toolkit.history +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_ollama import ChatOllama +from langgraph.prebuilt import create_react_agent logger = logging.getLogger(__name__) @@ -19,25 +13,28 @@ from . import tools cli_history = prompt_toolkit.history.FileHistory('output/cli_history.txt') -#MODEL = "gemma3:27b" -#MODEL = "qwen3:latest" +# MODEL = "gemma3:27b" +# MODEL = "qwen3:latest" MODEL = 'hf.co/unsloth/Qwen3-30B-A3B-GGUF:Q4_K_M' + def create_model(): available_tools = tools.get_tools() - logger.info("Available tools:") + logger.info('Available tools:') for tool in available_tools: - logger.info("- %s", tool.name) + logger.info('- %s', tool.name) llm = ChatOllama(model=MODEL) - llm.bind_tools(tools=available_tools ) - return create_react_agent(llm, tools=available_tools ) + llm.bind_tools(tools=available_tools) + return create_react_agent(llm, tools=available_tools) -SYSTEM_MESSAGE = ''' + +SYSTEM_MESSAGE = """ You are a useful assistant with access to built in system tools. Format responses as markdown. Provide links when available. -''' +""" + def main(): logging.basicConfig(level='INFO') @@ -45,15 +42,18 @@ def main(): llm = create_model() prev_idx = 0 while True: - user_input = prompt_toolkit.prompt("Human: ", - history=cli_history, - auto_suggest=prompt_toolkit.auto_suggest.AutoSuggestFromHistory(), - ) + user_input = prompt_toolkit.prompt( + 'Human: ', + history=cli_history, + auto_suggest=prompt_toolkit.auto_suggest.AutoSuggestFromHistory(), + ) messages.append(HumanMessage(user_input)) - result = llm.invoke({ - 'messages': messages, - }) + result = llm.invoke( + { + 'messages': messages, + }, + ) messages = result['messages'] for msg in messages[prev_idx:]: print(msg.pretty_repr()) diff --git a/test_langgraph/tools.py b/test_langgraph/tools.py index 3377888..bcc0b92 100644 --- a/test_langgraph/tools.py +++ b/test_langgraph/tools.py @@ -1,14 +1,12 @@ -from langchain_core.tools import tool - -import requests_cache import dataclasses -import clients -from typing import get_type_hints, List, Dict, Any -from collections.abc import Iterator import logging -import secret_loader from decimal import Decimal +import clients +import requests_cache +import secret_loader +from langchain_core.tools import tool + try: import pycountry.db except ImportError: @@ -20,53 +18,64 @@ logger = logging.getLogger(__name__) @tool def search(query: str): """Call to surf the web.""" - if "sf" in query.lower() or "san francisco" in query.lower(): + if 'sf' in query.lower() or 'san francisco' in query.lower(): return "It's 60 degrees and foggy." return "It's 90 degrees and sunny." + def dataclasses_to_json(data): if pycountry and isinstance(data, pycountry.db.Country): return data.alpha_2 if isinstance(data, list | tuple): return [dataclasses_to_json(d) for d in data] if isinstance(data, dict): - return {k:dataclasses_to_json(v) for k,v in data.items() if v} + return {k: dataclasses_to_json(v) for k, v in data.items() if v} if isinstance(data, Decimal): return float(data) if dataclasses.is_dataclass(data): result = {} for field in dataclasses.fields(data): - result[field.name] = getattr(data,field.name) + result[field.name] = getattr(data, field.name) return dataclasses_to_json(result) return data -assert dataclasses_to_json([1,2,3]) == [1,2,3] +assert dataclasses_to_json([1, 2, 3]) == [1, 2, 3] RETURN_FORMATS = { - 'raw_python': lambda x: x, - 'json': dataclasses_to_json, + 'raw_python': lambda x: x, + 'json': dataclasses_to_json, } RETURN_FORMAT = 'json' + def wrap_method(class_, method): - logger.info("Wrapping %s.%s", class_.__name__, method.__name__) - is_iterator = str(method.__annotations__.get('return', '')).startswith('collections.abc.Iterator') + logger.info('Wrapping %s.%s', class_.__name__, method.__name__) + is_iterator = str(method.__annotations__.get('return', '')).startswith( + 'collections.abc.Iterator', + ) def wrapper(input_value): if isinstance(input_value, dict): - logger.warning("Silently converting from dict to plain value!") + logger.warning('Silently converting from dict to plain value!') input_value = next(input_value.values()) - logger.info("AI called %s.%s(%s)", class_.__name__, method.__name__, repr(input_value)) + logger.info( + 'AI called %s.%s(%s)', class_.__name__, method.__name__, repr(input_value), + ) try: result = method(input_value) if is_iterator: result = list(result) return RETURN_FORMATS[RETURN_FORMAT](result) except: - logger.exception("AI invocation of %s.%s(%s) failed!", class_.__name__, method.__name__, repr(input_value)) + logger.exception( + 'AI invocation of %s.%s(%s) failed!', + class_.__name__, + method.__name__, + repr(input_value), + ) raise wrapper.__name__ = f'{class_.__name__}.{method.__name__}' @@ -74,6 +83,7 @@ def wrap_method(class_, method): wrapper.__annotations__ = method.__annotations__ return tool(wrapper) + def wrap_all_methods_on_client(obj): for field_name in dir(obj): if field_name.startswith('_'): @@ -85,6 +95,7 @@ def wrap_all_methods_on_client(obj): continue yield wrap_method(obj.__class__, method) + def get_tools(): session = requests_cache.CachedSession('output/test.sqlite') session.headers['User-Agent'] = 'Test Test'