iT邦幫忙

1

(筆記)風格轉移

  • 分享至 

  • xImage
  •  

簡介

根據<< A Neural Algorithm of Artistic Style >>,畫家所畫的作品可以分成內容(content)和風格(style),此次要撰寫的程式需要兩張圖片-風格圖片以及內容圖片,使用風格圖片學習畫作的風格並和內容圖片的內容進行結合,產生一張新圖片(同時具有風格圖片風格和內容圖片的內容)。
https://ithelp.ithome.com.tw/upload/images/20200608/20126864JwZ1PSqsZn.jpg+https://ithelp.ithome.com.tw/upload/images/20200608/20126864fsYlhNCi3I.jpg=https://ithelp.ithome.com.tw/upload/images/20200608/20126864TRuEZiR3Lo.jpg

原理

既然是使用神經網路進行學習必定需要求loss,但在求loss前我們先了解圖案在不同層CNN進行重塑的結果(下面圖出自於論文)
https://ithelp.ithome.com.tw/upload/images/20200608/20126864eGeKZmY3IU.png

  • 內容重塑:從圖片下方可以發現越接近輸入層重塑出的圖片幾乎和原始內容圖片相同,而越深層重構出的圖片背景則是變得較模糊。這表示在第五層我們專注在提取內容圖片中真正的內容而不是詳細的像素信息。
  • 風格重塑:在CNN中我們可以認為同一層中的不同filter都提取了一種特徵,以內容的角度來說不同特徵間的差異相當大,但一副作品的風格應存在於每個特徵中,因此我們藉由計算同一層不同filter間的相關性來提取風格特徵。

網路

論文中使用的是vgg19(pretrained=True)的網路架構

loss

  • 計算loss前我們需要準備內容圖片、風格圖片和一張目標圖(此圖就是我們程式中的輸出圖,可以是全白圖、也可以直接使用內容圖片)
  • content loss:將內容圖片以及目標圖片輸入網路中,對兩個圖片在conv4層的輸出進行MSELoss得到content loss。
    https://ithelp.ithome.com.tw/upload/images/20200608/20126864Uo9ynhQXdZ.jpg
  • style loss:和計算content loss最大的差異在於,不能直接使用MSEloss計算style loss,如同上述所說的提取風格特徵時考慮的不是每個像素間的差異,而是feature map間的相關程度,因此我們需要先將風格圖和目標圖轉為Gram matrix的形式,接著才是使用MSELoss求出loss。
    https://ithelp.ithome.com.tw/upload/images/20200608/20126864HhSyZQvbPl.jpg
  • total loss = content loss + style loss

程式碼

(程式碼基本上都出自於pytorch的Tutorials,加上註解)
https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

  • import庫
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import copy
import matplotlib.pyplot as plt
from torchvision import models
from torchvision import transforms
from PIL import Image
  • 圖片前處理
imgsize = 128
loader = transforms.Compose(
    [
        transforms.Resize(imgsize),
        transforms.ToTensor(),
    ]
)
def image_loader(image_name):
    image = Image.open(image_name)
    image = loader(image)                #調整圖片大小並且轉為tensor
    image = torch.unsqueeze(image,dim=0) #調整至Conv2d的輸入格式
    image = image.float()
    return image

style_img = image_loader(".\image2\picasso.jpg")   #風格圖片
content_img = image_loader(".\image2\dancing.jpg") #內容圖片
assert style_img.size() == content_img.size()      #確保風格圖片和內容圖片大小相同
  • content loss
    此class並非真的是pytorch中的loss方法,而是繼承nn.Module用來計算content_loss的一層網路並不會更動到輸入值
class Contentloss(nn.Module):
    def __init__(self,target):
        super().__init__()
        self.target = target.detach()
    def forward(self,input):
        self.loss = F.mse_loss(input,self.target)
        return input
  • style loss
    如同前文所說我們需要先將feature maps轉成矩陣,接著求出Gram matrix(G=矩陣*transpose(矩陣))
def gram_matrix(input):
    a,b,c,d = input.shape               #a=batch_size b=featuremap c,d=length*height
    features = input.view(a*b,c*d)      #轉為矩陣形式 
    G = torch.mm(features,features.t()) #計算Gram matrix
    return G.div(a*b*c*d)

