迁移模型检查点

在 TensorFlow.org上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

注:使用 tf.compat.v1.Saver 保存的检查点通常称为 TF1 或基于名称的检查点。使用 tf.train.Checkpoint 保存的检查点称为 TF2 或基于对象的检查点。

概述

本指南假定您有一个使用 tf.compat.v1.Saver 保存和加载检查点的模型,并且想要使用 TF2 tf.train.Checkpoint API 迁移代码,或者使用 TF2 模型中既有的检查点。

下面是可能会遇到的一些常见情形:

情形 1

以前的训练运行中存在现有的 TF1 检查点,需要加载或转换为 TF2。

情形 2

您正在以一种存在更改变量名称和路径的风险的方式调整您的模型(例如,从 get_variable 增量迁移到显式 tf.Variable 创建时),并且希望在此过程中保持现有检查点的保存/加载。

请参阅如何在模型迁移期间保持检查点兼容性部分

情形 3

您正在将训练代码和检查点迁移到 TF2,但您的推断流水线目前仍需要 TF1 检查点(为了生产稳定性)。

选项 1

训练时同时保存 TF1 和 TF2 检查点。

选项 2

将 TF2 检查点转换为 TF1。


下面的示例显示了 TF1/TF2 中保存和加载检查点的所有组合,因此可以灵活地确定如何迁移模型。

安装

import tensorflow as tf
import tensorflow.compat.v1 as tf1

def print_checkpoint(save_path):
  reader = tf.train.load_checkpoint(save_path)
  shapes = reader.get_variable_to_shape_map()
  dtypes = reader.get_variable_to_dtype_map()
  print(f"Checkpoint at '{save_path}':")
  for key in shapes:
    print(f"  (key='{key}', shape={shapes[key]}, dtype={dtypes[key].name}, "
          f"value={reader.get_tensor(key)})")
2022-12-14 20:51:33.179350: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:51:33.179459: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:51:33.179469: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

从 TF1 到 TF2 的变化

如果您对 TF1 和 TF2 之间发生了哪些变化以及我们所说的“基于名称”(TF1) 与“基于对象”(TF2) 的检查点的含义感到好奇,请阅读此部分。

这两种类型的检查点实际上以相同的格式保存,本质上是一个键值表。不同之处在于键的生成方式。

基于名称的检查点中的键是变量的名称。基于对象的检查点中的键指向从根对象到变量的路径(下面的示例将有助于更好地理解这段话的含义)。

首先,保存一些检查点:

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    saver.save(sess, 'tf1-ckpt')

print_checkpoint('tf1-ckpt')
Checkpoint at 'tf1-ckpt':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)
a = tf.Variable(5.0, name='a')
b = tf.Variable(6.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(7.0, name='c')

ckpt = tf.train.Checkpoint(variables=[a, b, c])
save_path_v2 = ckpt.save('tf2-ckpt')
print_checkpoint(save_path_v2)
Checkpoint at 'tf2-ckpt-1':
  (key='variables/2/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=7.0)
  (key='variables/1/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=6.0)
  (key='variables/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=5.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n%\n\r\x08\x01\x12\tvariables\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n\x19\n\x05\x08\x03\x12\x010\n\x05\x08\x04\x12\x011\n\x05\x08\x05\x12\x012*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nA\x12;\n\x0eVARIABLE_VALUE\x12\x01a\x1a&variables/0/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nA\x12;\n\x0eVARIABLE_VALUE\x12\x01b\x1a&variables/1/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nH\x12B\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a&variables/2/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

如果您查看 tf2-ckpt 中的键,它们全部指向每个变量的对象路径。例如,变量 avariables 列表中的第一个元素,因此它的键变为 variables/0/...(请尽管忽略 .ATTRIBUTES/VARIABLE_VALUE 常量)。

仔细检查下面的 Checkpoint 对象:

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
root = ckpt = tf.train.Checkpoint(variables=[a, b, c])
print("root type =", type(root).__name__)
print("root.variables =", root.variables)
print("root.variables[0] =", root.variables[0])
root type = Checkpoint
root.variables = ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>])
root.variables[0] = <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>

尝试使用下面的代码段,看看检查点键如何随对象结构变化:

module = tf.Module()
module.d = tf.Variable(0.)
test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b}, 
                                c=c,
                                module=module)
