今天我們來使用PyTorch來實現一個簡單的 SRCNN 模型,對給定的低分辨率圖像來進行超分辦率處理。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import ToTensor, ToPILImage
from PIL import Image
import numpy as np
定義 SRCNN 模型
class SRCNN(nn.Module):
def init(self):
super(SRCNN, self).init()
self.layer1 = nn.Conv2d(1, 64, kernel_size=9, padding=4)
self.layer2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
self.layer3 = nn.Conv2d(32, 1, kernel_size=5, padding=2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.layer1(x))
x = self.relu(self.layer2(x))
x = self.layer3(x)
return x
載入圖像並轉換為張量
def load_image(img_path):
img = Image.open(img_path).convert('YCbCr')
y, _, _ = img.split()
return ToTensor()(y).unsqueeze(0)
#保存處理後的圖像
def save_image(tensor, path):
img = tensor.clone().detach().squeeze(0).numpy()
img = (img * 255.0).clip(0, 255).astype(np.uint8)
img = Image.fromarray(img)
img.save(path)
初始化模型與損失函數
model = SRCNN()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
載入訓練數據
input_image = load_image('low_res_image.jpg')
target_image = load_image('high_res_image.jpg')
訓練模型
num_epochs = 100
for epoch in range(num_epochs):
model.train()
output = model(input_image)
loss = criterion(output, target_image)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {loss.item()}')
保存結果圖像
output = model(input_image)
save_image(output, 'output_image.jpg')