昨天大致上,對CNN有基本的了解
今天就照第8天的流程跑一次
首先昨天就有的新增模型
function createModel() {
const model = tf.sequential();
const IMAGE_WIDTH = 28;
const IMAGE_HEIGHT = 28;
const IMAGE_CHANNELS = 1;
model.add(tf.layers.conv2d({
inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPooling2d({
poolSize: [2, 2],
strides: [2, 2]
}));
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPooling2d({
poolSize: [2, 2],
strides: [2, 2]
}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({
units: 10,
kernelInitializer: 'varianceScaling',
activation: 'softmax'
}));
model.compile({
optimizer: tf.train.adam(),
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
return model;
}
再來拿資料的部分直接引用MnistData.js去載入資料,老實說這個程式碼我看了很久也想不出來要怎麼改
async function getData(){
const mnist_data = new MnistData();
await mnist_data.load();
return mnist_data;
}
再來就是轉換成tensor
因為裡面MnistData.js裡面包裝得差不多了
所以已只要調裡面的方法來用就好
function convertToTensor(mnist_data, method, size) {
return tf.tidy(() => {
const this_batch = mnist_data[method](size);
return {
inputs: this_batch.xs.reshape([size, 28, 28, 1]),
labels: this_batch.labels
}
});
}
然後訓練模型
async function trainModel(model, t_data,v_data) {
//每次訓練的樣本數
const batchSize = 500;
//訓練多少代
const epochs =10;
return await model.fit(t_data.inputs, t_data.labels, {
batchSize,
epochs,
shuffle: true,
validationData: [v_data.inputs, v_data.labels],
callbacks: tfvis.show.fitCallbacks(
{ name: 'Training Performance' },
['loss', 'val_loss', 'acc', 'val_acc'],
{ height: 200, callbacks: ['onEpochEnd'] }
)
});
}
然後就執行
async function runTensorFlow(){
const model=createModel();
const mnist_data= await getData();
const traindata=convertToTensor(mnist_data,"nextTrainBatch",5000);
const validationdata=convertToTensor(mnist_data,"nextTestBatch",1000);
await trainModel(model,traindata,validationdata);
}
document.addEventListener('DOMContentLoaded', runTensorFlow);
這時候就會看到訓練的過程