迁移提前停止

在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本

本笔记本演示了如何使用提前停止设置模型训练。首先,在 TensorFlow 1 中使用 tf.estimator.Estimator 和提前停止钩子,然后在 TensorFlow 2 中使用 Keras API 或自定义训练循环。 提前停止是一种正则化技术,可在验证损失达到特定阈值时停止训练。

在 TensorFlow 2 中,可以通过三种方式实现提前停止:

安装

import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2022-12-14 20:45:28.744945: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:45:28.745036: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:45:28.745045: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

TensorFlow 1:使用提前停止钩子和 tf.estimator 提前停止

首先,定义用于 MNIST 数据集加载和预处理的函数,以及与 tf.estimator.Estimator 一起使用的模型定义:

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def _input_fn():
  ds_train = tfds.load(
    name='mnist',
    split='train',
    shuffle_files=True,
    as_supervised=True)

  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.batch(128)
  ds_train = ds_train.repeat(100)
  return ds_train

def _eval_input_fn():
  ds_test = tfds.load(
    name='mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True)
  ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(128)
  return ds_test

def _model_fn(features, labels, mode):
  flatten = tf1.layers.Flatten()(features)
  features = tf1.layers.Dense(128, 'relu')(flatten)
  logits = tf1.layers.Dense(10)(features)

  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.005)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

在 TensorFlow 1 中,提前停止的工作方式是使用 tf.estimator.experimental.make_early_stopping_hook 设置提前停止钩子。将钩子传递给 make_early_stopping_hook 方法作为 should_stop_fn 的参数,它可以接受不带任何参数的函数。一旦 should_stop_fn 返回 True,训练就会停止。

下面的示例演示了如何实现将训练时间限制为最多 20 秒的提前停止技术:

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

start_time = time.time()
max_train_seconds = 20

def should_stop_fn():
  return time.time() - start_time > max_train_seconds

early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
    estimator=estimator,
    should_stop_fn=should_stop_fn,
    run_every_secs=1,
    run_every_steps=None)

train_spec = tf1.estimator.TrainSpec(
    input_fn=_input_fn,
    hooks=[early_stopping_hook])

eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)

tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp1_ztd0uk
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp1_ztd0uk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp1_ztd0uk/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp1_ztd0uk/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.3376417, step = 0
INFO:tensorflow:loss = 2.3376417, step = 0
INFO:tensorflow:global_step/sec: 151.62
INFO:tensorflow:global_step/sec: 151.62
INFO:tensorflow:loss = 1.3996981, step = 100 (0.662 sec)
INFO:tensorflow:loss = 1.3996981, step = 100 (0.662 sec)
INFO:tensorflow:global_step/sec: 347.47
INFO:tensorflow:global_step/sec: 347.47
INFO:tensorflow:loss = 0.8579437, step = 200 (0.287 sec)
INFO:tensorflow:loss = 0.8579437, step = 200 (0.287 sec)
INFO:tensorflow:global_step/sec: 412.975
INFO:tensorflow:global_step/sec: 412.975
INFO:tensorflow:loss = 0.75041175, step = 300 (0.242 sec)
INFO:tensorflow:loss = 0.75041175, step = 300 (0.242 sec)
INFO:tensorflow:global_step/sec: 412.011
INFO:tensorflow:global_step/sec: 412.011
INFO:tensorflow:loss = 0.68900204, step = 400 (0.243 sec)
INFO:tensorflow:loss = 0.68900204, step = 400 (0.243 sec)
INFO:tensorflow:global_step/sec: 410.717
INFO:tensorflow:global_step/sec: 410.717
INFO:tensorflow:loss = 0.5152224, step = 500 (0.244 sec)
INFO:tensorflow:loss = 0.5152224, step = 500 (0.244 sec)
INFO:tensorflow:global_step/sec: 510.018
INFO:tensorflow:global_step/sec: 510.018
INFO:tensorflow:loss = 0.4416644, step = 600 (0.196 sec)
INFO:tensorflow:loss = 0.4416644, step = 600 (0.196 sec)
INFO:tensorflow:global_step/sec: 515.045
INFO:tensorflow:global_step/sec: 515.045
INFO:tensorflow:loss = 0.3775224, step = 700 (0.194 sec)
INFO:tensorflow:loss = 0.3775224, step = 700 (0.194 sec)
INFO:tensorflow:global_step/sec: 525.694
INFO:tensorflow:global_step/sec: 525.694
INFO:tensorflow:loss = 0.51862353, step = 800 (0.190 sec)
INFO:tensorflow:loss = 0.51862353, step = 800 (0.190 sec)
INFO:tensorflow:global_step/sec: 529.37
INFO:tensorflow:global_step/sec: 529.37
INFO:tensorflow:loss = 0.39542085, step = 900 (0.189 sec)
INFO:tensorflow:loss = 0.39542085, step = 900 (0.189 sec)
INFO:tensorflow:global_step/sec: 441.878
INFO:tensorflow:global_step/sec: 441.878
INFO:tensorflow:loss = 0.43643415, step = 1000 (0.226 sec)
INFO:tensorflow:loss = 0.43643415, step = 1000 (0.226 sec)
INFO:tensorflow:global_step/sec: 548.7
INFO:tensorflow:global_step/sec: 548.7
INFO:tensorflow:loss = 0.4450134, step = 1100 (0.182 sec)
INFO:tensorflow:loss = 0.4450134, step = 1100 (0.182 sec)
INFO:tensorflow:global_step/sec: 552.721
INFO:tensorflow:global_step/sec: 552.721
INFO:tensorflow:loss = 0.39866757, step = 1200 (0.181 sec)
INFO:tensorflow:loss = 0.39866757, step = 1200 (0.181 sec)
INFO:tensorflow:global_step/sec: 548.957
INFO:tensorflow:global_step/sec: 548.957
INFO:tensorflow:loss = 0.47965074, step = 1300 (0.182 sec)
INFO:tensorflow:loss = 0.47965074, step = 1300 (0.182 sec)
INFO:tensorflow:global_step/sec: 552.235
INFO:tensorflow:global_step/sec: 552.235
INFO:tensorflow:loss = 0.28847384, step = 1400 (0.181 sec)
INFO:tensorflow:loss = 0.28847384, step = 1400 (0.181 sec)
INFO:tensorflow:global_step/sec: 457.179
INFO:tensorflow:global_step/sec: 457.179
INFO:tensorflow:loss = 0.30012336, step = 1500 (0.219 sec)
INFO:tensorflow:loss = 0.30012336, step = 1500 (0.219 sec)
INFO:tensorflow:global_step/sec: 518.519
INFO:tensorflow:global_step/sec: 518.519
INFO:tensorflow:loss = 0.39735657, step = 1600 (0.193 sec)
INFO:tensorflow:loss = 0.39735657, step = 1600 (0.193 sec)
INFO:tensorflow:global_step/sec: 550.054
INFO:tensorflow:global_step/sec: 550.054
INFO:tensorflow:loss = 0.37535176, step = 1700 (0.182 sec)
INFO:tensorflow:loss = 0.37535176, step = 1700 (0.182 sec)
INFO:tensorflow:global_step/sec: 553.481
INFO:tensorflow:global_step/sec: 553.481
INFO:tensorflow:loss = 0.32334548, step = 1800 (0.181 sec)
INFO:tensorflow:loss = 0.32334548, step = 1800 (0.181 sec)
INFO:tensorflow:global_step/sec: 513.917
INFO:tensorflow:global_step/sec: 513.917
INFO:tensorflow:loss = 0.49587888, step = 1900 (0.194 sec)
INFO:tensorflow:loss = 0.49587888, step = 1900 (0.194 sec)
INFO:tensorflow:global_step/sec: 557.853
INFO:tensorflow:global_step/sec: 557.853
INFO:tensorflow:loss = 0.20265089, step = 2000 (0.179 sec)
INFO:tensorflow:loss = 0.20265089, step = 2000 (0.179 sec)
INFO:tensorflow:global_step/sec: 534.383
INFO:tensorflow:global_step/sec: 534.383
INFO:tensorflow:loss = 0.26787725, step = 2100 (0.188 sec)
INFO:tensorflow:loss = 0.26787725, step = 2100 (0.188 sec)
INFO:tensorflow:global_step/sec: 553.182
INFO:tensorflow:global_step/sec: 553.182
INFO:tensorflow:loss = 0.29813218, step = 2200 (0.180 sec)
INFO:tensorflow:loss = 0.29813218, step = 2200 (0.180 sec)
INFO:tensorflow:global_step/sec: 549.305
INFO:tensorflow:global_step/sec: 549.305
INFO:tensorflow:loss = 0.33344537, step = 2300 (0.182 sec)
INFO:tensorflow:loss = 0.33344537, step = 2300 (0.182 sec)
INFO:tensorflow:global_step/sec: 426.45
INFO:tensorflow:global_step/sec: 426.45
INFO:tensorflow:loss = 0.25255543, step = 2400 (0.235 sec)
INFO:tensorflow:loss = 0.25255543, step = 2400 (0.235 sec)
INFO:tensorflow:global_step/sec: 542.406
INFO:tensorflow:global_step/sec: 542.406
INFO:tensorflow:loss = 0.21495241, step = 2500 (0.184 sec)
INFO:tensorflow:loss = 0.21495241, step = 2500 (0.184 sec)
INFO:tensorflow:global_step/sec: 557.665
INFO:tensorflow:global_step/sec: 557.665
INFO:tensorflow:loss = 0.16201726, step = 2600 (0.179 sec)
INFO:tensorflow:loss = 0.16201726, step = 2600 (0.179 sec)
INFO:tensorflow:global_step/sec: 553.981
INFO:tensorflow:global_step/sec: 553.981
INFO:tensorflow:loss = 0.29901528, step = 2700 (0.181 sec)
INFO:tensorflow:loss = 0.29901528, step = 2700 (0.181 sec)
INFO:tensorflow:global_step/sec: 544.654
INFO:tensorflow:global_step/sec: 544.654
INFO:tensorflow:loss = 0.47195303, step = 2800 (0.184 sec)
INFO:tensorflow:loss = 0.47195303, step = 2800 (0.184 sec)
INFO:tensorflow:global_step/sec: 437.786
INFO:tensorflow:global_step/sec: 437.786
INFO:tensorflow:loss = 0.23737015, step = 2900 (0.228 sec)
INFO:tensorflow:loss = 0.23737015, step = 2900 (0.228 sec)
INFO:tensorflow:global_step/sec: 558.157
INFO:tensorflow:global_step/sec: 558.157
INFO:tensorflow:loss = 0.3160169, step = 3000 (0.179 sec)
INFO:tensorflow:loss = 0.3160169, step = 3000 (0.179 sec)
INFO:tensorflow:global_step/sec: 569.531
INFO:tensorflow:global_step/sec: 569.531
INFO:tensorflow:loss = 0.20572095, step = 3100 (0.176 sec)
INFO:tensorflow:loss = 0.20572095, step = 3100 (0.176 sec)
INFO:tensorflow:global_step/sec: 560.12
INFO:tensorflow:global_step/sec: 560.12
INFO:tensorflow:loss = 0.41672167, step = 3200 (0.179 sec)
INFO:tensorflow:loss = 0.41672167, step = 3200 (0.179 sec)
INFO:tensorflow:global_step/sec: 515.62
INFO:tensorflow:global_step/sec: 515.62
INFO:tensorflow:loss = 0.3320045, step = 3300 (0.194 sec)
INFO:tensorflow:loss = 0.3320045, step = 3300 (0.194 sec)
INFO:tensorflow:global_step/sec: 505.927
INFO:tensorflow:global_step/sec: 505.927
INFO:tensorflow:loss = 0.2794531, step = 3400 (0.198 sec)
INFO:tensorflow:loss = 0.2794531, step = 3400 (0.198 sec)
INFO:tensorflow:global_step/sec: 543.322
INFO:tensorflow:global_step/sec: 543.322
INFO:tensorflow:loss = 0.21261992, step = 3500 (0.184 sec)
INFO:tensorflow:loss = 0.21261992, step = 3500 (0.184 sec)
INFO:tensorflow:global_step/sec: 480.85
INFO:tensorflow:global_step/sec: 480.85
INFO:tensorflow:loss = 0.22373354, step = 3600 (0.207 sec)
INFO:tensorflow:loss = 0.22373354, step = 3600 (0.207 sec)
INFO:tensorflow:global_step/sec: 466.387
INFO:tensorflow:global_step/sec: 466.387
INFO:tensorflow:loss = 0.27012455, step = 3700 (0.215 sec)
INFO:tensorflow:loss = 0.27012455, step = 3700 (0.215 sec)
INFO:tensorflow:global_step/sec: 525.974
INFO:tensorflow:global_step/sec: 525.974
INFO:tensorflow:loss = 0.35595867, step = 3800 (0.190 sec)
INFO:tensorflow:loss = 0.35595867, step = 3800 (0.190 sec)
INFO:tensorflow:global_step/sec: 554.295
INFO:tensorflow:global_step/sec: 554.295
INFO:tensorflow:loss = 0.2105535, step = 3900 (0.181 sec)
INFO:tensorflow:loss = 0.2105535, step = 3900 (0.181 sec)
INFO:tensorflow:global_step/sec: 564.572
INFO:tensorflow:global_step/sec: 564.572
INFO:tensorflow:loss = 0.302787, step = 4000 (0.177 sec)
INFO:tensorflow:loss = 0.302787, step = 4000 (0.177 sec)
INFO:tensorflow:global_step/sec: 485.176
INFO:tensorflow:global_step/sec: 485.176
INFO:tensorflow:loss = 0.19806938, step = 4100 (0.206 sec)
INFO:tensorflow:loss = 0.19806938, step = 4100 (0.206 sec)
INFO:tensorflow:global_step/sec: 506.325
INFO:tensorflow:global_step/sec: 506.325
INFO:tensorflow:loss = 0.26416677, step = 4200 (0.198 sec)
INFO:tensorflow:loss = 0.26416677, step = 4200 (0.198 sec)
INFO:tensorflow:global_step/sec: 472.743
INFO:tensorflow:global_step/sec: 472.743
INFO:tensorflow:loss = 0.2805268, step = 4300 (0.213 sec)
INFO:tensorflow:loss = 0.2805268, step = 4300 (0.213 sec)
INFO:tensorflow:global_step/sec: 550.783
INFO:tensorflow:global_step/sec: 550.783
INFO:tensorflow:loss = 0.30882898, step = 4400 (0.181 sec)
INFO:tensorflow:loss = 0.30882898, step = 4400 (0.181 sec)
INFO:tensorflow:global_step/sec: 542.109
INFO:tensorflow:global_step/sec: 542.109
INFO:tensorflow:loss = 0.25414428, step = 4500 (0.185 sec)
INFO:tensorflow:loss = 0.25414428, step = 4500 (0.185 sec)
INFO:tensorflow:global_step/sec: 549.018
INFO:tensorflow:global_step/sec: 549.018
INFO:tensorflow:loss = 0.3157362, step = 4600 (0.182 sec)
INFO:tensorflow:loss = 0.3157362, step = 4600 (0.182 sec)
INFO:tensorflow:global_step/sec: 457.956
INFO:tensorflow:global_step/sec: 457.956
INFO:tensorflow:loss = 0.14194128, step = 4700 (0.218 sec)
INFO:tensorflow:loss = 0.14194128, step = 4700 (0.218 sec)
INFO:tensorflow:global_step/sec: 489.526
INFO:tensorflow:global_step/sec: 489.526
INFO:tensorflow:loss = 0.27310932, step = 4800 (0.205 sec)
INFO:tensorflow:loss = 0.27310932, step = 4800 (0.205 sec)
INFO:tensorflow:global_step/sec: 548.936
INFO:tensorflow:global_step/sec: 548.936
INFO:tensorflow:loss = 0.35096985, step = 4900 (0.181 sec)
INFO:tensorflow:loss = 0.35096985, step = 4900 (0.181 sec)
INFO:tensorflow:global_step/sec: 552.546
INFO:tensorflow:global_step/sec: 552.546
INFO:tensorflow:loss = 0.26049778, step = 5000 (0.181 sec)
INFO:tensorflow:loss = 0.26049778, step = 5000 (0.181 sec)
INFO:tensorflow:global_step/sec: 563.492
INFO:tensorflow:global_step/sec: 563.492
INFO:tensorflow:loss = 0.31700188, step = 5100 (0.178 sec)
INFO:tensorflow:loss = 0.31700188, step = 5100 (0.178 sec)
INFO:tensorflow:global_step/sec: 446.663
INFO:tensorflow:global_step/sec: 446.663
INFO:tensorflow:loss = 0.21455708, step = 5200 (0.223 sec)
INFO:tensorflow:loss = 0.21455708, step = 5200 (0.223 sec)
INFO:tensorflow:global_step/sec: 530.86
INFO:tensorflow:global_step/sec: 530.86
INFO:tensorflow:loss = 0.23256294, step = 5300 (0.189 sec)
INFO:tensorflow:loss = 0.23256294, step = 5300 (0.189 sec)
INFO:tensorflow:global_step/sec: 541.861
INFO:tensorflow:global_step/sec: 541.861
INFO:tensorflow:loss = 0.15113032, step = 5400 (0.184 sec)
INFO:tensorflow:loss = 0.15113032, step = 5400 (0.184 sec)
INFO:tensorflow:global_step/sec: 555.169
INFO:tensorflow:global_step/sec: 555.169
INFO:tensorflow:loss = 0.23720858, step = 5500 (0.180 sec)
INFO:tensorflow:loss = 0.23720858, step = 5500 (0.180 sec)
INFO:tensorflow:global_step/sec: 552.955
INFO:tensorflow:global_step/sec: 552.955
INFO:tensorflow:loss = 0.17804441, step = 5600 (0.182 sec)
INFO:tensorflow:loss = 0.17804441, step = 5600 (0.182 sec)
INFO:tensorflow:global_step/sec: 456.66
INFO:tensorflow:global_step/sec: 456.66
INFO:tensorflow:loss = 0.16684741, step = 5700 (0.219 sec)
INFO:tensorflow:loss = 0.16684741, step = 5700 (0.219 sec)
INFO:tensorflow:global_step/sec: 459.406
INFO:tensorflow:global_step/sec: 459.406
INFO:tensorflow:loss = 0.27630642, step = 5800 (0.217 sec)
INFO:tensorflow:loss = 0.27630642, step = 5800 (0.217 sec)
INFO:tensorflow:global_step/sec: 553.192
INFO:tensorflow:global_step/sec: 553.192
INFO:tensorflow:loss = 0.21493328, step = 5900 (0.181 sec)
INFO:tensorflow:loss = 0.21493328, step = 5900 (0.181 sec)
INFO:tensorflow:global_step/sec: 520.514
INFO:tensorflow:global_step/sec: 520.514
INFO:tensorflow:loss = 0.24653763, step = 6000 (0.192 sec)
INFO:tensorflow:loss = 0.24653763, step = 6000 (0.192 sec)
INFO:tensorflow:global_step/sec: 434.322
INFO:tensorflow:global_step/sec: 434.322
INFO:tensorflow:loss = 0.18108746, step = 6100 (0.230 sec)
INFO:tensorflow:loss = 0.18108746, step = 6100 (0.230 sec)
INFO:tensorflow:global_step/sec: 542.961
INFO:tensorflow:global_step/sec: 542.961
INFO:tensorflow:loss = 0.25019443, step = 6200 (0.185 sec)
INFO:tensorflow:loss = 0.25019443, step = 6200 (0.185 sec)
INFO:tensorflow:global_step/sec: 551.722
INFO:tensorflow:global_step/sec: 551.722
INFO:tensorflow:loss = 0.23493612, step = 6300 (0.181 sec)
INFO:tensorflow:loss = 0.23493612, step = 6300 (0.181 sec)
INFO:tensorflow:global_step/sec: 546.622
INFO:tensorflow:global_step/sec: 546.622
INFO:tensorflow:loss = 0.28968674, step = 6400 (0.183 sec)
INFO:tensorflow:loss = 0.28968674, step = 6400 (0.183 sec)
INFO:tensorflow:global_step/sec: 397.399
INFO:tensorflow:global_step/sec: 397.399
INFO:tensorflow:loss = 0.25742713, step = 6500 (0.252 sec)
INFO:tensorflow:loss = 0.25742713, step = 6500 (0.252 sec)
INFO:tensorflow:global_step/sec: 501.325
INFO:tensorflow:global_step/sec: 501.325
INFO:tensorflow:loss = 0.19955175, step = 6600 (0.199 sec)
INFO:tensorflow:loss = 0.19955175, step = 6600 (0.199 sec)
INFO:tensorflow:global_step/sec: 504.54
INFO:tensorflow:global_step/sec: 504.54
INFO:tensorflow:loss = 0.23529533, step = 6700 (0.198 sec)
INFO:tensorflow:loss = 0.23529533, step = 6700 (0.198 sec)
INFO:tensorflow:global_step/sec: 516.526
INFO:tensorflow:global_step/sec: 516.526
INFO:tensorflow:loss = 0.391949, step = 6800 (0.194 sec)
INFO:tensorflow:loss = 0.391949, step = 6800 (0.194 sec)
INFO:tensorflow:global_step/sec: 554.826
INFO:tensorflow:global_step/sec: 554.826
INFO:tensorflow:loss = 0.1410023, step = 6900 (0.181 sec)
INFO:tensorflow:loss = 0.1410023, step = 6900 (0.181 sec)
INFO:tensorflow:global_step/sec: 560.983
INFO:tensorflow:global_step/sec: 560.983
INFO:tensorflow:loss = 0.32429492, step = 7000 (0.178 sec)
INFO:tensorflow:loss = 0.32429492, step = 7000 (0.178 sec)
INFO:tensorflow:global_step/sec: 525.523
INFO:tensorflow:global_step/sec: 525.523
INFO:tensorflow:loss = 0.19833161, step = 7100 (0.190 sec)
INFO:tensorflow:loss = 0.19833161, step = 7100 (0.190 sec)
INFO:tensorflow:global_step/sec: 558.047
INFO:tensorflow:global_step/sec: 558.047
INFO:tensorflow:loss = 0.24470912, step = 7200 (0.179 sec)
INFO:tensorflow:loss = 0.24470912, step = 7200 (0.179 sec)
INFO:tensorflow:global_step/sec: 557.232
INFO:tensorflow:global_step/sec: 557.232
INFO:tensorflow:loss = 0.1755924, step = 7300 (0.180 sec)
INFO:tensorflow:loss = 0.1755924, step = 7300 (0.180 sec)
INFO:tensorflow:global_step/sec: 559.329
INFO:tensorflow:global_step/sec: 559.329
INFO:tensorflow:loss = 0.24296087, step = 7400 (0.179 sec)
INFO:tensorflow:loss = 0.24296087, step = 7400 (0.179 sec)
INFO:tensorflow:Requesting early stopping at global step 7427
INFO:tensorflow:Requesting early stopping at global step 7427
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7428...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7428...
INFO:tensorflow:Saving checkpoints for 7428 into /tmpfs/tmp/tmp1_ztd0uk/model.ckpt.
INFO:tensorflow:Saving checkpoints for 7428 into /tmpfs/tmp/tmp1_ztd0uk/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7428...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7428...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:45:51
INFO:tensorflow:Starting evaluation at 2022-12-14T20:45:51
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp1_ztd0uk/model.ckpt-7428
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp1_ztd0uk/model.ckpt-7428
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 1.10429s
INFO:tensorflow:Inference Time : 1.10429s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:45:52
INFO:tensorflow:Finished evaluation at 2022-12-14-20:45:52
INFO:tensorflow:Saving dict for global step 7428: global_step = 7428, loss = 0.2138084
INFO:tensorflow:Saving dict for global step 7428: global_step = 7428, loss = 0.2138084
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 7428: /tmpfs/tmp/tmp1_ztd0uk/model.ckpt-7428
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 7428: /tmpfs/tmp/tmp1_ztd0uk/model.ckpt-7428
INFO:tensorflow:Loss for final step: 0.18073495.
INFO:tensorflow:Loss for final step: 0.18073495.
({'loss': 0.2138084, 'global_step': 7428}, [])

