Spaces:
Running
Running
| # memory_store.py | |
| import sqlite3 | |
| import os | |
| import logging | |
| from typing import List, Tuple | |
| DB_PATH = os.getenv("MEMORY_DB", "chat_memory.db") | |
| MAX_MESSAGES_PER_USER = int(os.getenv("MAX_MESSAGES_PER_USER", 500)) | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| def _get_conn(): | |
| # check_same_thread=False so Gradio threads can use the DB concurrently | |
| return sqlite3.connect(DB_PATH, timeout=10, check_same_thread=False) | |
| def init_db(): | |
| conn = _get_conn() | |
| try: | |
| with conn: | |
| conn.execute( | |
| """ | |
| CREATE TABLE IF NOT EXISTS memory ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id TEXT, | |
| role TEXT, | |
| message TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """ | |
| ) | |
| finally: | |
| conn.close() | |
| def save_message(user_id: str, role: str, message: str) -> None: | |
| if not user_id: | |
| raise ValueError("user_id is required") | |
| conn = _get_conn() | |
| try: | |
| with conn: | |
| conn.execute( | |
| "INSERT INTO memory (user_id, role, message) VALUES (?, ?, ?)", | |
| (user_id, role, message), | |
| ) | |
| # prune if too many | |
| if MAX_MESSAGES_PER_USER and MAX_MESSAGES_PER_USER > 0: | |
| cur = conn.execute( | |
| "SELECT id FROM memory WHERE user_id = ? ORDER BY id DESC", | |
| (user_id,), | |
| ) | |
| rows = cur.fetchall() | |
| if len(rows) > MAX_MESSAGES_PER_USER: | |
| ids_to_delete = [r[0] for r in rows[MAX_MESSAGES_PER_USER:]] | |
| conn.executemany("DELETE FROM memory WHERE id = ?", [(i,) for i in ids_to_delete]) | |
| except Exception: | |
| logger.exception("Failed to save message for user %s", user_id) | |
| raise | |
| finally: | |
| conn.close() | |
| def get_last_messages(user_id: str, limit: int = 200) -> List[Tuple[str, str, str]]: | |
| """ | |
| Return last `limit` messages in chronological order as (role, message, created_at) | |
| """ | |
| conn = _get_conn() | |
| try: | |
| cur = conn.cursor() | |
| cur.execute( | |
| """ | |
| SELECT role, message, created_at FROM memory | |
| WHERE user_id = ? | |
| ORDER BY id DESC | |
| LIMIT ? | |
| """, | |
| (user_id, limit), | |
| ) | |
| rows = cur.fetchall() | |
| return list(reversed(rows)) | |
| except Exception: | |
| logger.exception("Failed to fetch messages for user %s", user_id) | |
| return [] | |
| finally: | |
| conn.close() | |
| def clear_user_memory(user_id: str) -> int: | |
| """Delete memory for user. Returns deleted rowcount.""" | |
| conn = _get_conn() | |
| try: | |
| with conn: | |
| cur = conn.execute("DELETE FROM memory WHERE user_id = ?", (user_id,)) | |
| return cur.rowcount | |
| except Exception: | |
| logger.exception("Failed to clear memory for user %s", user_id) | |
| raise | |
| finally: | |
| conn.close() | |
| def build_gradio_history(user_id: str) -> List[dict]: | |
| """ | |
| Return history formatted for gr.Chatbot with type='messages': | |
| A chronological list of dicts: {'role':'user'|'assistant','content': '...'} | |
| """ | |
| rows = get_last_messages(user_id, limit=500) | |
| return [{"role": r[0], "content": r[1]} for r in rows] | |