学习的数据压缩

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

概述

此笔记本展示了如何使用神经网络和 TensorFlow Compression 进行有损数据压缩。

有损压缩涉及在速率、编码样本所需的预期比特数以及失真、样本重建中的预期误差之间进行权衡。

下面的示例使用类似自动编码器的模型来压缩来自 MNIST 数据集的图像。这种方式基于端到端优化图像压缩这篇论文。

有关学习的数据压缩的更多背景信息,请参阅面向熟悉经典数据压缩的读者的这篇论文,或者面向机器学习受众的这份调查

安装

通过 pip 安装 Tensorflow Compression。

# Installs the latest version of TFC compatible with the installed TF version.

read MAJOR MINOR <<< "$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\d+)\.(\d+).*/\1 \2/sg')"
pip install "tensorflow-compression<$MAJOR.$(($MINOR+1))"

导入库依赖项。

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_datasets as tfds
2023-11-07 19:11:50.603526: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 19:11:50.603572: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 19:11:50.603609: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

定义训练器模型

由于该模型类似于自动编码器,并且我们需要在训练和推断期间执行一组不同的功函数,设置与分类器略有不同。

训练模型由三个部分组成:

  • 分析(或编码器)转换,将图像转换为隐空间,
  • 合成(或解码器)转换,从隐空间转换回图像空间,以及
  • 先验和熵模型,对隐空间的边际概率进行建模。

首先,定义转换:

