iT邦幫忙

2025 iThome 鐵人賽

DAY 20
0
佛心分享-IT 人自學之術

學習 LLM系列 第 20

Day20 做成 CLI 版本

  • 分享至 

  • xImage
  •  

一、把 script 存成 run_faq.py

%%bash
cat > run_faq.py <<'PY'
#!/usr/bin/env python3
# run_faq.py -- 簡易 CLI FAQ (RAG-style) 程式
# 使用方式:
#  互動模式: python run_faq.py
#  單句查詢: python run_faq.py --query "如何退貨?" --retriever chroma --gen-mode echo
#  參數說明: python run_faq.py --help


import os
import sys
import json
import time
import argparse
from typing import List, Dict


# ---- 載入必要套件 ----
try:
    import pandas as pd
    import numpy as np
    from sentence_transformers import SentenceTransformer
except Exception as e:
    print("請先安裝必要套件:pip install sentence-transformers pandas numpy")
    raise e


# chroma / faiss 為可選依賴,若沒有安裝會 fallback
try:
    import chromadb
    CHROMA_AVAILABLE = True
except:
    CHROMA_AVAILABLE = False


try:
    import faiss
    FAISS_AVAILABLE = True
except:
    FAISS_AVAILABLE = False


# optional OpenAI
try:
    import openai
    OPENAI_AVAILABLE = True
except:
    OPENAI_AVAILABLE = False


# optional transformers for local generation
try:
    from transformers import AutoTokenizer, AutoModelForCausalLM
    TRANSFORMERS_AVAILABLE = True
except:
    TRANSFORMERS_AVAILABLE = False


# -------------------------
# Helper: prepare FAQ df
# -------------------------
def load_or_create_faqs(path="faqs.csv"):
    if os.path.exists(path):
        df = pd.read_csv(path, encoding="utf-8-sig")
    else:
        data = [
            {"id":"q1","question":"如何申請退貨?","answer":"請於訂單頁點選退貨申請並上傳商品照片,客服將於 3 個工作天內處理。"},
            {"id":"q2","question":"運費如何計算?","answer":"單筆訂單滿 1000 元享免運,未滿則收取 60 元運費。"},
            {"id":"q3","question":"可以更改收件地址嗎?","answer":"若訂單尚未出貨,您可在會員中心修改收件地址。"},
            {"id":"q4","question":"付款方式有哪些?","answer":"我們支援信用卡、LINE Pay 與貨到付款。"},
            {"id":"q5","question":"商品多久可以到貨?","answer":"一般商品 3–5 個工作天內送達,偏遠地區約 7 天。"},
            {"id":"q6","question":"如何查詢訂單狀態?","answer":"請至會員中心 → 訂單查詢頁面,即可查看目前狀態。"},
            {"id":"q7","question":"發票會如何提供?","answer":"電子發票將寄送至您填寫的 Email,也可於會員中心下載。"},
            {"id":"q8","question":"商品有瑕疵怎麼辦?","answer":"請拍照後至客服中心填寫表單,我們將盡快處理換貨或退款。"},
            {"id":"q9","question":"有提供客服聯絡方式嗎?","answer":"您可透過線上客服或來電 0800-123-456 與我們聯繫。"},
            {"id":"q10","question":"如何使用優惠券?","answer":"在結帳頁面輸入優惠碼,系統會自動折抵。"}
        ]
        df = pd.DataFrame(data)
        df.to_csv(path, index=False, encoding="utf-8-sig")
    return df


# -------------------------
# Embedding model
# -------------------------
def init_embedder(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
    print("載入 embedder:", model_name)
    return SentenceTransformer(model_name)


# -------------------------
# Chroma retriever (if available)
# -------------------------
def init_chroma(collection_name, persist_path, df, embeddings, embedder):
    import chromadb
    from chromadb import PersistentClient


    # 使用 PersistentClient 取代舊的 chromadb.Client + persist()
    client = PersistentClient(path=persist_path)


    try:
        collection = client.get_collection(collection_name)
    except Exception:
        collection = client.create_collection(name=collection_name)


    # 準備資料
    ids_list = df["id"].astype(str).tolist()
    documents = df["question"].astype(str).tolist()
    metadatas = df[["question", "answer"]].to_dict(orient="records")
    emb_list = embeddings.tolist()


    # upsert 資料
    collection.upsert(
        ids=ids_list,
        documents=documents,
        metadatas=metadatas,
        embeddings=emb_list
    )


    return collection




def retrieve_chroma(collection, embedder, query, k=3):
    q_emb = embedder.encode([query], convert_to_numpy=True).tolist()
    res = collection.query(query_embeddings=q_emb, n_results=k, include=["documents","metadatas","distances"])
    docs = []
    for doc, meta, dist in zip(res.get("documents",[[]])[0], res.get("metadatas",[[]])[0], res.get("distances",[[]])[0]):
        docs.append({
            "id": meta.get("id"),
            "question": meta.get("question", doc),
            "answer": meta.get("answer", ""),
            "score": float(dist) if dist is not None else None
        })
    return docs


# -------------------------
# FAISS retriever (if available)
# -------------------------
def build_faiss_index(embeddings):
    if not FAISS_AVAILABLE:
        raise RuntimeError("FAISS 尚未安裝。請 pip install faiss-cpu")
    emb = embeddings.astype("float32")
    d = emb.shape[1]
    faiss.normalize_L2(emb)
    index = faiss.IndexFlatIP(d)
    index.add(emb)
    return index


def retrieve_faiss(index, embedder, df, ids, query, k=3):
    q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32")
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, k)
    results = []
    for score, idx in zip(D[0], I[0]):
        idx = int(idx)
        results.append({
            "id": ids[idx],
            "question": df.iloc[idx]["question"],
            "answer": df.iloc[idx]["answer"],
            "score": float(score)
        })
    return results


