test-langgraph/test_langgraph/__main__.py
2025-05-12 00:28:21 +02:00

176 lines
4.4 KiB
Python

import logging
import json
import prompt_toolkit
import prompt_toolkit.auto_suggest
import prompt_toolkit.history
from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage
from langchain_ollama import ChatOllama
from langgraph.prebuilt import create_react_agent
from langmem import create_memory_manager
import dataclasses
logger = logging.getLogger(__name__)
from . import tools
cli_history = prompt_toolkit.history.FileHistory('output/cli_history.txt')
MODEL = 'hf.co/unsloth/Qwen3-30B-A3B-GGUF:Q4_K_M'
def create_raw_model():
return ChatOllama(model=MODEL)
def create_model():
available_tools = tools.get_tools()
logger.info('Available tools:')
for tool in available_tools:
logger.info('- %s', tool.name)
llm = create_raw_model()
llm.bind_tools(tools=available_tools)
return create_react_agent(llm, tools=available_tools)
SYSTEM_MESSAGE = """
You are a useful assistant with access to built in system tools.
Format responses as markdown.
Provide links when available.
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder
app = FastAPI()
origins = [
"http://localhost.tiangolo.com",
"https://localhost.tiangolo.com",
"http://localhost",
"http://localhost:8080",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@dataclasses.dataclass(frozen=True)
class OpenAIMessage:
role: str
content: str
@dataclasses.dataclass(frozen=True)
class OpenAIRequest:
model: str
messages: list[OpenAIMessage]
stream: bool
@dataclasses.dataclass(frozen=True)
class OpenAIUsage:
prompt_tokens: int
completion_tokens: int
total_tokens: int
@dataclasses.dataclass(frozen=True)
class OpenAIMessageSeq:
index: int
message: OpenAIMessage
@dataclasses.dataclass(frozen=True)
class OpenAIResponse:
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: list[OpenAIMessageSeq]
usage: OpenAIUsage
memory_manager = create_memory_manager(
create_raw_model(),
instructions="Extract all noteworthy facts, events, and relationships. Indicate their importance.",
enable_inserts=True,
)
llm = create_model()
def invoke_model(messages_input: list[OpenAIMessage]):
messages = [{'role': m.role, 'content': m.content} for m in messages_input]
return llm.invoke(
{
'messages': messages,
},
)
@app.post('/v1/chat/completions')
async def chat_completions(
request: OpenAIRequest
) -> OpenAIResponse:
print(request)
def fjerp():
derp = invoke_model(request.messages)['messages']
choices = [OpenAIMessageSeq(idx,OpenAIMessage(m.type, m.content)) for idx,m in enumerate(derp)]
return OpenAIResponse(
id = 'test1',
object='chat.completion',
created=1746999397,
model = request.model,
system_fingerprint=request.model,
choices=choices,
usage = OpenAIUsage(0,0,0)
)
async def response_stream():
yield json.dumps(jsonable_encoder(fjerp()))
if request.stream:
return StreamingResponse(response_stream())
return fjerp()
@app.get('/v1/models')
async def models():
return {"object":"list","data":[
{"id":"test_langgraph","object":"model","created":1746919302,"owned_by":"jmaa"},
]}
def main_cli():
messages = [SystemMessage(SYSTEM_MESSAGE)]
prev_idx = 0
while True:
user_input = prompt_toolkit.prompt(
'Human: ',
history=cli_history,
auto_suggest=prompt_toolkit.auto_suggest.AutoSuggestFromHistory(),
)
if user_input == '/memories':
memories = memory_manager.invoke({"messages": messages})
print(memories)
else:
messages.append(HumanMessage(user_input))
result = invoke_model(messages)
messages = result['messages']
for msg in messages[prev_idx:]:
print(msg.pretty_repr())
del msg
prev_idx = len(messages)
def main_server():
pass
def main():
logging.basicConfig(level='INFO')
main_server()
if __name__ == '__main__':
main()