

このノートブックでは、ニューラルネットワークと TensorFlow Compression を使って非可逆データ圧縮を行う方法を説明します。


以下の例では、オートエンコーダのようなモデルを使用して、MNIST データセットの画像を圧縮します。この手法は、『End-to-end Optimized Image Compression』という論文を基としています。



pip で Tensorflow Compression をインストールします。

# Installs the latest version of TFC compatible with the installed TF version.

read MAJOR MINOR <<< "$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\d+)\.(\d+).*/\1 \2/sg')"
pip install "tensorflow-compression<$MAJOR.$(($MINOR+1))"


import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_datasets as tfds
トレーニングモデルは、以下の 3 つで構成されています。

  • 分析(またはエンコーダ)変換: 画像を潜在空間に変換します。
  • 合成(またはデコーダ)変換: 潜在空間から画像空間に変換します。
  • 事前確率とエントロピーモデル: 潜在空間の周辺分布をモデル化します。


def make_analysis_transform(latent_dims):
  """Creates the analysis (encoder) transform."""
  return tf.keras.Sequential([
          20, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_1"),
          50, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_2"),
          500, use_bias=True, activation="leaky_relu", name="fc_1"),
          latent_dims, use_bias=True, activation=None, name="fc_2"),
  ], name="analysis_transform")
def make_synthesis_transform():
  """Creates the synthesis (decoder) transform."""
  return tf.keras.Sequential([
          500, use_bias=True, activation="leaky_relu", name="fc_1"),
          2450, use_bias=True, activation="leaky_relu", name="fc_2"),
      tf.keras.layers.Reshape((7, 7, 50)),
          20, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_1"),
          1, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_2"),
  ], name="synthesis_transform")


その call メソッドは、以下を計算するようにセットアップされます。

  • レート: 数字のバッチを表現するために必要なビット数の推定
  • ひずみ: 元の数字と再構築された数字のピクセルの平均絶対差
class MNISTCompressionTrainer(tf.keras.Model):
  """Model that trains a compressor/decompressor for MNIST."""

  def __init__(self, latent_dims):
    self.analysis_transform = make_analysis_transform(latent_dims)
    self.synthesis_transform = make_synthesis_transform()
    self.prior_log_scales = tf.Variable(tf.zeros((latent_dims,)))

  def prior(self):
    return tfc.NoisyLogistic(loc=0., scale=tf.exp(self.prior_log_scales))

  def call(self, x, training):
    """Computes rate and distortion losses."""
    # Ensure inputs are floats in the range (0, 1).
    x = tf.cast(x, self.compute_dtype) / 255.
    x = tf.reshape(x, (-1, 28, 28, 1))

    # Compute latent space representation y, perturb it and model its entropy,
    # then compute the reconstructed pixel-level representation x_hat.
    y = self.analysis_transform(x)
    entropy_model = tfc.ContinuousBatchedEntropyModel(
        self.prior, coding_rank=1, compression=False)
    y_tilde, rate = entropy_model(y, training=training)
    x_tilde = self.synthesis_transform(y_tilde)

    # Average number of bits per MNIST digit.
    rate = tf.reduce_mean(rate)

    # Mean absolute difference across pixels.
    distortion = tf.reduce_mean(abs(x - x_tilde))

    return dict(rate=rate, distortion=distortion)


では、トレーニングセットの画像を 1 つ使用して、順を追って説明します。トレーニングと検証用の MNIST データセットを読み込みます。