test_ckpt_path = test_ckpt.save('root-tf2-ckpt')
print_checkpoint(test_ckpt_path)
Checkpoint at 'root-tf2-ckpt-1':
  (key='v/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='v/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='module/d/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='c/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n0\n\x05\x08\x01\x12\x01c\n\n\x08\x02\x12\x06module\n\x05\x08\x03\x12\x01v\n\x10\x08\x04\x12\x0csave_counter*\x02\x08\x01\n>\x128\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1cc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n\x0b\n\x05\x08\x05\x12\x01d*\x02\x08\x01\n\x12\n\x05\x08\x06\x12\x01a\n\x05\x08\x07\x12\x01b*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nE\x12?\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a#module/d/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1ev/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1ev/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

为什么 TF2 使用这种机制?

TF2 中没有更多的全局计算图,因此变量名称是不可靠的,并且程序之间可能存在不一致。TF2 鼓励使用面向对象的建模方法,其中变量归层所有,层归模型所有:

variable = tf.Variable(...)
layer.variable_name = variable
model.layer_name = layer

如何在模型迁移期间保持检查点兼容性

迁移过程中的一个重要步骤是确保所有变量都被初始化为正确的值,这反过来又允许您验证运算/函数是否正在执行正确的计算。为此,您必须考虑迁移各个阶段中模型之间的检查点兼容性。本质上,本部分回答了一个问题,即如何在更改模型时继续使用相同的检查点

为了提高灵活性,下面是维护检查点兼容性的三种方法:

  1. 模型的变量名称和之前相同
  2. 模型有不同的变量名称,并维护一个将检查点中的变量名称映射到新名称的分配映射
  3. 模型有不同的变量名称,并维护了一个存储所有变量的 TF2 检查点对象

当变量名称匹配时

长标题:如何在变量名称匹配时重用检查点。

短答案:可以使用 tf1.train.Savertf.train.Checkpoint 直接加载既有的检查点。


如果您使用的是 tf.compat.v1.keras.utils.track_tf1_style_variables,那么它将确保模型变量名称与以前相同。您还可以手动确保变量名称匹配。

当迁移模型中的变量名称匹配时,您可以直接使用 tf.train.Checkpointtf.compat.v1.train.Saver 加载检查点。这两个 API 都与 Eager 和计算图模式兼容,因此您可以在迁移的任何阶段使用它们。

注:您可以使用 tf.train.Checkpoint 加载 TF1 检查点,但如果没有复杂的名称匹配,则不能使用 tf.compat.v1.Saver 加载 TF2 检查点。

下面是对不同模型使用相同检查点的示例。首先,使用 tf1.train.Saver 保存一个 TF1 检查点:

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    save_path = saver.save(sess, 'tf1-ckpt')
print_checkpoint(save_path)
Checkpoint at 'tf1-ckpt':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

下面的示例使用 tf.compat.v1.Saver 在 Eager 模式下加载检查点:

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# With the removal of collections in TF2, you must pass in the list of variables
# to the Saver object:
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path=save_path)
print(f"loaded values of [a, b, c]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}]")

# Saving also works in eager (sess must be None).
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
INFO:tensorflow:Restoring parameters from tf1-ckpt
loaded values of [a, b, c]:  [1.0, 2.0, 3.0]
Checkpoint at 'tf1-ckpt-saved-in-eager':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

下一个代码段使用 TF2 API tf.train.Checkpoint 加载检查点:

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# Without the name_scope, name="scoped/c" works too:
c_2 = tf.Variable(0.0, name='scoped/c')

print("Variable names: ")
print(f"  a.name = {a.name}")
print(f"  b.name = {b.name}")
print(f"  c.name = {c.name}")
print(f"  c_2.name = {c_2.name}")

# Restore the values with tf.train.Checkpoint
ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])
ckpt.restore(save_path)
print(f"loaded values of [a, b, c, c_2]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]")
Variable names: 
  a.name = a:0
  b.name = b:0
  c.name = scoped/c:0
  c_2.name = scoped/c:0
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/checkpoint/checkpoint.py:1473: NameBasedSaverStatus.__init__ (from tensorflow.python.checkpoint.checkpoint) is deprecated and will be removed in a future version.
Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
loaded values of [a, b, c, c_2]:  [1.0, 2.0, 3.0, 3.0]

TF2 中的变量名称

  • 变量仍然具有您可以设置的 name 参数。
  • Keras 模型还采用 name 参数,并将其设置为变量的前缀。
  • v1.name_scope 函数可用于设置变量名前缀,这与 tf.variable_scope 截然不同。它只影响名称,而不跟踪变量和重用。

