在 TensorFlow.org 上查看 | 在 Google Colab 运行 | 在 Github 上查看源代码 | 下载笔记本 |
本笔记本演示了如何使用提前停止设置模型训练。首先,在 TensorFlow 1 中使用 tf.estimator.Estimator
和提前停止钩子,然后在 TensorFlow 2 中使用 Keras API 或自定义训练循环。 提前停止是一种正则化技术,可在验证损失达到特定阈值时停止训练。
在 TensorFlow 2 中,可以通过三种方式实现提前停止:
- 使用内置的 Keras 回调
tf.keras.callbacks.EarlyStopping
并将其传递给Model.fit
。 - 定义自定义回调并将其传递给 Keras
Model.fit
。 - 在自定义训练循环中编写自定义提前停止规则(使用
tf.GradientTape
)。
安装
import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2022-12-14 20:45:28.744945: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:45:28.745036: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:45:28.745045: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
TensorFlow 1:使用提前停止钩子和 tf.estimator 提前停止
首先,定义用于 MNIST 数据集加载和预处理的函数,以及与 tf.estimator.Estimator
一起使用的模型定义:
def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255., label
def _input_fn():
ds_train = tfds.load(
name='mnist',
split='train',
shuffle_files=True,
as_supervised=True)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_train = ds_train.repeat(100)
return ds_train
def _eval_input_fn():
ds_test = tfds.load(
name='mnist',
split='test',
shuffle_files=True,
as_supervised=True)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
return ds_test
def _model_fn(features, labels, mode):
flatten = tf1.layers.Flatten()(features)
features = tf1.layers.Dense(128, 'relu')(flatten)
logits = tf1.layers.Dense(10)(features)
loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf1.train.AdagradOptimizer(0.005)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
在 TensorFlow 1 中,提前停止的工作方式是使用 tf.estimator.experimental.make_early_stopping_hook
设置提前停止钩子。将钩子传递给 make_early_stopping_hook
方法作为 should_stop_fn
的参数,它可以接受不带任何参数的函数。一旦 should_stop_fn
返回 True
,训练就会停止。
下面的示例演示了如何实现将训练时间限制为最多 20 秒的提前停止技术:
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
start_time = time.time()
max_train_seconds = 20
def should_stop_fn():
return time.time() - start_time > max_train_seconds
early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
estimator=estimator,
should_stop_fn=should_stop_fn,
run_every_secs=1,
run_every_steps=None)
train_spec = tf1.estimator.TrainSpec(
input_fn=_input_fn,
hooks=[early_stopping_hook])
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp1_ztd0uk INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp1_ztd0uk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp1_ztd0uk/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp1_ztd0uk/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 2.3376417, step = 0 INFO:tensorflow:loss = 2.3376417, step = 0 INFO:tensorflow:global_step/sec: 151.62 INFO:tensorflow:global_step/sec: 151.62 INFO:tensorflow:loss = 1.3996981, step = 100 (0.662 sec) INFO:tensorflow:loss = 1.3996981, step = 100 (0.662 sec) INFO:tensorflow:global_step/sec: 347.47 INFO:tensorflow:global_step/sec: 347.47 INFO:tensorflow:loss = 0.8579437, step = 200 (0.287 sec) INFO:tensorflow:loss = 0.8579437, step = 200 (0.287 sec) INFO:tensorflow:global_step/sec: 412.975 INFO:tensorflow:global_step/sec: 412.975 INFO:tensorflow:loss = 0.75041175, step = 300 (0.242 sec) INFO:tensorflow:loss = 0.75041175, step = 300 (0.242 sec) INFO:tensorflow:global_step/sec: 412.011 INFO:tensorflow:global_step/sec: 412.011 INFO:tensorflow:loss = 0.68900204, step = 400 (0.243 sec) INFO:tensorflow:loss = 0.68900204, step = 400 (0.243 sec) INFO:tensorflow:global_step/sec: 410.717 INFO:tensorflow:global_step/sec: 410.717 INFO:tensorflow:loss = 0.5152224, step = 500 (0.244 sec) INFO:tensorflow:loss = 0.5152224, step = 500 (0.244 sec) INFO:tensorflow:global_step/sec: 510.018 INFO:tensorflow:global_step/sec: 510.018 INFO:tensorflow:loss = 0.4416644, step = 600 (0.196 sec) INFO:tensorflow:loss = 0.4416644, step = 600 (0.196 sec) INFO:tensorflow:global_step/sec: 515.045 INFO:tensorflow:global_step/sec: 515.045 INFO:tensorflow:loss = 0.3775224, step = 700 (0.194 sec) INFO:tensorflow:loss = 0.3775224, step = 700 (0.194 sec) INFO:tensorflow:global_step/sec: 525.694 INFO:tensorflow:global_step/sec: 525.694 INFO:tensorflow:loss = 0.51862353, step = 800 (0.190 sec) INFO:tensorflow:loss = 0.51862353, step = 800 (0.190 sec) INFO:tensorflow:global_step/sec: 529.37 INFO:tensorflow:global_step/sec: 529.37 INFO:tensorflow:loss = 0.39542085, step = 900 (0.189 sec) INFO:tensorflow:loss = 0.39542085, step = 900 (0.189 sec) INFO:tensorflow:global_step/sec: 441.878 INFO:tensorflow:global_step/sec: 441.878 INFO:tensorflow:loss = 0.43643415, step = 1000 (0.226 sec) INFO:tensorflow:loss = 0.43643415, step = 1000 (0.226 sec) INFO:tensorflow:global_step/sec: 548.7 INFO:tensorflow:global_step/sec: 548.7 INFO:tensorflow:loss = 0.4450134, step = 1100 (0.182 sec) INFO:tensorflow:loss = 0.4450134, step = 1100 (0.182 sec) INFO:tensorflow:global_step/sec: 552.721 INFO:tensorflow:global_step/sec: 552.721 INFO:tensorflow:loss = 0.39866757, step = 1200 (0.181 sec) INFO:tensorflow:loss = 0.39866757, step = 1200 (0.181 sec) INFO:tensorflow:global_step/sec: 548.957 INFO:tensorflow:global_step/sec: 548.957 INFO:tensorflow:loss = 0.47965074, step = 1300 (0.182 sec) INFO:tensorflow:loss = 0.47965074, step = 1300 (0.182 sec) INFO:tensorflow:global_step/sec: 552.235 INFO:tensorflow:global_step/sec: 552.235 INFO:tensorflow:loss = 0.28847384, step = 1400 (0.181 sec) INFO:tensorflow:loss = 0.28847384, step = 1400 (0.181 sec) INFO:tensorflow:global_step/sec: 457.179 INFO:tensorflow:global_step/sec: 457.179 INFO:tensorflow:loss = 0.30012336, step = 1500 (0.219 sec) INFO:tensorflow:loss = 0.30012336, step = 1500 (0.219 sec) INFO:tensorflow:global_step/sec: 518.519 INFO:tensorflow:global_step/sec: 518.519 INFO:tensorflow:loss = 0.39735657, step = 1600 (0.193 sec) INFO:tensorflow:loss = 0.39735657, step = 1600 (0.193 sec) INFO:tensorflow:global_step/sec: 550.054 INFO:tensorflow:global_step/sec: 550.054 INFO:tensorflow:loss = 0.37535176, step = 1700 (0.182 sec) INFO:tensorflow:loss = 0.37535176, step = 1700 (0.182 sec) INFO:tensorflow:global_step/sec: 553.481 INFO:tensorflow:global_step/sec: 553.481 INFO:tensorflow:loss = 0.32334548, step = 1800 (0.181 sec) INFO:tensorflow:loss = 0.32334548, step = 1800 (0.181 sec) INFO:tensorflow:global_step/sec: 513.917 INFO:tensorflow:global_step/sec: 513.917 INFO:tensorflow:loss = 0.49587888, step = 1900 (0.194 sec) INFO:tensorflow:loss = 0.49587888, step = 1900 (0.194 sec) INFO:tensorflow:global_step/sec: 557.853 INFO:tensorflow:global_step/sec: 557.853 INFO:tensorflow:loss = 0.20265089, step = 2000 (0.179 sec) INFO:tensorflow:loss = 0.20265089, step = 2000 (0.179 sec) INFO:tensorflow:global_step/sec: 534.383 INFO:tensorflow:global_step/sec: 534.383 INFO:tensorflow:loss = 0.26787725, step = 2100 (0.188 sec) INFO:tensorflow:loss = 0.26787725, step = 2100 (0.188 sec) INFO:tensorflow:global_step/sec: 553.182 INFO:tensorflow:global_step/sec: 553.182 INFO:tensorflow:loss = 0.29813218, step = 2200 (0.180 sec) INFO:tensorflow:loss = 0.29813218, step = 2200 (0.180 sec) INFO:tensorflow:global_step/sec: 549.305 INFO:tensorflow:global_step/sec: 549.305 INFO:tensorflow:loss = 0.33344537, step = 2300 (0.182 sec) INFO:tensorflow:loss = 0.33344537, step = 2300 (0.182 sec) INFO:tensorflow:global_step/sec: 426.45 INFO:tensorflow:global_step/sec: 426.45 INFO:tensorflow:loss = 0.25255543, step = 2400 (0.235 sec) INFO:tensorflow:loss = 0.25255543, step = 2400 (0.235 sec) INFO:tensorflow:global_step/sec: 542.406 INFO:tensorflow:global_step/sec: 542.406 INFO:tensorflow:loss = 0.21495241, step = 2500 (0.184 sec) INFO:tensorflow:loss = 0.21495241, step = 2500 (0.184 sec) INFO:tensorflow:global_step/sec: 557.665 INFO:tensorflow:global_step/sec: 557.665 INFO:tensorflow:loss = 0.16201726, step = 2600 (0.179 sec) INFO:tensorflow:loss = 0.16201726, step = 2600 (0.179 sec) INFO:tensorflow:global_step/sec: 553.981 INFO:tensorflow:global_step/sec: 553.981 INFO:tensorflow:loss = 0.29901528, step = 2700 (0.181 sec) INFO:tensorflow:loss = 0.29901528, step = 2700 (0.181 sec) INFO:tensorflow:global_step/sec: 544.654 INFO:tensorflow:global_step/sec: 544.654 INFO:tensorflow:loss = 0.47195303, step = 2800 (0.184 sec) INFO:tensorflow:loss = 0.47195303, step = 2800 (0.184 sec) INFO:tensorflow:global_step/sec: 437.786 INFO:tensorflow:global_step/sec: 437.786 INFO:tensorflow:loss = 0.23737015, step = 2900 (0.228 sec) INFO:tensorflow:loss = 0.23737015, step = 2900 (0.228 sec) INFO:tensorflow:global_step/sec: 558.157 INFO:tensorflow:global_step/sec: 558.157 INFO:tensorflow:loss = 0.3160169, step = 3000 (0.179 sec) INFO:tensorflow:loss = 0.3160169, step = 3000 (0.179 sec) INFO:tensorflow:global_step/sec: 569.531 INFO:tensorflow:global_step/sec: 569.531 INFO:tensorflow:loss = 0.20572095, step = 3100 (0.176 sec) INFO:tensorflow:loss = 0.20572095, step = 3100 (0.176 sec) INFO:tensorflow:global_step/sec: 560.12 INFO:tensorflow:global_step/sec: 560.12 INFO:tensorflow:loss = 0.41672167, step = 3200 (0.179 sec) INFO:tensorflow:loss = 0.41672167, step = 3200 (0.179 sec) INFO:tensorflow:global_step/sec: 515.62 INFO:tensorflow:global_step/sec: 515.62 INFO:tensorflow:loss = 0.3320045, step = 3300 (0.194 sec) INFO:tensorflow:loss = 0.3320045, step = 3300 (0.194 sec) INFO:tensorflow:global_step/sec: 505.927 INFO:tensorflow:global_step/sec: 505.927 INFO:tensorflow:loss = 0.2794531, step = 3400 (0.198 sec) INFO:tensorflow:loss = 0.2794531, step = 3400 (0.198 sec) INFO:tensorflow:global_step/sec: 543.322 INFO:tensorflow:global_step/sec: 543.322 INFO:tensorflow:loss = 0.21261992, step = 3500 (0.184 sec) INFO:tensorflow:loss = 0.21261992, step = 3500 (0.184 sec) INFO:tensorflow:global_step/sec: 480.85 INFO:tensorflow:global_step/sec: 480.85 INFO:tensorflow:loss = 0.22373354, step = 3600 (0.207 sec) INFO:tensorflow:loss = 0.22373354, step = 3600 (0.207 sec) INFO:tensorflow:global_step/sec: 466.387 INFO:tensorflow:global_step/sec: 466.387 INFO:tensorflow:loss = 0.27012455, step = 3700 (0.215 sec) INFO:tensorflow:loss = 0.27012455, step = 3700 (0.215 sec) INFO:tensorflow:global_step/sec: 525.974 INFO:tensorflow:global_step/sec: 525.974 INFO:tensorflow:loss = 0.35595867, step = 3800 (0.190 sec) INFO:tensorflow:loss = 0.35595867, step = 3800 (0.190 sec) INFO:tensorflow:global_step/sec: 554.295 INFO:tensorflow:global_step/sec: 554.295 INFO:tensorflow:loss = 0.2105535, step = 3900 (0.181 sec) INFO:tensorflow:loss = 0.2105535, step = 3900 (0.181 sec) INFO:tensorflow:global_step/sec: 564.572 INFO:tensorflow:global_step/sec: 564.572 INFO:tensorflow:loss = 0.302787, step = 4000 (0.177 sec) INFO:tensorflow:loss = 0.302787, step = 4000 (0.177 sec) INFO:tensorflow:global_step/sec: 485.176 INFO:tensorflow:global_step/sec: 485.176 INFO:tensorflow:loss = 0.19806938, step = 4100 (0.206 sec) INFO:tensorflow:loss = 0.19806938, step = 4100 (0.206 sec) INFO:tensorflow:global_step/sec: 506.325 INFO:tensorflow:global_step/sec: 506.325 INFO:tensorflow:loss = 0.26416677, step = 4200 (0.198 sec) INFO:tensorflow:loss = 0.26416677, step = 4200 (0.198 sec) INFO:tensorflow:global_step/sec: 472.743 INFO:tensorflow:global_step/sec: 472.743 INFO:tensorflow:loss = 0.2805268, step = 4300 (0.213 sec) INFO:tensorflow:loss = 0.2805268, step = 4300 (0.213 sec) INFO:tensorflow:global_step/sec: 550.783 INFO:tensorflow:global_step/sec: 550.783 INFO:tensorflow:loss = 0.30882898, step = 4400 (0.181 sec) INFO:tensorflow:loss = 0.30882898, step = 4400 (0.181 sec) INFO:tensorflow:global_step/sec: 542.109 INFO:tensorflow:global_step/sec: 542.109 INFO:tensorflow:loss = 0.25414428, step = 4500 (0.185 sec) INFO:tensorflow:loss = 0.25414428, step = 4500 (0.185 sec) INFO:tensorflow:global_step/sec: 549.018 INFO:tensorflow:global_step/sec: 549.018 INFO:tensorflow:loss = 0.3157362, step = 4600 (0.182 sec) INFO:tensorflow:loss = 0.3157362, step = 4600 (0.182 sec) INFO:tensorflow:global_step/sec: 457.956 INFO:tensorflow:global_step/sec: 457.956 INFO:tensorflow:loss = 0.14194128, step = 4700 (0.218 sec) INFO:tensorflow:loss = 0.14194128, step = 4700 (0.218 sec) INFO:tensorflow:global_step/sec: 489.526 INFO:tensorflow:global_step/sec: 489.526 INFO:tensorflow:loss = 0.27310932, step = 4800 (0.205 sec) INFO:tensorflow:loss = 0.27310932, step = 4800 (0.205 sec) INFO:tensorflow:global_step/sec: 548.936 INFO:tensorflow:global_step/sec: 548.936 INFO:tensorflow:loss = 0.35096985, step = 4900 (0.181 sec) INFO:tensorflow:loss = 0.35096985, step = 4900 (0.181 sec) INFO:tensorflow:global_step/sec: 552.546 INFO:tensorflow:global_step/sec: 552.546 INFO:tensorflow:loss = 0.26049778, step = 5000 (0.181 sec) INFO:tensorflow:loss = 0.26049778, step = 5000 (0.181 sec) INFO:tensorflow:global_step/sec: 563.492 INFO:tensorflow:global_step/sec: 563.492 INFO:tensorflow:loss = 0.31700188, step = 5100 (0.178 sec) INFO:tensorflow:loss = 0.31700188, step = 5100 (0.178 sec) INFO:tensorflow:global_step/sec: 446.663 INFO:tensorflow:global_step/sec: 446.663 INFO:tensorflow:loss = 0.21455708, step = 5200 (0.223 sec) INFO:tensorflow:loss = 0.21455708, step = 5200 (0.223 sec) INFO:tensorflow:global_step/sec: 530.86 INFO:tensorflow:global_step/sec: 530.86 INFO:tensorflow:loss = 0.23256294, step = 5300 (0.189 sec) INFO:tensorflow:loss = 0.23256294, step = 5300 (0.189 sec) INFO:tensorflow:global_step/sec: 541.861 INFO:tensorflow:global_step/sec: 541.861 INFO:tensorflow:loss = 0.15113032, step = 5400 (0.184 sec) INFO:tensorflow:loss = 0.15113032, step = 5400 (0.184 sec) INFO:tensorflow:global_step/sec: 555.169 INFO:tensorflow:global_step/sec: 555.169 INFO:tensorflow:loss = 0.23720858, step = 5500 (0.180 sec) INFO:tensorflow:loss = 0.23720858, step = 5500 (0.180 sec) INFO:tensorflow:global_step/sec: 552.955 INFO:tensorflow:global_step/sec: 552.955 INFO:tensorflow:loss = 0.17804441, step = 5600 (0.182 sec) INFO:tensorflow:loss = 0.17804441, step = 5600 (0.182 sec) INFO:tensorflow:global_step/sec: 456.66 INFO:tensorflow:global_step/sec: 456.66 INFO:tensorflow:loss = 0.16684741, step = 5700 (0.219 sec) INFO:tensorflow:loss = 0.16684741, step = 5700 (0.219 sec) INFO:tensorflow:global_step/sec: 459.406 INFO:tensorflow:global_step/sec: 459.406 INFO:tensorflow:loss = 0.27630642, step = 5800 (0.217 sec) INFO:tensorflow:loss = 0.27630642, step = 5800 (0.217 sec) INFO:tensorflow:global_step/sec: 553.192 INFO:tensorflow:global_step/sec: 553.192 INFO:tensorflow:loss = 0.21493328, step = 5900 (0.181 sec) INFO:tensorflow:loss = 0.21493328, step = 5900 (0.181 sec) INFO:tensorflow:global_step/sec: 520.514 INFO:tensorflow:global_step/sec: 520.514 INFO:tensorflow:loss = 0.24653763, step = 6000 (0.192 sec) INFO:tensorflow:loss = 0.24653763, step = 6000 (0.192 sec) INFO:tensorflow:global_step/sec: 434.322 INFO:tensorflow:global_step/sec: 434.322 INFO:tensorflow:loss = 0.18108746, step = 6100 (0.230 sec) INFO:tensorflow:loss = 0.18108746, step = 6100 (0.230 sec) INFO:tensorflow:global_step/sec: 542.961 INFO:tensorflow:global_step/sec: 542.961 INFO:tensorflow:loss = 0.25019443, step = 6200 (0.185 sec) INFO:tensorflow:loss = 0.25019443, step = 6200 (0.185 sec) INFO:tensorflow:global_step/sec: 551.722 INFO:tensorflow:global_step/sec: 551.722 INFO:tensorflow:loss = 0.23493612, step = 6300 (0.181 sec) INFO:tensorflow:loss = 0.23493612, step = 6300 (0.181 sec) INFO:tensorflow:global_step/sec: 546.622 INFO:tensorflow:global_step/sec: 546.622 INFO:tensorflow:loss = 0.28968674, step = 6400 (0.183 sec) INFO:tensorflow:loss = 0.28968674, step = 6400 (0.183 sec) INFO:tensorflow:global_step/sec: 397.399 INFO:tensorflow:global_step/sec: 397.399 INFO:tensorflow:loss = 0.25742713, step = 6500 (0.252 sec) INFO:tensorflow:loss = 0.25742713, step = 6500 (0.252 sec) INFO:tensorflow:global_step/sec: 501.325 INFO:tensorflow:global_step/sec: 501.325 INFO:tensorflow:loss = 0.19955175, step = 6600 (0.199 sec) INFO:tensorflow:loss = 0.19955175, step = 6600 (0.199 sec) INFO:tensorflow:global_step/sec: 504.54 INFO:tensorflow:global_step/sec: 504.54 INFO:tensorflow:loss = 0.23529533, step = 6700 (0.198 sec) INFO:tensorflow:loss = 0.23529533, step = 6700 (0.198 sec) INFO:tensorflow:global_step/sec: 516.526 INFO:tensorflow:global_step/sec: 516.526 INFO:tensorflow:loss = 0.391949, step = 6800 (0.194 sec) INFO:tensorflow:loss = 0.391949, step = 6800 (0.194 sec) INFO:tensorflow:global_step/sec: 554.826 INFO:tensorflow:global_step/sec: 554.826 INFO:tensorflow:loss = 0.1410023, step = 6900 (0.181 sec) INFO:tensorflow:loss = 0.1410023, step = 6900 (0.181 sec) INFO:tensorflow:global_step/sec: 560.983 INFO:tensorflow:global_step/sec: 560.983 INFO:tensorflow:loss = 0.32429492, step = 7000 (0.178 sec) INFO:tensorflow:loss = 0.32429492, step = 7000 (0.178 sec) INFO:tensorflow:global_step/sec: 525.523 INFO:tensorflow:global_step/sec: 525.523 INFO:tensorflow:loss = 0.19833161, step = 7100 (0.190 sec) INFO:tensorflow:loss = 0.19833161, step = 7100 (0.190 sec) INFO:tensorflow:global_step/sec: 558.047 INFO:tensorflow:global_step/sec: 558.047 INFO:tensorflow:loss = 0.24470912, step = 7200 (0.179 sec) INFO:tensorflow:loss = 0.24470912, step = 7200 (0.179 sec) INFO:tensorflow:global_step/sec: 557.232 INFO:tensorflow:global_step/sec: 557.232 INFO:tensorflow:loss = 0.1755924, step = 7300 (0.180 sec) INFO:tensorflow:loss = 0.1755924, step = 7300 (0.180 sec) INFO:tensorflow:global_step/sec: 559.329 INFO:tensorflow:global_step/sec: 559.329 INFO:tensorflow:loss = 0.24296087, step = 7400 (0.179 sec) INFO:tensorflow:loss = 0.24296087, step = 7400 (0.179 sec) INFO:tensorflow:Requesting early stopping at global step 7427 INFO:tensorflow:Requesting early stopping at global step 7427 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7428... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7428... INFO:tensorflow:Saving checkpoints for 7428 into /tmpfs/tmp/tmp1_ztd0uk/model.ckpt. INFO:tensorflow:Saving checkpoints for 7428 into /tmpfs/tmp/tmp1_ztd0uk/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7428... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7428... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T20:45:51 INFO:tensorflow:Starting evaluation at 2022-12-14T20:45:51 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp1_ztd0uk/model.ckpt-7428 INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp1_ztd0uk/model.ckpt-7428 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Inference Time : 1.10429s INFO:tensorflow:Inference Time : 1.10429s INFO:tensorflow:Finished evaluation at 2022-12-14-20:45:52 INFO:tensorflow:Finished evaluation at 2022-12-14-20:45:52 INFO:tensorflow:Saving dict for global step 7428: global_step = 7428, loss = 0.2138084 INFO:tensorflow:Saving dict for global step 7428: global_step = 7428, loss = 0.2138084 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 7428: /tmpfs/tmp/tmp1_ztd0uk/model.ckpt-7428 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 7428: /tmpfs/tmp/tmp1_ztd0uk/model.ckpt-7428 INFO:tensorflow:Loss for final step: 0.18073495. INFO:tensorflow:Loss for final step: 0.18073495. ({'loss': 0.2138084, 'global_step': 7428}, [])
TensorFlow 2:使用内置回调和 Model.fit 提前停止
准备 MNIST 数据集和一个简单的 Keras 模型:
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.005),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
在 TensorFlow 2 中,当您使用内置的 Keras Model.fit
(或 Model.evaluate
)时,可以通过将内置回调 tf.keras.callbacks.EarlyStopping
传递给 Model.fit
的 callbacks
参数来配置提前停止。
EarlyStopping
回调会监视用户指定的指标,并在停止改进时结束训练。(请查看使用内置方法进行训练和评估或 API 文档来了解详情。)
下面是一个提前停止回调的示例,它监视损失并在显示没有改进的周期数设置为 3
(patience
) 后停止训练:
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 4s 5ms/step - loss: 0.2293 - sparse_categorical_accuracy: 0.9324 - val_loss: 0.1224 - val_sparse_categorical_accuracy: 0.9636 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0983 - sparse_categorical_accuracy: 0.9696 - val_loss: 0.0962 - val_sparse_categorical_accuracy: 0.9721 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0679 - sparse_categorical_accuracy: 0.9787 - val_loss: 0.1048 - val_sparse_categorical_accuracy: 0.9671 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0513 - sparse_categorical_accuracy: 0.9838 - val_loss: 0.0930 - val_sparse_categorical_accuracy: 0.9733 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0397 - sparse_categorical_accuracy: 0.9869 - val_loss: 0.1034 - val_sparse_categorical_accuracy: 0.9718 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0353 - sparse_categorical_accuracy: 0.9885 - val_loss: 0.0994 - val_sparse_categorical_accuracy: 0.9747 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0332 - sparse_categorical_accuracy: 0.9888 - val_loss: 0.1131 - val_sparse_categorical_accuracy: 0.9736 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0277 - sparse_categorical_accuracy: 0.9902 - val_loss: 0.1190 - val_sparse_categorical_accuracy: 0.9747 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0238 - sparse_categorical_accuracy: 0.9919 - val_loss: 0.1157 - val_sparse_categorical_accuracy: 0.9747 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0225 - sparse_categorical_accuracy: 0.9923 - val_loss: 0.1235 - val_sparse_categorical_accuracy: 0.9743 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0237 - sparse_categorical_accuracy: 0.9917 - val_loss: 0.1359 - val_sparse_categorical_accuracy: 0.9751 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0197 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1446 - val_sparse_categorical_accuracy: 0.9756 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0243 - sparse_categorical_accuracy: 0.9925 - val_loss: 0.1542 - val_sparse_categorical_accuracy: 0.9732 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0174 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.1482 - val_sparse_categorical_accuracy: 0.9749 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9940 - val_loss: 0.1308 - val_sparse_categorical_accuracy: 0.9770 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0199 - sparse_categorical_accuracy: 0.9936 - val_loss: 0.1596 - val_sparse_categorical_accuracy: 0.9738 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0173 - sparse_categorical_accuracy: 0.9945 - val_loss: 0.1918 - val_sparse_categorical_accuracy: 0.9742 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0166 - sparse_categorical_accuracy: 0.9945 - val_loss: 0.1604 - val_sparse_categorical_accuracy: 0.9740 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0154 - sparse_categorical_accuracy: 0.9948 - val_loss: 0.1554 - val_sparse_categorical_accuracy: 0.9795 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0171 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.1815 - val_sparse_categorical_accuracy: 0.9756 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0151 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.1638 - val_sparse_categorical_accuracy: 0.9796 Epoch 22/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0164 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1840 - val_sparse_categorical_accuracy: 0.9760 Epoch 23/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0171 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.1705 - val_sparse_categorical_accuracy: 0.9789 Epoch 24/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.1860 - val_sparse_categorical_accuracy: 0.9773 Epoch 25/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0123 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.1950 - val_sparse_categorical_accuracy: 0.9758 Epoch 26/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.1777 - val_sparse_categorical_accuracy: 0.9792 Epoch 27/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0111 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.1930 - val_sparse_categorical_accuracy: 0.9780 Epoch 28/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.1942 - val_sparse_categorical_accuracy: 0.9773 Epoch 29/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0138 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2088 - val_sparse_categorical_accuracy: 0.9763 Epoch 30/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2386 - val_sparse_categorical_accuracy: 0.9760 30
TensorFlow 2:使用自定义回调和 Model.fit 提前停止
您也可以实现自定义的提前停止回调,此回调也可以传递给 Model.fit
(或 Model.evaluate
)的 callbacks
参数。
在此示例中,一旦 self.model.stop_training
设置为 True
,训练过程就会停止:
class LimitTrainingTime(tf.keras.callbacks.Callback):
def __init__(self, max_time_s):
super().__init__()
self.max_time_s = max_time_s
self.start_time = None
def on_train_begin(self, logs):
self.start_time = time.time()
def on_train_batch_end(self, batch, logs):
now = time.time()
if now - self.start_time > self.max_time_s:
self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0130 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2281 - val_sparse_categorical_accuracy: 0.9742 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0103 - sparse_categorical_accuracy: 0.9973 - val_loss: 0.2193 - val_sparse_categorical_accuracy: 0.9778 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2398 - val_sparse_categorical_accuracy: 0.9762 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0176 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.2247 - val_sparse_categorical_accuracy: 0.9791 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0065 - sparse_categorical_accuracy: 0.9980 - val_loss: 0.2441 - val_sparse_categorical_accuracy: 0.9759 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2602 - val_sparse_categorical_accuracy: 0.9761 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0134 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.2380 - val_sparse_categorical_accuracy: 0.9770 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0113 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2523 - val_sparse_categorical_accuracy: 0.9779 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0106 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.2976 - val_sparse_categorical_accuracy: 0.9758 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0125 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2725 - val_sparse_categorical_accuracy: 0.9767 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0172 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.3325 - val_sparse_categorical_accuracy: 0.9747 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0138 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.3081 - val_sparse_categorical_accuracy: 0.9756 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0093 - sparse_categorical_accuracy: 0.9977 - val_loss: 0.2600 - val_sparse_categorical_accuracy: 0.9796 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0083 - sparse_categorical_accuracy: 0.9978 - val_loss: 0.2636 - val_sparse_categorical_accuracy: 0.9805 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.3270 - val_sparse_categorical_accuracy: 0.9750 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0147 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2812 - val_sparse_categorical_accuracy: 0.9777 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0090 - sparse_categorical_accuracy: 0.9979 - val_loss: 0.3137 - val_sparse_categorical_accuracy: 0.9779 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0113 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2774 - val_sparse_categorical_accuracy: 0.9779 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0073 - sparse_categorical_accuracy: 0.9982 - val_loss: 0.3150 - val_sparse_categorical_accuracy: 0.9780 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0072 - sparse_categorical_accuracy: 0.9982 - val_loss: 0.3140 - val_sparse_categorical_accuracy: 0.9791 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.3382 - val_sparse_categorical_accuracy: 0.9766 Epoch 22/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0108 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.3149 - val_sparse_categorical_accuracy: 0.9772 Epoch 23/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0122 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.3227 - val_sparse_categorical_accuracy: 0.9761 23
TensorFlow 2:使用自定义训练循环提前停止
在 TensorFlow 2 中,如果您不使用内置 Keras 方法进行训练和评估,则可以在自定义训练循环中实现提前停止。
首先,使用 Keras API 定义另一个简单的模型、优化器、损失函数和指标:
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
使用 tf.GradientTape 和 @tf.function
装饰器定义参数更新函数以加快速度:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
train_loss_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
logits = model(x, training=False)
val_acc_metric.update_state(y, logits)
val_loss_metric.update_state(y, logits)
接下来,编写一个自定义训练循环,可以在其中手动实现提前停止规则。
下面的示例显示了当验证损失在一定数量的周期内没有改进时如何停止训练:
epochs = 100
patience = 5
wait = 0
best = 0
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
loss_value = train_step(x_batch_train, y_batch_train)
if step % 200 == 0:
print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
print("Seen so far: %s samples" % ((step + 1) * 128))
train_acc = train_acc_metric.result()
train_loss = train_loss_metric.result()
train_acc_metric.reset_states()
train_loss_metric.reset_states()
print("Training acc over epoch: %.4f" % (train_acc.numpy()))
for x_batch_val, y_batch_val in ds_test:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_loss = val_loss_metric.result()
val_acc_metric.reset_states()
val_loss_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
# The early stopping strategy: stop the training if `val_loss` does not
# decrease over a certain number of epochs.
wait += 1
if val_loss > best:
best = val_loss
wait = 0
if wait >= patience:
break
Start of epoch 0 Training loss at step 0: 2.3400 Seen so far: 128 samples Training loss at step 200: 0.2146 Seen so far: 25728 samples Training loss at step 400: 0.2027 Seen so far: 51328 samples Training acc over epoch: 0.9319 Validation acc: 0.9621 Time taken: 2.04s Start of epoch 1 Training loss at step 0: 0.0858 Seen so far: 128 samples Training loss at step 200: 0.1738 Seen so far: 25728 samples Training loss at step 400: 0.1355 Seen so far: 51328 samples Training acc over epoch: 0.9704 Validation acc: 0.9697 Time taken: 1.09s Start of epoch 2 Training loss at step 0: 0.0521 Seen so far: 128 samples Training loss at step 200: 0.1280 Seen so far: 25728 samples Training loss at step 400: 0.0969 Seen so far: 51328 samples Training acc over epoch: 0.9786 Validation acc: 0.9715 Time taken: 1.05s Start of epoch 3 Training loss at step 0: 0.0667 Seen so far: 128 samples Training loss at step 200: 0.0836 Seen so far: 25728 samples Training loss at step 400: 0.0493 Seen so far: 51328 samples Training acc over epoch: 0.9825 Validation acc: 0.9669 Time taken: 1.05s Start of epoch 4 Training loss at step 0: 0.0226 Seen so far: 128 samples Training loss at step 200: 0.0580 Seen so far: 25728 samples Training loss at step 400: 0.0669 Seen so far: 51328 samples Training acc over epoch: 0.9859 Validation acc: 0.9689 Time taken: 1.09s Start of epoch 5 Training loss at step 0: 0.0274 Seen so far: 128 samples Training loss at step 200: 0.0540 Seen so far: 25728 samples Training loss at step 400: 0.0745 Seen so far: 51328 samples Training acc over epoch: 0.9880 Validation acc: 0.9702 Time taken: 1.06s
后续步骤
- 在 API 文档中详细了解 Keras 内置提前停止回调 API。
- 了解如何编写自定义 Keras 回调,包括以最小损失提前停止。
- 了解使用 Keras 内置方法进行训练和评估。
- 在使用
EarlyStopping
回调的过拟合和欠拟合教程中探索常见的正则化技术。