TensorFlow 2:使用内置回调和 Model.fit 提前停止

准备 MNIST 数据集和一个简单的 Keras 模型:

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.005),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

在 TensorFlow 2 中,当您使用内置的 Keras Model.fit(或 Model.evaluate)时,可以通过将内置回调 tf.keras.callbacks.EarlyStopping 传递给 Model.fitcallbacks 参数来配置提前停止。

EarlyStopping 回调会监视用户指定的指标,并在停止改进时结束训练。(请查看使用内置方法进行训练和评估API 文档来了解详情。)

下面是一个提前停止回调的示例,它监视损失并在显示没有改进的周期数设置为 3 (patience) 后停止训练:

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)

len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 4s 5ms/step - loss: 0.2293 - sparse_categorical_accuracy: 0.9324 - val_loss: 0.1224 - val_sparse_categorical_accuracy: 0.9636
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0983 - sparse_categorical_accuracy: 0.9696 - val_loss: 0.0962 - val_sparse_categorical_accuracy: 0.9721
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0679 - sparse_categorical_accuracy: 0.9787 - val_loss: 0.1048 - val_sparse_categorical_accuracy: 0.9671
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0513 - sparse_categorical_accuracy: 0.9838 - val_loss: 0.0930 - val_sparse_categorical_accuracy: 0.9733
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0397 - sparse_categorical_accuracy: 0.9869 - val_loss: 0.1034 - val_sparse_categorical_accuracy: 0.9718
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0353 - sparse_categorical_accuracy: 0.9885 - val_loss: 0.0994 - val_sparse_categorical_accuracy: 0.9747
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0332 - sparse_categorical_accuracy: 0.9888 - val_loss: 0.1131 - val_sparse_categorical_accuracy: 0.9736
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0277 - sparse_categorical_accuracy: 0.9902 - val_loss: 0.1190 - val_sparse_categorical_accuracy: 0.9747
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0238 - sparse_categorical_accuracy: 0.9919 - val_loss: 0.1157 - val_sparse_categorical_accuracy: 0.9747
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0225 - sparse_categorical_accuracy: 0.9923 - val_loss: 0.1235 - val_sparse_categorical_accuracy: 0.9743
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0237 - sparse_categorical_accuracy: 0.9917 - val_loss: 0.1359 - val_sparse_categorical_accuracy: 0.9751
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0197 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1446 - val_sparse_categorical_accuracy: 0.9756
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0243 - sparse_categorical_accuracy: 0.9925 - val_loss: 0.1542 - val_sparse_categorical_accuracy: 0.9732
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0174 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.1482 - val_sparse_categorical_accuracy: 0.9749
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9940 - val_loss: 0.1308 - val_sparse_categorical_accuracy: 0.9770
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0199 - sparse_categorical_accuracy: 0.9936 - val_loss: 0.1596 - val_sparse_categorical_accuracy: 0.9738
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0173 - sparse_categorical_accuracy: 0.9945 - val_loss: 0.1918 - val_sparse_categorical_accuracy: 0.9742
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0166 - sparse_categorical_accuracy: 0.9945 - val_loss: 0.1604 - val_sparse_categorical_accuracy: 0.9740
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0154 - sparse_categorical_accuracy: 0.9948 - val_loss: 0.1554 - val_sparse_categorical_accuracy: 0.9795
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0171 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.1815 - val_sparse_categorical_accuracy: 0.9756
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0151 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.1638 - val_sparse_categorical_accuracy: 0.9796
Epoch 22/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0164 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1840 - val_sparse_categorical_accuracy: 0.9760
Epoch 23/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0171 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.1705 - val_sparse_categorical_accuracy: 0.9789
Epoch 24/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.1860 - val_sparse_categorical_accuracy: 0.9773
Epoch 25/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0123 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.1950 - val_sparse_categorical_accuracy: 0.9758
Epoch 26/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.1777 - val_sparse_categorical_accuracy: 0.9792
Epoch 27/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0111 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.1930 - val_sparse_categorical_accuracy: 0.9780
Epoch 28/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.1942 - val_sparse_categorical_accuracy: 0.9773
Epoch 29/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0138 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2088 - val_sparse_categorical_accuracy: 0.9763
Epoch 30/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2386 - val_sparse_categorical_accuracy: 0.9760
30

