-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add SQLAlchemy session backend for conversation history management #1357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The new
The Temporary Solution As a temporary measure to allow the feature to be merged without breaking the build, the following changes have been made:
Next Steps This solution is a trade-off. It prevents the CI from breaking, but it means that the I think the correct long-term fix is to update Let me know what you think. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At a glance, this looks good to me; @rm-openai any thoughts?
@habema could you resolve the conflict? |
@seratch you got it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this enhancement looks good to me; great work 👍
@habema Thanks for working on this. Overall, this PR looks good to me, but I don't have the bandwidth to do thorough testing with it before merging. I will take a look this week or early next week, but until then, please feel free to make further adjustments and/or adding unit tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! Overall, it already looks great, but I have a few minor suggestions. Let me know what you think!
diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py
index 038f845..20f0555 100644
--- a/src/agents/extensions/memory/sqlalchemy_session.py
+++ b/src/agents/extensions/memory/sqlalchemy_session.py
@@ -31,6 +31,7 @@ from sqlalchemy import (
TIMESTAMP,
Column,
ForeignKey,
+ Index,
Integer,
MetaData,
String,
@@ -122,18 +123,10 @@ class SQLAlchemySession(SessionABC):
server_default=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
+ Index(f"idx_{messages_table}_session_time", "session_id", "created_at"),
sqlite_autoincrement=True,
)
- # Index for efficient retrieval of messages per session ordered by time
- from sqlalchemy import Index
-
- Index(
- f"idx_{messages_table}_session_time",
- self._messages.c.session_id,
- self._messages.c.created_at,
- )
-
# Async session factory
self._session_factory = async_sessionmaker(
self._engine, expire_on_commit=False
@@ -180,21 +173,28 @@ class SQLAlchemySession(SessionABC):
await conn.run_sync(self._metadata.create_all)
self._create_tables = False # Only create once
+ async def _serialize_message_data(self, item: TResponseInputItem) -> str:
+ return json.dumps(item, separators=(",", ":"))
+
+ async def _deserialize_message_data(self, item: str) -> TResponseInputItem:
+ return json.loads(item)
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
await self._ensure_tables()
async with self._session_factory() as sess:
- if limit is None:
+ if limit is not None:
stmt = (
select(self._messages.c.message_data)
.where(self._messages.c.session_id == self.session_id)
- .order_by(self._messages.c.created_at.asc())
+ # pick up the top-n new messages here, then reverse the order below
+ .order_by(self._messages.c.created_at.desc())
+ .limit(limit)
)
else:
stmt = (
select(self._messages.c.message_data)
.where(self._messages.c.session_id == self.session_id)
- .order_by(self._messages.c.created_at.desc())
- .limit(limit)
+ .order_by(self._messages.c.created_at.asc())
)
result = await sess.execute(stmt)
@@ -206,7 +206,7 @@ class SQLAlchemySession(SessionABC):
items: list[TResponseInputItem] = []
for raw in rows:
try:
- items.append(json.loads(raw))
+ items.append(await self._deserialize_message_data(raw))
except json.JSONDecodeError:
# Skip corrupted rows
continue
@@ -220,7 +220,7 @@ class SQLAlchemySession(SessionABC):
payload = [
{
"session_id": self.session_id,
- "message_data": json.dumps(item, separators=(",", ":")),
+ "message_data": await self._serialize_message_data(item),
}
for item in items
]
# Index for efficient retrieval of messages per session ordered by time | ||
from sqlalchemy import Index | ||
|
||
Index( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this to the table definition? The following code should work.
self._messages = Table(
messages_table,
self._metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("session_id", String, ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"), nullable=False),
Column("message_data", Text, nullable=False),
Column("created_at", TIMESTAMP(timezone=False), server_default=sql_text("CURRENT_TIMESTAMP"), nullable=False),
Index(f"idx_{messages_table}_session_time", "session_id", "created_at"),
sqlite_autoincrement=True,
)
stmt = ( | ||
select(self._messages.c.message_data) | ||
.where(self._messages.c.session_id == self.session_id) | ||
.order_by(self._messages.c.created_at.desc()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic looks good to me, but can you add a quick comment to the code explaining why desc here and we do rows.reverse() later?
payload = [ | ||
{ | ||
"session_id": self.session_id, | ||
"message_data": json.dumps(item, separators=(",", ":")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to have a method to do this; with that, users can enhance this part just by inheriting this class.
@@ -0,0 +1,153 @@ | |||
from __future__ import annotations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for adding tests!
This PR introduces
SQLAlchemySession
, providing a production-grade session storage backend that can connect to any database supported by SQLAlchemy (e.g., PostgreSQL, MySQL).This implementation is based on the discussion in Issue #1328 and follows the agreed-upon architectural pattern:
src/agents/extensions/memory/
.[sqlalchemy]
extra.A test has been added.
Resolves #1328.