OpenAI wip
This commit is contained in:
parent
60ae842764
commit
20bfd588f6
|
@ -1,12 +1,14 @@
|
|||
import logging
|
||||
|
||||
import json
|
||||
import prompt_toolkit
|
||||
import prompt_toolkit.auto_suggest
|
||||
import prompt_toolkit.history
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage
|
||||
from langchain_ollama import ChatOllama
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langmem import create_memory_manager
|
||||
import dataclasses
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -37,17 +39,110 @@ Format responses as markdown.
|
|||
Provide links when available.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
def main():
|
||||
memory_manager = create_memory_manager(
|
||||
app = FastAPI()
|
||||
|
||||
origins = [
|
||||
"http://localhost.tiangolo.com",
|
||||
"https://localhost.tiangolo.com",
|
||||
"http://localhost",
|
||||
"http://localhost:8080",
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OpenAIMessage:
|
||||
role: str
|
||||
content: str
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OpenAIRequest:
|
||||
model: str
|
||||
messages: list[OpenAIMessage]
|
||||
stream: bool
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OpenAIUsage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OpenAIMessageSeq:
|
||||
index: int
|
||||
message: OpenAIMessage
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OpenAIResponse:
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: list[OpenAIMessageSeq]
|
||||
usage: OpenAIUsage
|
||||
|
||||
memory_manager = create_memory_manager(
|
||||
create_raw_model(),
|
||||
instructions="Extract all noteworthy facts, events, and relationships. Indicate their importance.",
|
||||
enable_inserts=True,
|
||||
)
|
||||
|
||||
llm = create_model()
|
||||
|
||||
def invoke_model(messages_input: list[OpenAIMessage]):
|
||||
messages = [{'role': m.role, 'content': m.content} for m in messages_input]
|
||||
return llm.invoke(
|
||||
{
|
||||
'messages': messages,
|
||||
},
|
||||
)
|
||||
|
||||
logging.basicConfig(level='INFO')
|
||||
@app.post('/v1/chat/completions')
|
||||
async def chat_completions(
|
||||
request: OpenAIRequest
|
||||
) -> OpenAIResponse:
|
||||
print(request)
|
||||
def fjerp():
|
||||
derp = invoke_model(request.messages)['messages']
|
||||
choices = [OpenAIMessageSeq(idx,OpenAIMessage(m.type, m.content)) for idx,m in enumerate(derp)]
|
||||
return OpenAIResponse(
|
||||
id = 'test1',
|
||||
object='chat.completion',
|
||||
created=1746999397,
|
||||
model = request.model,
|
||||
system_fingerprint=request.model,
|
||||
choices=choices,
|
||||
usage = OpenAIUsage(0,0,0)
|
||||
)
|
||||
|
||||
async def response_stream():
|
||||
yield json.dumps(jsonable_encoder(fjerp()))
|
||||
if request.stream:
|
||||
return StreamingResponse(response_stream())
|
||||
return fjerp()
|
||||
|
||||
@app.get('/v1/models')
|
||||
async def models():
|
||||
return {"object":"list","data":[
|
||||
{"id":"test_langgraph","object":"model","created":1746919302,"owned_by":"jmaa"},
|
||||
]}
|
||||
|
||||
|
||||
|
||||
def main_cli():
|
||||
messages = [SystemMessage(SYSTEM_MESSAGE)]
|
||||
llm = create_model()
|
||||
prev_idx = 0
|
||||
while True:
|
||||
user_input = prompt_toolkit.prompt(
|
||||
|
@ -61,11 +156,7 @@ def main():
|
|||
else:
|
||||
messages.append(HumanMessage(user_input))
|
||||
|
||||
result = llm.invoke(
|
||||
{
|
||||
'messages': messages,
|
||||
},
|
||||
)
|
||||
result = invoke_model(messages)
|
||||
messages = result['messages']
|
||||
for msg in messages[prev_idx:]:
|
||||
print(msg.pretty_repr())
|
||||
|
@ -73,5 +164,12 @@ def main():
|
|||
prev_idx = len(messages)
|
||||
|
||||
|
||||
def main_server():
|
||||
pass
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level='INFO')
|
||||
main_server()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -12,6 +12,11 @@ try:
|
|||
except ImportError:
|
||||
pycountry = None
|
||||
|
||||
try:
|
||||
import fin_defs
|
||||
except ImportError:
|
||||
fin_defs = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -26,6 +31,10 @@ def search(query: str):
|
|||
def dataclasses_to_json(data):
|
||||
if pycountry and isinstance(data, pycountry.db.Country):
|
||||
return data.alpha_2
|
||||
if fin_defs and isinstance(data, fin_defs.AssetAmount):
|
||||
return str(data)
|
||||
if fin_defs and isinstance(data, fin_defs.Asset):
|
||||
return data.raw_short_name()
|
||||
if isinstance(data, list | tuple):
|
||||
return [dataclasses_to_json(d) for d in data]
|
||||
if isinstance(data, dict):
|
||||
|
@ -50,25 +59,29 @@ RETURN_FORMATS = {
|
|||
|
||||
RETURN_FORMAT = 'json'
|
||||
|
||||
MAX_TOOL_RESULT_LEN = 1000
|
||||
APPEND_RESULT_TYPE_DOCS = True
|
||||
|
||||
def wrap_method(class_, method):
|
||||
logger.info('Wrapping %s.%s', class_.__name__, method.__name__)
|
||||
is_iterator = str(method.__annotations__.get('return', '')).startswith(
|
||||
return_type = method.__annotations__.get('return', '')
|
||||
is_iterator = str(return_type).startswith(
|
||||
'collections.abc.Iterator',
|
||||
)
|
||||
|
||||
def wrapper(input_value):
|
||||
if isinstance(input_value, dict):
|
||||
logger.warning('Silently converting from dict to plain value!')
|
||||
input_value = next(input_value.values())
|
||||
logger.info(
|
||||
'AI called %s.%s(%s)', class_.__name__, method.__name__, repr(input_value),
|
||||
)
|
||||
try:
|
||||
if isinstance(input_value, dict):
|
||||
logger.warning('Silently converting from dict to plain value!')
|
||||
input_value = next(input_value.values())
|
||||
result = method(input_value)
|
||||
if is_iterator:
|
||||
result = list(result)
|
||||
return RETURN_FORMATS[RETURN_FORMAT](result)
|
||||
result_str: str = str(RETURN_FORMATS[RETURN_FORMAT](result))
|
||||
del result
|
||||
except:
|
||||
logger.exception(
|
||||
'AI invocation of %s.%s(%s) failed!',
|
||||
|
@ -77,6 +90,11 @@ def wrap_method(class_, method):
|
|||
repr(input_value),
|
||||
)
|
||||
raise
|
||||
if len(result_str) > MAX_TOOL_RESULT_LEN:
|
||||
result_str = result_str[:MAX_TOOL_RESULT_LEN] + ' (remaining tool result elicited...)'
|
||||
if APPEND_RESULT_TYPE_DOCS and (return_docs := getattr(return_type, '__doc__', None)):
|
||||
result_str = result_str+'\n'+return_docs
|
||||
return result_str
|
||||
|
||||
wrapper.__name__ = f'{class_.__name__}.{method.__name__}'
|
||||
wrapper.__doc__ = method.__doc__
|
||||
|
|
Loading…
Reference in New Issue
Block a user