# -------------------------
# Generator (openai / local / echo)
# -------------------------
def generate_echo(docs, query):
    # 最簡單的生成:回傳最相關的答案(第一筆),並標註來源
    if not docs:
        return "查無資料", []
    top = docs[0]
    srcs = [d["id"] for d in docs if d.get("id")]
    answer = f"{top['answer']}\n\n[來源: {','.join(srcs)}]"
    return answer, srcs


def generate_openai(prompt, model="gpt-4o-mini", temperature=0.0, max_tokens=300):
    if not OPENAI_AVAILABLE:
        raise RuntimeError("openai 套件未安裝")
    key = os.getenv("OPENAI_API_KEY")
    if not key:
        raise RuntimeError("請先把 OPENAI_API_KEY 設為環境變數")
    openai.api_key = key
    # 使用 ChatCompletion(可改 model)
    resp = openai.ChatCompletion.create(
        model=model,
        messages=[{"role":"system","content":"你是中文客服助理;只使用提供的參考資料回答;若資料不足回答「查無資料」。"},
                  {"role":"user","content":prompt}],
        temperature=temperature,
        max_tokens=max_tokens
    )
    return resp["choices"][0]["message"]["content"].strip()


def generate_local(prompt, local_model_name="uer/gpt2-chinese-cluecorpussmall", max_new_tokens=200, temperature=0.0):
    if not TRANSFORMERS_AVAILABLE:
        raise RuntimeError("本地 transformers 未安裝,或未包含在環境中")
    # 載入模型
    tokenizer = AutoTokenizer.from_pretrained(local_model_name, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(local_model_name)
    device = "cuda" if __import__("torch").cuda.is_available() else "cpu"
    model.to(device)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    out = model.generate(**inputs,
                         max_new_tokens=max_new_tokens,
                         do_sample=(temperature>0.0),
                         temperature=temperature,
                         pad_token_id=tokenizer.eos_token_id)
    # 只 decode新產生的 token
    gen = tokenizer.decode(out[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
    return gen.strip()


# -------------------------
# build prompt 
# -------------------------
def build_prompt(query, docs, max_context_items=5):
    # 把檢索到的 docs(已排序)拼成 context
    pieces = []
    for d in docs[:max_context_items]:
        pieces.append(f"[{d.get('id','')}] Q: {d.get('question','')}\nA: {d.get('answer','')}")
    context = "\n\n".join(pieces)
    prompt = f"以下為系統檢索到的參考資料,請僅根據這些資料回答使用者問題(若資料不足請回答「查無資料」):\n\n{context}\n\n使用者問題:{query}\n\n請用繁體中文精簡回答,並在最後一行標註來源 ID(格式 [來源: id1,id2])。"
    return prompt


# -------------------------
# RAG pipeline function
# -------------------------
def rag_answer(query, retriever_type, embedder, df, collection=None, faiss_index=None, ids=None, gen_mode="echo", openai_model="gpt-4o-mini", local_model_name="uer/gpt2-chinese-cluecorpussmall", k=3, score_threshold=None):
    # retrieve
    if retriever_type == "chroma":
        if collection is None:
            raise RuntimeError("Chroma collection 未初始化")
        docs = retrieve_chroma(collection, embedder, query, k=k)
    elif retriever_type == "faiss":
        if faiss_index is None or ids is None:
            raise RuntimeError("FAISS index 或 ids 未提供")
        docs = retrieve_faiss(faiss_index, embedder, df, ids, query, k=k)
    else:
        raise ValueError("retriever_type 必須是 chroma 或 faiss")
    # optional threshold check (若有 score 而小於閾值)
    if score_threshold is not None and docs:
        top_score = docs[0].get("score")
        if top_score is not None:
            if top_score < score_threshold:
                return {"answer":"查無資料", "sources": [], "retrieved": docs}
    # build prompt
    prompt = build_prompt(query, docs)
    # generate
    if gen_mode == "echo":
        ans, srcs = generate_echo(docs, query)
    elif gen_mode == "openai":
        ans = generate_openai(prompt, model=openai_model, temperature=0.0)
        # pick sources heuristically
        srcs = [d["id"] for d in docs]
    elif gen_mode == "local":
        ans = generate_local(prompt, local_model_name=local_model_name)
        srcs = [d["id"] for d in docs]
    else:
        raise ValueError("gen_mode 必須是 echo / openai / local")
    return {"answer": ans, "sources": srcs, "retrieved": docs, "prompt": prompt}


# -------------------------
# CLI: main
# -------------------------
def main():
    parser = argparse.ArgumentParser(description="run_faq.py - CLI FAQ (simple RAG)")
    parser.add_argument("--retriever", choices=["chroma","faiss"], default="chroma", help="選擇檢索器 (chroma or faiss)")
    parser.add_argument("--gen-mode", choices=["echo","openai","local"], default="echo", help="生成模式: echo(回傳 top1 答案), openai, local")
    parser.add_argument("--query", type=str, default=None, help="直接在 CLI 傳入問題並輸出結果")
    parser.add_argument("--k", type=int, default=3, help="檢索 top-k")
    parser.add_argument("--score-threshold", type=float, default=None, help="相似度閾值 (若 top1 < 閾值 則回 查無資料)")
    parser.add_argument("--openai-model", type=str, default="gpt-4o-mini", help="OpenAI 模型名稱 (若使用 openai)")
    parser.add_argument("--local-model", type=str, default="uer/gpt2-chinese-cluecorpussmall", help="本地 transformer 模型名稱")
    parser.add_argument("--show-docs", action="store_true", help="是否顯示檢索到的 docs")
    parser.add_argument("--persist-path", type=str, default="./chroma_db", help="Chroma persist path")
    args = parser.parse_args()


    # load faqs
    df = load_or_create_faqs("faqs.csv")
    embedder = init_embedder()


    # load embeddings if exist, else generate
    emb_path = "faq_question_embeddings.npy"
    if os.path.exists(emb_path):
        embeddings = np.load(emb_path)
    else:
        embeddings = embedder.encode(df["question"].tolist(), convert_to_numpy=True).astype("float32")
        np.save(emb_path, embeddings)


    # prepare retriever
    collection = None
    faiss_index = None
    ids = df["id"].astype(str).tolist()


    if args.retriever == "chroma":
        if not CHROMA_AVAILABLE:
            print("Chroma 未安裝,請改用 --retriever faiss 或安裝 chromadb")
            sys.exit(1)
        collection = init_chroma(collection_name="faq_collection", persist_path=args.persist_path, df=df, embeddings=embeddings, embedder=embedder)
    else:
        if not FAISS_AVAILABLE:
            print("FAISS 未安裝,請 pip install faiss-cpu 或改用 --retriever chroma")
            sys.exit(1)
        faiss_index = build_faiss_index(embeddings)


    # single query mode
    if args.query:
        res = rag_answer(args.query, retriever_type=args.retriever, embedder=embedder, df=df, collection=collection, faiss_index=faiss_index, ids=ids, gen_mode=args.gen_mode, openai_model=args.openai_model, local_model_name=args.local_model, k=args.k, score_threshold=args.score_threshold)
        print("=== Answer ===")
        print(res["answer"])
        if args.show_docs:
            print("\n=== Retrieved ===")
            for r in res["retrieved"]:
                print(r)
        # log to file
        log_entry = {"time": time.time(), "query": args.query, "answer": res["answer"], "sources": res.get("sources", []), "retrieved": res.get("retrieved", [])}
        with open("faq_logs.jsonl","a",encoding="utf-8") as f:
            f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
        sys.exit(0)


    # interactive loop
    print("啟動互動模式(Ctrl+C 或輸入 exit 離開)")
    while True:
        try:
            q = input("\n你的問題:").strip()
            if q.lower() in ("exit","quit","q"):
                print("離開。")
                break
            if not q:
                continue
            res = rag_answer(q, retriever_type=args.retriever, embedder=embedder, df=df, collection=collection, faiss_index=faiss_index, ids=ids, gen_mode=args.gen_mode, openai_model=args.openai_model, local_model_name=args.local_model, k=args.k, score_threshold=args.score_threshold)
            print("\n>>> 回答:\n")
            print(res["answer"])
            if args.show_docs:
                print("\n>>> 檢索到的文件:")
                for r in res["retrieved"]:
                    print(r)
            # log
            log_entry = {"time": time.time(), "query": q, "answer": res["answer"], "sources": res.get("sources", []), "retrieved": res.get("retrieved", [])}
            with open("faq_logs.jsonl","a",encoding="utf-8") as f:
                f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
        except KeyboardInterrupt:
            print("使用者終止,離開。")
            break
        except Exception as e:
            print("發生錯誤:", e)
            import traceback; traceback.print_exc()
            # 不中斷 loop


if __name__ == "__main__":
    main()
PY

二、啟動互動模式

!python run_faq.py

https://ithelp.ithome.com.tw/upload/images/20251004/20169173tPllGVCVpH.png


上一篇
Day19 把檢索結果套進 LLM 做回答
系列文
學習 LLM20
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言