training_dataset, validation_dataset = tfds.load(
    split=["train", "test"],

1 つの画像 x を抽出します。

(x, _), = validation_dataset.take(1)

print(f"Data type: {x.dtype}")
print(f"Shape: {x.shape}")
Data type: <dtype: 'uint8'>
Shape: (28, 28, 1)
潜在の表現 y を取得するには、float32 にキャストして batch 次元を追加し、それを分析変換に通す必要があります。

x = tf.cast(x, tf.float32) / 255.
x = tf.reshape(x, (-1, 28, 28, 1))
y = make_analysis_transform(10)(x)

print("y:", y)
y: tf.Tensor(
[[ 0.03988452 -0.02631121 -0.05344866 -0.04364791  0.06735273 -0.00989169
  -0.05671643 -0.01362787 -0.0330795  -0.03137782]], shape=(1, 10), dtype=float32)

潜在は、テスト時に量子化されます。これをトレーニング中に区別可能な方法でモデル化するために、(.5,.5) の間隔で一様ノイズを追加し、その結果を y~ をとします。これは、『End-to-end Optimized Image Compression』論文で使用されているのと同じです。

y_tilde = y + tf.random.uniform(y.shape, -.5, .5)

print("y_tilde:", y_tilde)
y_tilde: tf.Tensor(
[[-0.00847201 -0.19141883  0.01641406 -0.3539529   0.10089394  0.23102134
   0.12243177  0.16727103  0.22074556  0.08011779]], shape=(1, 10), dtype=float32)

「事前確率」は、ノイズを含む潜在の周辺分布をモデル化するためにトレーニングする分布の密度です。たとえば、潜在次元ごとに異なるスケールを持つ独立した一連のロジスティック分布であることがあります。tfc.NoisyLogistic は、潜在には追加ノイズがあるという事実を考慮します。スケールがゼロに近づくにつれ、ロジスティック分布はディラックのデルタ(スパイク)に近づくものですが、追加ノイズにより、「ノイズの多い」分布は一様分布に近づきます。

prior = tfc.NoisyLogistic(loc=0., scale=tf.linspace(.01, 2., 10))

_ = tf.linspace(-6., 6., 501)[:, None]
plt.plot(_, prior.prob(_));


トレーニング中、tfc.ContinuousBatchedEntropyModel は一様ノイズを追加し、そのノイズと事前確率を使用して(区別可能な)レート(潜在表現をエンコードするために必要な平均ビット数)の上限を計算します。この上限は、損失として最小化できます。

entropy_model = tfc.ContinuousBatchedEntropyModel(
    prior, coding_rank=1, compression=False)
y_tilde, rate = entropy_model(y, training=True)

print("rate:", rate)
print("y_tilde:", y_tilde)
rate: tf.Tensor([18.430876], shape=(1,), dtype=float32)
y_tilde: tf.Tensor(
[[ 0.00818415  0.4172811  -0.05954609  0.3539252  -0.02196757  0.2851495
  -0.00319849 -0.15237509 -0.46015334  0.07735881]], shape=(1, 10), dtype=float32)

最後に、ノイズのある潜在が合成変換を通過し、画像の再構築 x~ が生成されます。明らかに、変換はトレーニングされていないため、この再構築にはあまり利用価値がありません。

x_tilde = make_synthesis_transform()(y_tilde)

# Mean absolute difference across pixels.
distortion = tf.reduce_mean(abs(x - x_tilde))
print("distortion:", distortion)

x_tilde = tf.saturate_cast(x_tilde[0] * 255, tf.uint8)
print(f"Data type: {x_tilde.dtype}")
print(f"Shape: {x_tilde.shape}")
distortion: tf.Tensor(0.17073585, shape=(), dtype=float32)
Data type: <dtype: 'uint8'>
Shape: (28, 28, 1)


数字のバッチごとに MNISTCompressionTrainer を呼び出すと、レートとそのバッチの平均としてのひずみが生成されます。

(example_batch, _), = validation_dataset.batch(32).take(1)
trainer = MNISTCompressionTrainer(10)
example_output = trainer(example_batch)

print("rate: ", example_output["rate"])
print("distortion: ", example_output["distortion"])
rate:  tf.Tensor(20.296253, shape=(), dtype=float32)
distortion:  tf.Tensor(0.14659302, shape=(), dtype=float32)
次のセクションでは、これらの 2 つの損失で勾配降下を行うようにモデルをセットアップします。


レートとひずみのラグアンジアン、つまりレートとひずみの和を最適化するようにトレーナーをコンパイルします。ここで、いずれかの項はラグランジュ関数パラメータ λ で重み付けされます。


  • 分析変換は、レートとひずみの目的のトレードオフを達成する潜在表現を生成するようにトレーニングされます。
  • 合成変換は、特定の潜在表現でひずみを最小化するようにトレーニングされます。
  • 事前確率のパラメータは、特定の潜在表現でレートを最小化するようにトレーニングされます。これは、事前確率を最大尤度において潜在の周辺分布に適合するのと同じです。
def pass_through_loss(_, x):
  # Since rate and distortion are unsupervised, the loss doesn't need a target.
  return x

def make_mnist_compression_trainer(lmbda, latent_dims=50):
  trainer = MNISTCompressionTrainer(latent_dims)
    # Just pass through rate and distortion as losses/metrics.
    loss=dict(rate=pass_through_loss, distortion=pass_through_loss),
    metrics=dict(rate=pass_through_loss, distortion=pass_through_loss),
    loss_weights=dict(rate=1., distortion=lmbda),
  return trainer

次に、モデルをトレーニングします。ここでは、画像を圧縮するだけであるため、人間による注釈付けは必要ありません。そのため、map を使って注釈を削除し、レートとひずみの「ダミー」ターゲットを追加します。

def add_rd_targets(image, label):
  # Training is unsupervised, so labels aren't necessary here. However, we
  # need to add "dummy" targets for rate and distortion.
  return image, dict(rate=0., distortion=0.)

def train_mnist_model(lmbda):
  trainer = make_mnist_compression_trainer(lmbda)
  return trainer

trainer = train_mnist_model(lmbda=2000)
Epoch 1/15
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
467/469 [============================>.] - ETA: 0s - loss: 217.8335 - distortion_loss: 0.0589 - rate_loss: 100.0286 - distortion_pass_through_loss: 0.0589 - rate_pass_through_loss: 100.0286
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 8s 8ms/step - loss: 217.6788 - distortion_loss: 0.0588 - rate_loss: 99.9971 - distortion_pass_through_loss: 0.0588 - rate_pass_through_loss: 99.9924 - val_loss: 178.6942 - val_distortion_loss: 0.0434 - val_rate_loss: 91.9543 - val_distortion_pass_through_loss: 0.0434 - val_rate_pass_through_loss: 91.9594
Epoch 2/15
469/469 [==============================] - 3s 6ms/step - loss: 166.7406 - distortion_loss: 0.0415 - rate_loss: 83.8010 - distortion_pass_through_loss: 0.0415 - rate_pass_through_loss: 83.7965 - val_loss: 157.1429 - val_distortion_loss: 0.0409 - val_rate_loss: 75.3245 - val_distortion_pass_through_loss: 0.0409 - val_rate_pass_through_loss: 75.3312
Epoch 3/15
469/469 [==============================] - 3s 6ms/step - loss: 151.3889 - distortion_loss: 0.0402 - rate_loss: 70.9442 - distortion_pass_through_loss: 0.0402 - rate_pass_through_loss: 70.9411 - val_loss: 144.4897 - val_distortion_loss: 0.0403 - val_rate_loss: 63.9764 - val_distortion_pass_through_loss: 0.0402 - val_rate_pass_through_loss: 63.9807
Epoch 4/15
469/469 [==============================] - 3s 6ms/step - loss: 142.9433 - distortion_loss: 0.0400 - rate_loss: 63.0224 - distortion_pass_through_loss: 0.0400 - rate_pass_through_loss: 63.0202 - val_loss: 137.4160 - val_distortion_loss: 0.0411 - val_rate_loss: 55.3039 - val_distortion_pass_through_loss: 0.0411 - val_rate_pass_through_loss: 55.2885
Epoch 5/15
469/469 [==============================] - 3s 6ms/step - loss: 137.3771 - distortion_loss: 0.0395 - rate_loss: 58.3224 - distortion_pass_through_loss: 0.0395 - rate_pass_through_loss: 58.3205 - val_loss: 132.2905 - val_distortion_loss: 0.0417 - val_rate_loss: 48.9274 - val_distortion_pass_through_loss: 0.0417 - val_rate_pass_through_loss: 48.9382
Epoch 6/15
469/469 [==============================] - 3s 6ms/step - loss: 133.5226 - distortion_loss: 0.0391 - rate_loss: 55.3185 - distortion_pass_through_loss: 0.0391 - rate_pass_through_loss: 55.3175 - val_loss: 127.0724 - val_distortion_loss: 0.0404 - val_rate_loss: 46.3232 - val_distortion_pass_through_loss: 0.0404 - val_rate_pass_through_loss: 46.3234
Epoch 7/15
469/469 [==============================] - 3s 6ms/step - loss: 130.3693 - distortion_loss: 0.0386 - rate_loss: 53.1581 - distortion_pass_through_loss: 0.0386 - rate_pass_through_loss: 53.1566 - val_loss: 123.5252 - val_distortion_loss: 0.0403 - val_rate_loss: 42.8875 - val_distortion_pass_through_loss: 0.0403 - val_rate_pass_through_loss: 42.8826
Epoch 8/15
469/469 [==============================] - 3s 6ms/step - loss: 128.0058 - distortion_loss: 0.0383 - rate_loss: 51.4280 - distortion_pass_through_loss: 0.0383 - rate_pass_through_loss: 51.4268 - val_loss: 121.3483 - val_distortion_loss: 0.0400 - val_rate_loss: 41.3487 - val_distortion_pass_through_loss: 0.0400 - val_rate_pass_through_loss: 41.3539
Epoch 9/15
469/469 [==============================] - 3s 6ms/step - loss: 125.6857 - distortion_loss: 0.0379 - rate_loss: 49.9369 - distortion_pass_through_loss: 0.0379 - rate_pass_through_loss: 49.9354 - val_loss: 119.4494 - val_distortion_loss: 0.0398 - val_rate_loss: 39.8691 - val_distortion_pass_through_loss: 0.0398 - val_rate_pass_through_loss: 39.8512
Epoch 10/15
469/469 [==============================] - 3s 6ms/step - loss: 123.4883 - distortion_loss: 0.0375 - rate_loss: 48.5796 - distortion_pass_through_loss: 0.0375 - rate_pass_through_loss: 48.5789 - val_loss: 118.5806 - val_distortion_loss: 0.0391 - val_rate_loss: 40.3094 - val_distortion_pass_through_loss: 0.0392 - val_rate_pass_through_loss: 40.3033
Epoch 11/15
469/469 [==============================] - 3s 6ms/step - loss: 121.5731 - distortion_loss: 0.0371 - rate_loss: 47.4418 - distortion_pass_through_loss: 0.0371 - rate_pass_through_loss: 47.4408 - val_loss: 115.8420 - val_distortion_loss: 0.0380 - val_rate_loss: 39.9038 - val_distortion_pass_through_loss: 0.0380 - val_rate_pass_through_loss: 39.8994
Epoch 12/15
469/469 [==============================] - 3s 6ms/step - loss: 119.7753 - distortion_loss: 0.0367 - rate_loss: 46.3968 - distortion_pass_through_loss: 0.0367 - rate_pass_through_loss: 46.3957 - val_loss: 114.8861 - val_distortion_loss: 0.0373 - val_rate_loss: 40.2797 - val_distortion_pass_through_loss: 0.0373 - val_rate_pass_through_loss: 40.2883
Epoch 13/15
469/469 [==============================] - 3s 6ms/step - loss: 118.1635 - distortion_loss: 0.0363 - rate_loss: 45.5972 - distortion_pass_through_loss: 0.0363 - rate_pass_through_loss: 45.5967 - val_loss: 114.0300 - val_distortion_loss: 0.0367 - val_rate_loss: 40.5612 - val_distortion_pass_through_loss: 0.0367 - val_rate_pass_through_loss: 40.5718
Epoch 14/15
469/469 [==============================] - 3s 6ms/step - loss: 116.8593 - distortion_loss: 0.0360 - rate_loss: 44.9107 - distortion_pass_through_loss: 0.0360 - rate_pass_through_loss: 44.9097 - val_loss: 112.6166 - val_distortion_loss: 0.0363 - val_rate_loss: 40.0470 - val_distortion_pass_through_loss: 0.0363 - val_rate_pass_through_loss: 40.0628
Epoch 15/15
469/469 [==============================] - 3s 6ms/step - loss: 115.6814 - distortion_loss: 0.0356 - rate_loss: 44.4095 - distortion_pass_through_loss: 0.0356 - rate_pass_through_loss: 44.4091 - val_loss: 112.2964 - val_distortion_loss: 0.0360 - val_rate_loss: 40.3579 - val_distortion_pass_through_loss: 0.0360 - val_rate_pass_through_loss: 40.3735

MNIST 画像を圧縮する

テスト時の圧縮と解凍用に、トレーニング済みのモデルを以下の 2 つに分割します。

  • エンコーダ側には、分析変換とエントロピーモデルが含まれます。
  • デコーダ側には、合成変換と同じエントロピーモデルが含まれます。

テスト時には、潜在に追加ノイズが含まれませんが、量子化されてから非可逆的に圧縮されるため、それらに新しい名前を指定します。それらと再構築の x^y^ をそれぞれに呼び出します(『End-to-end Optimized Image Compression』に従います)。

class MNISTCompressor(tf.keras.Model):
  """Compresses MNIST images to strings."""

  def __init__(self, analysis_transform, entropy_model):
    self.analysis_transform = analysis_transform
    self.entropy_model = entropy_model

  def call(self, x):
    # Ensure inputs are floats in the range (0, 1).
    x = tf.cast(x, self.compute_dtype) / 255.
    y = self.analysis_transform(x)
    # Also return the exact information content of each digit.
    _, bits = self.entropy_model(y, training=False)
    return self.entropy_model.compress(y), bits
class MNISTDecompressor(tf.keras.Model):
  """Decompresses MNIST images from strings."""

  def __init__(self, entropy_model, synthesis_transform):
    self.entropy_model = entropy_model
    self.synthesis_transform = synthesis_transform

  def call(self, string):
    y_hat = self.entropy_model.decompress(string, ())
    x_hat = self.synthesis_transform(y_hat)
    # Scale and cast back to 8-bit integer.
    return tf.saturate_cast(tf.round(x_hat * 255.), tf.uint8)

compression=True でインスタンス化すると、エントロピーモデルは、学習した事前確率をレンジコーディングアルゴリズムのテーブルに変換します。compress() を呼び出すと、このアルゴリズムが呼び出され、潜在空間ベクトルをビットシーケンスに変換します。各バイナリ文字列の長さは、潜在の情報コンテンツに近似します(事前確率の下の潜在の負の対数尤度)。


def make_mnist_codec(trainer, **kwargs):
  # The entropy model must be created with `compression=True` and the same
  # instance must be shared between compressor and decompressor.
  entropy_model = tfc.ContinuousBatchedEntropyModel(
      trainer.prior, coding_rank=1, compression=True, **kwargs)
  compressor = MNISTCompressor(trainer.analysis_transform, entropy_model)
  decompressor = MNISTDecompressor(entropy_model, trainer.synthesis_transform)
  return compressor, decompressor

compressor, decompressor = make_mnist_codec(trainer)

検証データセットから 16 個の画像を取得します。skip の引数を変えることで、さまざまなサブセットを選択できます。

(originals, _), = validation_dataset.batch(16).skip(3).take(1)


strings, entropies = compressor(originals)

print(f"String representation of first digit in hexadecimal: 0x{strings[0].numpy().hex()}")
print(f"Number of bits actually needed to represent it: {entropies[0]:0.2f}")
String representation of first digit in hexadecimal: 0x39c3f87dec58
Number of bits actually needed to represent it: 44.04


reconstructions = decompressor(strings)

各 16 個の元の数字を圧縮されたバイナリ表現と再構築された数字と共に表示します。

def display_digits(originals, strings, entropies, reconstructions):
  """Visualizes 16 digits together with their reconstructions."""
  fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(12.5, 5))
  axes = axes.ravel()
  for i in range(len(axes)):
    image = tf.concat([
        tf.zeros((28, 14), tf.uint8),
    ], 1)
        .5, .5, f"→ 0x{strings[i].numpy().hex()}\n{entropies[i]:0.2f} bits",
        ha="center", va="top", color="white", fontsize="small",
  plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

