昨天我們學會了用 List 管理對話歷史,讓 AI 終於能記住剛才說過的話。這個簡單而有效的方法,已經讓我們的 AI 應用從「健忘機器人」進化成「有記憶的助手」。
# 昨天學會的基本方法
conversation_history = []
conversation_history.append({"role": "user", "content": "我想學Python"})
conversation_history.append({"role": "assistant", "content": "好的!Python是很棒的入門語言"})
但在實際使用中,我們很快就會遇到一些挑戰:
List 方法的瓶頸
想像一下這個場景:
# 第100輪對話後...
conversation_history = [
{"role": "user", "content": "今天天氣真好"},
{"role": "assistant", "content": "是啊,很適合出門"},
# ... 98 輪各種閒聊 ...
{"role": "user", "content": "我叫小明,今年25歲,是軟體工程師"}, # 重要資訊埋在中間
{"role": "assistant", "content": "很高興認識您,小明!"},
{"role": "user", "content": "你還記得我的職業嗎?"} # AI 需要在100條記錄中找答案
]
當我們與 AI 對話時,最理想的狀態是什麼?就像與一位老朋友聊天一樣——他既記得剛才說的話,也記得你們多年來的點點滴滴。
這就需要 AI 擁有類似人類的記憶系統:短期記憶處理當下的對話,長期記憶儲存重要的歷史資訊。今天將在昨天 List 管理的基礎上,深入探討如何為 AI 對話系統設計智慧的記憶架構
人類的記憶系統可以簡化為三個層次:
感官記憶 → 短期記憶 → 長期記憶 (幾秒鐘) (幾分鐘) (永久保存)
短期記憶 (Working Memory)
長期記憶 (Long-term Memory)
對應AI 記憶系統的設計
# AI 記憶系統的核心概念
class AIMemorySystem:
def __init__(self):
# 短期記憶:當前對話上下文
self.working_memory = [] # 最近幾輪對話
# 長期記憶:重要的歷史資訊
self.long_term_memory = {
'user_profile': {}, # 用戶基本資訊
'preferences': {}, # 偏好設定
'important_facts': [], # 重要事實
'conversation_summaries': [] # 對話摘要
}
# 記憶管理器:決定什麼該記住、什麼該忘記
self.memory_manager = MemoryManager()
短期記憶負責處理當前對話的上下文,包括:
from collections import deque
from datetime import datetime, timedelta
class ShortTermMemory:
def __init__(self, max_turns=10, max_age_minutes=30):
self.max_turns = max_turns # 最多保留幾輪對話
self.max_age = timedelta(minutes=max_age_minutes) # 最長保留時間
self.conversations = deque(maxlen=max_turns * 2) # 使用雙端佇列
self.current_topic = None
self.task_context = {}
def add_message(self, role, content, topic=None):
"""新增訊息到短期記憶"""
message = {
'role': role,
'content': content,
'timestamp': datetime.now(),
'topic': topic or self.current_topic
}
self.conversations.append(message)
# 更新當前話題
if topic:
self.current_topic = topic
# 清理過期訊息
self._cleanup_expired_messages()
def _cleanup_expired_messages(self):
"""清理過期的訊息"""
cutoff_time = datetime.now() - self.max_age
# 從左側移除過期訊息
while (self.conversations and
self.conversations[0]['timestamp'] < cutoff_time):
self.conversations.popleft()
def get_recent_context(self, turns=None):
"""取得最近的對話上下文"""
if turns is None:
turns = self.max_turns
# 取得最近的對話,保持 user-assistant 配對
recent = list(self.conversations)[-(turns * 2):]
return [
{
"role": msg['role'],
"parts": [{"text": msg['content']}]
}
for msg in recent
]
def get_current_topic(self):
"""取得當前話題"""
return self.current_topic
def set_task_context(self, task_name, context):
"""設定任務上下文"""
self.task_context[task_name] = {
'context': context,
'timestamp': datetime.now()
}
def get_task_context(self, task_name):
"""取得任務上下文"""
return self.task_context.get(task_name, {}).get('context')
def clear_expired_tasks(self, max_age_hours=2):
"""清理過期的任務上下文"""
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
expired_tasks = [
task for task, data in self.task_context.items()
if data['timestamp'] < cutoff_time
]
for task in expired_tasks:
del self.task_context[task]
# 使用範例
stm = ShortTermMemory(max_turns=5, max_age_minutes=15)
# 模擬對話
stm.add_message("user", "我想學習 Python", topic="程式學習")
stm.add_message("assistant", "很好!Python 是很棒的入門語言")
stm.add_message("user", "應該從哪裡開始?")
stm.add_message("assistant", "建議從基本語法開始,然後練習小專案")
# 設定任務上下文
stm.set_task_context("learning_plan", {
"subject": "Python",
"level": "beginner",
"goal": "web development"
})
print("當前話題:", stm.get_current_topic())
print("最近對話:", stm.get_recent_context(turns=2))
長期記憶負責儲存持久且重要的資訊,包括:
import json
import sqlite3
from datetime import datetime
from typing import Dict, List, Any
class LongTermMemory:
def __init__(self, user_id: str, db_path: str = "long_term_memory.db"):
self.user_id = user_id
self.db_path = db_path
self.init_database()
# 記憶體快取,提升存取速度
self.cache = {
'user_profile': None,
'preferences': None,
'last_cache_update': None
}
def init_database(self):
"""初始化長期記憶資料庫"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 用戶基本資料表
cursor.execute('''
CREATE TABLE IF NOT EXISTS user_profiles (
user_id TEXT PRIMARY KEY,
profile_data TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
# 重要事實表
cursor.execute('''
CREATE TABLE IF NOT EXISTS important_facts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT,
fact_type TEXT,
content TEXT,
confidence REAL DEFAULT 1.0,
source TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
# 對話摘要表
cursor.execute('''
CREATE TABLE IF NOT EXISTS conversation_summaries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT,
summary TEXT,
topics TEXT,
date_range TEXT,
message_count INTEGER,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
# 學習進度表
cursor.execute('''
CREATE TABLE IF NOT EXISTS learning_progress (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT,
subject TEXT,
skill_level TEXT,
progress_data TEXT,
last_activity DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
conn.commit()
conn.close()
def get_user_profile(self) -> Dict[str, Any]:
"""取得用戶基本資料"""
# 檢查快取
if (self.cache['user_profile'] and
self.cache['last_cache_update'] and
(datetime.now() - self.cache['last_cache_update']).seconds < 300): # 5分鐘快取
return self.cache['user_profile']
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
SELECT profile_data FROM user_profiles WHERE user_id = ?
''', (self.user_id,))
result = cursor.fetchone()
conn.close()
if result:
profile = json.loads(result[0])
else:
profile = {
'name': None,
'age': None,
'interests': [],
'goals': [],
'communication_style': 'normal'
}
# 更新快取
self.cache['user_profile'] = profile
self.cache['last_cache_update'] = datetime.now()
return profile
def update_user_profile(self, profile_updates: Dict[str, Any]):
"""更新用戶基本資料"""
current_profile = self.get_user_profile()
current_profile.update(profile_updates)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO user_profiles
(user_id, profile_data, updated_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
''', (self.user_id, json.dumps(current_profile, ensure_ascii=False)))
conn.commit()
conn.close()
# 更新快取
self.cache['user_profile'] = current_profile
self.cache['last_cache_update'] = datetime.now()
def add_important_fact(self, fact_type: str, content: str,
confidence: float = 1.0, source: str = "conversation"):
"""新增重要事實"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO important_facts
(user_id, fact_type, content, confidence, source)
VALUES (?, ?, ?, ?, ?)
''', (self.user_id, fact_type, content, confidence, source))
conn.commit()
conn.close()
def get_important_facts(self, fact_type: str = None,
min_confidence: float = 0.7) -> List[Dict]:
"""取得重要事實"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
if fact_type:
cursor.execute('''
SELECT fact_type, content, confidence, source, created_at
FROM important_facts
WHERE user_id = ? AND fact_type = ? AND confidence >= ?
ORDER BY confidence DESC, created_at DESC
''', (self.user_id, fact_type, min_confidence))
else:
cursor.execute('''
SELECT fact_type, content, confidence, source, created_at
FROM important_facts
WHERE user_id = ? AND confidence >= ?
ORDER BY confidence DESC, created_at DESC
''', (self.user_id, min_confidence))
results = cursor.fetchall()
conn.close()
return [
{
'type': row[0],
'content': row[1],
'confidence': row[2],
'source': row[3],
'created_at': row[4]
}
for row in results
]
def save_conversation_summary(self, summary: str, topics: List[str],
date_range: str, message_count: int):
"""儲存對話摘要"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO conversation_summaries
(user_id, summary, topics, date_range, message_count)
VALUES (?, ?, ?, ?, ?)
''', (self.user_id, summary, json.dumps(topics), date_range, message_count))
conn.commit()
conn.close()
def get_recent_summaries(self, limit: int = 5) -> List[Dict]:
"""取得最近的對話摘要"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
SELECT summary, topics, date_range, message_count, created_at
FROM conversation_summaries
WHERE user_id = ?
ORDER BY created_at DESC
LIMIT ?
''', (self.user_id, limit))
results = cursor.fetchall()
conn.close()
return [
{
'summary': row[0],
'topics': json.loads(row[1]),
'date_range': row[2],
'message_count': row[3],
'created_at': row[4]
}
for row in results
]
def update_learning_progress(self, subject: str, skill_level: str,
progress_data: Dict[str, Any]):
"""更新學習進度"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO learning_progress
(user_id, subject, skill_level, progress_data, last_activity)
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
''', (self.user_id, subject, skill_level,
json.dumps(progress_data, ensure_ascii=False)))
conn.commit()
conn.close()
def get_learning_progress(self, subject: str = None) -> List[Dict]:
"""取得學習進度"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
if subject:
cursor.execute('''
SELECT subject, skill_level, progress_data, last_activity
FROM learning_progress
WHERE user_id = ? AND subject = ?
''', (self.user_id, subject))
else:
cursor.execute('''
SELECT subject, skill_level, progress_data, last_activity
FROM learning_progress
WHERE user_id = ?
ORDER BY last_activity DESC
''', (self.user_id,))
results = cursor.fetchall()
conn.close()
return [
{
'subject': row[0],
'skill_level': row[1],
'progress_data': json.loads(row[2]),
'last_activity': row[3]
}
for row in results
]
# 使用範例
ltm = LongTermMemory("alice")
# 更新用戶資料
ltm.update_user_profile({
'name': 'Alice',
'age': 28,
'interests': ['programming', 'AI', 'reading'],
'goals': ['learn Python', 'build AI chatbot']
})
# 新增重要事實
ltm.add_important_fact("preference", "喜歡簡潔的解釋", confidence=0.9)
ltm.add_important_fact("skill", "有基礎程式經驗", confidence=0.8)
# 儲存學習進度
ltm.update_learning_progress("Python", "beginner", {
'completed_topics': ['variables', 'functions'],
'current_topic': 'classes',
'difficulty_areas': ['object-oriented programming']
})
print("用戶資料:", ltm.get_user_profile())
print("重要事實:", ltm.get_important_facts())
print("學習進度:", ltm.get_learning_progress("Python"))