TensorFlow 2:使用自定义回调和 Model.fit 提前停止

您也可以实现自定义的提前停止回调,此回调也可以传递给 Model.fit(或 Model.evaluate)的 callbacks 参数。

在此示例中,一旦 self.model.stop_training 设置为 True,训练过程就会停止:

class LimitTrainingTime(tf.keras.callbacks.Callback):
  def __init__(self, max_time_s):
    super().__init__()
    self.max_time_s = max_time_s
    self.start_time = None

  def on_train_begin(self, logs):
    self.start_time = time.time()

  def on_train_batch_end(self, batch, logs):
    now = time.time()
    if now - self.start_time >  self.max_time_s:
      self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0130 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2281 - val_sparse_categorical_accuracy: 0.9742
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0103 - sparse_categorical_accuracy: 0.9973 - val_loss: 0.2193 - val_sparse_categorical_accuracy: 0.9778
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2398 - val_sparse_categorical_accuracy: 0.9762
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0176 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.2247 - val_sparse_categorical_accuracy: 0.9791
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0065 - sparse_categorical_accuracy: 0.9980 - val_loss: 0.2441 - val_sparse_categorical_accuracy: 0.9759
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2602 - val_sparse_categorical_accuracy: 0.9761
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0134 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.2380 - val_sparse_categorical_accuracy: 0.9770
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0113 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2523 - val_sparse_categorical_accuracy: 0.9779
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0106 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.2976 - val_sparse_categorical_accuracy: 0.9758
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0125 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2725 - val_sparse_categorical_accuracy: 0.9767
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0172 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.3325 - val_sparse_categorical_accuracy: 0.9747
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0138 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.3081 - val_sparse_categorical_accuracy: 0.9756
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0093 - sparse_categorical_accuracy: 0.9977 - val_loss: 0.2600 - val_sparse_categorical_accuracy: 0.9796
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0083 - sparse_categorical_accuracy: 0.9978 - val_loss: 0.2636 - val_sparse_categorical_accuracy: 0.9805
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.3270 - val_sparse_categorical_accuracy: 0.9750
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0147 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2812 - val_sparse_categorical_accuracy: 0.9777
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0090 - sparse_categorical_accuracy: 0.9979 - val_loss: 0.3137 - val_sparse_categorical_accuracy: 0.9779
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0113 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2774 - val_sparse_categorical_accuracy: 0.9779
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0073 - sparse_categorical_accuracy: 0.9982 - val_loss: 0.3150 - val_sparse_categorical_accuracy: 0.9780
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0072 - sparse_categorical_accuracy: 0.9982 - val_loss: 0.3140 - val_sparse_categorical_accuracy: 0.9791
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.3382 - val_sparse_categorical_accuracy: 0.9766
Epoch 22/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0108 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.3149 - val_sparse_categorical_accuracy: 0.9772
Epoch 23/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0122 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.3227 - val_sparse_categorical_accuracy: 0.9761
23

