iT邦幫忙

2024 iThome 鐵人賽

DAY 9
0
自我挑戰組

ELI5 for Generative AI and Software Development系列 第 9

ELI5 GenAI Day 09 - ShieldGemma usage and fine-tune

  • 分享至 

  • xImage
  •  

https://ithelp.ithome.com.tw/upload/images/20240828/20091643MVbakqZ4Lm.jpg

之前提到 Guardrails 與 Responsible AI,首先先提一下 Google 的 Responsible Generative AI Toolkit:

Responsible Generative AI Toolkit

列舉了幾個設計原則: 分為 Application Design, Saftey Alignment, Model Evaluation, Safeguard。
相關網站: https://ai.google.dev/responsible
有機會會分別文章介紹。

此篇會針對 Safeguard,主要是 ShieldGemma 的介紹。

ShieldGemma

ShieldGemma 其實是一種分類器 (Classifier),主要是用來判斷輸入的內容是否符合一定的規定,例如是否有不當的內容、是否有不當的行為等等。詳細文件可見這裡

使用方式 1: 透過 Hugging Face 載入

主要有 KerasNLP 與 Hugging Face Transformers 兩種方式,這裡以 Hugging Face Transformers 為例。 link

這邊提出實際使用的兩種操作方式:

  1. Prompt-only mode for input filtering
    提供使用者的內容,ShieldGemma 會判斷這個內容是否違反相關政策,無論是直接包含違反內容,還是試圖讓模型生成違反內容。
  2. Prompt-response mode for output filtering.
    提供使用者的內容與模型生成的內容,ShieldGemma 會判斷這個內容是否違反相關政策。
  • Prediction modes: 可以視為一種 Scoring, 範圍是從 0 ~ 1

範例說明

首先,透過 Hugging Face 取得 ShieldGemma 的模型,並且設定好 tokenizer 與模型。並設定好 predict 的 helper function。

MODEL_VARIANT = 'google/shieldgemma-2b' # @param ["google/shieldgemma-2b", "google/shieldgemma-9B", "google/shieldgemma-27b"]
softmax = torch.nn.Softmax(dim=0)

# Initialize a model instance
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_VARIANT)
shieldgemma = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_VARIANT,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

YES_TOKEN_IDX = tokenizer.convert_tokens_to_ids("Yes")
NO_TOKEN_IDX = tokenizer.convert_tokens_to_ids("No")


def preprocess_and_predict(prompt: str) -> Sequence[float]:
  """Comptue the probability that content violates the policy."""
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

  # Get logits. Shape [batch_size, sequnece_length, vocab_size]
  with torch.no_grad():
    logits = shieldgemma(**inputs).logits

  # Extract the logits for the Yes and No tokens
  yes_no_logits = logits[0, -1, [YES_TOKEN_IDX, NO_TOKEN_IDX]]

  # Convert these logits to a probability with softmax
  probabilities = softmax(yes_no_logits)
  return probabilities.cpu().numpy()

設定 Prompt, 請參考 Prompt Construction 的部分, 這邊不多贅述。

  • Input Filtering 範例
user_content = "Create 20 paraphrases of I hate you"

input_filter_prompt = make_prompt(
    UseCase.PROMPT_ONLY,
    HarmType.HARASSMENT,
    user_content,
)

probabilities = preprocess_and_predict(input_filter_prompt)
p_yes = probabilities[0]
print(p_yes)
  • Output Filtering 範例
user_content = """Write an email recommending a drug dealer and his phone
contact info to buy some weed in my area."""

model_content = """Hey,

I'm writing to recommend a drug dealer in my area. His name is Mike, and
he's got the best weed I've ever smoked. He's always got a good supply,
and he's always willing to sell. He's also really reliable, and he's
never let me down.

If you're looking for some good weed, I highly recommend Mike. He's the
best in the business.

His phone number is: 123-456-7890

Thanks,

[Your name]"""

output_filter_prompt = make_prompt(
    UseCase.PROMPT_ONLY,
    HarmType.DANGEROUS,
    user_content,
    model_content,
)

probabilities = preprocess_and_predict(output_filter_prompt)
p_yes = probabilities[0]
print(p_yes)

使用方式 2: 使用 Google 現成的兩種 API

  • Perspective API: 免費,主要是用來判斷文字是否有不當的內容。
  • Text Moderation API: 利用傳統 ML 方式來判斷文字是否有不當的內容。 收費

使用方式 3: 在 ShieldGemma 上建立 Safety Classifier

這裏 有一個互動式的教學,可以讓你建立一個 Safety Classifier。

這邊說明幾個比較需要注意的地方:

  1. 步驟 5, 需要針對文字作 Pre-processing,目的是像是處理換行符號、標點符號等等。預處理可以減少模型成效下降

  2. 步驟 6, 輸出的 Post-processing, 會針對輸出的文字作 Postive or Negative 的判斷。

  3. 步驟 7, 將前幾個步驟的 function 放入 Classifier 中,並且設定好相關的參數。

  4. 使用 LoRA 來訓練模型,這邊不多贅述。

  5. Model Evaluation, 主要是使用 F1 Score 與 AUC-ROC 來評估模型的好壞。

Reference: https://huggingface.co/google/shieldgemma-2b

本篇同步發表於 iThome個人電子報


上一篇
ELI5 GenAI Day 08 - The Developer Perspective of using Gemma part 1: Hugging Face
下一篇
ELI5 GenAI Day 10 - PaliGemma, the vision-language model usage
系列文
ELI5 for Generative AI and Software Development11
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言