iT邦幫忙

2025 iThome 鐵人賽

DAY 18
0
AI & Data

來都來了,那就做一個GCP從0到100的AI助理系列 第 18

長期記憶與短期記憶:打造 AI 的智慧記憶系統 - 1

  • 分享至 

  • xImage
  •  

昨天我們學會了用 List 管理對話歷史,讓 AI 終於能記住剛才說過的話。這個簡單而有效的方法,已經讓我們的 AI 應用從「健忘機器人」進化成「有記憶的助手」。

# 昨天學會的基本方法
conversation_history = []
conversation_history.append({"role": "user", "content": "我想學Python"})
conversation_history.append({"role": "assistant", "content": "好的!Python是很棒的入門語言"})

但在實際使用中,我們很快就會遇到一些挑戰:

List 方法的瓶頸

  • 對話越來越長,記憶體消耗不斷增加
  • 重要資訊和閒聊內容混在一起
  • 程式重啟後,所有記憶都消失了
  • 無法區分「剛才說的話」和「重要的個人資訊」

想像一下這個場景:

# 第100輪對話後...
conversation_history = [
    {"role": "user", "content": "今天天氣真好"},
    {"role": "assistant", "content": "是啊,很適合出門"},
    # ... 98 輪各種閒聊 ...
    {"role": "user", "content": "我叫小明,今年25歲,是軟體工程師"},  # 重要資訊埋在中間
    {"role": "assistant", "content": "很高興認識您,小明!"},
    {"role": "user", "content": "你還記得我的職業嗎?"}  # AI 需要在100條記錄中找答案
]

當我們與 AI 對話時,最理想的狀態是什麼?就像與一位老朋友聊天一樣——他既記得剛才說的話,也記得你們多年來的點點滴滴。

這就需要 AI 擁有類似人類的記憶系統:短期記憶處理當下的對話,長期記憶儲存重要的歷史資訊。今天將在昨天 List 管理的基礎上,深入探討如何為 AI 對話系統設計智慧的記憶架構

理解記憶的運作原理

人類的記憶系統可以簡化為三個層次:

感官記憶 → 短期記憶 → 長期記憶 (幾秒鐘) (幾分鐘) (永久保存)

短期記憶 (Working Memory)

  • 容量有限:大約能記住 7±2 個項目
  • 時間短暫:通常只能維持 15-30 秒
  • 主要功能:處理當前任務的資訊

長期記憶 (Long-term Memory)

  • 容量幾乎無限
  • 可以永久保存
  • 分為不同類型:事實記憶、程序記憶、情節記憶等

對應AI 記憶系統的設計

# AI 記憶系統的核心概念
class AIMemorySystem:
    def __init__(self):
        # 短期記憶:當前對話上下文
        self.working_memory = []  # 最近幾輪對話
        
        # 長期記憶:重要的歷史資訊
        self.long_term_memory = {
            'user_profile': {},      # 用戶基本資訊
            'preferences': {},       # 偏好設定
            'important_facts': [],   # 重要事實
            'conversation_summaries': []  # 對話摘要
        }
        
        # 記憶管理器:決定什麼該記住、什麼該忘記
        self.memory_manager = MemoryManager()

短期記憶:專注當下的對話

什麼是短期記憶?

短期記憶負責處理當前對話的上下文,包括:

  • 最近 5-10 輪的對話
  • 當前話題的相關資訊
  • 正在進行的任務狀態

簡易實作短期記憶系統

from collections import deque
from datetime import datetime, timedelta

class ShortTermMemory:
    def __init__(self, max_turns=10, max_age_minutes=30):
        self.max_turns = max_turns  # 最多保留幾輪對話
        self.max_age = timedelta(minutes=max_age_minutes)  # 最長保留時間
        self.conversations = deque(maxlen=max_turns * 2)  # 使用雙端佇列
        self.current_topic = None
        self.task_context = {}
    
    def add_message(self, role, content, topic=None):
        """新增訊息到短期記憶"""
        message = {
            'role': role,
            'content': content,
            'timestamp': datetime.now(),
            'topic': topic or self.current_topic
        }
        
        self.conversations.append(message)
        
        # 更新當前話題
        if topic:
            self.current_topic = topic
        
        # 清理過期訊息
        self._cleanup_expired_messages()
    
    def _cleanup_expired_messages(self):
        """清理過期的訊息"""
        cutoff_time = datetime.now() - self.max_age
        
        # 從左側移除過期訊息
        while (self.conversations and 
               self.conversations[0]['timestamp'] < cutoff_time):
            self.conversations.popleft()
    
    def get_recent_context(self, turns=None):
        """取得最近的對話上下文"""
        if turns is None:
            turns = self.max_turns
        
        # 取得最近的對話,保持 user-assistant 配對
        recent = list(self.conversations)[-(turns * 2):]
        
        return [
            {
                "role": msg['role'],
                "parts": [{"text": msg['content']}]
            }
            for msg in recent
        ]
    
    def get_current_topic(self):
        """取得當前話題"""
        return self.current_topic
    
    def set_task_context(self, task_name, context):
        """設定任務上下文"""
        self.task_context[task_name] = {
            'context': context,
            'timestamp': datetime.now()
        }
    
    def get_task_context(self, task_name):
        """取得任務上下文"""
        return self.task_context.get(task_name, {}).get('context')
    
    def clear_expired_tasks(self, max_age_hours=2):
        """清理過期的任務上下文"""
        cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
        
        expired_tasks = [
            task for task, data in self.task_context.items()
            if data['timestamp'] < cutoff_time
        ]
        
        for task in expired_tasks:
            del self.task_context[task]