TensorFlow 2:使用自定义训练循环提前停止

在 TensorFlow 2 中,如果您不使用内置 Keras 方法进行训练和评估,则可以在自定义训练循环中实现提前停止。

首先,使用 Keras API 定义另一个简单的模型、优化器、损失函数和指标:

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()

使用 tf.GradientTape@tf.function 装饰器定义参数更新函数以加快速度

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
      logits = model(x, training=True)
      loss_value = loss_fn(y, logits)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  train_acc_metric.update_state(y, logits)
  train_loss_metric.update_state(y, logits)
  return loss_value

@tf.function
def test_step(x, y):
  logits = model(x, training=False)
  val_acc_metric.update_state(y, logits)
  val_loss_metric.update_state(y, logits)

接下来,编写一个自定义训练循环,可以在其中手动实现提前停止规则。

下面的示例显示了当验证损失在一定数量的周期内没有改进时如何停止训练:

epochs = 100
patience = 5
wait = 0
best = 0

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
      if step % 200 == 0:
        print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
        print("Seen so far: %s samples" % ((step + 1) * 128))        
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))

    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss > best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
Start of epoch 0
Training loss at step 0: 2.3400
Seen so far: 128 samples
Training loss at step 200: 0.2146
Seen so far: 25728 samples
Training loss at step 400: 0.2027
Seen so far: 51328 samples
Training acc over epoch: 0.9319
Validation acc: 0.9621
Time taken: 2.04s

