iT邦幫忙

2024 iThome 鐵人賽

DAY 22
0

今天我們來使用PyTorch來實現一個簡單的 SRCNN 模型,對給定的低分辨率圖像來進行超分辦率處理。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import ToTensor, ToPILImage
from PIL import Image
import numpy as np

定義 SRCNN 模型
class SRCNN(nn.Module):
def init(self):
super(SRCNN, self).init()
self.layer1 = nn.Conv2d(1, 64, kernel_size=9, padding=4)
self.layer2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
self.layer3 = nn.Conv2d(32, 1, kernel_size=5, padding=2)
self.relu = nn.ReLU()

def forward(self, x):
    x = self.relu(self.layer1(x))
    x = self.relu(self.layer2(x))
    x = self.layer3(x)
    return x

載入圖像並轉換為張量
def load_image(img_path):
img = Image.open(img_path).convert('YCbCr')
y, _, _ = img.split()
return ToTensor()(y).unsqueeze(0)

#保存處理後的圖像
def save_image(tensor, path):
img = tensor.clone().detach().squeeze(0).numpy()
img = (img * 255.0).clip(0, 255).astype(np.uint8)
img = Image.fromarray(img)
img.save(path)

初始化模型與損失函數
model = SRCNN()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

載入訓練數據
input_image = load_image('low_res_image.jpg')
target_image = load_image('high_res_image.jpg')

訓練模型
num_epochs = 100
for epoch in range(num_epochs):
model.train()
output = model(input_image)
loss = criterion(output, target_image)

optimizer.zero_grad()
loss.backward()
optimizer.step()

if epoch % 10 == 0:
    print(f'Epoch {epoch}, Loss: {loss.item()}')

保存結果圖像
output = model(input_image)
save_image(output, 'output_image.jpg')


上一篇
SRCNN 3
下一篇
SRCNN模型 2
系列文
用AI做圖像super resolution 或用AI做圖像中的物件消除30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言