Spaces:
Running
Running
File size: 3,388 Bytes
30f67dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
# memory_store.py
import sqlite3
import os
import logging
from typing import List, Tuple
DB_PATH = os.getenv("MEMORY_DB", "chat_memory.db")
MAX_MESSAGES_PER_USER = int(os.getenv("MAX_MESSAGES_PER_USER", 500))
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def _get_conn():
# check_same_thread=False so Gradio threads can use the DB concurrently
return sqlite3.connect(DB_PATH, timeout=10, check_same_thread=False)
def init_db():
conn = _get_conn()
try:
with conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS memory (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT,
role TEXT,
message TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
finally:
conn.close()
def save_message(user_id: str, role: str, message: str) -> None:
if not user_id:
raise ValueError("user_id is required")
conn = _get_conn()
try:
with conn:
conn.execute(
"INSERT INTO memory (user_id, role, message) VALUES (?, ?, ?)",
(user_id, role, message),
)
# prune if too many
if MAX_MESSAGES_PER_USER and MAX_MESSAGES_PER_USER > 0:
cur = conn.execute(
"SELECT id FROM memory WHERE user_id = ? ORDER BY id DESC",
(user_id,),
)
rows = cur.fetchall()
if len(rows) > MAX_MESSAGES_PER_USER:
ids_to_delete = [r[0] for r in rows[MAX_MESSAGES_PER_USER:]]
conn.executemany("DELETE FROM memory WHERE id = ?", [(i,) for i in ids_to_delete])
except Exception:
logger.exception("Failed to save message for user %s", user_id)
raise
finally:
conn.close()
def get_last_messages(user_id: str, limit: int = 200) -> List[Tuple[str, str, str]]:
"""
Return last `limit` messages in chronological order as (role, message, created_at)
"""
conn = _get_conn()
try:
cur = conn.cursor()
cur.execute(
"""
SELECT role, message, created_at FROM memory
WHERE user_id = ?
ORDER BY id DESC
LIMIT ?
""",
(user_id, limit),
)
rows = cur.fetchall()
return list(reversed(rows))
except Exception:
logger.exception("Failed to fetch messages for user %s", user_id)
return []
finally:
conn.close()
def clear_user_memory(user_id: str) -> int:
"""Delete memory for user. Returns deleted rowcount."""
conn = _get_conn()
try:
with conn:
cur = conn.execute("DELETE FROM memory WHERE user_id = ?", (user_id,))
return cur.rowcount
except Exception:
logger.exception("Failed to clear memory for user %s", user_id)
raise
finally:
conn.close()
def build_gradio_history(user_id: str) -> List[dict]:
"""
Return history formatted for gr.Chatbot with type='messages':
A chronological list of dicts: {'role':'user'|'assistant','content': '...'}
"""
rows = get_last_messages(user_id, limit=500)
return [{"role": r[0], "content": r[1]} for r in rows]
|