display_digits(originals, strings, entropies, reconstructions)





上記では、モデルは、各数字を表現するために使用されるビットの平均数と再構築で発生した誤差の間の特定のトレードオフのためにトレーニングされました(lmbda=2000 で指定)。


まずは、λ を 500 に減らしてみましょう。

def train_and_visualize_model(lmbda):
  trainer = train_mnist_model(lmbda=lmbda)
  compressor, decompressor = make_mnist_codec(trainer)
  strings, entropies = compressor(originals)
  reconstructions = decompressor(strings)
  display_digits(originals, strings, entropies, reconstructions)

Epoch 1/15
465/469 [============================>.] - ETA: 0s - loss: 127.4447 - distortion_loss: 0.0695 - rate_loss: 92.6941 - distortion_pass_through_loss: 0.0695 - rate_pass_through_loss: 92.6941
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 6s 6ms/step - loss: 127.2831 - distortion_loss: 0.0694 - rate_loss: 92.5993 - distortion_pass_through_loss: 0.0694 - rate_pass_through_loss: 92.5930 - val_loss: 107.4028 - val_distortion_loss: 0.0549 - val_rate_loss: 79.9750 - val_distortion_pass_through_loss: 0.0549 - val_rate_pass_through_loss: 79.9802
Epoch 2/15
469/469 [==============================] - 3s 6ms/step - loss: 97.1724 - distortion_loss: 0.0538 - rate_loss: 70.2783 - distortion_pass_through_loss: 0.0538 - rate_pass_through_loss: 70.2729 - val_loss: 86.0219 - val_distortion_loss: 0.0594 - val_rate_loss: 56.3162 - val_distortion_pass_through_loss: 0.0594 - val_rate_pass_through_loss: 56.3187
Epoch 3/15
469/469 [==============================] - 3s 6ms/step - loss: 81.1227 - distortion_loss: 0.0560 - rate_loss: 53.1034 - distortion_pass_through_loss: 0.0560 - rate_pass_through_loss: 53.0995 - val_loss: 71.9202 - val_distortion_loss: 0.0686 - val_rate_loss: 37.5965 - val_distortion_pass_through_loss: 0.0687 - val_rate_pass_through_loss: 37.5954
Epoch 4/15
469/469 [==============================] - 3s 6ms/step - loss: 71.5995 - distortion_loss: 0.0595 - rate_loss: 41.8626 - distortion_pass_through_loss: 0.0595 - rate_pass_through_loss: 41.8596 - val_loss: 64.1463 - val_distortion_loss: 0.0786 - val_rate_loss: 24.8592 - val_distortion_pass_through_loss: 0.0786 - val_rate_pass_through_loss: 24.8602
Epoch 5/15
469/469 [==============================] - 3s 6ms/step - loss: 66.1026 - distortion_loss: 0.0624 - rate_loss: 34.8940 - distortion_pass_through_loss: 0.0624 - rate_pass_through_loss: 34.8927 - val_loss: 58.4913 - val_distortion_loss: 0.0795 - val_rate_loss: 18.7568 - val_distortion_pass_through_loss: 0.0795 - val_rate_pass_through_loss: 18.7560
Epoch 6/15
469/469 [==============================] - 3s 6ms/step - loss: 62.6672 - distortion_loss: 0.0646 - rate_loss: 30.3623 - distortion_pass_through_loss: 0.0646 - rate_pass_through_loss: 30.3613 - val_loss: 54.9740 - val_distortion_loss: 0.0818 - val_rate_loss: 14.0641 - val_distortion_pass_through_loss: 0.0818 - val_rate_pass_through_loss: 14.0646
Epoch 7/15
469/469 [==============================] - 3s 6ms/step - loss: 60.1863 - distortion_loss: 0.0660 - rate_loss: 27.2017 - distortion_pass_through_loss: 0.0660 - rate_pass_through_loss: 27.2010 - val_loss: 52.4609 - val_distortion_loss: 0.0806 - val_rate_loss: 12.1524 - val_distortion_pass_through_loss: 0.0806 - val_rate_pass_through_loss: 12.1531
Epoch 8/15
469/469 [==============================] - 3s 6ms/step - loss: 58.0571 - distortion_loss: 0.0665 - rate_loss: 24.8082 - distortion_pass_through_loss: 0.0665 - rate_pass_through_loss: 24.8073 - val_loss: 49.9638 - val_distortion_loss: 0.0771 - val_rate_loss: 11.4078 - val_distortion_pass_through_loss: 0.0771 - val_rate_pass_through_loss: 11.4103
Epoch 9/15
469/469 [==============================] - 3s 6ms/step - loss: 56.1462 - distortion_loss: 0.0665 - rate_loss: 22.8890 - distortion_pass_through_loss: 0.0665 - rate_pass_through_loss: 22.8888 - val_loss: 48.1192 - val_distortion_loss: 0.0704 - val_rate_loss: 12.9410 - val_distortion_pass_through_loss: 0.0704 - val_rate_pass_through_loss: 12.9476
Epoch 10/15
469/469 [==============================] - 3s 6ms/step - loss: 54.1863 - distortion_loss: 0.0657 - rate_loss: 21.3211 - distortion_pass_through_loss: 0.0657 - rate_pass_through_loss: 21.3206 - val_loss: 47.0492 - val_distortion_loss: 0.0674 - val_rate_loss: 13.3331 - val_distortion_pass_through_loss: 0.0674 - val_rate_pass_through_loss: 13.3350
Epoch 11/15
469/469 [==============================] - 3s 6ms/step - loss: 52.4151 - distortion_loss: 0.0647 - rate_loss: 20.0704 - distortion_pass_through_loss: 0.0647 - rate_pass_through_loss: 20.0700 - val_loss: 46.5608 - val_distortion_loss: 0.0665 - val_rate_loss: 13.2897 - val_distortion_pass_through_loss: 0.0665 - val_rate_pass_through_loss: 13.2926
Epoch 12/15
469/469 [==============================] - 3s 6ms/step - loss: 50.9138 - distortion_loss: 0.0636 - rate_loss: 19.1121 - distortion_pass_through_loss: 0.0636 - rate_pass_through_loss: 19.1114 - val_loss: 45.9211 - val_distortion_loss: 0.0645 - val_rate_loss: 13.6699 - val_distortion_pass_through_loss: 0.0645 - val_rate_pass_through_loss: 13.6701
Epoch 13/15
469/469 [==============================] - 3s 6ms/step - loss: 49.7118 - distortion_loss: 0.0626 - rate_loss: 18.4105 - distortion_pass_through_loss: 0.0626 - rate_pass_through_loss: 18.4100 - val_loss: 45.6058 - val_distortion_loss: 0.0628 - val_rate_loss: 14.1970 - val_distortion_pass_through_loss: 0.0628 - val_rate_pass_through_loss: 14.1988
Epoch 14/15
469/469 [==============================] - 3s 6ms/step - loss: 48.7698 - distortion_loss: 0.0617 - rate_loss: 17.9120 - distortion_pass_through_loss: 0.0617 - rate_pass_through_loss: 17.9119 - val_loss: 45.1903 - val_distortion_loss: 0.0612 - val_rate_loss: 14.5991 - val_distortion_pass_through_loss: 0.0612 - val_rate_pass_through_loss: 14.6004
Epoch 15/15
469/469 [==============================] - 3s 6ms/step - loss: 48.0780 - distortion_loss: 0.0611 - rate_loss: 17.5208 - distortion_pass_through_loss: 0.0611 - rate_pass_through_loss: 17.5206 - val_loss: 45.0955 - val_distortion_loss: 0.0615 - val_rate_loss: 14.3562 - val_distortion_pass_through_loss: 0.0615 - val_rate_pass_through_loss: 14.3626