tf.compat.v1.keras.utils.track_tf1_style_variables 装饰器是一个填充码,它通过保持 tf.variable_scopetf.compat.v1.get_variable 的命名和重用语义不变,帮助您维护变量名称和 TF1 检查点兼容性。如需了解详情,请参阅模型映射指南

注 1:如果您使用填充码,请使用 TF2 API 加载您的检查点(即使使用预训练的 TF1 检查点)。

请参阅检查点 Keras 部分。

注 2:从 get_variable 迁移到 tf.Variable 时:

如果您的填充码装饰层或模块包含一些使用 tf.Variable 而不是 tf.compat.v1.get_variable 的变量(或 Keras 层/模型),并作为属性附加/以面向对象的方式进行跟踪,则与 Eager Execution 期间相比,它们在 TF1.x 计算图/会话中可能具有不同的变量命名语义。

简而言之,在 TF2 中运行时,这些名称可能不是预期的名称

警告:变量在 Eager Execution 中可能有重复的名称,如果基于名称的检查点中的多个变量需要映射到相同的名称,这可能会导致问题。可以使用 tf.name_scope 和层构造函数或 tf.Variable name 参数显式调整层和变量名称,从而调整变量名称并确保没有重复。

维护分配映射

分配映射通常用于在 TF1 模型之间传递权重,如果变量名称发生变化,也可以在模型迁移期间使用。

您可以将这些映射与 tf.compat.v1.train.init_from_checkpointtf.compat.v1.train.Savertf.train.load_checkpoint 一起使用,以将权重加载到变量或范围名称可能已更改的模型中。

本部分中的示例将使用之前保存的检查点:

print_checkpoint('tf1-ckpt')
Checkpoint at 'tf1-ckpt':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

使用 init_from_checkpoint 加载

tf1.train.init_from_checkpoint 必须在计算图/会话中调用,因为它将值置于变量初始值设定项中,而不是创建分配运算。

您可以使用 assignment_map 参数来配置变量的加载方式。从文档中:

分配映射支持以下语法:

  • 'checkpoint_scope_name/': 'scope_name/' - 将从 checkpoint_scope_name 加载当前 scope_name 中的所有变量,并具有匹配的张量名称。
  • 'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name' - 将从 checkpoint_scope_name/some_other_variable 初始化 scope_name/variable_name 变量。
  • 'scope_variable_name': variable - 将使用来自检查点的张量 'scope_variable_name' 初始化给定的 tf.Variable 对象。
  • 'scope_variable_name': list(variable) - 将使用来自检查点的张量 'scope_variable_name' 初始化分区变量列表。
  • '/': 'scope_name/' - 将从检查点的根目录(例如,无范围)加载当前 scope_name 中的所有变量。
# Restoring with tf1.train.init_from_checkpoint:

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # The assignment map will remap all variables in the checkpoint to the
    # new scope:
    tf1.train.init_from_checkpoint(
        'tf1-ckpt',
        assignment_map={'/': 'new_scope/'})
    # `init_from_checkpoint` adds the initializers to these variables.
    # Use `sess.run` to run these initializers.
    sess.run(tf1.global_variables_initializer())

    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Restored [a, b, c]:  [1.0, 2.0, 3.0]

使用 tf1.train.Saver 加载

init_from_checkpoint 不同,tf.compat.v1.train.Saver 同时支持在计算图模式和 Eager 模式下运行。var_list 参数可以接受字典,但它必须将变量名称映射到 tf.Variable 对象。

# Restoring with tf1.train.Saver (works in both graph and eager):

# A new model with a different scope for the variables.
with tf1.variable_scope('new_scope'):
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
# Initialize the saver with a dictionary with the original variable names:
saver = tf1.train.Saver({'a': a, 'b': b, 'scoped/c': c})
saver.restore(sess=None, save_path='tf1-ckpt')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
INFO:tensorflow:Restoring parameters from tf1-ckpt
Restored [a, b, c]:  [1.0, 2.0, 3.0]

使用 tf.train.load_checkpoint 加载

如果您需要精确控制变量值,则此选项适合您。同样,这适用于计算图和 Eager 模式。

# Restoring with tf.train.load_checkpoint (works in both graph and eager):

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # It may be easier writing a loop if your model has a lot of variables.
    reader = tf.train.load_checkpoint('tf1-ckpt')
    sess.run(a.assign(reader.get_tensor('a')))
    sess.run(b.assign(reader.get_tensor('b')))
    sess.run(c.assign(reader.get_tensor('scoped/c')))
    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Restored [a, b, c]:  [1.0, 2.0, 3.0]

