iT邦幫忙

2018 iT 邦幫忙鐵人賽
DAY 25
0

Introduction

重要性採樣(Sampled Softmax)相較於歸一化函式(Regular softmax)標籤值(label)分類過多時計算效率較低,預先篩選掉可能性的標籤值(label)後,僅計算局部標籤值(label)所以速度較快。

交叉熵(Cross-entropy)是一種成本函數,評估模型訓練的一種方法,用以評估期望值與輸出值的誤差機率。

Tasks

引用物件。

from __future__ import print_function # Use a function definition from future version (say 3.x from 2.7 interpreter)
from __future__ import division

import os
import cntk as C
import cntk.tests.test_utils
cntk.tests.test_utils.set_device_from_pytest_env()
C.cntk_py.set_fixed_random_seed(1)

宣告函式:def cross_entropy_with_sampled_softmax_and_embedding,使用重要性採樣並計算交叉熵。

def cross_entropy_with_sampled_softmax_and_embedding(
    hidden_vector,            # 輸入參數
    target_vector,            # 期望輸出
    num_classes,              # 分類數量
    hidden_dim,               # 隱藏維度
    num_samples,              # 採樣數量
    sampling_weights,         # 採樣權重
    allow_duplicates = True,  # 隨機採樣權重控制
    ):
    
    # 定義學習參數
    b = C.Parameter(shape = (num_classes, 1), init = 0)
    W = C.Parameter(shape = (num_classes, hidden_dim), init = C.glorot_uniform())

    # 隨機採樣生成批次資料集合
    sample_selector = C.random_sample(sampling_weights, num_samples, allow_duplicates)

    # 隨機採樣生成機率資料集合
    inclusion_probs = C.random_sample_inclusion_frequency(sampling_weights, num_samples, allow_duplicates) # dense row [1 * vocab_size]
    log_prior = C.log(inclusion_probs) # dense row [1 * num_classes]

    # 隨機採樣權重矩陣
    W_sampled = C.times(sample_selector, W) # [num_samples * hidden_dim]
    z_sampled = C.times_transpose(W_sampled, hidden_vector) + C.times(sample_selector, b) - C.times_transpose (sample_selector, log_prior)# [num_samples]

    # 期望輸出權重矩陣
    W_target = C.times(target_vector, W) # [1 * hidden_dim]
    z_target = C.times_transpose(W_target, hidden_vector) + C.times(target_vector, b) - C.times_transpose(target_vector, log_prior) # [1]

    z_reduced = C.reduce_log_sum_exp(z_sampled)
    
    # 計算交叉熵
    cross_entropy_on_samples = C.log_add_exp(z_target, z_reduced) - z_target

    # 模型輸出
    z = C.times_transpose(W, hidden_vector) + b
    z = C.reshape(z, shape = (num_classes))

    zSMax = C.reduce_max(z_sampled)
    error_on_samples = C.less(z_target, zSMax)
    
    return (z, cross_entropy_on_samples, error_on_samples)

上一篇
深度捲積生成對抗網路
下一篇
語音識別
系列文
探索 Microsoft CNTK 機器學習工具30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言