iT邦幫忙

2023 iThome 鐵人賽

DAY 29
0
自我挑戰組

AI研究系列 第 29

影像辨識實作(四):Kaggle平台實作yolo映像辨識

  • 分享至 

  • xImage
  •  

將模型的權重轉換為適用於TensorFlow的格式。我們會逐步讀取相關檔案,並進行必要的處理,以建立適用於TensorFlow的tf.assign操作,以便將權重正確地載入到模型中。這個過程需要仔細而逐步的操作,以確保權重轉換的準確性和正確性。

def load_weights(variables, file_name):
"""Reshapes and loads official pretrained Yolo weights.

Args:
    variables: A list of tf.Variable to be assigned.
    file_name: A name of a file containing weights.

Returns:
    A list of assign operations.
"""
with open(file_name, "rb") as f:
    # Skip first 5 values containing irrelevant info
    np.fromfile(f, dtype=np.int32, count=5)
    weights = np.fromfile(f, dtype=np.float32)

    assign_ops = []
    ptr = 0

    # Load weights for Darknet part.
    # Each convolution layer has batch normalization.
    for i in range(52):
        conv_var = variables[5 * i]
        gamma, beta, mean, variance = variables[5 * i + 1:5 * i + 5]
        batch_norm_vars = [beta, gamma, mean, variance]

        for var in batch_norm_vars:
            shape = var.shape.as_list()
            num_params = np.prod(shape)
            var_weights = weights[ptr:ptr + num_params].reshape(shape)
            ptr += num_params
            assign_ops.append(tf.assign(var, var_weights))

        shape = conv_var.shape.as_list()
        num_params = np.prod(shape)
        var_weights = weights[ptr:ptr + num_params].reshape(
            (shape[3], shape[2], shape[0], shape[1]))
        var_weights = np.transpose(var_weights, (2, 3, 1, 0))
        ptr += num_params
        assign_ops.append(tf.assign(conv_var, var_weights))

    # Loading weights for Yolo part.
    # 7th, 15th and 23rd convolution layer has biases and no batch norm.
    ranges = [range(0, 6), range(6, 13), range(13, 20)]
    unnormalized = [6, 13, 20]
    for j in range(3):
        for i in ranges[j]:
            current = 52 * 5 + 5 * i + j * 2
            conv_var = variables[current]
            gamma, beta, mean, variance =  \
                variables[current + 1:current + 5]
            batch_norm_vars = [beta, gamma, mean, variance]

            for var in batch_norm_vars:
                shape = var.shape.as_list()
                num_params = np.prod(shape)
                var_weights = weights[ptr:ptr + num_params].reshape(shape)
                ptr += num_params
                assign_ops.append(tf.assign(var, var_weights))

            shape = conv_var.shape.as_list()
            num_params = np.prod(shape)
            var_weights = weights[ptr:ptr + num_params].reshape(
                (shape[3], shape[2], shape[0], shape[1]))
            var_weights = np.transpose(var_weights, (2, 3, 1, 0))
            ptr += num_params
            assign_ops.append(tf.assign(conv_var, var_weights))

        bias = variables[52 * 5 + unnormalized[j] * 5 + j * 2 + 1]
        shape = bias.shape.as_list()
        num_params = np.prod(shape)
        var_weights = weights[ptr:ptr + num_params].reshape(shape)
        ptr += num_params
        assign_ops.append(tf.assign(bias, var_weights))

        conv_var = variables[52 * 5 + unnormalized[j] * 5 + j * 2]
        shape = conv_var.shape.as_list()
        num_params = np.prod(shape)
        var_weights = weights[ptr:ptr + num_params].reshape(
            (shape[3], shape[2], shape[0], shape[1]))
        var_weights = np.transpose(var_weights, (2, 3, 1, 0))
        ptr += num_params
        assign_ops.append(tf.assign(conv_var, var_weights))

return assign_ops

上一篇
影像辨識實作(三):Kaggle平台實作yolo映像辨識
下一篇
影像辨識實作(五):Kaggle平台實作yolo映像辨識
系列文
AI研究30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言