import os
import random
from typing import Tuple
import numpy as np
from matplotlib import pyplot as plt
# INFO messages are not printed.
# This must be run before loading other modules.
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
這邊最後一行是說,要不要印出訊息。
import tensorflow as tf
# install TF similarity if needed
try:
import tensorflow_similarity as tfsim # main package
except ModuleNotFoundError:
!pip install tensorflow_similarity
import tensorflow_similarity as tfsim
這一段要檢查是否有安裝 tensorflow similarity,沒有的話就安裝起來。
num_ms_examples = 100000 # @param {type:"slider", min:1000, max:1000000}
num_ms_features = 784 # @param {type:"slider", min:10, max:1000}
num_ms_classes = 10 # @param {type:"slider", min:2, max:1000}
# We use random floats here to represent a dense feature vector
X_ms = np.random.rand(num_ms_examples, num_ms_features)
# We use random ints to represent N different classes
y_ms = np.random.randint(low=0, high=10, size=num_ms_examples)
這邊看到很多 ms 是指 multiple shot 的縮寫。
num_known_ms_classes = 5 # @param {type:"slider", min:2, max:1000}
ms_classes_per_batch = num_known_ms_classes
ms_examples_per_class_per_batch = 2 # @param {type:"integer"}
ms_class_list = random.sample(range(num_ms_classes), k=num_known_ms_classes)
ms_sampler = tfsim.samplers.MultiShotMemorySampler(
X_ms,
y_ms,
classes_per_batch=ms_classes_per_batch,
examples_per_class_per_batch=ms_examples_per_class_per_batch,
class_list=ms_class_list,
)
這樣就建好一個取樣器了。
接著就用這個取樣器來測試看看。
X_ms_batch, y_ms_batch = ms_sampler.generate_batch(100)
print("#" * 10 + " X " + "#" * 10)
print(X_ms_batch)
print("\n" + "#" * 10 + " y " + "#" * 10)
print(y_ms_batch)
# Check that the batch size is equal to the target number of classes * target number of examples per class.
assert tf.shape(X_ms_batch)[0] == (ms_classes_per_batch * ms_examples_per_class_per_batch)
# Check that the number of columns matches the number of expected features.
assert tf.shape(X_ms_batch)[1] == (num_ms_features)
# Check that classes in the batch are from the allowed set in CLASS_LIST
assert set(tf.unique(y_ms_batch)[0].numpy()) - set(ms_class_list) == set()
# Check that we only have NUM_CLASSES_PER_BATCH
assert len(tf.unique(y_ms_batch)[0]) == ms_classes_per_batch
這裡已可以取片段來使用。
# Get 10 examples starting at example 200.
X_ms_slice, y_ms_slice = ms_sampler.get_slice(begin=200, size=10)
print("#" * 10 + " X " + "#" * 10)
print(X_ms_slice)
print("\n" + "#" * 10 + " y " + "#" * 10)
print(y_ms_slice)
# Check that the batch size is equal to our get_slice size.
assert tf.shape(X_ms_slice)[0] == 10
# Check that the number of columns matches the number of expected features.
assert tf.shape(X_ms_slice)[1] == (num_ms_features)
# Check that classes in the batch are from the allowed set in CLASS_LIST
assert set(tf.unique(y_ms_slice)[0].numpy()) - set(ms_class_list) == set()