# 使用範例
stm = ShortTermMemory(max_turns=5, max_age_minutes=15)

# 模擬對話
stm.add_message("user", "我想學習 Python", topic="程式學習")
stm.add_message("assistant", "很好!Python 是很棒的入門語言")
stm.add_message("user", "應該從哪裡開始?")
stm.add_message("assistant", "建議從基本語法開始,然後練習小專案")

# 設定任務上下文
stm.set_task_context("learning_plan", {
    "subject": "Python",
    "level": "beginner",
    "goal": "web development"
})

print("當前話題:", stm.get_current_topic())
print("最近對話:", stm.get_recent_context(turns=2))

長期記憶:保存重要的歷史資訊

什麼是長期記憶?

長期記憶負責儲存持久且重要的資訊,包括:

  • 用戶的基本資料和偏好
  • 重要的對話片段和決定
  • 學習到的知識和模式
  • 長期的目標和計畫

實作長期記憶

import json
import sqlite3
from datetime import datetime
from typing import Dict, List, Any

class LongTermMemory:
    def __init__(self, user_id: str, db_path: str = "long_term_memory.db"):
        self.user_id = user_id
        self.db_path = db_path
        self.init_database()
        
        # 記憶體快取,提升存取速度
        self.cache = {
            'user_profile': None,
            'preferences': None,
            'last_cache_update': None
        }
    
    def init_database(self):
        """初始化長期記憶資料庫"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        # 用戶基本資料表
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS user_profiles (
                user_id TEXT PRIMARY KEY,
                profile_data TEXT,
                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
                updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
            )
        ''')
        
        # 重要事實表
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS important_facts (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                user_id TEXT,
                fact_type TEXT,
                content TEXT,
                confidence REAL DEFAULT 1.0,
                source TEXT,
                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
                updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
            )
        ''')
        
        # 對話摘要表
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS conversation_summaries (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                user_id TEXT,
                summary TEXT,
                topics TEXT,
                date_range TEXT,
                message_count INTEGER,
                created_at DATETIME DEFAULT CURRENT_TIMESTAMP
            )
        ''')
        
        # 學習進度表
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS learning_progress (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                user_id TEXT,
                subject TEXT,
                skill_level TEXT,
                progress_data TEXT,
                last_activity DATETIME DEFAULT CURRENT_TIMESTAMP
            )
        ''')
        
        conn.commit()
        conn.close()
    
    def get_user_profile(self) -> Dict[str, Any]:
        """取得用戶基本資料"""
        # 檢查快取
        if (self.cache['user_profile'] and 
            self.cache['last_cache_update'] and
            (datetime.now() - self.cache['last_cache_update']).seconds < 300):  # 5分鐘快取
            return self.cache['user_profile']
        
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute('''
            SELECT profile_data FROM user_profiles WHERE user_id = ?
        ''', (self.user_id,))
        
        result = cursor.fetchone()
        conn.close()
        
        if result:
            profile = json.loads(result[0])
        else:
            profile = {
                'name': None,
                'age': None,
                'interests': [],
                'goals': [],
                'communication_style': 'normal'
            }
        
        # 更新快取
        self.cache['user_profile'] = profile
        self.cache['last_cache_update'] = datetime.now()
        
        return profile
    
    def update_user_profile(self, profile_updates: Dict[str, Any]):
        """更新用戶基本資料"""
        current_profile = self.get_user_profile()
        current_profile.update(profile_updates)
        
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute('''
            INSERT OR REPLACE INTO user_profiles 
            (user_id, profile_data, updated_at)
            VALUES (?, ?, CURRENT_TIMESTAMP)
        ''', (self.user_id, json.dumps(current_profile, ensure_ascii=False)))
        
        conn.commit()
        conn.close()
        
        # 更新快取
        self.cache['user_profile'] = current_profile
        self.cache['last_cache_update'] = datetime.now()
    
    def add_important_fact(self, fact_type: str, content: str, 
                          confidence: float = 1.0, source: str = "conversation"):
        """新增重要事實"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute('''
            INSERT INTO important_facts 
            (user_id, fact_type, content, confidence, source)
            VALUES (?, ?, ?, ?, ?)
        ''', (self.user_id, fact_type, content, confidence, source))
        
        conn.commit()
        conn.close()
    
    def get_important_facts(self, fact_type: str = None, 
                           min_confidence: float = 0.7) -> List[Dict]:
        """取得重要事實"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        if fact_type:
            cursor.execute('''
                SELECT fact_type, content, confidence, source, created_at
                FROM important_facts 
                WHERE user_id = ? AND fact_type = ? AND confidence >= ?
                ORDER BY confidence DESC, created_at DESC
            ''', (self.user_id, fact_type, min_confidence))
        else:
            cursor.execute('''
                SELECT fact_type, content, confidence, source, created_at
                FROM important_facts 
                WHERE user_id = ? AND confidence >= ?
                ORDER BY confidence DESC, created_at DESC
            ''', (self.user_id, min_confidence))
        
        results = cursor.fetchall()
        conn.close()
        
        return [
            {
                'type': row[0],
                'content': row[1],
                'confidence': row[2],
                'source': row[3],
                'created_at': row[4]
            }
            for row in results
        ]
    
    def save_conversation_summary(self, summary: str, topics: List[str], 
                                 date_range: str, message_count: int):
        """儲存對話摘要"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute('''
            INSERT INTO conversation_summaries 
            (user_id, summary, topics, date_range, message_count)
            VALUES (?, ?, ?, ?, ?)
        ''', (self.user_id, summary, json.dumps(topics), date_range, message_count))
        
        conn.commit()
        conn.close()
    
    def get_recent_summaries(self, limit: int = 5) -> List[Dict]:
        """取得最近的對話摘要"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute('''
            SELECT summary, topics, date_range, message_count, created_at
            FROM conversation_summaries 
            WHERE user_id = ?
            ORDER BY created_at DESC
            LIMIT ?
        ''', (self.user_id, limit))
        
        results = cursor.fetchall()
        conn.close()
        
        return [
            {
                'summary': row[0],
                'topics': json.loads(row[1]),
                'date_range': row[2],
                'message_count': row[3],
                'created_at': row[4]
            }
            for row in results
        ]
    
    def update_learning_progress(self, subject: str, skill_level: str, 
                               progress_data: Dict[str, Any]):
        """更新學習進度"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        cursor.execute('''
            INSERT OR REPLACE INTO learning_progress 
            (user_id, subject, skill_level, progress_data, last_activity)
            VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
        ''', (self.user_id, subject, skill_level, 
              json.dumps(progress_data, ensure_ascii=False)))
        
        conn.commit()
        conn.close()
    
    def get_learning_progress(self, subject: str = None) -> List[Dict]:
        """取得學習進度"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        if subject:
            cursor.execute('''
                SELECT subject, skill_level, progress_data, last_activity
                FROM learning_progress 
                WHERE user_id = ? AND subject = ?
            ''', (self.user_id, subject))
        else:
            cursor.execute('''
                SELECT subject, skill_level, progress_data, last_activity
                FROM learning_progress 
                WHERE user_id = ?
                ORDER BY last_activity DESC
            ''', (self.user_id,))
        
        results = cursor.fetchall()
        conn.close()
        
        return [
            {
                'subject': row[0],
                'skill_level': row[1],
                'progress_data': json.loads(row[2]),
                'last_activity': row[3]
            }
            for row in results
        ]

# 使用範例
ltm = LongTermMemory("alice")

# 更新用戶資料
ltm.update_user_profile({
    'name': 'Alice',
    'age': 28,
    'interests': ['programming', 'AI', 'reading'],
    'goals': ['learn Python', 'build AI chatbot']
})

# 新增重要事實
ltm.add_important_fact("preference", "喜歡簡潔的解釋", confidence=0.9)
ltm.add_important_fact("skill", "有基礎程式經驗", confidence=0.8)

# 儲存學習進度
ltm.update_learning_progress("Python", "beginner", {
    'completed_topics': ['variables', 'functions'],
    'current_topic': 'classes',
    'difficulty_areas': ['object-oriented programming']
})

print("用戶資料:", ltm.get_user_profile())
print("重要事實:", ltm.get_important_facts())
print("學習進度:", ltm.get_learning_progress("Python"))

上一篇
簡單的上下文管理:用 List 打造記憶型對話系統
系列文
來都來了,那就做一個GCP從0到100的AI助理18
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言