在 TensorFlow.org 上查看  | 
   在 Google Colab 中运行  | 
   在 GitHub 中查看源代码 | 
   下载笔记本 | 
欢迎阅读 Keras 量化感知训练的综合指南。
本页面记录了各种用例,并展示了如何将 API 用于每种用例。了解需要哪些 API 后,可在 API 文档中找到参数和底层详细信息:
涵盖了以下用例:
- 按下列步骤操作,部署 8 位量化模型。
- 定义一个量化感知模型。
 - 仅对于 Keras HDF5 模型,使用特殊的检查点和反序列化逻辑。否则,将使用标准训练。
 - 通过量化感知模型创建量化模型。
 
 - 试验量化。
- 实验的任何方面都没有支持的部署路径。
 - 自定义 Keras 层处于实验阶段。
 
 
设置
如果只是查找您需要的 API 并了解其用途,您可以运行但不阅读本部分。
! pip uninstall -y tensorflow
! pip install -q tf-nightly
! pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
import tempfile
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)
def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape=input_shape),
      tf.keras.layers.Flatten()
  ])
  return model
def setup_pretrained_weights():
  model= setup_model()
  model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )
  model.fit(x_train, y_train)
  _, pretrained_weights = tempfile.mkstemp('.tf')
  model.save_weights(pretrained_weights)
  return pretrained_weights
def setup_pretrained_model():
  model = setup_model()
  pretrained_weights = setup_pretrained_weights()
  model.load_weights(pretrained_weights)
  return model
setup_model()
pretrained_weights = setup_pretrained_weights()
定义量化感知模型
通过按以下方式定义模型,可以获得概述页面中所列后端的部署路径。默认情况下,使用 8 位量化。
注:量化感知模型实际上并未量化。创建量化模型是一个单独的步骤。
量化整个模型
您的用例:
- 不支持子类化模型。
 
提高模型准确率的提示:
- 尝试“量化某些层”以跳过量化对准确率影响最大的层。
 - 与从头开始训练相比,使用量化感知训练进行微调的效果一般更好。
 