def make_analysis_transform(latent_dims):
  """Creates the analysis (encoder) transform."""
  return tf.keras.Sequential([
      tf.keras.layers.Conv2D(
          20, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_1"),
      tf.keras.layers.Conv2D(
          50, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_2"),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(
          500, use_bias=True, activation="leaky_relu", name="fc_1"),
      tf.keras.layers.Dense(
          latent_dims, use_bias=True, activation=None, name="fc_2"),
  ], name="analysis_transform")
def make_synthesis_transform():
  """Creates the synthesis (decoder) transform."""
  return tf.keras.Sequential([
      tf.keras.layers.Dense(
          500, use_bias=True, activation="leaky_relu", name="fc_1"),
      tf.keras.layers.Dense(
          2450, use_bias=True, activation="leaky_relu", name="fc_2"),
      tf.keras.layers.Reshape((7, 7, 50)),
      tf.keras.layers.Conv2DTranspose(
          20, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_1"),
      tf.keras.layers.Conv2DTranspose(
          1, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_2"),
  ], name="synthesis_transform")

训练器拥有两个转换的实例,以及先验的参数。

它的 call 方法设置为计算如下参数:

  • 速率,估计表示该批次数字所需的位数,以及
  • 失真,原始数字的像素与其重建之间的平均绝对差。
class MNISTCompressionTrainer(tf.keras.Model):
  """Model that trains a compressor/decompressor for MNIST."""

  def __init__(self, latent_dims):
    super().__init__()
    self.analysis_transform = make_analysis_transform(latent_dims)
    self.synthesis_transform = make_synthesis_transform()
    self.prior_log_scales = tf.Variable(tf.zeros((latent_dims,)))

  @property
  def prior(self):
    return tfc.NoisyLogistic(loc=0., scale=tf.exp(self.prior_log_scales))

  def call(self, x, training):
    """Computes rate and distortion losses."""
    # Ensure inputs are floats in the range (0, 1).
    x = tf.cast(x, self.compute_dtype) / 255.
    x = tf.reshape(x, (-1, 28, 28, 1))

    # Compute latent space representation y, perturb it and model its entropy,
    # then compute the reconstructed pixel-level representation x_hat.
    y = self.analysis_transform(x)
    entropy_model = tfc.ContinuousBatchedEntropyModel(
        self.prior, coding_rank=1, compression=False)
    y_tilde, rate = entropy_model(y, training=training)
    x_tilde = self.synthesis_transform(y_tilde)

    # Average number of bits per MNIST digit.
    rate = tf.reduce_mean(rate)

    # Mean absolute difference across pixels.
    distortion = tf.reduce_mean(abs(x - x_tilde))

    return dict(rate=rate, distortion=distortion)

计算速率和失真

我们使用训练集中的一张图像逐步完成此操作。加载 MNIST 数据集进行训练和验证:

training_dataset, validation_dataset = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=False,
)
2023-11-07 19:11:54.590926: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...

接着提取一张图像 \(x\):

(x, _), = validation_dataset.take(1)

plt.imshow(tf.squeeze(x))
print(f"Data type: {x.dtype}")
print(f"Shape: {x.shape}")
Data type: <dtype: 'uint8'>
Shape: (28, 28, 1)
2023-11-07 19:11:54.882247: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

要获得隐空间表示 \(y\),我们需要将其转换为 float32,添加一个批次维度,并将其传递给分析转换。

x = tf.cast(x, tf.float32) / 255.
x = tf.reshape(x, (-1, 28, 28, 1))
y = make_analysis_transform(10)(x)

print("y:", y)
y: tf.Tensor(
[[-0.05629363  0.07482004 -0.12003498  0.01332713  0.03305333 -0.03994709
   0.03772307  0.06138449  0.03466422 -0.02930463]], shape=(1, 10), dtype=float32)

隐空间将在测试时被量化。为了在训练期间以可微的方式对此进行建模,我们在区间 \((-.5, .5)\) 中添加均匀噪声,并将结果称为 \(\tilde y\)。这与论文端到端优化图像压缩中使用的术语相同。

y_tilde = y + tf.random.uniform(y.shape, -.5, .5)

print("y_tilde:", y_tilde)
y_tilde: tf.Tensor(
[[-0.144532   -0.1691755   0.2158155  -0.48123586 -0.10522515 -0.3356318
  -0.29996884  0.14784485  0.33987206 -0.16194786]], shape=(1, 10), dtype=float32)

“先验”是一个概率密度,我们训练它来模拟噪声隐空间的边缘分布。例如,它可以是一组独立的逻辑分布,每个隐空间维度具有不同的尺度。tfc.NoisyLogistic 说明了隐空间具有加性噪声的事实。随着尺度接近零,逻辑分布接近狄拉克增量(尖峰),但添加的噪声导致“嘈杂”分布改为接近均匀分布。

prior = tfc.NoisyLogistic(loc=0., scale=tf.linspace(.01, 2., 10))

_ = tf.linspace(-6., 6., 501)[:, None]
plt.plot(_, prior.prob(_));

png

在训练期间,tfc.ContinuousBatchedEntropyModel 会添加均匀噪声,并使用噪声和先验来计算速率的(可微分)上限(编码隐空间表示所需的平均位数)。此界限可以作为损失最小化。

entropy_model = tfc.ContinuousBatchedEntropyModel(
    prior, coding_rank=1, compression=False)
y_tilde, rate = entropy_model(y, training=True)

print("rate:", rate)
print("y_tilde:", y_tilde)
rate: tf.Tensor([18.3915], shape=(1,), dtype=float32)
y_tilde: tf.Tensor(
[[-0.45884362 -0.30689514 -0.10407083 -0.44574344  0.329508   -0.49761873
  -0.15345061  0.4179291  -0.34433204  0.07495283]], shape=(1, 10), dtype=float32)

最后,噪声隐空间通过合成转换向回传递以产生图像重建 \(\tilde x\)。失真是原始图像与重建之间的误差。显然,使用未训练的转换时,重建不太有用。

x_tilde = make_synthesis_transform()(y_tilde)

# Mean absolute difference across pixels.
distortion = tf.reduce_mean(abs(x - x_tilde))
print("distortion:", distortion)

x_tilde = tf.saturate_cast(x_tilde[0] * 255, tf.uint8)
plt.imshow(tf.squeeze(x_tilde))
print(f"Data type: {x_tilde.dtype}")
print(f"Shape: {x_tilde.shape}")
distortion: tf.Tensor(0.17111982, shape=(), dtype=float32)
Data type: <dtype: 'uint8'>
Shape: (28, 28, 1)

png

对于每个批次的数字,调用 MNISTCompressionTrainer 会产生该批次的平均速率和失真:

(example_batch, _), = validation_dataset.batch(32).take(1)
trainer = MNISTCompressionTrainer(10)
example_output = trainer(example_batch)

print("rate: ", example_output["rate"])
print("distortion: ", example_output["distortion"])
rate:  tf.Tensor(20.296253, shape=(), dtype=float32)
distortion:  tf.Tensor(0.14659302, shape=(), dtype=float32)
2023-11-07 19:11:55.744184: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

在下一部分中,我们建立模型来对这两个损失执行梯度下降。

训练模型

我们以优化速率–失真拉格朗日的方式编译训练器,即速率和失真的总和,其中一项由拉格朗日参数 \(\lambda\) 加权。

此损失函数对模型的不同部分有着不同的影响:

  • 对分析转换进行训练以产生隐空间表示,该表示会在速率和失真之间实现所需的权衡。
  • 给定隐空间表示,训练合成转换以将失真最小化。
  • 训练先验参数以将给定隐空间表示的速率最小化。这与在最大似然意义上拟合隐空间的边缘分布的先验相同。
def pass_through_loss(_, x):
  # Since rate and distortion are unsupervised, the loss doesn't need a target.
  return x

def make_mnist_compression_trainer(lmbda, latent_dims=50):
  trainer = MNISTCompressionTrainer(latent_dims)
  trainer.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    # Just pass through rate and distortion as losses/metrics.
    loss=dict(rate=pass_through_loss, distortion=pass_through_loss),
    metrics=dict(rate=pass_through_loss, distortion=pass_through_loss),
    loss_weights=dict(rate=1., distortion=lmbda),
  )
  return trainer

接下来训练模型。此处不需要人工注释,因为我们只想压缩图像,所以我们使用 map 将它们丢弃,取而代之的是为速率和失真添加“虚拟”目标。

def add_rd_targets(image, label):
  # Training is unsupervised, so labels aren't necessary here. However, we
  # need to add "dummy" targets for rate and distortion.
  return image, dict(rate=0., distortion=0.)

def train_mnist_model(lmbda):
  trainer = make_mnist_compression_trainer(lmbda)
  trainer.fit(
      training_dataset.map(add_rd_targets).batch(128).prefetch(8),
      epochs=15,
      validation_data=validation_dataset.map(add_rd_targets).batch(128).cache(),
      validation_freq=1,
      verbose=1,
  )
  return trainer

trainer = train_mnist_model(lmbda=2000)
Epoch 1/15
468/469 [============================>.] - ETA: 0s - loss: 220.8049 - distortion_loss: 0.0604 - rate_loss: 99.9450 - distortion_pass_through_loss: 0.0604 - rate_pass_through_loss: 99.9450
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 14s 23ms/step - loss: 220.7253 - distortion_loss: 0.0604 - rate_loss: 99.9313 - distortion_pass_through_loss: 0.0604 - rate_pass_through_loss: 99.9268 - val_loss: 178.0563 - val_distortion_loss: 0.0427 - val_rate_loss: 92.6118 - val_distortion_pass_through_loss: 0.0427 - val_rate_pass_through_loss: 92.6175
Epoch 2/15
469/469 [==============================] - 10s 20ms/step - loss: 167.0754 - distortion_loss: 0.0414 - rate_loss: 84.3045 - distortion_pass_through_loss: 0.0414 - rate_pass_through_loss: 84.3002 - val_loss: 156.9387 - val_distortion_loss: 0.0404 - val_rate_loss: 76.0390 - val_distortion_pass_through_loss: 0.0404 - val_rate_pass_through_loss: 76.0476
Epoch 3/15
469/469 [==============================] - 9s 20ms/step - loss: 151.9107 - distortion_loss: 0.0402 - rate_loss: 71.5396 - distortion_pass_through_loss: 0.0402 - rate_pass_through_loss: 71.5365 - val_loss: 144.9222 - val_distortion_loss: 0.0404 - val_rate_loss: 64.1234 - val_distortion_pass_through_loss: 0.0404 - val_rate_pass_through_loss: 64.1224
Epoch 4/15
469/469 [==============================] - 10s 20ms/step - loss: 142.9746 - distortion_loss: 0.0398 - rate_loss: 63.3548 - distortion_pass_through_loss: 0.0398 - rate_pass_through_loss: 63.3520 - val_loss: 136.4569 - val_distortion_loss: 0.0406 - val_rate_loss: 55.2759 - val_distortion_pass_through_loss: 0.0406 - val_rate_pass_through_loss: 55.2598
Epoch 5/15
469/469 [==============================] - 9s 20ms/step - loss: 137.2784 - distortion_loss: 0.0394 - rate_loss: 58.3862 - distortion_pass_through_loss: 0.0394 - rate_pass_through_loss: 58.3849 - val_loss: 131.0564 - val_distortion_loss: 0.0410 - val_rate_loss: 48.9617 - val_distortion_pass_through_loss: 0.0411 - val_rate_pass_through_loss: 48.9352
Epoch 6/15
469/469 [==============================] - 9s 20ms/step - loss: 133.4307 - distortion_loss: 0.0391 - rate_loss: 55.3152 - distortion_pass_through_loss: 0.0391 - rate_pass_through_loss: 55.3141 - val_loss: 127.2503 - val_distortion_loss: 0.0409 - val_rate_loss: 45.5429 - val_distortion_pass_through_loss: 0.0409 - val_rate_pass_through_loss: 45.5313
Epoch 7/15
469/469 [==============================] - 9s 20ms/step - loss: 130.2906 - distortion_loss: 0.0385 - rate_loss: 53.1917 - distortion_pass_through_loss: 0.0385 - rate_pass_through_loss: 53.1902 - val_loss: 125.3515 - val_distortion_loss: 0.0415 - val_rate_loss: 42.2906 - val_distortion_pass_through_loss: 0.0415 - val_rate_pass_through_loss: 42.2763
Epoch 8/15
469/469 [==============================] - 9s 20ms/step - loss: 127.7567 - distortion_loss: 0.0382 - rate_loss: 51.4476 - distortion_pass_through_loss: 0.0382 - rate_pass_through_loss: 51.4464 - val_loss: 121.5973 - val_distortion_loss: 0.0399 - val_rate_loss: 41.8566 - val_distortion_pass_through_loss: 0.0399 - val_rate_pass_through_loss: 41.8531
Epoch 9/15
469/469 [==============================] - 9s 20ms/step - loss: 125.2723 - distortion_loss: 0.0377 - rate_loss: 49.9508 - distortion_pass_through_loss: 0.0377 - rate_pass_through_loss: 49.9505 - val_loss: 118.3108 - val_distortion_loss: 0.0381 - val_rate_loss: 42.1530 - val_distortion_pass_through_loss: 0.0381 - val_rate_pass_through_loss: 42.1452
Epoch 10/15
469/469 [==============================] - 9s 20ms/step - loss: 123.1027 - distortion_loss: 0.0372 - rate_loss: 48.6255 - distortion_pass_through_loss: 0.0372 - rate_pass_through_loss: 48.6245 - val_loss: 117.7949 - val_distortion_loss: 0.0385 - val_rate_loss: 40.8236 - val_distortion_pass_through_loss: 0.0385 - val_rate_pass_through_loss: 40.8248
Epoch 11/15
469/469 [==============================] - 9s 20ms/step - loss: 121.0393 - distortion_loss: 0.0368 - rate_loss: 47.4335 - distortion_pass_through_loss: 0.0368 - rate_pass_through_loss: 47.4328 - val_loss: 115.8072 - val_distortion_loss: 0.0374 - val_rate_loss: 40.9586 - val_distortion_pass_through_loss: 0.0374 - val_rate_pass_through_loss: 40.9627
Epoch 12/15
469/469 [==============================] - 9s 20ms/step - loss: 119.2975 - distortion_loss: 0.0364 - rate_loss: 46.5058 - distortion_pass_through_loss: 0.0364 - rate_pass_through_loss: 46.5050 - val_loss: 113.8450 - val_distortion_loss: 0.0366 - val_rate_loss: 40.6510 - val_distortion_pass_through_loss: 0.0366 - val_rate_pass_through_loss: 40.6602
Epoch 13/15
469/469 [==============================] - 9s 20ms/step - loss: 117.8968 - distortion_loss: 0.0361 - rate_loss: 45.7277 - distortion_pass_through_loss: 0.0361 - rate_pass_through_loss: 45.7267 - val_loss: 113.4282 - val_distortion_loss: 0.0366 - val_rate_loss: 40.1531 - val_distortion_pass_through_loss: 0.0366 - val_rate_pass_through_loss: 40.1618
Epoch 14/15
469/469 [==============================] - 9s 20ms/step - loss: 116.6584 - distortion_loss: 0.0358 - rate_loss: 45.0070 - distortion_pass_through_loss: 0.0358 - rate_pass_through_loss: 45.0067 - val_loss: 112.4899 - val_distortion_loss: 0.0355 - val_rate_loss: 41.4291 - val_distortion_pass_through_loss: 0.0355 - val_rate_pass_through_loss: 41.4347
Epoch 15/15
469/469 [==============================] - 9s 20ms/step - loss: 115.5233 - distortion_loss: 0.0355 - rate_loss: 44.4816 - distortion_pass_through_loss: 0.0355 - rate_pass_through_loss: 44.4811 - val_loss: 111.9730 - val_distortion_loss: 0.0355 - val_rate_loss: 41.0487 - val_distortion_pass_through_loss: 0.0355 - val_rate_pass_through_loss: 41.0486

压缩一些 MNIST 图像

对于测试时的压缩和解压缩,我们将训练好的模型分成两部分:

  • 编码器端由分析转换和熵模型组成。
  • 解码端由合成转换和相同的熵模型组成。

测试时,隐空间没有加性噪声,但它们会被量化并随后无损压缩,因此我们给它们提供新的名称。我们将它们和图像重建分别称为 \(\hat x\) 和 \(\hat y\)(按照端到端优化图像压缩)。

class MNISTCompressor(tf.keras.Model):
  """Compresses MNIST images to strings."""

  def __init__(self, analysis_transform, entropy_model):
    super().__init__()
    self.analysis_transform = analysis_transform
    self.entropy_model = entropy_model

  def call(self, x):
    # Ensure inputs are floats in the range (0, 1).
    x = tf.cast(x, self.compute_dtype) / 255.
    y = self.analysis_transform(x)
    # Also return the exact information content of each digit.
    _, bits = self.entropy_model(y, training=False)
    return self.entropy_model.compress(y), bits
class MNISTDecompressor(tf.keras.Model):
  """Decompresses MNIST images from strings."""

  def __init__(self, entropy_model, synthesis_transform):
    super().__init__()
    self.entropy_model = entropy_model
    self.synthesis_transform = synthesis_transform

  def call(self, string):
    y_hat = self.entropy_model.decompress(string, ())
    x_hat = self.synthesis_transform(y_hat)
    # Scale and cast back to 8-bit integer.
    return tf.saturate_cast(tf.round(x_hat * 255.), tf.uint8)

当使用 compression=True 实例化时,熵模型将学习的先验转换为范围编码算法的表。调用 compress() 时,会调用此算法以将隐空间向量转换为位序列。每个二进制字符串的长度近似于隐空间的信息内容(先验下隐空间的负对数似然值)。

压缩和解压缩的熵模型必须是相同的实例,因为范围编码表需要在两端完全相同。否则,可能会出现解码错误。

def make_mnist_codec(trainer, **kwargs):
  # The entropy model must be created with `compression=True` and the same
  # instance must be shared between compressor and decompressor.
  entropy_model = tfc.ContinuousBatchedEntropyModel(
      trainer.prior, coding_rank=1, compression=True, **kwargs)
  compressor = MNISTCompressor(trainer.analysis_transform, entropy_model)
  decompressor = MNISTDecompressor(entropy_model, trainer.synthesis_transform)
  return compressor, decompressor

compressor, decompressor = make_mnist_codec(trainer)

从验证数据集中抓取 16 个图像。您可以通过将参数更改为 skip 来选择不同的子集。

(originals, _), = validation_dataset.batch(16).skip(3).take(1)

将它们压缩为字符串,并以位为单位跟踪它们的每个信息内容。

strings, entropies = compressor(originals)

print(f"String representation of first digit in hexadecimal: 0x{strings[0].numpy().hex()}")
print(f"Number of bits actually needed to represent it: {entropies[0]:0.2f}")
String representation of first digit in hexadecimal: 0x24a97dd328
Number of bits actually needed to represent it: 36.21

从字符串中将图像解压缩回来。

reconstructions = decompressor(strings)

显示 16 个原始数字中的每一个及其压缩二进制表示,以及重建的数字。

display_digits(originals, strings, entropies, reconstructions)

png

请注意,编码字符串的长度与每个数字的信息内容不同。

这是因为范围编码流程使用离散概率,并且具有少量开销。因此,特别是对于短字符串,这种对应关系只是近似的。不过,范围编码是渐近最优的:在极限情况下,期望的比特数将接近交叉熵(期望的信息内容),训练模型中的速率项是一个上限。

速率–失真权衡

在上面,该模型经过训练以在用于表示每个数字的平均位数与重建中产生的错误之间进行特定权衡(由 lmbda=2000 给出)。

当我们用不同的值重复实验时,会发生什么?

我们首先将 \(\lambda\) 减少到 500。

def train_and_visualize_model(lmbda):
  trainer = train_mnist_model(lmbda=lmbda)
  compressor, decompressor = make_mnist_codec(trainer)
  strings, entropies = compressor(originals)
  reconstructions = decompressor(strings)
  display_digits(originals, strings, entropies, reconstructions)

train_and_visualize_model(lmbda=500)
Epoch 1/15
469/469 [==============================] - ETA: 0s - loss: 127.5774 - distortion_loss: 0.0701 - rate_loss: 92.5198 - distortion_pass_through_loss: 0.0701 - rate_pass_through_loss: 92.5137
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 12s 21ms/step - loss: 127.5774 - distortion_loss: 0.0701 - rate_loss: 92.5198 - distortion_pass_through_loss: 0.0701 - rate_pass_through_loss: 92.5137 - val_loss: 107.2737 - val_distortion_loss: 0.0543 - val_rate_loss: 80.1066 - val_distortion_pass_through_loss: 0.0543 - val_rate_pass_through_loss: 80.1077
Epoch 2/15
469/469 [==============================] - 9s 20ms/step - loss: 97.0805 - distortion_loss: 0.0537 - rate_loss: 70.2336 - distortion_pass_through_loss: 0.0537 - rate_pass_through_loss: 70.2283 - val_loss: 86.2141 - val_distortion_loss: 0.0607 - val_rate_loss: 55.8584 - val_distortion_pass_through_loss: 0.0607 - val_rate_pass_through_loss: 55.8653
Epoch 3/15
469/469 [==============================] - 10s 20ms/step - loss: 81.0770 - distortion_loss: 0.0559 - rate_loss: 53.1267 - distortion_pass_through_loss: 0.0559 - rate_pass_through_loss: 53.1230 - val_loss: 71.6339 - val_distortion_loss: 0.0679 - val_rate_loss: 37.7086 - val_distortion_pass_through_loss: 0.0678 - val_rate_pass_through_loss: 37.7265
Epoch 4/15
469/469 [==============================] - 9s 20ms/step - loss: 71.4976 - distortion_loss: 0.0591 - rate_loss: 41.9352 - distortion_pass_through_loss: 0.0591 - rate_pass_through_loss: 41.9328 - val_loss: 62.8137 - val_distortion_loss: 0.0751 - val_rate_loss: 25.2747 - val_distortion_pass_through_loss: 0.0751 - val_rate_pass_through_loss: 25.2795
Epoch 5/15
469/469 [==============================] - 9s 20ms/step - loss: 65.9054 - distortion_loss: 0.0619 - rate_loss: 34.9402 - distortion_pass_through_loss: 0.0619 - rate_pass_through_loss: 34.9385 - val_loss: 57.3453 - val_distortion_loss: 0.0781 - val_rate_loss: 18.2953 - val_distortion_pass_through_loss: 0.0781 - val_rate_pass_through_loss: 18.3022
Epoch 6/15
469/469 [==============================] - 9s 20ms/step - loss: 62.4431 - distortion_loss: 0.0640 - rate_loss: 30.4191 - distortion_pass_through_loss: 0.0640 - rate_pass_through_loss: 30.4172 - val_loss: 55.2744 - val_distortion_loss: 0.0853 - val_rate_loss: 12.6452 - val_distortion_pass_through_loss: 0.0852 - val_rate_pass_through_loss: 12.6545
Epoch 7/15
469/469 [==============================] - 9s 20ms/step - loss: 59.8767 - distortion_loss: 0.0654 - rate_loss: 27.1681 - distortion_pass_through_loss: 0.0654 - rate_pass_through_loss: 27.1672 - val_loss: 51.8969 - val_distortion_loss: 0.0795 - val_rate_loss: 12.1563 - val_distortion_pass_through_loss: 0.0794 - val_rate_pass_through_loss: 12.1664
Epoch 8/15
469/469 [==============================] - 9s 20ms/step - loss: 57.8186 - distortion_loss: 0.0660 - rate_loss: 24.8123 - distortion_pass_through_loss: 0.0660 - rate_pass_through_loss: 24.8117 - val_loss: 49.5085 - val_distortion_loss: 0.0729 - val_rate_loss: 13.0673 - val_distortion_pass_through_loss: 0.0729 - val_rate_pass_through_loss: 13.0776
Epoch 9/15
469/469 [==============================] - 9s 20ms/step - loss: 55.8623 - distortion_loss: 0.0659 - rate_loss: 22.9370 - distortion_pass_through_loss: 0.0658 - rate_pass_through_loss: 22.9366 - val_loss: 48.2837 - val_distortion_loss: 0.0715 - val_rate_loss: 12.5473 - val_distortion_pass_through_loss: 0.0715 - val_rate_pass_through_loss: 12.5685
Epoch 10/15
469/469 [==============================] - 9s 20ms/step - loss: 54.0484 - distortion_loss: 0.0652 - rate_loss: 21.4430 - distortion_pass_through_loss: 0.0652 - rate_pass_through_loss: 21.4425 - val_loss: 47.3454 - val_distortion_loss: 0.0687 - val_rate_loss: 13.0092 - val_distortion_pass_through_loss: 0.0687 - val_rate_pass_through_loss: 13.0121
Epoch 11/15
469/469 [==============================] - 9s 20ms/step - loss: 52.4273 - distortion_loss: 0.0644 - rate_loss: 20.2257 - distortion_pass_through_loss: 0.0644 - rate_pass_through_loss: 20.2251 - val_loss: 46.6972 - val_distortion_loss: 0.0666 - val_rate_loss: 13.3847 - val_distortion_pass_through_loss: 0.0667 - val_rate_pass_through_loss: 13.3839
Epoch 12/15
469/469 [==============================] - 9s 20ms/step - loss: 50.9907 - distortion_loss: 0.0636 - rate_loss: 19.2036 - distortion_pass_through_loss: 0.0636 - rate_pass_through_loss: 19.2033 - val_loss: 46.1560 - val_distortion_loss: 0.0640 - val_rate_loss: 14.1343 - val_distortion_pass_through_loss: 0.0641 - val_rate_pass_through_loss: 14.1383
Epoch 13/15
469/469 [==============================] - 9s 20ms/step - loss: 49.7850 - distortion_loss: 0.0628 - rate_loss: 18.4064 - distortion_pass_through_loss: 0.0628 - rate_pass_through_loss: 18.4059 - val_loss: 45.6184 - val_distortion_loss: 0.0635 - val_rate_loss: 13.8763 - val_distortion_pass_through_loss: 0.0636 - val_rate_pass_through_loss: 13.8705
Epoch 14/15
469/469 [==============================] - 9s 20ms/step - loss: 48.8055 - distortion_loss: 0.0621 - rate_loss: 17.7620 - distortion_pass_through_loss: 0.0621 - rate_pass_through_loss: 17.7618 - val_loss: 45.3379 - val_distortion_loss: 0.0622 - val_rate_loss: 14.2591 - val_distortion_pass_through_loss: 0.0622 - val_rate_pass_through_loss: 14.2615
Epoch 15/15
469/469 [==============================] - 9s 20ms/step - loss: 48.0368 - distortion_loss: 0.0614 - rate_loss: 17.3278 - distortion_pass_through_loss: 0.0614 - rate_pass_through_loss: 17.3274 - val_loss: 45.1061 - val_distortion_loss: 0.0621 - val_rate_loss: 14.0476 - val_distortion_pass_through_loss: 0.0622 - val_rate_pass_through_loss: 14.0455

png

代码的比特率下降了,数字的保真度也随之降低。但是,大多数数字仍然可以识别。

我们进一步减少 \(\lambda\)。

train_and_visualize_model(lmbda=300)
Epoch 1/15
469/469 [==============================] - ETA: 0s - loss: 114.0797 - distortion_loss: 0.0770 - rate_loss: 90.9856 - distortion_pass_through_loss: 0.0770 - rate_pass_through_loss: 90.9790
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 11s 20ms/step - loss: 114.0797 - distortion_loss: 0.0770 - rate_loss: 90.9856 - distortion_pass_through_loss: 0.0770 - rate_pass_through_loss: 90.9790 - val_loss: 96.4566 - val_distortion_loss: 0.0668 - val_rate_loss: 76.4093 - val_distortion_pass_through_loss: 0.0668 - val_rate_pass_through_loss: 76.4141
Epoch 2/15
469/469 [==============================] - 9s 20ms/step - loss: 85.9912 - distortion_loss: 0.0617 - rate_loss: 67.4928 - distortion_pass_through_loss: 0.0617 - rate_pass_through_loss: 67.4874 - val_loss: 73.8890 - val_distortion_loss: 0.0744 - val_rate_loss: 51.5573 - val_distortion_pass_through_loss: 0.0744 - val_rate_pass_through_loss: 51.5612
Epoch 3/15
469/469 [==============================] - 9s 20ms/step - loss: 68.9426 - distortion_loss: 0.0650 - rate_loss: 49.4320 - distortion_pass_through_loss: 0.0650 - rate_pass_through_loss: 49.4282 - val_loss: 58.8838 - val_distortion_loss: 0.0916 - val_rate_loss: 31.3956 - val_distortion_pass_through_loss: 0.0916 - val_rate_pass_through_loss: 31.4006
Epoch 4/15
469/469 [==============================] - 9s 20ms/step - loss: 58.3601 - distortion_loss: 0.0696 - rate_loss: 37.4756 - distortion_pass_through_loss: 0.0696 - rate_pass_through_loss: 37.4729 - val_loss: 49.2228 - val_distortion_loss: 0.1028 - val_rate_loss: 18.3848 - val_distortion_pass_through_loss: 0.1028 - val_rate_pass_through_loss: 18.3872
Epoch 5/15
469/469 [==============================] - 9s 20ms/step - loss: 52.0172 - distortion_loss: 0.0738 - rate_loss: 29.8754 - distortion_pass_through_loss: 0.0738 - rate_pass_through_loss: 29.8737 - val_loss: 42.4726 - val_distortion_loss: 0.1023 - val_rate_loss: 11.7721 - val_distortion_pass_through_loss: 0.1024 - val_rate_pass_through_loss: 11.7677
Epoch 6/15
469/469 [==============================] - 9s 20ms/step - loss: 48.1506 - distortion_loss: 0.0774 - rate_loss: 24.9305 - distortion_pass_through_loss: 0.0774 - rate_pass_through_loss: 24.9294 - val_loss: 38.6997 - val_distortion_loss: 0.1037 - val_rate_loss: 7.5762 - val_distortion_pass_through_loss: 0.1038 - val_rate_pass_through_loss: 7.5757
Epoch 7/15
469/469 [==============================] - 9s 20ms/step - loss: 45.3835 - distortion_loss: 0.0799 - rate_loss: 21.4162 - distortion_pass_through_loss: 0.0799 - rate_pass_through_loss: 21.4161 - val_loss: 36.0530 - val_distortion_loss: 0.0992 - val_rate_loss: 6.2864 - val_distortion_pass_through_loss: 0.0992 - val_rate_pass_through_loss: 6.2964
Epoch 8/15
469/469 [==============================] - 9s 20ms/step - loss: 43.2323 - distortion_loss: 0.0816 - rate_loss: 18.7436 - distortion_pass_through_loss: 0.0816 - rate_pass_through_loss: 18.7432 - val_loss: 34.5369 - val_distortion_loss: 0.0974 - val_rate_loss: 5.3176 - val_distortion_pass_through_loss: 0.0975 - val_rate_pass_through_loss: 5.3151
Epoch 9/15
469/469 [==============================] - 9s 20ms/step - loss: 41.3475 - distortion_loss: 0.0823 - rate_loss: 16.6684 - distortion_pass_through_loss: 0.0823 - rate_pass_through_loss: 16.6681 - val_loss: 33.3808 - val_distortion_loss: 0.0913 - val_rate_loss: 5.9764 - val_distortion_pass_through_loss: 0.0914 - val_rate_pass_through_loss: 5.9756
Epoch 10/15
469/469 [==============================] - 9s 20ms/step - loss: 39.4882 - distortion_loss: 0.0813 - rate_loss: 15.0862 - distortion_pass_through_loss: 0.0813 - rate_pass_through_loss: 15.0858 - val_loss: 32.6920 - val_distortion_loss: 0.0849 - val_rate_loss: 7.2103 - val_distortion_pass_through_loss: 0.0850 - val_rate_pass_through_loss: 7.2137
Epoch 11/15
469/469 [==============================] - 9s 20ms/step - loss: 37.8729 - distortion_loss: 0.0798 - rate_loss: 13.9354 - distortion_pass_through_loss: 0.0798 - rate_pass_through_loss: 13.9349 - val_loss: 32.1751 - val_distortion_loss: 0.0800 - val_rate_loss: 8.1763 - val_distortion_pass_through_loss: 0.0800 - val_rate_pass_through_loss: 8.1914
Epoch 12/15
469/469 [==============================] - 9s 20ms/step - loss: 36.5992 - distortion_loss: 0.0783 - rate_loss: 13.1068 - distortion_pass_through_loss: 0.0783 - rate_pass_through_loss: 13.1064 - val_loss: 31.9699 - val_distortion_loss: 0.0797 - val_rate_loss: 8.0496 - val_distortion_pass_through_loss: 0.0798 - val_rate_pass_through_loss: 8.0581
Epoch 13/15
469/469 [==============================] - 9s 20ms/step - loss: 35.5226 - distortion_loss: 0.0768 - rate_loss: 12.4821 - distortion_pass_through_loss: 0.0768 - rate_pass_through_loss: 12.4820 - val_loss: 31.7481 - val_distortion_loss: 0.0770 - val_rate_loss: 8.6345 - val_distortion_pass_through_loss: 0.0770 - val_rate_pass_through_loss: 8.6493
Epoch 14/15
469/469 [==============================] - 9s 20ms/step - loss: 34.7314 - distortion_loss: 0.0759 - rate_loss: 11.9717 - distortion_pass_through_loss: 0.0759 - rate_pass_through_loss: 11.9713 - val_loss: 31.7234 - val_distortion_loss: 0.0779 - val_rate_loss: 8.3608 - val_distortion_pass_through_loss: 0.0779 - val_rate_pass_through_loss: 8.3658
Epoch 15/15
469/469 [==============================] - 9s 20ms/step - loss: 34.1140 - distortion_loss: 0.0751 - rate_loss: 11.5740 - distortion_pass_through_loss: 0.0751 - rate_pass_through_loss: 11.5740 - val_loss: 31.4139 - val_distortion_loss: 0.0745 - val_rate_loss: 9.0766 - val_distortion_pass_through_loss: 0.0745 - val_rate_pass_through_loss: 9.0960

png

字符串现在开始变得更短,大约每个数字一个字节。然而,这是有代价的。越来越多的数字变得无法辨认。

这表明此模型与人类对错误的感知无关,它只是根据像素值测量绝对偏差。为了获得更好的感知图像质量,我们需要用感知损失来代替像素损失。

使用解码器作为生成模型

如果我们向解码器提供随机位,这将有效地从模型学习表示数字的分布中采样。

首先,重新实例化压缩器/解压缩器而不进行完整性检查,该检查将检测输入字符串是否未完全解码。

compressor, decompressor = make_mnist_codec(trainer, decode_sanity_check=False)

现在,将足够长的随机字符串输入解压缩器,以便它可以从中解码/采样数字。

import os

strings = tf.constant([os.urandom(8) for _ in range(16)])
samples = decompressor(strings)

fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(5, 5))
axes = axes.ravel()
for i in range(len(axes)):
  axes[i].imshow(tf.squeeze(samples[i]))
  axes[i].axis("off")
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

png