OpenAI wip

This commit is contained in:
Jon Michael Aanes 2025-05-12 00:28:21 +02:00
parent 60ae842764
commit 20bfd588f6
2 changed files with 134 additions and 18 deletions

View File

@ -1,12 +1,14 @@
import logging
import json
import prompt_toolkit
import prompt_toolkit.auto_suggest
import prompt_toolkit.history
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage
from langchain_ollama import ChatOllama
from langgraph.prebuilt import create_react_agent
from langmem import create_memory_manager
import dataclasses
logger = logging.getLogger(__name__)
@ -37,17 +39,110 @@ Format responses as markdown.
Provide links when available.
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder
def main():
memory_manager = create_memory_manager(
app = FastAPI()
origins = [
"http://localhost.tiangolo.com",
"https://localhost.tiangolo.com",
"http://localhost",
"http://localhost:8080",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@dataclasses.dataclass(frozen=True)
class OpenAIMessage:
role: str
content: str
@dataclasses.dataclass(frozen=True)
class OpenAIRequest:
model: str
messages: list[OpenAIMessage]
stream: bool
@dataclasses.dataclass(frozen=True)
class OpenAIUsage:
prompt_tokens: int
completion_tokens: int
total_tokens: int
@dataclasses.dataclass(frozen=True)
class OpenAIMessageSeq:
index: int
message: OpenAIMessage
@dataclasses.dataclass(frozen=True)
class OpenAIResponse:
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: list[OpenAIMessageSeq]
usage: OpenAIUsage
memory_manager = create_memory_manager(
create_raw_model(),
instructions="Extract all noteworthy facts, events, and relationships. Indicate their importance.",
enable_inserts=True,
)
llm = create_model()
def invoke_model(messages_input: list[OpenAIMessage]):
messages = [{'role': m.role, 'content': m.content} for m in messages_input]
return llm.invoke(
{
'messages': messages,
},
)
logging.basicConfig(level='INFO')
@app.post('/v1/chat/completions')
async def chat_completions(
request: OpenAIRequest
) -> OpenAIResponse:
print(request)
def fjerp():
derp = invoke_model(request.messages)['messages']
choices = [OpenAIMessageSeq(idx,OpenAIMessage(m.type, m.content)) for idx,m in enumerate(derp)]
return OpenAIResponse(
id = 'test1',
object='chat.completion',
created=1746999397,
model = request.model,
system_fingerprint=request.model,
choices=choices,
usage = OpenAIUsage(0,0,0)
)
async def response_stream():
yield json.dumps(jsonable_encoder(fjerp()))
if request.stream:
return StreamingResponse(response_stream())
return fjerp()
@app.get('/v1/models')
async def models():
return {"object":"list","data":[
{"id":"test_langgraph","object":"model","created":1746919302,"owned_by":"jmaa"},
]}
def main_cli():
messages = [SystemMessage(SYSTEM_MESSAGE)]
llm = create_model()
prev_idx = 0
while True:
user_input = prompt_toolkit.prompt(
@ -61,11 +156,7 @@ def main():
else:
messages.append(HumanMessage(user_input))
result = llm.invoke(
{
'messages': messages,
},
)
result = invoke_model(messages)
messages = result['messages']
for msg in messages[prev_idx:]:
print(msg.pretty_repr())
@ -73,5 +164,12 @@ def main():
prev_idx = len(messages)
def main_server():
pass
def main():
logging.basicConfig(level='INFO')
main_server()
if __name__ == '__main__':
main()

View File

@ -12,6 +12,11 @@ try:
except ImportError:
pycountry = None
try:
import fin_defs
except ImportError:
fin_defs = None
logger = logging.getLogger(__name__)
@ -26,6 +31,10 @@ def search(query: str):
def dataclasses_to_json(data):
if pycountry and isinstance(data, pycountry.db.Country):
return data.alpha_2
if fin_defs and isinstance(data, fin_defs.AssetAmount):
return str(data)
if fin_defs and isinstance(data, fin_defs.Asset):
return data.raw_short_name()
if isinstance(data, list | tuple):
return [dataclasses_to_json(d) for d in data]
if isinstance(data, dict):
@ -50,25 +59,29 @@ RETURN_FORMATS = {
RETURN_FORMAT = 'json'
MAX_TOOL_RESULT_LEN = 1000
APPEND_RESULT_TYPE_DOCS = True
def wrap_method(class_, method):
logger.info('Wrapping %s.%s', class_.__name__, method.__name__)
is_iterator = str(method.__annotations__.get('return', '')).startswith(
return_type = method.__annotations__.get('return', '')
is_iterator = str(return_type).startswith(
'collections.abc.Iterator',
)
def wrapper(input_value):
if isinstance(input_value, dict):
logger.warning('Silently converting from dict to plain value!')
input_value = next(input_value.values())
logger.info(
'AI called %s.%s(%s)', class_.__name__, method.__name__, repr(input_value),
)
try:
if isinstance(input_value, dict):
logger.warning('Silently converting from dict to plain value!')
input_value = next(input_value.values())
result = method(input_value)
if is_iterator:
result = list(result)
return RETURN_FORMATS[RETURN_FORMAT](result)
result_str: str = str(RETURN_FORMATS[RETURN_FORMAT](result))
del result
except:
logger.exception(
'AI invocation of %s.%s(%s) failed!',
@ -77,6 +90,11 @@ def wrap_method(class_, method):
repr(input_value),
)
raise
if len(result_str) > MAX_TOOL_RESULT_LEN:
result_str = result_str[:MAX_TOOL_RESULT_LEN] + ' (remaining tool result elicited...)'
if APPEND_RESULT_TYPE_DOCS and (return_docs := getattr(return_type, '__doc__', None)):
result_str = result_str+'\n'+return_docs
return result_str
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
wrapper.__doc__ = method.__doc__