更新task env
驗證
提交training metric,根據上篇追蹤,應該是不包含model註冊,需要另外手動寫。
@task(enable_deck=True, container_image=custom_image, requests=Resources(mem="3000Mi"))
@mlflow_autolog(framework=mlflow.keras)
def train_model(epochs: int):
# Refer to https://www.tensorflow.org/tutorials/keras/classification
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (_, _) = fashion_mnist.load_data()
train_images = train_images / 255.0
model = tf.keras.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10),
]
)
model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
model.fit(train_images, train_labels, epochs=epochs)