Skip to content

Commit

Permalink
store and recover sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
heiko-braun committed Jan 15, 2024
1 parent 6b772f5 commit 459ebdc
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 28 deletions.
4 changes: 2 additions & 2 deletions core/agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

__all__ = ['agent_executor', 'agent_llm', 'agent_memory']

from langchain.embeddings import OpenAIEmbeddings
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_openai.chat_models import ChatOpenAI
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor

from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage
from langchain.prompts import MessagesPlaceholder

Expand Down
96 changes: 88 additions & 8 deletions core/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
AgentTokenBufferMemory,
)

from langchain.schema import messages_from_dict, messages_to_dict
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory

import datetime

from abc import ABC, abstractmethod
Expand All @@ -17,6 +20,9 @@
from uuid import UUID
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union

import psycopg2
import json

# --

# the time in seconds, after which a conversation will be retried if inactive
Expand Down Expand Up @@ -117,7 +123,7 @@ class Conversation(StateMachine):
request_docs = running.to(lookup)
docs_supplied = lookup.to(running)

def __init__(self, slack_client, channel, thread_ts):
def __init__(self, slack_client, channel, thread_ts, memory, start_message="How can I help you?"):

self.last_activity = datetime.datetime.now()
self.client = slack_client
Expand All @@ -128,30 +134,39 @@ def __init__(self, slack_client, channel, thread_ts):
self.thread_ts = thread_ts

# internal states
self.prompt_text = None
self.thread = None
self.run = None
self.prompt_text = None
self.start_message = start_message

# the main interface towards the LLM
self.agent = agent_executor

# keeps track of previous messages
self.memory = AgentTokenBufferMemory(llm=agent_llm)
self.memory = memory

# interim, runtime states
self.response_handle = None

super().__init__()

def get_channel(self):
return self.channel

def get_thread(self):
return self.thread_ts

def export_memory(self):
extracted_messages = self.memory.chat_memory.messages
return messages_to_dict(extracted_messages)

def is_expired(self):
return self.last_activity < datetime.datetime.now()-datetime.timedelta(seconds=CONVERSATION_EXPIRY_TIME)

def on_enter_greeting(self):
# mimic the first LLM response to get things started
self.response_handle = {
"output": "How can I help you?"
"output": self.start_message
}
self.feedback.set_tagline("ID "+self.thread_ts)
self.feedback.set_tagline("New Session: "+self.thread_ts)

# starting a thinking loop
def on_enter_running(self):
Expand Down Expand Up @@ -256,4 +271,69 @@ async def on_chat_model_start(
"""Run when a chat model starts running."""
print("chat model start")



def save_session(conversation):

json_export = json.dumps(conversation.export_memory())

conn = None
try:
conn = psycopg2.connect(os.environ['PG_URL'])
cur = conn.cursor()

cur.execute("""
INSERT INTO slack_sessions (channel, thread, data)
VALUES (%s, %s, %s);
""",
(conversation.get_channel(), conversation.get_thread(), (json_export))
)

conn.commit()
cur.close()
except (Exception, psycopg2.DatabaseError) as error:
print("Failed to persist session: ", str(error))
finally:
if conn is not None:
conn.close()


def restore_session(client, channel, thread_ts):

conn = None
try:
conn = psycopg2.connect(os.environ['PG_URL'])
cur = conn.cursor()


cur.execute("""
SELECT id, data FROM slack_sessions WHERE channel=%s AND thread=%s;
""",
(channel, thread_ts)
)

row = cur.fetchone()
data = row[1]
cur.close()

if row is not None:
message_import = messages_from_dict(data)
message_history = ChatMessageHistory(messages=message_import)

restored_memory = AgentTokenBufferMemory(llm=agent_llm, chat_memory=message_history)

conversation = Conversation(
slack_client=client,
channel=channel,
thread_ts=thread_ts,
memory=restored_memory,
start_message="Trying to remember what we talked about ..."
)
return conversation
else:
return None

except (Exception, psycopg2.DatabaseError) as error:
print("Failed to recover session: ", str(error))
finally:
if conn is not None:
conn.close()
57 changes: 39 additions & 18 deletions slack_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from slack_bolt import App, Ack, Respond

from core.agent import agent_executor, agent_llm
from core.slack import Conversation
from core.slack import Conversation, save_session, restore_session

from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory,
)

from http.server import BaseHTTPRequestHandler, HTTPServer
from multiprocessing import Process
Expand Down Expand Up @@ -50,17 +54,23 @@ def retire_inactive_conversation():
conversation = ref["conversation"]
if(conversation.is_expired()):
if(conversation.current_state!='answered'):
conversation.retire()
handle_retirement(conversation)
active_conversations.remove(ref)
else:
print("Conversation is still active, keep for next cycle: ", str(conversation))

def handle_retirement(conversation):
# persist session
save_session(conversation)
# noity client
conversation.retire()


# This gets activated when the bot is tagged in a channel
# it will start a new thread that will hold the conversation
@app.event("app_mention")
def handle_message_events(body, logger):

thread_ts = body["event"].get("thread_ts")

if thread_ts:
Expand All @@ -76,7 +86,8 @@ def handle_message_events(body, logger):
conversation = Conversation(
slack_client=client,
channel=response_channel,
thread_ts=response_thread
thread_ts=response_thread,
memory=AgentTokenBufferMemory(llm=agent_llm)
)

with conversation_lock:
Expand All @@ -96,22 +107,33 @@ def handle_message_events(event, say):
if event.get("thread_ts"):
# within threads we listen to messages
print("handle message within thread")

response_channel = event.get("channel")
response_thread = event.get("thread_ts")

conversation = find_conversation(response_channel, response_thread)

if(conversation is None):
slack_response = client.chat_postMessage(
channel=response_channel,
thread_ts=response_thread,
text=f"Cannot find conversation ({response_thread}), is it expired?"

# any active conversation?
conversation = find_conversation(
channel=response_channel,
thread_ts=response_thread
)

else:

if(conversation is None):

conversation = restore_session(client, response_channel, response_thread)
if(conversation is not None):
with conversation_lock:
active_conversations.append({
"channel": response_channel,
"thread": str(response_thread),
"conversation": conversation
})
conversation.listening()

if(conversation is not None):
text = event.get('text')
conversation.inquire(text)
else:
print("Cannot find conversation")

else:
# outside thread we ingore messages
Expand Down Expand Up @@ -143,15 +165,14 @@ def run_healthcheck():
finally:
httpd.server_close()


healthcheck_process = Process(target=run_healthcheck, daemon=True)

# make sure conversation are retried when bot stops
def graceful_shutdown(signum, frame):
print("Shutdown bot ...")

# retire all active conversations
[ref["conversation"].retire() for ref in active_conversations]
[handle_retirement(ref["conversation"]) for ref in active_conversations]
time.sleep(3)

# stop the scheduler
Expand Down

0 comments on commit 459ebdc

Please sign in to comment.