iT邦幫忙

0

Pytorch 的 BatchNormaliztion 跟 Tensorflow 的差異很大

小弟最近因為研究的需要,要把 Pytorch 的 Model Porting 到 TF 來,但是在 BatchNormaliztion 這裡遇到了問題,目前看過他們部分的 Source code 知道一些參數不相同也改了,也試著從 Pytorch 這邊 load weight 到 TF 這裡,還是一樣(如下圖)不行,不知道該如何是好,還請各位大大幫幫忙,感激不禁!

https://ithelp.ithome.com.tw/upload/images/20210903/20127932yFkmiJgtA9.png

Rorschach iT邦新手 5 級 ‧ 2021-09-03 16:32:57 檢舉
自己回答一下自己,看起來是裡面再訓練的時候不一樣,如果想在 tensorflow 上訓練時候也得到一模一樣的結果,重寫 BN 就可以了。

1 個回答

1
Capillary J
iT邦新手 4 級 ‧ 2021-09-24 11:56:32

我之前也有把pytorch的weight轉到TF使用過,是個費時費工的活= ="

在想你的input shape是(1,2,2,2)
以tensorflow來說順序是 NHWC,但pytorch是NCHW
你如果要在tensorflow取得和pytorch一樣的結果
是不是該transpose你的input?

我要發表回答

立即登入回答