DAY 10
0
AI & Data

## Day 10: 利用 numpy 和 scipy 客製化 PyTorch 模型

2. 複寫 forward 方法
3. 複寫 backward 方法，並在方法中實踐對參數求梯度的方法。

# 是 `autograd.Function`，不是 `nn.Module`

``````from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

class ScipyConv2d(Module):
def __init__(self, filter_width, filter_height):
super(ScipyConv2d, self).__init__()
# 兩個可學習的 Parameters，我們將會對這兩個參數求梯度
self.filter = Parameter(torch.randn(filter_width, filter_height))
self.bias = Parameter(torch.randn(1, 1))

def forward(self, input):
# 呼叫 ScipyConv2dFunction 的 apply
return ScipyConv2dFunction.apply(input, self.filter, self.bias)

module = ScipyConv2d(2, 2)
print("Filter and bias: ", list(module.parameters()))
#=>Filter and bias:  [Parameter containing:
tensor([[ 0.8463,  1.5286],
``````

``````from scipy.signal import correlate2d
class ScipyConv2dFunction(Function):
@staticmethod
def forward(ctx, input, filter, bias):
# detach so we can cast to NumPy
input, filter, bias = input.detach(), filter.detach(), bias.detach()
# scipy correlate2d method
result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
result += bias.numpy()
ctx.save_for_backward(input, filter, bias)

@staticmethod
pass

output = module(input)
print("Output from the convolution: ", output)
=> Output from the convolution:  tensor([[ 2.8685,  0.4486,  3.8280,  1.1291],
[ 3.1263,  1.2215,  0.1854,  4.9828],
[-0.0896,  4.3831,  1.1503,  0.5699],
[-2.5296, -3.6354, -2.2648, -0.7458]],
``````

``````    @staticmethod
input, filter, bias = ctx.saved_tensors

output.backward(torch.randn(4, 4))
# => Gradient for the input map:  tensor([[-2.8094,  2.4611,  2.0011, -4.5112,  2.1211],
[-3.1463, -0.4110,  0.5583,  0.4608, -1.8372],
[ 1.2061, -3.3830, -3.0883,  3.0992, -1.3506],
[ 2.8678, -0.3643, -3.5285, -2.4539,  2.3888],
[ 0.8490,  1.4713, -1.1287, -2.4457, -1.4976]])
``````

``````from torch.autograd.gradcheck import gradcheck
moduleConv = ScipyConv2d(3, 3)
input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)]
test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4)
print("Are the gradients correct: ", test)
=> Are the gradients correct:  True
``````