维护 TF2 检查点对象

如果在迁移过程中变量和范围名称可能会发生很大变化,请使用 tf.train.Checkpoint 和 TF2 检查点。TF2 使用对象结构而不是变量名(有关详情,请参阅从 TF1 到 TF2 的变化)。

简而言之,在创建 tf.train.Checkpoint 来保存或恢复检查点时,请确保它使用相同的顺序(对于列表)和(对于 Checkpoint 初始值设定项的字典和关键字参数)。检查点兼容性的一些示例:

ckpt = tf.train.Checkpoint(foo=[var_a, var_b])

# compatible with ckpt
tf.train.Checkpoint(foo=[var_a, var_b])

# not compatible with ckpt
tf.train.Checkpoint(foo=[var_b, var_a])
tf.train.Checkpoint(bar=[var_a, var_b])

下面的代码示例显示了如何使用“相同”的 tf.train.Checkpoint 来加载具有不同名称的变量。首先,保存一个 TF2 检查点:

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("[a, b, c]: ", sess.run([a, b, c]))

    # Save a TF2 checkpoint
    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    tf2_ckpt_path = ckpt.save('tf2-ckpt')
    print_checkpoint(tf2_ckpt_path)
[a, b, c]:  [1.0, 2.0, 3.0]
Checkpoint at 'tf2-ckpt-1':
  (key='unscoped/1/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='unscoped/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='scoped/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n0\n\n\x08\x01\x12\x06scoped\n\x0c\x08\x02\x12\x08unscoped\n\x10\x08\x03\x12\x0csave_counter*\x02\x08\x01\n\x0b\n\x05\x08\x04\x12\x010*\x02\x08\x01\n\x12\n\x05\x08\x05\x12\x010\n\x05\x08\x06\x12\x011*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nE\x12?\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a#scoped/0/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%unscoped/0/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%unscoped/1/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

即使变量/范围名称发生变化,也可以继续使用 tf.train.Checkpoint

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.variable_scope('different_scope'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))

    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    # `assert_consumed` validates that all checkpoint objects are restored from
    # the checkpoint. `run_restore_ops` is required when running in a TF1
    # session.
    ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()

    # Removing `assert_consumed` is fine if you want to skip the validation.
    # ckpt.restore(tf2_ckpt_path).run_restore_ops()

    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
Restored [a, b, c]:  [1.0, 2.0, 3.0]

在 Eager 模式下:

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

# The keys "scoped" and "unscoped" are no longer relevant, but are used to
# maintain compatibility with the saved checkpoints.
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])

ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
Restored [a, b, c]:  [1.0, 2.0, 3.0]

Estimator 中的 TF2 检查点

上面的部分介绍了如何在迁移模型时保持检查点兼容性。这些概念也适用于 Estimator 模型,尽管保存/加载检查点的方式略有不同。当迁移 Estimator 模型以使用 TF2 API 时,您可能希望在模型仍使用 Estimator 时从 TF1 切换到 TF2 检查点。本部分介绍如何实现此目的。

tf.estimator.EstimatorMonitoredSession 有一种称为 scaffold 的保存机制,即 tf.compat.v1.train.Scaffold 对象。Scaffold 可以包含 tf1.train.Savertf.train.Checkpoint,它使 EstimatorMonitoredSession 能够保存 TF1 或 TF2 样式的检查点。

# A model_fn that saves a TF1 checkpoint
def model_fn_tf1_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=None
  )

!rm -rf est-tf1
est = tf.estimator.Estimator(model_fn_tf1_ckpt, 'est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf1')
print_checkpoint(latest_checkpoint)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'est-tf1', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into est-tf1/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into est-tf1/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:Loss for final step: 1.0.
Checkpoint at 'est-tf1/model.ckpt-1':
  (key='var', shape=[], dtype=float32, value=2.0)
  (key='global_step', shape=[], dtype=int64, value=1)
# A model_fn that saves a TF2 checkpoint
def model_fn_tf2_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  ckpt = tf.train.Checkpoint(var_list={'var': v}, step=train_step)
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=tf1.train.Scaffold(saver=ckpt)
  )

!rm -rf est-tf2
est = tf.estimator.Estimator(model_fn_tf2_ckpt, 'est-tf2',
                             warm_start_from='est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf2')