Start of epoch 1
Training loss at step 0: 0.0858
Seen so far: 128 samples
Training loss at step 200: 0.1738
Seen so far: 25728 samples
Training loss at step 400: 0.1355
Seen so far: 51328 samples
Training acc over epoch: 0.9704
Validation acc: 0.9697
Time taken: 1.09s

Start of epoch 2
Training loss at step 0: 0.0521
Seen so far: 128 samples
Training loss at step 200: 0.1280
Seen so far: 25728 samples
Training loss at step 400: 0.0969
Seen so far: 51328 samples
Training acc over epoch: 0.9786
Validation acc: 0.9715
Time taken: 1.05s

Start of epoch 3
Training loss at step 0: 0.0667
Seen so far: 128 samples
Training loss at step 200: 0.0836
Seen so far: 25728 samples
Training loss at step 400: 0.0493
Seen so far: 51328 samples
Training acc over epoch: 0.9825
Validation acc: 0.9669
Time taken: 1.05s

Start of epoch 4
Training loss at step 0: 0.0226
Seen so far: 128 samples
Training loss at step 200: 0.0580
Seen so far: 25728 samples
Training loss at step 400: 0.0669
Seen so far: 51328 samples
Training acc over epoch: 0.9859
Validation acc: 0.9689
Time taken: 1.09s

Start of epoch 5
Training loss at step 0: 0.0274
Seen so far: 128 samples
Training loss at step 200: 0.0540
Seen so far: 25728 samples
Training loss at step 400: 0.0745
Seen so far: 51328 samples
Training acc over epoch: 0.9880
Validation acc: 0.9702
Time taken: 1.06s

后续步骤