diff --git a/test_langgraph/__main__.py b/test_langgraph/__main__.py index 41371cd..bf99ecf 100644 --- a/test_langgraph/__main__.py +++ b/test_langgraph/__main__.py @@ -1,12 +1,14 @@ import logging +import json import prompt_toolkit import prompt_toolkit.auto_suggest import prompt_toolkit.history -from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage from langchain_ollama import ChatOllama from langgraph.prebuilt import create_react_agent from langmem import create_memory_manager +import dataclasses logger = logging.getLogger(__name__) @@ -37,17 +39,110 @@ Format responses as markdown. Provide links when available. """ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from fastapi.encoders import jsonable_encoder -def main(): - memory_manager = create_memory_manager( - create_raw_model(), - instructions="Extract all noteworthy facts, events, and relationships. Indicate their importance.", - enable_inserts=True, +app = FastAPI() + +origins = [ + "http://localhost.tiangolo.com", + "https://localhost.tiangolo.com", + "http://localhost", + "http://localhost:8080", +] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +@dataclasses.dataclass(frozen=True) +class OpenAIMessage: + role: str + content: str + +@dataclasses.dataclass(frozen=True) +class OpenAIRequest: + model: str + messages: list[OpenAIMessage] + stream: bool + +@dataclasses.dataclass(frozen=True) +class OpenAIUsage: + prompt_tokens: int + completion_tokens: int + total_tokens: int + +@dataclasses.dataclass(frozen=True) +class OpenAIMessageSeq: + index: int + message: OpenAIMessage + +@dataclasses.dataclass(frozen=True) +class OpenAIResponse: + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: list[OpenAIMessageSeq] + usage: OpenAIUsage + +memory_manager = create_memory_manager( + create_raw_model(), + instructions="Extract all noteworthy facts, events, and relationships. Indicate their importance.", + enable_inserts=True, +) + +llm = create_model() + +def invoke_model(messages_input: list[OpenAIMessage]): + messages = [{'role': m.role, 'content': m.content} for m in messages_input] + return llm.invoke( + { + 'messages': messages, + }, ) - logging.basicConfig(level='INFO') +@app.post('/v1/chat/completions') +async def chat_completions( + request: OpenAIRequest +) -> OpenAIResponse: + print(request) + def fjerp(): + derp = invoke_model(request.messages)['messages'] + choices = [OpenAIMessageSeq(idx,OpenAIMessage(m.type, m.content)) for idx,m in enumerate(derp)] + return OpenAIResponse( + id = 'test1', + object='chat.completion', + created=1746999397, + model = request.model, + system_fingerprint=request.model, + choices=choices, + usage = OpenAIUsage(0,0,0) + ) + + async def response_stream(): + yield json.dumps(jsonable_encoder(fjerp())) + if request.stream: + return StreamingResponse(response_stream()) + return fjerp() + +@app.get('/v1/models') +async def models(): + return {"object":"list","data":[ + {"id":"test_langgraph","object":"model","created":1746919302,"owned_by":"jmaa"}, + ]} + + + +def main_cli(): messages = [SystemMessage(SYSTEM_MESSAGE)] - llm = create_model() prev_idx = 0 while True: user_input = prompt_toolkit.prompt( @@ -61,11 +156,7 @@ def main(): else: messages.append(HumanMessage(user_input)) - result = llm.invoke( - { - 'messages': messages, - }, - ) + result = invoke_model(messages) messages = result['messages'] for msg in messages[prev_idx:]: print(msg.pretty_repr()) @@ -73,5 +164,12 @@ def main(): prev_idx = len(messages) +def main_server(): + pass + +def main(): + logging.basicConfig(level='INFO') + main_server() + if __name__ == '__main__': main() diff --git a/test_langgraph/tools.py b/test_langgraph/tools.py index bcc0b92..b64f6a8 100644 --- a/test_langgraph/tools.py +++ b/test_langgraph/tools.py @@ -12,6 +12,11 @@ try: except ImportError: pycountry = None +try: + import fin_defs +except ImportError: + fin_defs = None + logger = logging.getLogger(__name__) @@ -26,6 +31,10 @@ def search(query: str): def dataclasses_to_json(data): if pycountry and isinstance(data, pycountry.db.Country): return data.alpha_2 + if fin_defs and isinstance(data, fin_defs.AssetAmount): + return str(data) + if fin_defs and isinstance(data, fin_defs.Asset): + return data.raw_short_name() if isinstance(data, list | tuple): return [dataclasses_to_json(d) for d in data] if isinstance(data, dict): @@ -50,25 +59,29 @@ RETURN_FORMATS = { RETURN_FORMAT = 'json' +MAX_TOOL_RESULT_LEN = 1000 +APPEND_RESULT_TYPE_DOCS = True def wrap_method(class_, method): logger.info('Wrapping %s.%s', class_.__name__, method.__name__) - is_iterator = str(method.__annotations__.get('return', '')).startswith( + return_type = method.__annotations__.get('return', '') + is_iterator = str(return_type).startswith( 'collections.abc.Iterator', ) def wrapper(input_value): - if isinstance(input_value, dict): - logger.warning('Silently converting from dict to plain value!') - input_value = next(input_value.values()) logger.info( 'AI called %s.%s(%s)', class_.__name__, method.__name__, repr(input_value), ) try: + if isinstance(input_value, dict): + logger.warning('Silently converting from dict to plain value!') + input_value = next(input_value.values()) result = method(input_value) if is_iterator: result = list(result) - return RETURN_FORMATS[RETURN_FORMAT](result) + result_str: str = str(RETURN_FORMATS[RETURN_FORMAT](result)) + del result except: logger.exception( 'AI invocation of %s.%s(%s) failed!', @@ -77,6 +90,11 @@ def wrap_method(class_, method): repr(input_value), ) raise + if len(result_str) > MAX_TOOL_RESULT_LEN: + result_str = result_str[:MAX_TOOL_RESULT_LEN] + ' (remaining tool result elicited...)' + if APPEND_RESULT_TYPE_DOCS and (return_docs := getattr(return_type, '__doc__', None)): + result_str = result_str+'\n'+return_docs + return result_str wrapper.__name__ = f'{class_.__name__}.{method.__name__}' wrapper.__doc__ = method.__doc__