print_checkpoint(latest_checkpoint)  

assert est.get_variable_value('var_list/var/.ATTRIBUTES/VARIABLE_VALUE') == 4
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'est-tf2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='est-tf1', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: est-tf1
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 1 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into est-tf2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into est-tf2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:Loss for final step: 1.0.
Checkpoint at 'est-tf2/model.ckpt-1':
  (key='var_list/var/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=4.0)
  (key='step/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n\x1c\n\x08\x08\x01\x12\x04step\n\x0c\x08\x02\x12\x08var_list*\x02\x08\x01\nD\x12>\n\x0eVARIABLE_VALUE\x12\x0bglobal_step\x1a\x1fstep/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n\r\n\x07\x08\x03\x12\x03var*\x02\x08\x01\nD\x12>\n\x0eVARIABLE_VALUE\x12\x03var\x1a'var_list/var/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

在从 est-tf1 热启动之后,v 的最终值应当是 16,随后再训练 5 步。训练步的值不会从 warm_start 检查点结转。

检查点 Keras

使用 Keras 构建的模型仍然使用 tf1.train.Savertf.train.Checkpoint 来加载既有的权重。当您的模型完全迁移后,请切换为使用 model.save_weightsmodel.load_weights,尤其是当您在训练时使用 ModelCheckpoint 回调时。

关于检查点和 Keras,您需要了解以下信息:

初始化与构建

Keras 模型和层必须经过两个步骤才能完全创建。首先是 Python 对象的初始化layer = tf.keras.layers.Dense(x)。其次是构建步骤,此过程实际会创建大部分权重:layer.build(input_shape)。此外,您还可以通过调用或运行单个 trainevalpredict 步骤(仅限第一次)来构建模型。

如果您发现 model.load_weights(path).assert_consumed() 引发错误,则很可能是模型/层尚未构建。

Keras 使用 TF2 检查点

tf.train.Checkpoint(model).write 等效于 model.save_weights。与 tf.train.Checkpoint(model).readmodel.load_weights 相同。请注意,Checkpoint(model) != Checkpoint(model=model)

TF2 检查点可与 Keras 的 build() 步骤一起使用

tf.train.Checkpoint.restore 有一种称为延迟恢复的机制,它允许在尚未创建变量时使用 tf.Module 和 Keras 对象存储变量值。这允许已初始化的模型加载权重并在之后构建

m = YourKerasModel()
status = m.load_weights(path)

# This call builds the model. The variables are created with the restored
# values.
m.predict(inputs)

status.assert_consumed()

由于存在这种机制,我们强烈建议您将 TF2 检查点加载 API 与 Keras 模型一起使用(即使在将既有的 TF1 检查点恢复到模型映射填充码中时)。有关详情,请参阅检查点指南

代码段

下面的代码段显示了检查点保存 API 中的 TF1/TF2 版本兼容性。

在 TF2 中保存 TF1 检查点

a = tf.Variable(1.0, name='a')
b = tf.Variable(2.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(3.0, name='c')

saver = tf1.train.Saver(var_list=[a, b, c])
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
Checkpoint at 'tf1-ckpt-saved-in-eager':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

在 TF2 中加载 TF1 检查点

a = tf.Variable(0., name='a')
b = tf.Variable(0., name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0., name='c')
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path='tf1-ckpt-saved-in-eager')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
INFO:tensorflow:Restoring parameters from tf1-ckpt-saved-in-eager
Restored [a, b, c]:  [1.0, 2.0, 3.0]

在 TF1 中保存 TF2 检查点

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    tf2_in_tf1_path = ckpt.save('tf2-ckpt-saved-in-session')
    print_checkpoint(tf2_in_tf1_path)
Checkpoint at 'tf2-ckpt-saved-in-session-1':
  (key='var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='var_list/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='var_list/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n$\n\x0c\x08\x01\x12\x08var_list\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n \n\x05\x08\x03\x12\x01a\n\x05\x08\x04\x12\x01b\n\x0c\x08\x05\x12\x08scoped/c*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%var_list/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%var_list/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nO\x12I\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a-var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

在 TF1 中加载 TF2 检查点

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(0))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    ckpt.restore('tf2-ckpt-saved-in-session-1').run_restore_ops()
    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
Restored [a, b, c]:  [1.0, 2.0, 3.0]

检查点转换

可以通过加载和重新保存检查点在 TF1 和 TF2 之间转换检查点。另一种方式是 tf.train.load_checkpoint,如下面的代码所示。

将 TF1 检查点转换为 TF2

def convert_tf1_to_tf2(checkpoint_path, output_prefix):
  """Converts a TF1 checkpoint to TF2.

  To load the converted checkpoint, you must build a dictionary that maps
  variable names to variable objects.
  ```
  ckpt = tf.train.Checkpoint(vars={name: variable})  
  ckpt.restore(converted_ckpt_path)

    ```

    Args:
      checkpoint_path: Path to the TF1 checkpoint.
      output_prefix: Path prefix to the converted checkpoint.

    Returns:
      Path to the converted checkpoint.
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      vars[key] = tf.Variable(reader.get_tensor(key))
    return tf.train.Checkpoint(vars=vars).save(output_prefix)
  ```

转换代码段 `Save a TF1 checkpoint in TF2` 中保存的检查点


```python
# Make sure to run the snippet in `Save a TF1 checkpoint in TF2`.
print_checkpoint('tf1-ckpt-saved-in-eager')
converted_path = convert_tf1_to_tf2('tf1-ckpt-saved-in-eager', 
                                     'converted-tf1-to-tf2')
print("\n[Converted]")
print_checkpoint(converted_path)

# Try loading the converted checkpoint.
a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
ckpt = tf.train.Checkpoint(vars={'a': a, 'b': b, 'scoped/c': c})
ckpt.restore(converted_path).assert_consumed()
print("\nRestored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Checkpoint at 'tf1-ckpt-saved-in-eager':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

[Converted]
Checkpoint at 'converted-tf1-to-tf2-1':
  (key='vars/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='vars/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='vars/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n \n\x08\x08\x01\x12\x04vars\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n \n\x0c\x08\x03\x12\x08scoped/c\n\x05\x08\x04\x12\x01b\n\x05\x08\x05\x12\x01a*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nK\x12E\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a)vars/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nC\x12=\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a!vars/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nC\x12=\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a!vars/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

Restored [a, b, c]:  [1.0, 2.0, 3.0]

将 TF2 检查点转换为 TF1

def convert_tf2_to_tf1(checkpoint_path, output_prefix):
  """Converts a TF2 checkpoint to TF1.

  The checkpoint must be saved using a 
  `tf.train.Checkpoint(var_list={name: variable})`

  To load the converted checkpoint with `tf.compat.v1.Saver`:
  ```
  saver = tf.compat.v1.train.Saver(var_list={name: variable}) 

  # An alternative, if the variable names match the keys:
  saver = tf.compat.v1.train.Saver(var_list=[variables]) 
  saver.restore(sess, output_path)

    ```
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      # Get the "name" from the 
      if key.startswith('var_list/'):
        var_name = key.split('/')[1]
        # TF2 checkpoint keys use '/', so if they appear in the user-defined name,
        # they are escaped to '.S'.
        var_name = var_name.replace('.S', '/')
        vars[var_name] = tf.Variable(reader.get_tensor(key))

    return tf1.train.Saver(var_list=vars).save(sess=None, save_path=output_prefix)
  ```

转换代码段 `Save a TF2 checkpoint in TF1` 中保存的检查点


```python
# Make sure to run the snippet in `Save a TF2 checkpoint in TF1`.
print_checkpoint('tf2-ckpt-saved-in-session-1')
converted_path = convert_tf2_to_tf1('tf2-ckpt-saved-in-session-1',
                                    'converted-tf2-to-tf1')
print("\n[Converted]")
print_checkpoint(converted_path)

# Try loading the converted checkpoint.
with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(0))
  with tf1.Session() as sess:
    saver = tf1.train.Saver([a, b, c])
    saver.restore(sess, converted_path)
    print("\nRestored [a, b, c]: ", sess.run([a, b, c]))
Checkpoint at 'tf2-ckpt-saved-in-session-1':
  (key='var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='var_list/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='var_list/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n$\n\x0c\x08\x01\x12\x08var_list\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n \n\x05\x08\x03\x12\x01a\n\x05\x08\x04\x12\x01b\n\x0c\x08\x05\x12\x08scoped/c*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%var_list/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%var_list/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nO\x12I\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a-var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.

[Converted]
Checkpoint at 'converted-tf2-to-tf1':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)
INFO:tensorflow:Restoring parameters from converted-tf2-to-tf1

Restored [a, b, c]:  [1.0, 2.0, 3.0]

相关指南