要使整个模型可以感知量化,请将 tfmot.quantization.keras.quantize_model 应用于模型。
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
quant_aware_model.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer (QuantizeLaye (None, 20) 3 _________________________________________________________________ quant_dense_2 (QuantizeWrapp (None, 20) 425 _________________________________________________________________ quant_flatten_2 (QuantizeWra (None, 20) 1 ================================================================= Total params: 429 Trainable params: 420 Non-trainable params: 9 _________________________________________________________________
量化某些层
量化模型可能会对准确率造成负面影响。您可以选择性地量化模型的各个层来探索准确率、速度和模型大小之间的最佳平衡。
您的用例:
- 要部署到仅适用于完全量化模型(例如 EdgeTPU v1、大多数 DSP)的后端,请尝试“量化整个模型”。
 
提高模型准确率的提示:
- 与从头开始训练相比,使用量化感知训练进行微调的效果一般更好。
 - 尝试量化后面的层而不是前面的层。
 - 避免量化关键层(例如注意力机制)。
 
在下面的示例中,仅量化 Dense 层。
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
# Helper function uses `quantize_annotate_layer` to annotate that only the 
# Dense layers should be quantized.
def apply_quantization_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.quantization.keras.quantize_annotate_layer(layer)
  return layer
# Use `tf.keras.models.clone_model` to apply `apply_quantization_to_dense` 
# to the layers of the model.
annotated_model = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_quantization_to_dense,
)
# Now that the Dense layers are annotated,
# `quantize_apply` actually makes the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details. Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_1 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_3 (QuantizeWrapp (None, 20) 425 _________________________________________________________________ flatten_3 (Flatten) (None, 20) 0 ================================================================= Total params: 428 Trainable params: 420 Non-trainable params: 8 _________________________________________________________________
尽管此示例使用层的类型来决定要量化的内容,但是量化特定层的最简单方式是设置其 name 属性,然后在 clone_function 中查找该名称。
print(base_model.layers[0].name)
dense_3
更具可读性,但模型准确率可能较低
这与通过量化感知训练进行的微调不兼容,这就是它的准确率可能低于上述示例的原因。
函数式模型示例
# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
i = tf.keras.Input(shape=(20,))
x = tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)
annotated_model = tf.keras.Model(inputs=i, outputs=o)
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
# For deployment purposes, the tool adds `QuantizeLayer` after `InputLayer` so that the
# quantized model can take in float inputs instead of only uint8.
quant_aware_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 20)] 0 _________________________________________________________________ quantize_layer_2 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_4 (QuantizeWrapp (None, 10) 215 _________________________________________________________________ flatten_4 (Flatten) (None, 10) 0 ================================================================= Total params: 218 Trainable params: 210 Non-trainable params: 8 _________________________________________________________________
序贯模型示例
# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
annotated_model = tf.keras.Sequential([
  tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()
Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_3 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_5 (QuantizeWrapp (None, 20) 425 _________________________________________________________________ flatten_5 (Flatten) (None, 20) 0 ================================================================= Total params: 428 Trainable params: 420 Non-trainable params: 8 _________________________________________________________________
设置检查点和反序列化
您的用例:仅 HDF5 模型格式需要此代码(HDF5 权重或其他格式不需要)。
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
# Save or checkpoint the model.
_, keras_model_file = tempfile.mkstemp('.h5')
quant_aware_model.save(keras_model_file)
# `quantize_scope` is needed for deserializing HDF5 models.
with tfmot.quantization.keras.quantize_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)
loaded_model.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details. WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually. Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_4 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_6 (QuantizeWrapp (None, 20) 425 _________________________________________________________________ quant_flatten_6 (QuantizeWra (None, 20) 1 ================================================================= Total params: 429 Trainable params: 420 Non-trainable params: 9 _________________________________________________________________
创建并部署量化模型
通常,请参考将要使用的部署后端的文档。
下面是一个 TFLite 后端的示例。
base_model = setup_pretrained_model()
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
# Typically you train the model here.
converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()
1/1 [==============================] - 0s 280ms/step - loss: 0.5727 - accuracy: 0.0000e+00
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2342: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`Model.state_updates` will be removed in a future version. '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:1395: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`layer.updates` will be removed in a future version. '
WARNING:absl:Found untraced functions such as dense_7_layer_call_and_return_conditional_losses, dense_7_layer_call_fn, flatten_7_layer_call_and_return_conditional_losses, flatten_7_layer_call_fn, dense_7_layer_call_fn while saving (showing 5 of 10). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpnpakkrcw/assets
INFO:tensorflow:Assets written to: /tmp/tmpnpakkrcw/assets
试验量化
您的用例:使用以下 API 意味着没有支持的部署路径。这些功能也是实验性功能,不具备向后兼容性。
tfmot.quantization.keras.QuantizeConfigtfmot.quantization.keras.quantizers.Quantizertfmot.quantization.keras.quantizers.LastValueQuantizertfmot.quantization.keras.quantizers.MovingAverageQuantizer
设置:DefaultDenseQuantizeConfig
要进行实验,需要使用 tfmot.quantization.keras.QuantizeConfig,它描述了如何量化层的权重、激活和输出。
以下示例定义了 API 默认值中用于 Dense 层的相同 QuantizeConfig。
在此示例的正向传播过程中,以 layer.kernel 作为输入调用了 get_weights_and_quantizers 中返回的 LastValueQuantizer,从而产生了输出。通过 set_quantize_weights 中定义的逻辑,输出将替换 Dense 层的原始正向传播中的 layer.kernel。同样的构想也适用于激活和输出。
LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer
class DefaultDenseQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    # Configure how to quantize weights.
    def get_weights_and_quantizers(self, layer):
      return [(layer.kernel, LastValueQuantizer(num_bits=8, symmetric=True, narrow_range=False, per_axis=False))]
    # Configure how to quantize activations.
    def get_activations_and_quantizers(self, layer):
      return [(layer.activation, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False))]
    def set_quantize_weights(self, layer, quantize_weights):
      # Add this line for each item returned in `get_weights_and_quantizers`
      # , in the same order
      layer.kernel = quantize_weights[0]
    def set_quantize_activations(self, layer, quantize_activations):
      # Add this line for each item returned in `get_activations_and_quantizers`
      # , in the same order.
      layer.activation = quantize_activations[0]
    # Configure how to quantize outputs (may be equivalent to activations).
    def get_output_quantizers(self, layer):
      return []
    def get_config(self):
      return {}
量化自定义 Keras 层
本示例使用 DefaultDenseQuantizeConfig 来量化 CustomLayer。
在“试验量化”用例中,应用的配置是相同的。
- 将 
tfmot.quantization.keras.quantize_annotate_layer应用于CustomLayer并在QuantizeConfig中传递。 - 通过 
tfmot.quantization.keras.quantize_annotate_model继续使用 API 默认值来量化模型的其余部分。 
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class CustomLayer(tf.keras.layers.Dense):
  pass
model = quantize_annotate_model(tf.keras.Sequential([
   quantize_annotate_layer(CustomLayer(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`
# as well as the custom Keras layer.
with quantize_scope(
  {'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig,
   'CustomLayer': CustomLayer}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_8" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_6 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_custom_layer (Quantize (None, 20) 425 _________________________________________________________________ quant_flatten_9 (QuantizeWra (None, 20) 1 ================================================================= Total params: 429 Trainable params: 420 Non-trainable params: 9 _________________________________________________________________
修改量化参数
常见误区:将偏差量化为少于 32 位通常会严重影响模型准确率。
本示例将 Dense 层修改为将 4 位用于其权重,而不是默认的 8 位。模型的其余部分继续使用 API 默认值。
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    # Configure weights to quantize with 4-bit instead of 8-bits.
    def get_weights_and_quantizers(self, layer):
      return [(layer.kernel, LastValueQuantizer(num_bits=4, symmetric=True, narrow_range=False, per_axis=False))]
在“试验量化”用例中,应用的配置是相同的。
- 将 
tfmot.quantization.keras.quantize_annotate_layer应用于Dense层并在QuantizeConfig中传递。 - 通过 
tfmot.quantization.keras.quantize_annotate_model继续使用 API 默认值来量化模型的其余部分。 
model = quantize_annotate_model(tf.keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this Dense layer.
   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_9" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_7 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_9 (QuantizeWrapp (None, 20) 425 _________________________________________________________________ quant_flatten_10 (QuantizeWr (None, 20) 1 ================================================================= Total params: 429 Trainable params: 420 Non-trainable params: 9 _________________________________________________________________
修改要量化的层的部分
本示例将 Dense 层修改为跳过量化激活。模型的其余部分继续使用 API 默认值。
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    def get_activations_and_quantizers(self, layer):
      # Skip quantizing activations.
      return []
    def set_quantize_activations(self, layer, quantize_activations):
      # Empty since `get_activaations_and_quantizers` returns
      # an empty list.
      return
在“试验量化”用例中,应用的配置是相同的。
- 将 
tfmot.quantization.keras.quantize_annotate_layer应用于Dense层并在QuantizeConfig中传递。 - 通过 
tfmot.quantization.keras.quantize_annotate_model继续使用 API 默认值来量化模型的其余部分。 
model = quantize_annotate_model(tf.keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this Dense layer.
   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_10" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_8 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_10 (QuantizeWrap (None, 20) 423 _________________________________________________________________ quant_flatten_11 (QuantizeWr (None, 20) 1 ================================================================= Total params: 427 Trainable params: 420 Non-trainable params: 7 _________________________________________________________________
使用自定义量化算法
tfmot.quantization.keras.quantizers.Quantizer 类是一个可调用对象,可以将任何算法应用于其输入。
在本示例中,输入是权重,我们将 FixedRangeQuantizer call 函数中的数学运算应用于权重。现在,FixedRangeQuantizer 的输出将代替原始权重值传递给使用这些权重的任何对象。
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class FixedRangeQuantizer(tfmot.quantization.keras.quantizers.Quantizer):
  """Quantizer which forces outputs to be between -1 and 1."""
  def build(self, tensor_shape, name, layer):
    # Not needed. No new TensorFlow variables needed.
    return {}
  def __call__(self, inputs, training, weights, **kwargs):
    return tf.keras.backend.clip(inputs, -1.0, 1.0)
  def get_config(self):
    # Not needed. No __init__ parameters to serialize.
    return {}
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    # Configure weights to quantize with 4-bit instead of 8-bits.
    def get_weights_and_quantizers(self, layer):
      # Use custom algorithm defined in `FixedRangeQuantizer` instead of default Quantizer.
      return [(layer.kernel, FixedRangeQuantizer())]
在“试验量化”用例中,应用的配置是相同的。
- 将 
tfmot.quantization.keras.quantize_annotate_layer应用于Dense层并在QuantizeConfig中传递。 - 通过 
tfmot.quantization.keras.quantize_annotate_model继续使用 API 默认值来量化模型的其余部分。 
model = quantize_annotate_model(tf.keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this `Dense` layer.
   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_11" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_9 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_11 (QuantizeWrap (None, 20) 423 _________________________________________________________________ quant_flatten_12 (QuantizeWr (None, 20) 1 ================================================================= Total params: 427 Trainable params: 420 Non-trainable params: 7 _________________________________________________________________
在 TensorFlow.org 上查看 
在 Google Colab 中运行 
在 GitHub 中查看源代码
下载笔记本