File size: 3,388 Bytes
30f67dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# memory_store.py
import sqlite3
import os
import logging
from typing import List, Tuple

DB_PATH = os.getenv("MEMORY_DB", "chat_memory.db")
MAX_MESSAGES_PER_USER = int(os.getenv("MAX_MESSAGES_PER_USER", 500))

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def _get_conn():
    # check_same_thread=False so Gradio threads can use the DB concurrently
    return sqlite3.connect(DB_PATH, timeout=10, check_same_thread=False)


def init_db():
    conn = _get_conn()
    try:
        with conn:
            conn.execute(
                """
                CREATE TABLE IF NOT EXISTS memory (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    user_id TEXT,
                    role TEXT,
                    message TEXT,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
                """
            )
    finally:
        conn.close()


def save_message(user_id: str, role: str, message: str) -> None:
    if not user_id:
        raise ValueError("user_id is required")
    conn = _get_conn()
    try:
        with conn:
            conn.execute(
                "INSERT INTO memory (user_id, role, message) VALUES (?, ?, ?)",
                (user_id, role, message),
            )
            # prune if too many
            if MAX_MESSAGES_PER_USER and MAX_MESSAGES_PER_USER > 0:
                cur = conn.execute(
                    "SELECT id FROM memory WHERE user_id = ? ORDER BY id DESC",
                    (user_id,),
                )
                rows = cur.fetchall()
                if len(rows) > MAX_MESSAGES_PER_USER:
                    ids_to_delete = [r[0] for r in rows[MAX_MESSAGES_PER_USER:]]
                    conn.executemany("DELETE FROM memory WHERE id = ?", [(i,) for i in ids_to_delete])
    except Exception:
        logger.exception("Failed to save message for user %s", user_id)
        raise
    finally:
        conn.close()


def get_last_messages(user_id: str, limit: int = 200) -> List[Tuple[str, str, str]]:
    """
    Return last `limit` messages in chronological order as (role, message, created_at)
    """
    conn = _get_conn()
    try:
        cur = conn.cursor()
        cur.execute(
            """
            SELECT role, message, created_at FROM memory
            WHERE user_id = ?
            ORDER BY id DESC
            LIMIT ?
            """,
            (user_id, limit),
        )
        rows = cur.fetchall()
        return list(reversed(rows))
    except Exception:
        logger.exception("Failed to fetch messages for user %s", user_id)
        return []
    finally:
        conn.close()


def clear_user_memory(user_id: str) -> int:
    """Delete memory for user. Returns deleted rowcount."""
    conn = _get_conn()
    try:
        with conn:
            cur = conn.execute("DELETE FROM memory WHERE user_id = ?", (user_id,))
            return cur.rowcount
    except Exception:
        logger.exception("Failed to clear memory for user %s", user_id)
        raise
    finally:
        conn.close()


def build_gradio_history(user_id: str) -> List[dict]:
    """
    Return history formatted for gr.Chatbot with type='messages':
    A chronological list of dicts: {'role':'user'|'assistant','content': '...'}
    """
    rows = get_last_messages(user_id, limit=500)
    return [{"role": r[0], "content": r[1]} for r in rows]