diff --git a/test_langgraph/__main__.py b/test_langgraph/__main__.py index bf99ecf..4cd2659 100644 --- a/test_langgraph/__main__.py +++ b/test_langgraph/__main__.py @@ -1,49 +1,11 @@ -import logging - -import json -import prompt_toolkit -import prompt_toolkit.auto_suggest -import prompt_toolkit.history -from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage -from langchain_ollama import ChatOllama -from langgraph.prebuilt import create_react_agent -from langmem import create_memory_manager -import dataclasses - -logger = logging.getLogger(__name__) - -from . import tools - -cli_history = prompt_toolkit.history.FileHistory('output/cli_history.txt') - -MODEL = 'hf.co/unsloth/Qwen3-30B-A3B-GGUF:Q4_K_M' - - -def create_raw_model(): - return ChatOllama(model=MODEL) - -def create_model(): - available_tools = tools.get_tools() - logger.info('Available tools:') - for tool in available_tools: - logger.info('- %s', tool.name) - - llm = create_raw_model() - llm.bind_tools(tools=available_tools) - return create_react_agent(llm, tools=available_tools) - - -SYSTEM_MESSAGE = """ -You are a useful assistant with access to built in system tools. -Format responses as markdown. -Provide links when available. -""" - from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.encoders import jsonable_encoder +from . import tools + + app = FastAPI() origins = [ @@ -61,115 +23,7 @@ app.add_middleware( allow_headers=["*"], ) -@dataclasses.dataclass(frozen=True) -class OpenAIMessage: - role: str - content: str - -@dataclasses.dataclass(frozen=True) -class OpenAIRequest: - model: str - messages: list[OpenAIMessage] - stream: bool - -@dataclasses.dataclass(frozen=True) -class OpenAIUsage: - prompt_tokens: int - completion_tokens: int - total_tokens: int - -@dataclasses.dataclass(frozen=True) -class OpenAIMessageSeq: - index: int - message: OpenAIMessage - -@dataclasses.dataclass(frozen=True) -class OpenAIResponse: - id: str - object: str - created: int - model: str - system_fingerprint: str - choices: list[OpenAIMessageSeq] - usage: OpenAIUsage - -memory_manager = create_memory_manager( - create_raw_model(), - instructions="Extract all noteworthy facts, events, and relationships. Indicate their importance.", - enable_inserts=True, -) - -llm = create_model() - -def invoke_model(messages_input: list[OpenAIMessage]): - messages = [{'role': m.role, 'content': m.content} for m in messages_input] - return llm.invoke( - { - 'messages': messages, - }, - ) - -@app.post('/v1/chat/completions') -async def chat_completions( - request: OpenAIRequest -) -> OpenAIResponse: - print(request) - def fjerp(): - derp = invoke_model(request.messages)['messages'] - choices = [OpenAIMessageSeq(idx,OpenAIMessage(m.type, m.content)) for idx,m in enumerate(derp)] - return OpenAIResponse( - id = 'test1', - object='chat.completion', - created=1746999397, - model = request.model, - system_fingerprint=request.model, - choices=choices, - usage = OpenAIUsage(0,0,0) - ) - - async def response_stream(): - yield json.dumps(jsonable_encoder(fjerp())) - if request.stream: - return StreamingResponse(response_stream()) - return fjerp() - -@app.get('/v1/models') -async def models(): - return {"object":"list","data":[ - {"id":"test_langgraph","object":"model","created":1746919302,"owned_by":"jmaa"}, - ]} - - - -def main_cli(): - messages = [SystemMessage(SYSTEM_MESSAGE)] - prev_idx = 0 - while True: - user_input = prompt_toolkit.prompt( - 'Human: ', - history=cli_history, - auto_suggest=prompt_toolkit.auto_suggest.AutoSuggestFromHistory(), - ) - if user_input == '/memories': - memories = memory_manager.invoke({"messages": messages}) - print(memories) - else: - messages.append(HumanMessage(user_input)) - - result = invoke_model(messages) - messages = result['messages'] - for msg in messages[prev_idx:]: - print(msg.pretty_repr()) - del msg - prev_idx = len(messages) - - -def main_server(): - pass - -def main(): - logging.basicConfig(level='INFO') - main_server() - -if __name__ == '__main__': - main() +for tool in tools.get_tools(): + component, method = tool.__name__.split('.') + path = f'/{component}/{method}' + app.get(path, response_model=None)(tool) diff --git a/test_langgraph/main_openai_api.py b/test_langgraph/main_openai_api.py new file mode 100644 index 0000000..08817f9 --- /dev/null +++ b/test_langgraph/main_openai_api.py @@ -0,0 +1,173 @@ +import logging + +import json +import prompt_toolkit +import prompt_toolkit.auto_suggest +import prompt_toolkit.history +from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage +from langchain_ollama import ChatOllama +from langgraph.prebuilt import create_react_agent +from langmem import create_memory_manager +import dataclasses + +logger = logging.getLogger(__name__) + +from . import tools + +cli_history = prompt_toolkit.history.FileHistory('output/cli_history.txt') + +MODEL = 'hf.co/unsloth/Qwen3-30B-A3B-GGUF:Q4_K_M' + + +def create_raw_model(): + return ChatOllama(model=MODEL) + +def create_model(): + available_tools = tools.get_tools() + logger.info('Available tools:') + for tool in available_tools: + logger.info('- %s', tool.name) + + llm = create_raw_model() + llm.bind_tools(tools=available_tools) + return create_react_agent(llm, tools=available_tools) + + +SYSTEM_MESSAGE = """ +You are a useful assistant with access to built in system tools. +Format responses as markdown. +Provide links when available. +""" + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from fastapi.encoders import jsonable_encoder + +app = FastAPI() + +origins = [ + "http://localhost.tiangolo.com", + "https://localhost.tiangolo.com", + "http://localhost", + "http://localhost:8080", +] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +@dataclasses.dataclass(frozen=True) +class OpenAIMessage: + role: str + content: str + +@dataclasses.dataclass(frozen=True) +class OpenAIRequest: + model: str + messages: list[OpenAIMessage] + stream: bool + +@dataclasses.dataclass(frozen=True) +class OpenAIUsage: + prompt_tokens: int + completion_tokens: int + total_tokens: int + +@dataclasses.dataclass(frozen=True) +class OpenAIMessageSeq: + index: int + message: OpenAIMessage + +@dataclasses.dataclass(frozen=True) +class OpenAIResponse: + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: list[OpenAIMessageSeq] + usage: OpenAIUsage + +memory_manager = create_memory_manager( + create_raw_model(), + instructions="Extract all noteworthy facts, events, and relationships. Indicate their importance.", + enable_inserts=True, +) + +llm = create_model() + +def invoke_model(messages_input: list[OpenAIMessage]): + messages = [{'role': m.role, 'content': m.content} for m in messages_input] + return llm.invoke( + { + 'messages': messages, + }, + ) + +@app.post('/v1/chat/completions') +async def chat_completions( + request: OpenAIRequest +) -> OpenAIResponse: + print(request) + def fjerp(): + derp = invoke_model(request.messages)['messages'] + choices = [OpenAIMessageSeq(idx,OpenAIMessage(m.type, m.content)) for idx,m in enumerate(derp)] + return OpenAIResponse( + id = 'test1', + object='chat.completion', + created=1746999397, + model = request.model, + system_fingerprint=request.model, + choices=choices, + usage = OpenAIUsage(0,0,0) + ) + + async def response_stream(): + yield json.dumps(jsonable_encoder(fjerp())) + if request.stream: + return StreamingResponse(response_stream()) + return fjerp() + +@app.get('/v1/models') +async def models(): + return {"object":"list","data":[ + {"id":"test_langgraph","object":"model","created":1746919302,"owned_by":"jmaa"}, + ]} + + + +def main_cli(): + messages = [SystemMessage(SYSTEM_MESSAGE)] + prev_idx = 0 + while True: + user_input = prompt_toolkit.prompt( + 'Human: ', + history=cli_history, + auto_suggest=prompt_toolkit.auto_suggest.AutoSuggestFromHistory(), + ) + if user_input == '/memories': + memories = memory_manager.invoke({"messages": messages}) + print(memories) + else: + messages.append(HumanMessage(user_input)) + + result = invoke_model(messages) + messages = result['messages'] + for msg in messages[prev_idx:]: + print(msg.pretty_repr()) + del msg + prev_idx = len(messages) + + +def main_server(): + pass + +def main(): + logging.basicConfig(level='INFO') + main_server() + diff --git a/test_langgraph/tools.py b/test_langgraph/tools.py index b64f6a8..af2ea0c 100644 --- a/test_langgraph/tools.py +++ b/test_langgraph/tools.py @@ -99,7 +99,7 @@ def wrap_method(class_, method): wrapper.__name__ = f'{class_.__name__}.{method.__name__}' wrapper.__doc__ = method.__doc__ wrapper.__annotations__ = method.__annotations__ - return tool(wrapper) + return wrapper def wrap_all_methods_on_client(obj):