class Styleloss(nn.Module):
    def __init__(self,target_feature):
        super().__init__()
        self.target = gram_matrix(target_feature).detach()
    def forward(self,input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G,self.target)
        return input
  • 導入模型
    vgg網路訓練時是使用mean=[0.485, 0.456, 0.406]和std=[0.229, 0.224, 0.225]來進行標準化,因此進入網路的圖面也必須使用這兩個參數進行一次標準化。
cnn = models.vgg19(pretrained=True).features.eval()       #pretrained = true表示保留參數值
cnn_normalization_mean = torch.tensor([0.485,0.456,0.406])
cnn_normalization_std = torch.tensor([0.229,0.224,0.225])

class Normalization(nn.Module):
    def __init__(self,mean,std):
        super().__init__()
        self.mean = torch.tensor(mean).view(-1,1,1)
        self.std = torch.tensor(std).view(-1,1,1)
    
    def forward(self,img):
        return (img-self.mean)/self.std
  • 搭建網路
    藉由nn.sequential一層一層的增加網路,vgg19的架構為conv2d->Relu->conv2d->Relu->Maxpool->conv2d...,由此特性判斷在何處應該添加content loss層和style loss層
def get_style_model_and_losses(cnn,normalization_mean,normalization_std,style_img,content_img,content_layers=content_layers_default,style_layers=style_layers_default):
    cnn = copy.deepcopy(cnn)
    normalization = Normalization(normalization_mean,normalization_std)
    content_losses=[]                       #用來存放content loss網路層的list
    style_losses=[]                         #用來存放style loss網路層的list
    model = nn.Sequential(normalization)    #加入第一層標準化層
    i=0
    for layer in cnn.children():                 
        if isinstance(layer, nn.Conv2d):
            i=i+1
            name = 'conv_{}'.format(i)

        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
		
        model.add_module(name,layer)

        if name in content_layers:
            
            target = model(content_img).detach()
            
            content_loss = Contentloss(target)#創建content loss網路層
                                              #並將target傳入init
            
            model.add_module("content_loss_{}".format(i),content_loss)
            
            content_losses.append(content_loss)#添加的是網路層

        if name in style_layers:
            
            target_feature = model(style_img).detach()
            
            style_loss = Styleloss(target_feature)#創建style loss網路層
                                                  #並將target傳入init
            
            model.add_module("style_loss_{}".format(i),style_loss)#添加的是網路層
            
            style_losses.append(style_loss)
    for i in range(len(model)-1,-1,-1):
        if isinstance(model[i],Contentloss) or isinstance(model[i],Styleloss):
            break
    model = model[:i+1]

    return model,style_losses,content_losses
  • 初始化要輸出的圖
input_img = content_img.clone()     #可以是內容圖片或白噪聲
  • optimizer
def get_input_optimizer(input_img): #採用論文中建議的LBFGS
    optimizer = optim.LBFGS([input_img.requires_grad_()])
    return optimizer
  • 運行程式
def run_style_transfer(cnn,normalization_mean,normalization_std,content_img,style_img,input_img,num_setps=300,style_weight=1000000,content_weight=1):
    print('Building the style transfer model...')
    model,style_losses,content_losses = get_style_model_and_losses(cnn,normalization_mean,normalization_std,style_img,content_img)
    print(model)
    optimizer = get_input_optimizer(input_img)
    print('optimizimg...')
    run = [0]
    while run[0]<= num_setps:
        def closure():
            input_img.data.clamp_(0,1)
            optimizer.zero_grad()
            model(input_img)
            style_score = 0
            content_score = 0

            for s1 in style_losses:
                style_score = style_score+s1.loss
            for c1 in content_losses:
                content_score = content_score+c1.loss
                
            style_score = style_weight*style_score
            content_score = content_weight*content_score

            loss = style_score + content_score
            loss.backward()
            run[0] = run[0]+1
            if run[0] % 50 == 0:
                print("run{}:".format(run))
                print('style loss:{:4f}  content loss:{:4f}'.format(style_score.item(),content_score.item()))
                print()
            return style_score + content_score
        optimizer.step(closure)
    input_img.data.clamp_(0,1)

    return input_img

圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言