iT邦幫忙

2025 iThome 鐵人賽

DAY 30
0
生成式 AI

VLM系列 第 30

Day30 :VLM Fine-tuning 尚未成功,這是另一個開始!

  • 分享至 

  • xImage
  •  

剩20分鐘,這場戰役就要結束。為了避免我在最後一下功虧一簣,雖然VLM Fine-tuning沒成功,我還是記錄一下與AI協作程式後的進度,心得是AI的debug能力沒有很好,愈改愈差。另一原因可能是我的prompt下的不好, 導致它愈改愈差。

另外Colab T4微調沒有成功,在成功載入模型後,後續步驟就宣告RAM已用盡。

以下是在A100運行的程式碼:


# 步驟 1: 安裝必要庫(解決 tokenizers 衝突)
!pip uninstall torchvision torchaudio bitsandbytes tokenizers -y  # 移除衝突套件
!pip install --no-cache-dir transformers==4.41.2 datasets pillow peft==0.10.0 bitsandbytes==0.44.1 triton==2.3.0 tokenizers==0.19.1 --no-deps
# 重新安裝相容的 torchvision/torchaudio(匹配 PyTorch 2.3.1+cu126)
!pip install torchvision==0.18.1+cu121 torchaudio==2.3.1+cu121 --index-url https://download.pytorch.org/whl/cu121
# 步驟 2: 載入模組並驗證環境
import json
import torch
from transformers import AutoTokenizer, AutoImageProcessor, LlavaForConditionalGeneration, Trainer, TrainingArguments
from PIL import Image
from peft import LoraConfig, get_peft_model 
from torch.utils.data import Dataset
from transformers.data.data_collator import default_data_collator
# 步驟 3: 載入模型和處理器
# 載入 tokenizer(來自 LLaVA)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# 載入圖像處理器(來自 CLIP)
image_processor = AutoImageProcessor.from_pretrained(clip_model_name, trust_remote_code=True)

# 自訂 LlavaProcessor 類別
class LlavaProcessor:
    def __init__(self, tokenizer, image_processor):
        self.tokenizer = tokenizer
        self.image_processor = image_processor
    
    def __call__(self, text, images, return_tensors="pt", padding=True, truncation=False):
        image_inputs = self.image_processor(images, return_tensors=return_tensors)
        text_inputs = self.tokenizer(text, return_tensors=return_tensors, padding=padding, truncation=truncation)
        inputs = {**text_inputs, **{k: v for k, v in image_inputs.items() if k != 'pixel_values' or v is not None}}
        if 'pixel_values' in image_inputs:
            inputs['pixel_values'] = image_inputs['pixel_values']
        return inputs

processor = LlavaProcessor(tokenizer, image_processor)
print("處理器載入成功(自訂組合)")

# 載入模型(無量化,使用 bf16)
model = LlavaForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,  # A100 原生支援 bf16
    device_map={"": 0},  # 載入到 GPU
    trust_remote_code=True
)
model.gradient_checkpointing_enable()  # 啟用梯度檢查點
print("模型載入成功(bf16,非量化)")
# 步驟 4: 配置 LoRA 適配器(解決量化模型無法直接微調)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["multi_modal_projector.linear_1", "multi_modal_projector.linear_2"], 
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

# 啟用 multi_modal_projector 訓練
for name, param in model.named_parameters():
    param.requires_grad = False
    if "multi_modal_projector" in name:
        param.requires_grad = True
#步驟5: 載入資料集及轉換格式(略)
# 步驟 6: 創建資料集
import json
from PIL import Image
import torch
from torch.utils.data import Dataset
import os

class LLaVADataset(Dataset):
    def __init__(self, jsonl_path, image_folder, processor, max_length=512):
        self.jsonl_path = jsonl_path
        self.image_folder = image_folder
        self.processor = processor
        self.max_length = max_length
        self.data = []
        with open(jsonl_path, "r", encoding="utf-8") as f:
            for line in f:
                self.data.append(json.loads(line))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        # 圖像路徑
        image_path = f"{item['image']}"
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"圖像檔案不存在: {image_path}")
        image = Image.open(image_path).convert("RGB")
        
        # 提取對話(USER 和 ASSISTANT)
        conversations = item["conversations"]
        user_prompt = next(c["value"] for c in conversations if c["from"] == "human")
        assistant_response = next(c["value"] for c in conversations if c["from"] == "gpt")
        
        # 確保 user_prompt 包含 <image> token
        if "<image>" not in user_prompt:
            user_prompt = f"<image>\n{user_prompt.strip()}"
        
        # 處理圖像
        image_inputs = self.processor.image_processor(
            images=image,
            return_tensors="pt",
            size={"height": 336, "width": 336},  # 適配 CLIP-ViT-L-336px
            do_resize=True,
            do_center_crop=False
        )
        
        # 處理文字(user_prompt)
        text_inputs = self.processor.tokenizer(
            user_prompt,
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_length,
            truncation=True
        )
        
        # 合併輸入
        inputs = {
            "input_ids": text_inputs["input_ids"].squeeze(0),
            "attention_mask": text_inputs["attention_mask"].squeeze(0),
            "pixel_values": image_inputs["pixel_values"].squeeze(0)
        }
        
        # 處理標籤(ASSISTANT 回應)
        labels = self.processor.tokenizer(
            assistant_response,
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_length,
            truncation=True
        )["input_ids"].squeeze(0)
        
        inputs["labels"] = labels
        return inputs
# 驗證資料集
dataset = LLaVADataset(
    jsonl_path="/content/llava_finetune.jsonl",
    image_folder="/content/data/images",
    processor=processor,
    max_length=512
)
print(f"資料集大小:{len(dataset)}")

# 檢查資料
sample = dataset[0]
print(f"樣本 input_ids 長度:{sample['input_ids'].shape}")
print(f"樣本 attention_mask 長度:{sample['attention_mask'].shape}")
print(f"樣本 pixel_values 形狀:{sample['pixel_values'].shape}")
print(f"樣本 user_prompt:{next(c['value'] for c in dataset.data[0]['conversations'] if c['from'] == 'human')}")
# 步驟 7: 配置訓練參數(適配 A100)
training_args = TrainingArguments(
    output_dir="/content/output/checkpoints",
    per_device_train_batch_size=4,  # 利用 A100 40GB RAM
    gradient_accumulation_steps=4,  # 模擬批次大小 16
    learning_rate=2e-5,
    num_train_epochs=1,
    max_steps=100,
    logging_steps=10,
    save_steps=50,
    fp16=True,  # A100 支援 bf16,fp16 相容
    gradient_checkpointing=True,
    remove_unused_columns=False,
    report_to="none"
)
# 步驟 8: 初始化 Trainer(custom_collate_fn)
from transformers import Trainer, TrainingArguments
from torch.utils.data import DataLoader

def custom_collate_fn(batch):
    input_ids = torch.stack([item["input_ids"] for item in batch])
    attention_mask = torch.stack([item["attention_mask"] for item in batch])
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "pixel_values": pixel_values,
        "labels": labels
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=custom_collate_fn
)
# 步驟 9: 開始訓練
trainer.train()

訓練時出現錯誤,主要問題在訓練資料的格式,持續加油!!!


這次很多內容吸收不完全,就發佈了內容,後續會再一一審視及精進,並找個適合的平台,把文章內容再檢查過,修改後再發佈,以免內容有錯誤而誤導了願意花時間看的人。


上一篇
Day 29:Improving grounding 改善VLM 視覺定位能力
系列文
VLM30
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言