昨天只是先畫了流程圖,今天終於要動手把程式寫出來了。這次的程式算是我AI問答機器人的第一個雛型,功能很單純,就是先做摘要,再做翻譯。
不過老實說,寫起來比我想像中的還要麻煩...程式碼比之前長很多,中間也錯了相當多次,不過最後總算能順利跑起來!
下面我先把完整程式碼貼出來:
!pip -q install transformers==4.44.2 accelerate huggingface_hub regex
import os, re, torch, regex as re2
from google.colab import userdata
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import StoppingCriteria, StoppingCriteriaList
#Login via Colab Secrets
hf_token = userdata.get('HF_TOKEN')
if not hf_token:
raise RuntimeError("找不到 HF_TOKEN,請到左側 Secret 新增同名密鑰。")
login(token=hf_token)
MODEL_ID = "google/gemma-3-1b-it"
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto", trust_remote_code=True
)
def _apply_chat(messages):
"""用chat template產生input_ids與attention_mask。"""
input_ids = tok.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
attention_mask = torch.ones_like(input_ids, device=model.device) # 明確給 mask
return dict(input_ids=input_ids, attention_mask=attention_mask)
def _strip_roles_and_blank(text):
lines = [ln.strip() for ln in text.splitlines()]
lines = [ln for ln in lines if ln and not ln.lower().startswith(("user", "system", "assistant"))]
return lines
class ThreeSentenceStop(StoppingCriteria):
"""偵測新生成文本中句尾符號 . ? ! 的數量,達到3句就停止。"""
def __init__(self, tokenizer, prompt_len_ids: int):
super().__init__()
self.tok = tokenizer
self.prompt_len_ids = prompt_len_ids
def __call__(self, input_ids, scores, **kwargs):
gen_ids = input_ids[0, self.prompt_len_ids:]
text = self.tok.decode(gen_ids, skip_special_tokens=True)
end_marks = sum(text.count(x) for x in [".", "?", "!"])
return end_marks >= 3
def _generate_stop3(messages, max_new_tokens=256):
inputs = _apply_chat(messages)
stop_list = StoppingCriteriaList([ThreeSentenceStop(tok, inputs["input_ids"].shape[-1])])
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.eos_token_id,
do_sample=False,
no_repeat_ngram_size=3,
repetition_penalty=1.05,
stopping_criteria=stop_list,
)
new_tokens = out[0, inputs["input_ids"].shape[-1]:]
return tok.decode(new_tokens, skip_special_tokens=True)
def _generate_plain(messages, max_new_tokens=256, do_sample=False, temperature=0.2, top_p=0.9):
inputs = _apply_chat(messages)
gen_kwargs = dict(
max_new_tokens=max_new_tokens,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.eos_token_id,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
)
with torch.no_grad():
out = model.generate(**inputs, **gen_kwargs)
new_tokens = out[0, inputs["input_ids"].shape[-1]:]
return tok.decode(new_tokens, skip_special_tokens=True)
#Cleaning for summary
_word_re = re.compile(r"[A-Za-z]+(?:'[A-Za-z]+)?")
def _is_english_sentence(line, max_words=25):
# 僅允許ASCII英文和常見標點
if re.search(r"[^\x09\x0A\x0D\x20-\x7E]", line): #非ASCII
return False
if re.search(r"[*_#>|`]{2,}", line): #避免Markdown噪音
return False
words = _word_re.findall(line)
return 1 <= len(words) <= max_words
def clean_summary_3_lines(text):
lines = _strip_roles_and_blank(text)
# Step1:先挑出乾淨英文句
cleaned = []
for ln in lines:
ln2 = ln.strip()
ln2 = re.sub(r"[\*`_#>|]+.*$", "", ln2) #去掉尾端奇怪符號
if _is_english_sentence(ln2):
if not re.search(r"[.!?]$", ln2):
ln2 += "."
cleaned.append(ln2)
if len(cleaned) == 3:
break
# Step2:保底補齊到三句(從原始行挑、必要時補句號)
if len(cleaned) < 3:
for ln in lines:
if ln not in cleaned:
cand = ln.strip()
if not cand.endswith((".", "?", "!")):
cand += "."
cleaned.append(cand)
if len(cleaned) == 3:
break
return "\n".join(cleaned[:3])
#Cleaning for Chinese translation
_cjk_re = re2.compile(r"[\p{Script=Han}。,、!?;:「」『』()《》〈〉—…¥·.-﹔﹖﹗﹙﹚]")
def clean_zh_translation(raw_text):
lines = _strip_roles_and_blank(raw_text)
kept = []
for ln in lines:
cjk = len(_cjk_re.findall(ln))
ratio = cjk / max(len(ln), 1)
if ratio >= 0.35:
kept.append(ln.strip())
#若出現大量連續符號,提前停止收集
if re.search(r"(\*{3,}|_{3,}|`{3,}|\.{6,})", ln):
break
text = "\n".join(kept)
#收尾清理
text = re.sub(r"(\*{2,}|_{2,}|`{2,})+", "", text)
text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE)
text = re.sub(r"\n{3,}", "\n\n", text).strip()
return text
def gen_summary_3sent(text):
msgs = [{
"role": "user",
"content": (
"You are a concise summarization assistant.\n"
"Given the TEXT, produce EXACTLY three English sentences summarizing it.\n"
"Rules:\n"
"1) Three lines, one sentence per line.\n"
"2) Each sentence ≤ 25 words.\n"
"3) No headings, no extra text.\n\n"
f"TEXT:\n{text}\n\n"
"OUTPUT (three lines only):"
)
}]
raw = _generate_stop3(msgs, max_new_tokens=140)
return clean_summary_3_lines(raw)
def translate_to_zh(text):
msgs = [
{"role": "system", "content": "你是嚴謹的翻譯助手。只輸出譯文,不要任何解釋或前綴。"},
{"role": "user", "content": "請將以下英文完整翻譯為繁體中文,只輸出譯文本身:\n" + text}
]
raw = _generate_plain(msgs, max_new_tokens=800, do_sample=False) # ← 保持原邏輯
return clean_zh_translation(raw)
# ---------- Demo ----------
user_text = """Artificial Intelligence is transforming the field of medicine.
It can help in diagnosing diseases, analyzing medical images,
and even assisting in drug discovery.
This progress is driven by advances in machine learning and deep learning.
However, there are challenges such as data privacy and security.
"""
print("=== 三句摘要(英文) ===")
print(gen_summary_3sent(user_text))
print("\n=== 原文翻譯(繁中) ===")
print(translate_to_zh(user_text))
程式拆解說明
因為這次的程式比較長,所以我把它分成幾段來看,這樣會比較清楚。
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
hf_token = userdata.get('HF_TOKEN')
login(token=hf_token)
MODEL_ID = "google/gemma-3-1b-it"
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto", trust_remote_code=True
)
這段就是先登入Hugging Face,然後把Gemma模型載進來。因為我有把token存在Colab的Secret,所以這邊直接讀就好。
def _apply_chat(messages):
input_ids = tok.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
attention_mask = torch.ones_like(input_ids, device=model.device)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def _generate(messages, max_new_tokens=256, do_sample=False):
inputs = _apply_chat(messages)
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=max_new_tokens)
new_tokens = out[0, inputs["input_ids"].shape[-1]:]
return tok.decode(new_tokens, skip_special_tokens=True)
這邊是我自己包了一個通用的_generate函式。主要是讓模型能根據我給的prompt生成文字,而且加了attention_mask,不然會跳警告。
def gen_summary_3sent(text):
msgs = [{
"role": "user",
"content": (
"You are a concise summarization assistant.\n"
"Given the TEXT, produce EXACTLY three English sentences summarizing it.\n"
"Rules:\n"
"1) Three lines, one sentence per line.\n"
"2) Each sentence ≤ 25 words.\n"
"3) No headings, no extra text.\n\n"
f"TEXT:\n{text}\n\n"
"OUTPUT (three lines only):"
)
}]
raw = _generate(msgs, max_new_tokens=140, do_sample=False)
return raw
這段就是負責做三句英文摘要。我有特別限制三句,而且要求一行一句,避免模型亂接太多內容。
def translate_to_zh(text):
msgs = [
{"role": "system", "content": "你是嚴謹的翻譯助手。只輸出譯文,不要任何解釋或前綴。"},
{"role": "user", "content": "請將以下英文完整翻譯為繁體中文,只輸出譯文本身:\n" + text}
]
raw = _generate(msgs, max_new_tokens=800, do_sample=False)
return raw
這個就是翻譯的部分。我有特別寫system prompt,要求它只輸出翻譯,不要加多餘的解釋,這樣才會乾淨。
user_text = """Artificial Intelligence is transforming the field of medicine.
It can help in diagnosing diseases, analyzing medical images,
and even assisting in drug discovery.
This progress is driven by advances in machine learning and deep learning.
However, there are challenges such as data privacy and security.
"""
print("=== 三句摘要(英文) ===")
print(gen_summary_3sent(user_text))
print("\n=== 原文翻譯(繁中) ===")
print(translate_to_zh(user_text))
最後就是測試。這裡我拿了一段關於AI在醫學的文章,丟給模型做摘要跟翻譯。
執行結果
今天花的時間比想像的多,尤其是中間模型一直亂接龍、或是翻譯尾巴會冒出奇怪的字,Debug了好一陣子。
不過最後能把摘要和翻譯順利組合起來,真的有一種終於看到雛型的成就感。