さらに λ を減らしてみましょう。

Epoch 1/15
466/469 [============================>.] - ETA: 0s - loss: 113.7090 - distortion_loss: 0.0753 - rate_loss: 91.1310 - distortion_pass_through_loss: 0.0753 - rate_pass_through_loss: 91.1310
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 6s 6ms/step - loss: 113.6090 - distortion_loss: 0.0752 - rate_loss: 91.0583 - distortion_pass_through_loss: 0.0752 - rate_pass_through_loss: 91.0516 - val_loss: 96.5233 - val_distortion_loss: 0.0679 - val_rate_loss: 76.1659 - val_distortion_pass_through_loss: 0.0679 - val_rate_pass_through_loss: 76.1617
Epoch 2/15
469/469 [==============================] - 3s 6ms/step - loss: 85.8572 - distortion_loss: 0.0613 - rate_loss: 67.4572 - distortion_pass_through_loss: 0.0613 - rate_pass_through_loss: 67.4516 - val_loss: 74.5242 - val_distortion_loss: 0.0793 - val_rate_loss: 50.7241 - val_distortion_pass_through_loss: 0.0793 - val_rate_pass_through_loss: 50.7321
Epoch 3/15
469/469 [==============================] - 3s 6ms/step - loss: 68.9031 - distortion_loss: 0.0650 - rate_loss: 49.4174 - distortion_pass_through_loss: 0.0650 - rate_pass_through_loss: 49.4135 - val_loss: 59.5138 - val_distortion_loss: 0.0954 - val_rate_loss: 30.8864 - val_distortion_pass_through_loss: 0.0954 - val_rate_pass_through_loss: 30.8949
Epoch 4/15
469/469 [==============================] - 3s 6ms/step - loss: 58.3479 - distortion_loss: 0.0696 - rate_loss: 37.4610 - distortion_pass_through_loss: 0.0696 - rate_pass_through_loss: 37.4585 - val_loss: 49.3487 - val_distortion_loss: 0.1029 - val_rate_loss: 18.4801 - val_distortion_pass_through_loss: 0.1028 - val_rate_pass_through_loss: 18.4910
Epoch 5/15
469/469 [==============================] - 3s 6ms/step - loss: 52.0953 - distortion_loss: 0.0740 - rate_loss: 29.8993 - distortion_pass_through_loss: 0.0740 - rate_pass_through_loss: 29.8978 - val_loss: 42.9612 - val_distortion_loss: 0.1054 - val_rate_loss: 11.3369 - val_distortion_pass_through_loss: 0.1054 - val_rate_pass_through_loss: 11.3445
Epoch 6/15
469/469 [==============================] - 3s 6ms/step - loss: 48.1743 - distortion_loss: 0.0775 - rate_loss: 24.9172 - distortion_pass_through_loss: 0.0775 - rate_pass_through_loss: 24.9160 - val_loss: 38.8429 - val_distortion_loss: 0.1035 - val_rate_loss: 7.7809 - val_distortion_pass_through_loss: 0.1035 - val_rate_pass_through_loss: 7.7837
Epoch 7/15
469/469 [==============================] - 3s 6ms/step - loss: 45.4033 - distortion_loss: 0.0800 - rate_loss: 21.4013 - distortion_pass_through_loss: 0.0800 - rate_pass_through_loss: 21.4004 - val_loss: 36.4476 - val_distortion_loss: 0.1025 - val_rate_loss: 5.7000 - val_distortion_pass_through_loss: 0.1025 - val_rate_pass_through_loss: 5.7030
Epoch 8/15
469/469 [==============================] - 3s 6ms/step - loss: 43.1902 - distortion_loss: 0.0815 - rate_loss: 18.7450 - distortion_pass_through_loss: 0.0815 - rate_pass_through_loss: 18.7442 - val_loss: 34.4560 - val_distortion_loss: 0.0938 - val_rate_loss: 6.3266 - val_distortion_pass_through_loss: 0.0938 - val_rate_pass_through_loss: 6.3243
Epoch 9/15
469/469 [==============================] - 3s 6ms/step - loss: 41.1994 - distortion_loss: 0.0816 - rate_loss: 16.7293 - distortion_pass_through_loss: 0.0816 - rate_pass_through_loss: 16.7284 - val_loss: 33.6424 - val_distortion_loss: 0.0906 - val_rate_loss: 6.4604 - val_distortion_pass_through_loss: 0.0906 - val_rate_pass_through_loss: 6.4591
Epoch 10/15
469/469 [==============================] - 3s 6ms/step - loss: 39.5689 - distortion_loss: 0.0811 - rate_loss: 15.2472 - distortion_pass_through_loss: 0.0811 - rate_pass_through_loss: 15.2467 - val_loss: 32.8275 - val_distortion_loss: 0.0851 - val_rate_loss: 7.3065 - val_distortion_pass_through_loss: 0.0851 - val_rate_pass_through_loss: 7.3087
Epoch 11/15
469/469 [==============================] - 3s 6ms/step - loss: 38.1226 - distortion_loss: 0.0800 - rate_loss: 14.1267 - distortion_pass_through_loss: 0.0800 - rate_pass_through_loss: 14.1260 - val_loss: 32.5285 - val_distortion_loss: 0.0841 - val_rate_loss: 7.2859 - val_distortion_pass_through_loss: 0.0841 - val_rate_pass_through_loss: 7.2903
Epoch 12/15
469/469 [==============================] - 3s 6ms/step - loss: 36.8895 - distortion_loss: 0.0786 - rate_loss: 13.3042 - distortion_pass_through_loss: 0.0786 - rate_pass_through_loss: 13.3038 - val_loss: 32.3595 - val_distortion_loss: 0.0830 - val_rate_loss: 7.4672 - val_distortion_pass_through_loss: 0.0830 - val_rate_pass_through_loss: 7.4686
Epoch 13/15
469/469 [==============================] - 3s 6ms/step - loss: 35.9890 - distortion_loss: 0.0778 - rate_loss: 12.6432 - distortion_pass_through_loss: 0.0778 - rate_pass_through_loss: 12.6427 - val_loss: 32.1735 - val_distortion_loss: 0.0807 - val_rate_loss: 7.9718 - val_distortion_pass_through_loss: 0.0807 - val_rate_pass_through_loss: 7.9728
Epoch 14/15
469/469 [==============================] - 3s 6ms/step - loss: 35.2725 - distortion_loss: 0.0771 - rate_loss: 12.1506 - distortion_pass_through_loss: 0.0771 - rate_pass_through_loss: 12.1504 - val_loss: 31.8718 - val_distortion_loss: 0.0777 - val_rate_loss: 8.5510 - val_distortion_pass_through_loss: 0.0778 - val_rate_pass_through_loss: 8.5569
Epoch 15/15
469/469 [==============================] - 3s 6ms/step - loss: 34.6280 - distortion_loss: 0.0764 - rate_loss: 11.6954 - distortion_pass_through_loss: 0.0764 - rate_pass_through_loss: 11.6957 - val_loss: 31.8556 - val_distortion_loss: 0.0772 - val_rate_loss: 8.7024 - val_distortion_pass_through_loss: 0.0772 - val_rate_pass_through_loss: 8.7201


数字当たり 1 バイトの順に、文字列がさらに短くなり始めました。ただし、これにはコストが伴い、さらに多くの数字が認識できなくなってしまいました。





compressor, decompressor = make_mnist_codec(trainer, decode_sanity_check=False)


import os

strings = tf.constant([os.urandom(8) for _ in range(16)])
samples = decompressor(strings)

fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(5, 5))
axes = axes.ravel()
for i in range(len(axes)):
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
