如何使用 TF-Hub 构建简单的文本分类器

注:本教程使用已弃用的 TensorFlow 1 功能。有关完成此任务的新方式,请参阅 TensorFlow 2 版本

在 Google Colab 中运行 在 GitHub 上查看源代码 查看 TF Hub 模型

TF-Hub 是一个共享打包在可重用资源(尤其是预训练的模块)中的机器学习专业知识的平台。本教程分为两个主要部分。

入门:使用 TF-Hub 训练文本分类器

我们将使用 TF-Hub 文本嵌入向量模块训练具有合理基线准确率的简单情感分类器。然后,我们将分析预测结果以确保模型合理,并提出改进措施以提高准确率。

高级:迁移学习分析

在本部分中,我们将使用各种 TF-Hub 模块来比较它们对 Estimator 准确率的影响,并展示迁移学习的优势和缺陷。

可选前提条件

设置

# Install TF-Hub.
pip install seaborn

有关安装 Tensorflow 的更多详细信息,请访问 https://tensorflow.google.cn/install/

from absl import logging

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns

开始

数据

我们将尝试解决 Large Movie Review Dataset v1.0 任务(Mass 等人,2011 年)。数据集由 IMDB 电影评论组成,这些评论使用从 1 到 10 的正数标记。任务是将评论标记为负面正面

# Load all files from a directory in a DataFrame.
def load_directory_data(directory):
  data = {}
  data["sentence"] = []
  data["sentiment"] = []
  for file_path in os.listdir(directory):
    with tf.io.gfile.GFile(os.path.join(directory, file_path), "r") as f:
      data["sentence"].append(f.read())
      data["sentiment"].append(re.match("\d+_(\d+)\.txt", file_path).group(1))
  return pd.DataFrame.from_dict(data)

# Merge positive and negative examples, add a polarity column and shuffle.
def load_dataset(directory):
  pos_df = load_directory_data(os.path.join(directory, "pos"))
  neg_df = load_directory_data(os.path.join(directory, "neg"))
  pos_df["polarity"] = 1
  neg_df["polarity"] = 0
  return pd.concat([pos_df, neg_df]).sample(frac=1).reset_index(drop=True)

# Download and process the dataset files.
def download_and_load_datasets(force_download=False):
  dataset = tf.keras.utils.get_file(
      fname="aclImdb.tar.gz", 
      origin="http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz", 
      extract=True)

  train_df = load_dataset(os.path.join(os.path.dirname(dataset), 
                                       "aclImdb", "train"))
  test_df = load_dataset(os.path.join(os.path.dirname(dataset), 
                                      "aclImdb", "test"))

  return train_df, test_df

# Reduce logging output.
logging.set_verbosity(logging.ERROR)

train_df, test_df = download_and_load_datasets()
train_df.head()
Downloading data from http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
84125825/84125825 [==============================] - 3s 0us/step

模型

输入函数

Estimator 框架提供了封装 Pandas 数据帧的输入函数

# Training input on the whole training set with no limit on training epochs.
train_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(
    train_df, train_df["polarity"], num_epochs=None, shuffle=True)

# Prediction on the whole training set.
predict_train_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(
    train_df, train_df["polarity"], shuffle=False)
# Prediction on the test set.
predict_test_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(
    test_df, test_df["polarity"], shuffle=False)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_65948/2827882428.py:2: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

WARNING:tensorflow:From /tmpfs/tmp/ipykernel_65948/2827882428.py:2: The name tf.estimator.inputs.pandas_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.pandas_input_fn instead.

特征列

TF-Hub 提供了一个特征列,此列在给定的文本特征上应用模块,并进一步传递模块的输出。在本教程中,我们将使用 nnlm-en-dim128 模块。对于本教程而言,最重要的事实如下:

  • 模块将字符串的一维张量中的一批句子作为输入。
  • 模块负责句子的预处理(例如,移除标点符号和在空格处拆分)。
  • 模块可以使用任何输入(例如,nnlm-en-dim128 将词汇中不存在的单词散列到约 20000 个桶中)。
embedded_text_feature_column = hub.text_embedding_column(
    key="sentence", 
    module_spec="https://tfhub.dev/google/nnlm-en-dim128/1")

Estimator

要实现分类,我们可以使用 DNN 分类器(请注意本教程结尾处有关标签函数的不同建模的补充说明)。

estimator = tf.estimator.DNNClassifier(
    hidden_units=[500, 100],
    feature_columns=[embedded_text_feature_column],
    n_classes=2,
    optimizer=tf.keras.optimizers.Adagrad(lr=0.003))
INFO:tensorflow:Using default config.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_v2/adagrad.py:77: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  super(Adagrad, self).__init__(name, **kwargs)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpudlj2nw3
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpudlj2nw3
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpudlj2nw3', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpudlj2nw3', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

训练

以合理的步骤数训练 Estimator。

# Training for 5,000 steps means 640,000 training examples with the default
# batch size. This is roughly equivalent to 25 epochs since the training dataset
# contains 25,000 examples.
estimator.train(input_fn=train_input_fn, steps=5000);
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:60: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:60: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:37:56.838873: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_v2/adagrad.py:86: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_v2/adagrad.py:86: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
2022-06-03 17:38:00.347586: W tensorflow/core/common_runtime/forward_type_inference.cc:231] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT64
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}

    while inferring type of node 'dnn/zero_fraction/cond/output/_18'
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpudlj2nw3/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpudlj2nw3/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-0.data-00000-of-00001
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-0.data-00000-of-00001
INFO:tensorflow:499500
INFO:tensorflow:499500
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-0.index
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-0.index
INFO:tensorflow:499500
INFO:tensorflow:499500
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-0.meta
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-0.meta
INFO:tensorflow:499700
INFO:tensorflow:499700
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.70826715, step = 0
INFO:tensorflow:loss = 0.70826715, step = 0
INFO:tensorflow:global_step/sec: 78.4902
INFO:tensorflow:global_step/sec: 78.4902
INFO:tensorflow:loss = 0.6891204, step = 100 (1.276 sec)
INFO:tensorflow:loss = 0.6891204, step = 100 (1.276 sec)
INFO:tensorflow:global_step/sec: 87.7307
INFO:tensorflow:global_step/sec: 87.7307
INFO:tensorflow:loss = 0.6726319, step = 200 (1.140 sec)
INFO:tensorflow:loss = 0.6726319, step = 200 (1.140 sec)
INFO:tensorflow:global_step/sec: 91.2884
INFO:tensorflow:global_step/sec: 91.2884
INFO:tensorflow:loss = 0.679251, step = 300 (1.096 sec)
INFO:tensorflow:loss = 0.679251, step = 300 (1.096 sec)
INFO:tensorflow:global_step/sec: 87.7936
INFO:tensorflow:global_step/sec: 87.7936
INFO:tensorflow:loss = 0.65527, step = 400 (1.139 sec)
INFO:tensorflow:loss = 0.65527, step = 400 (1.139 sec)
INFO:tensorflow:global_step/sec: 88.9277
INFO:tensorflow:global_step/sec: 88.9277
INFO:tensorflow:loss = 0.6522072, step = 500 (1.125 sec)
INFO:tensorflow:loss = 0.6522072, step = 500 (1.125 sec)
INFO:tensorflow:global_step/sec: 83.4568
INFO:tensorflow:global_step/sec: 83.4568
INFO:tensorflow:loss = 0.65518993, step = 600 (1.197 sec)
INFO:tensorflow:loss = 0.65518993, step = 600 (1.197 sec)
INFO:tensorflow:global_step/sec: 91.2489
INFO:tensorflow:global_step/sec: 91.2489
INFO:tensorflow:loss = 0.65161276, step = 700 (1.096 sec)
INFO:tensorflow:loss = 0.65161276, step = 700 (1.096 sec)
INFO:tensorflow:global_step/sec: 92.5192
INFO:tensorflow:global_step/sec: 92.5192
INFO:tensorflow:loss = 0.619715, step = 800 (1.081 sec)
INFO:tensorflow:loss = 0.619715, step = 800 (1.081 sec)
INFO:tensorflow:global_step/sec: 90.1138
INFO:tensorflow:global_step/sec: 90.1138
INFO:tensorflow:loss = 0.59995437, step = 900 (1.178 sec)
INFO:tensorflow:loss = 0.59995437, step = 900 (1.178 sec)
INFO:tensorflow:global_step/sec: 84.0547
INFO:tensorflow:global_step/sec: 84.0547
INFO:tensorflow:loss = 0.6152175, step = 1000 (1.121 sec)
INFO:tensorflow:loss = 0.6152175, step = 1000 (1.121 sec)
INFO:tensorflow:global_step/sec: 89.2772
INFO:tensorflow:global_step/sec: 89.2772
INFO:tensorflow:loss = 0.5878293, step = 1100 (1.120 sec)
INFO:tensorflow:loss = 0.5878293, step = 1100 (1.120 sec)
INFO:tensorflow:global_step/sec: 85.3779
INFO:tensorflow:global_step/sec: 85.3779
INFO:tensorflow:loss = 0.5657145, step = 1200 (1.171 sec)
INFO:tensorflow:loss = 0.5657145, step = 1200 (1.171 sec)
INFO:tensorflow:global_step/sec: 90.3413
INFO:tensorflow:global_step/sec: 90.3413
INFO:tensorflow:loss = 0.56806666, step = 1300 (1.107 sec)
INFO:tensorflow:loss = 0.56806666, step = 1300 (1.107 sec)
INFO:tensorflow:global_step/sec: 86.0456
INFO:tensorflow:global_step/sec: 86.0456
INFO:tensorflow:loss = 0.5690006, step = 1400 (1.162 sec)
INFO:tensorflow:loss = 0.5690006, step = 1400 (1.162 sec)
INFO:tensorflow:global_step/sec: 91.9504
INFO:tensorflow:global_step/sec: 91.9504
INFO:tensorflow:loss = 0.56584495, step = 1500 (1.087 sec)
INFO:tensorflow:loss = 0.56584495, step = 1500 (1.087 sec)
INFO:tensorflow:global_step/sec: 90.0863
INFO:tensorflow:global_step/sec: 90.0863
INFO:tensorflow:loss = 0.5279515, step = 1600 (1.110 sec)
INFO:tensorflow:loss = 0.5279515, step = 1600 (1.110 sec)
INFO:tensorflow:global_step/sec: 86.0793
INFO:tensorflow:global_step/sec: 86.0793
INFO:tensorflow:loss = 0.5813559, step = 1700 (1.161 sec)
INFO:tensorflow:loss = 0.5813559, step = 1700 (1.161 sec)
INFO:tensorflow:global_step/sec: 89.9058
INFO:tensorflow:global_step/sec: 89.9058
INFO:tensorflow:loss = 0.47850803, step = 1800 (1.112 sec)
INFO:tensorflow:loss = 0.47850803, step = 1800 (1.112 sec)
INFO:tensorflow:global_step/sec: 91.0798
INFO:tensorflow:global_step/sec: 91.0798
INFO:tensorflow:loss = 0.51293707, step = 1900 (1.098 sec)
INFO:tensorflow:loss = 0.51293707, step = 1900 (1.098 sec)
INFO:tensorflow:global_step/sec: 91.8762
INFO:tensorflow:global_step/sec: 91.8762
INFO:tensorflow:loss = 0.49730766, step = 2000 (1.088 sec)
INFO:tensorflow:loss = 0.49730766, step = 2000 (1.088 sec)
INFO:tensorflow:global_step/sec: 82.7615
INFO:tensorflow:global_step/sec: 82.7615
INFO:tensorflow:loss = 0.47837245, step = 2100 (1.208 sec)
INFO:tensorflow:loss = 0.47837245, step = 2100 (1.208 sec)
INFO:tensorflow:global_step/sec: 89.6234
INFO:tensorflow:global_step/sec: 89.6234
INFO:tensorflow:loss = 0.42676574, step = 2200 (1.116 sec)
INFO:tensorflow:loss = 0.42676574, step = 2200 (1.116 sec)
INFO:tensorflow:global_step/sec: 83.3326
INFO:tensorflow:global_step/sec: 83.3326
INFO:tensorflow:loss = 0.4834714, step = 2300 (1.200 sec)
INFO:tensorflow:loss = 0.4834714, step = 2300 (1.200 sec)
INFO:tensorflow:global_step/sec: 85.027
INFO:tensorflow:global_step/sec: 85.027
INFO:tensorflow:loss = 0.4663557, step = 2400 (1.177 sec)
INFO:tensorflow:loss = 0.4663557, step = 2400 (1.177 sec)
INFO:tensorflow:global_step/sec: 91.9186
INFO:tensorflow:global_step/sec: 91.9186
INFO:tensorflow:loss = 0.4588431, step = 2500 (1.087 sec)
INFO:tensorflow:loss = 0.4588431, step = 2500 (1.087 sec)
INFO:tensorflow:global_step/sec: 91.5467
INFO:tensorflow:global_step/sec: 91.5467
INFO:tensorflow:loss = 0.49664712, step = 2600 (1.092 sec)
INFO:tensorflow:loss = 0.49664712, step = 2600 (1.092 sec)
INFO:tensorflow:global_step/sec: 88.3678
INFO:tensorflow:global_step/sec: 88.3678
INFO:tensorflow:loss = 0.4873639, step = 2700 (1.132 sec)
INFO:tensorflow:loss = 0.4873639, step = 2700 (1.132 sec)
INFO:tensorflow:global_step/sec: 84.3798
INFO:tensorflow:global_step/sec: 84.3798
INFO:tensorflow:loss = 0.4256683, step = 2800 (1.185 sec)
INFO:tensorflow:loss = 0.4256683, step = 2800 (1.185 sec)
INFO:tensorflow:global_step/sec: 91.4618
INFO:tensorflow:global_step/sec: 91.4618
INFO:tensorflow:loss = 0.35195667, step = 2900 (1.093 sec)
INFO:tensorflow:loss = 0.35195667, step = 2900 (1.093 sec)
INFO:tensorflow:global_step/sec: 89.9722
INFO:tensorflow:global_step/sec: 89.9722
INFO:tensorflow:loss = 0.47497162, step = 3000 (1.112 sec)
INFO:tensorflow:loss = 0.47497162, step = 3000 (1.112 sec)
INFO:tensorflow:global_step/sec: 83.4468
INFO:tensorflow:global_step/sec: 83.4468
INFO:tensorflow:loss = 0.48370215, step = 3100 (1.198 sec)
INFO:tensorflow:loss = 0.48370215, step = 3100 (1.198 sec)
INFO:tensorflow:global_step/sec: 89.5798
INFO:tensorflow:global_step/sec: 89.5798
INFO:tensorflow:loss = 0.44609123, step = 3200 (1.116 sec)
INFO:tensorflow:loss = 0.44609123, step = 3200 (1.116 sec)
INFO:tensorflow:global_step/sec: 88.3876
INFO:tensorflow:global_step/sec: 88.3876
INFO:tensorflow:loss = 0.4710789, step = 3300 (1.131 sec)
INFO:tensorflow:loss = 0.4710789, step = 3300 (1.131 sec)
INFO:tensorflow:global_step/sec: 82.133
INFO:tensorflow:global_step/sec: 82.133
INFO:tensorflow:loss = 0.46137106, step = 3400 (1.217 sec)
INFO:tensorflow:loss = 0.46137106, step = 3400 (1.217 sec)
INFO:tensorflow:global_step/sec: 86.479
INFO:tensorflow:global_step/sec: 86.479
INFO:tensorflow:loss = 0.38244545, step = 3500 (1.157 sec)
INFO:tensorflow:loss = 0.38244545, step = 3500 (1.157 sec)
INFO:tensorflow:global_step/sec: 90.1074
INFO:tensorflow:global_step/sec: 90.1074
INFO:tensorflow:loss = 0.46425316, step = 3600 (1.110 sec)
INFO:tensorflow:loss = 0.46425316, step = 3600 (1.110 sec)
INFO:tensorflow:global_step/sec: 90.2319
INFO:tensorflow:global_step/sec: 90.2319
INFO:tensorflow:loss = 0.4668852, step = 3700 (1.108 sec)
INFO:tensorflow:loss = 0.4668852, step = 3700 (1.108 sec)
INFO:tensorflow:global_step/sec: 87.2783
INFO:tensorflow:global_step/sec: 87.2783
INFO:tensorflow:loss = 0.5064398, step = 3800 (1.146 sec)
INFO:tensorflow:loss = 0.5064398, step = 3800 (1.146 sec)
INFO:tensorflow:global_step/sec: 86.2334
INFO:tensorflow:global_step/sec: 86.2334
INFO:tensorflow:loss = 0.4716252, step = 3900 (1.160 sec)
INFO:tensorflow:loss = 0.4716252, step = 3900 (1.160 sec)
INFO:tensorflow:global_step/sec: 91.1204
INFO:tensorflow:global_step/sec: 91.1204
INFO:tensorflow:loss = 0.52002573, step = 4000 (1.097 sec)
INFO:tensorflow:loss = 0.52002573, step = 4000 (1.097 sec)
INFO:tensorflow:global_step/sec: 90.9833
INFO:tensorflow:global_step/sec: 90.9833
INFO:tensorflow:loss = 0.44316882, step = 4100 (1.099 sec)
INFO:tensorflow:loss = 0.44316882, step = 4100 (1.099 sec)
INFO:tensorflow:global_step/sec: 92.7287
INFO:tensorflow:global_step/sec: 92.7287
INFO:tensorflow:loss = 0.4133105, step = 4200 (1.079 sec)
INFO:tensorflow:loss = 0.4133105, step = 4200 (1.079 sec)
INFO:tensorflow:global_step/sec: 86.3855
INFO:tensorflow:global_step/sec: 86.3855
INFO:tensorflow:loss = 0.43366018, step = 4300 (1.157 sec)
INFO:tensorflow:loss = 0.43366018, step = 4300 (1.157 sec)
INFO:tensorflow:global_step/sec: 91.6433
INFO:tensorflow:global_step/sec: 91.6433
INFO:tensorflow:loss = 0.49312478, step = 4400 (1.091 sec)
INFO:tensorflow:loss = 0.49312478, step = 4400 (1.091 sec)
INFO:tensorflow:global_step/sec: 80.5434
INFO:tensorflow:global_step/sec: 80.5434
INFO:tensorflow:loss = 0.44904065, step = 4500 (1.242 sec)
INFO:tensorflow:loss = 0.44904065, step = 4500 (1.242 sec)
INFO:tensorflow:global_step/sec: 83.7805
INFO:tensorflow:global_step/sec: 83.7805
INFO:tensorflow:loss = 0.47336102, step = 4600 (1.193 sec)
INFO:tensorflow:loss = 0.47336102, step = 4600 (1.193 sec)
INFO:tensorflow:global_step/sec: 89.698
INFO:tensorflow:global_step/sec: 89.698
INFO:tensorflow:loss = 0.4439745, step = 4700 (1.115 sec)
INFO:tensorflow:loss = 0.4439745, step = 4700 (1.115 sec)
INFO:tensorflow:global_step/sec: 90.8685
INFO:tensorflow:global_step/sec: 90.8685
INFO:tensorflow:loss = 0.4688311, step = 4800 (1.100 sec)
INFO:tensorflow:loss = 0.4688311, step = 4800 (1.100 sec)
INFO:tensorflow:global_step/sec: 82.8678
INFO:tensorflow:global_step/sec: 82.8678
INFO:tensorflow:loss = 0.50729465, step = 4900 (1.207 sec)
INFO:tensorflow:loss = 0.50729465, step = 4900 (1.207 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000...
INFO:tensorflow:Saving checkpoints for 5000 into /tmpfs/tmp/tmpudlj2nw3/model.ckpt.
INFO:tensorflow:Saving checkpoints for 5000 into /tmpfs/tmp/tmpudlj2nw3/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000.data-00000-of-00001
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000.data-00000-of-00001
INFO:tensorflow:499500
INFO:tensorflow:499500
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000.meta
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000.meta
INFO:tensorflow:499700
INFO:tensorflow:499700
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000.index
INFO:tensorflow:/tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000.index
INFO:tensorflow:499700
INFO:tensorflow:499700
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000...
INFO:tensorflow:Loss for final step: 0.3802716.
INFO:tensorflow:Loss for final step: 0.3802716.

预测

为训练集和测试集运行预测。

train_eval_result = estimator.evaluate(input_fn=predict_train_input_fn)
test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn)

print("Training set accuracy: {accuracy}".format(**train_eval_result))
print("Test set accuracy: {accuracy}".format(**test_eval_result))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
2022-06-03 17:39:00.753823: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:01
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:01
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 3.71303s
INFO:tensorflow:Inference Time : 3.71303s
INFO:tensorflow:Finished evaluation at 2022-06-03-17:39:05
INFO:tensorflow:Finished evaluation at 2022-06-03-17:39:05
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.79312, accuracy_baseline = 0.5, auc = 0.87429845, auc_precision_recall = 0.8750604, average_loss = 0.44668466, global_step = 5000, label/mean = 0.5, loss = 0.4473904, precision = 0.8062009, prediction/mean = 0.48745117, recall = 0.77176
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.79312, accuracy_baseline = 0.5, auc = 0.87429845, auc_precision_recall = 0.8750604, average_loss = 0.44668466, global_step = 5000, label/mean = 0.5, loss = 0.4473904, precision = 0.8062009, prediction/mean = 0.48745117, recall = 0.77176
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:39:05.667917: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:06
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:06
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 3.68073s
INFO:tensorflow:Inference Time : 3.68073s
INFO:tensorflow:Finished evaluation at 2022-06-03-17:39:10
INFO:tensorflow:Finished evaluation at 2022-06-03-17:39:10
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.78612, accuracy_baseline = 0.5, auc = 0.86889553, auc_precision_recall = 0.8715557, average_loss = 0.45473337, global_step = 5000, label/mean = 0.5, loss = 0.45439148, precision = 0.80508405, prediction/mean = 0.4827163, recall = 0.75504
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.78612, accuracy_baseline = 0.5, auc = 0.86889553, auc_precision_recall = 0.8715557, average_loss = 0.45473337, global_step = 5000, label/mean = 0.5, loss = 0.45439148, precision = 0.80508405, prediction/mean = 0.4827163, recall = 0.75504
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
Training set accuracy: 0.7931200265884399
Test set accuracy: 0.7861199975013733

混淆矩阵

我们可以目视检查混淆矩阵,以了解错误分类的分布。

def get_predictions(estimator, input_fn):
  return [x["class_ids"][0] for x in estimator.predict(input_fn=input_fn)]

LABELS = [
    "negative", "positive"
]

# Create a confusion matrix on training data.
cm = tf.math.confusion_matrix(train_df["polarity"], 
                              get_predictions(estimator, predict_train_input_fn))

# Normalize the confusion matrix so that each row sums to 1.
cm = tf.cast(cm, dtype=tf.float32)
cm = cm / tf.math.reduce_sum(cm, axis=1)[:, np.newaxis]

sns.heatmap(cm, annot=True, xticklabels=LABELS, yticklabels=LABELS);
plt.xlabel("Predicted");
plt.ylabel("True");
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:39:10.579107: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.

png

进一步改进

  1. 情感回归:我们使用分类器将每个样本分配给一个极性类。但实际上,我们还有另一个分类特征 - 情感。在这里,类实际上表示一个比例,并且基础值(正/负)可以很好地映射到连续范围内。我们可以通过计算回归(DNN 回归器)而不是分类(DNN 分类器)来利用此属性。
  2. 较大的模块:对于本教程而言,我们使用了较小的模块来限制内存使用。有些模块具有更大的词汇和更大的嵌入向量空间,可以提供更多的准确率点。
  3. 参数调节:我们可以通过调节元参数(例如学习率或步骤数)来提高准确率,尤其是在使用不同模块的情况下。如果我们想获得任何合理的结果,那么验证集非常重要,因为这样可以轻松建立一个模型来学习预测训练数据,而无需很好地泛化到测试集。
  4. 更复杂的模型:我们使用了一个通过嵌入每个单词并随后将其与平均值相结合来计算句子嵌入向量的模块。此外,也可以使用序贯模块(例如 Universal Sentence Encoder 模块)来更好地捕获句子的性质。或者,使用两个或多个 TF-Hub 模块的集合。
  5. 正则化:为了避免过拟合,我们可以尝试使用执行某种正则化的优化器,例如近端 Adagrad 优化器

高级:迁移学习分析

迁移学习可以节省训练资源,即使基于小数据集训练也可以实现良好的模型泛化。在这一部分中,我们将通过使用两个不同的 TF-Hub 模块进行训练来演示这一点:

  • nnlm-en-dim128 - 预训练的文本嵌入向量模块;
  • random-nnlm-en-dim128 - 文本嵌入向量模块,其词汇和网络与 nnlm-en-dim128 相同,但权重只是随机初始化的,从未基于真实数据进行训练。

在以下两种模式下训练:

  • 仅训练分类器(即冻结模块),以及
  • 将分类器与模块一起训练

我们运行一些训练和评估来查看使用各种模块如何影响准确率。

def train_and_evaluate_with_module(hub_module, train_module=False):
  embedded_text_feature_column = hub.text_embedding_column(
      key="sentence", module_spec=hub_module, trainable=train_module)

  estimator = tf.estimator.DNNClassifier(
      hidden_units=[500, 100],
      feature_columns=[embedded_text_feature_column],
      n_classes=2,
      optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.003))

  estimator.train(input_fn=train_input_fn, steps=1000)

  train_eval_result = estimator.evaluate(input_fn=predict_train_input_fn)
  test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn)

  training_set_accuracy = train_eval_result["accuracy"]
  test_set_accuracy = test_eval_result["accuracy"]

  return {
      "Training accuracy": training_set_accuracy,
      "Test accuracy": test_set_accuracy
  }


results = {}
results["nnlm-en-dim128"] = train_and_evaluate_with_module(
    "https://tfhub.dev/google/nnlm-en-dim128/1")
results["nnlm-en-dim128-with-module-training"] = train_and_evaluate_with_module(
    "https://tfhub.dev/google/nnlm-en-dim128/1", True)
results["random-nnlm-en-dim128"] = train_and_evaluate_with_module(
    "https://tfhub.dev/google/random-nnlm-en-dim128/1")
results["random-nnlm-en-dim128-with-module-training"] = train_and_evaluate_with_module(
    "https://tfhub.dev/google/random-nnlm-en-dim128/1", True)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp5aep1srb
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp5aep1srb
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp5aep1srb', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp5aep1srb', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:39:13.748401: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp5aep1srb/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp5aep1srb/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-0.data-00000-of-00001
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-0.data-00000-of-00001
INFO:tensorflow:499500
INFO:tensorflow:499500
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-0.index
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-0.index
INFO:tensorflow:499500
INFO:tensorflow:499500
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-0.meta
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-0.meta
INFO:tensorflow:499700
INFO:tensorflow:499700
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.68280816, step = 0
INFO:tensorflow:loss = 0.68280816, step = 0
INFO:tensorflow:global_step/sec: 79.7628
INFO:tensorflow:global_step/sec: 79.7628
INFO:tensorflow:loss = 0.6813121, step = 100 (1.255 sec)
INFO:tensorflow:loss = 0.6813121, step = 100 (1.255 sec)
INFO:tensorflow:global_step/sec: 90.2023
INFO:tensorflow:global_step/sec: 90.2023
INFO:tensorflow:loss = 0.6718718, step = 200 (1.109 sec)
INFO:tensorflow:loss = 0.6718718, step = 200 (1.109 sec)
INFO:tensorflow:global_step/sec: 86.2219
INFO:tensorflow:global_step/sec: 86.2219
INFO:tensorflow:loss = 0.6620325, step = 300 (1.160 sec)
INFO:tensorflow:loss = 0.6620325, step = 300 (1.160 sec)
INFO:tensorflow:global_step/sec: 92.1523
INFO:tensorflow:global_step/sec: 92.1523
INFO:tensorflow:loss = 0.6521486, step = 400 (1.085 sec)
INFO:tensorflow:loss = 0.6521486, step = 400 (1.085 sec)
INFO:tensorflow:global_step/sec: 89.5852
INFO:tensorflow:global_step/sec: 89.5852
INFO:tensorflow:loss = 0.64504206, step = 500 (1.116 sec)
INFO:tensorflow:loss = 0.64504206, step = 500 (1.116 sec)
INFO:tensorflow:global_step/sec: 90.6965
INFO:tensorflow:global_step/sec: 90.6965
INFO:tensorflow:loss = 0.62079316, step = 600 (1.103 sec)
INFO:tensorflow:loss = 0.62079316, step = 600 (1.103 sec)
INFO:tensorflow:global_step/sec: 85.8023
INFO:tensorflow:global_step/sec: 85.8023
INFO:tensorflow:loss = 0.61648655, step = 700 (1.165 sec)
INFO:tensorflow:loss = 0.61648655, step = 700 (1.165 sec)
INFO:tensorflow:global_step/sec: 90.1708
INFO:tensorflow:global_step/sec: 90.1708
INFO:tensorflow:loss = 0.62497187, step = 800 (1.109 sec)
INFO:tensorflow:loss = 0.62497187, step = 800 (1.109 sec)
INFO:tensorflow:global_step/sec: 91.0562
INFO:tensorflow:global_step/sec: 91.0562
INFO:tensorflow:loss = 0.61609304, step = 900 (1.098 sec)
INFO:tensorflow:loss = 0.61609304, step = 900 (1.098 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1000...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1000...
INFO:tensorflow:Saving checkpoints for 1000 into /tmpfs/tmp/tmp5aep1srb/model.ckpt.
INFO:tensorflow:Saving checkpoints for 1000 into /tmpfs/tmp/tmp5aep1srb/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-1000.meta
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-1000.meta
INFO:tensorflow:200
INFO:tensorflow:200
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-1000.index
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-1000.index
INFO:tensorflow:200
INFO:tensorflow:200
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-1000.data-00000-of-00001
INFO:tensorflow:/tmpfs/tmp/tmp5aep1srb/model.ckpt-1000.data-00000-of-00001
INFO:tensorflow:499700
INFO:tensorflow:499700
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1000...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1000...
INFO:tensorflow:Loss for final step: 0.6102089.
INFO:tensorflow:Loss for final step: 0.6102089.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:39:28.366443: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:29
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:29
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp5aep1srb/model.ckpt-1000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp5aep1srb/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 3.70011s
INFO:tensorflow:Inference Time : 3.70011s
INFO:tensorflow:Finished evaluation at 2022-06-03-17:39:33
INFO:tensorflow:Finished evaluation at 2022-06-03-17:39:33
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.72284, accuracy_baseline = 0.5, auc = 0.79714954, auc_precision_recall = 0.7990524, average_loss = 0.59645987, global_step = 1000, label/mean = 0.5, loss = 0.59674233, precision = 0.7322218, prediction/mean = 0.49653652, recall = 0.70264
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.72284, accuracy_baseline = 0.5, auc = 0.79714954, auc_precision_recall = 0.7990524, average_loss = 0.59645987, global_step = 1000, label/mean = 0.5, loss = 0.59674233, precision = 0.7322218, prediction/mean = 0.49653652, recall = 0.70264
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp5aep1srb/model.ckpt-1000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp5aep1srb/model.ckpt-1000
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:39:33.218944: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:34
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:34
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp5aep1srb/model.ckpt-1000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp5aep1srb/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 3.72721s
INFO:tensorflow:Inference Time : 3.72721s
INFO:tensorflow:Finished evaluation at 2022-06-03-17:39:37
INFO:tensorflow:Finished evaluation at 2022-06-03-17:39:37
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.72032, accuracy_baseline = 0.5, auc = 0.7895848, auc_precision_recall = 0.79116887, average_loss = 0.6001009, global_step = 1000, label/mean = 0.5, loss = 0.5998418, precision = 0.7328767, prediction/mean = 0.49439234, recall = 0.69336
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.72032, accuracy_baseline = 0.5, auc = 0.7895848, auc_precision_recall = 0.79116887, average_loss = 0.6001009, global_step = 1000, label/mean = 0.5, loss = 0.5998418, precision = 0.7328767, prediction/mean = 0.49439234, recall = 0.69336
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp5aep1srb/model.ckpt-1000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp5aep1srb/model.ckpt-1000
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpl8_1sju2
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpl8_1sju2
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpl8_1sju2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpl8_1sju2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:39:37.940793: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpl8_1sju2/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpl8_1sju2/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-0.data-00000-of-00001
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-0.data-00000-of-00001
INFO:tensorflow:499500
INFO:tensorflow:499500
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-0.index
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-0.index
INFO:tensorflow:499500
INFO:tensorflow:499500
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-0.meta
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-0.meta
INFO:tensorflow:499700
INFO:tensorflow:499700
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.70170236, step = 0
INFO:tensorflow:loss = 0.70170236, step = 0
INFO:tensorflow:global_step/sec: 82.0201
INFO:tensorflow:global_step/sec: 82.0201
INFO:tensorflow:loss = 0.6839398, step = 100 (1.221 sec)
INFO:tensorflow:loss = 0.6839398, step = 100 (1.221 sec)
INFO:tensorflow:global_step/sec: 89.3805
INFO:tensorflow:global_step/sec: 89.3805
INFO:tensorflow:loss = 0.6783775, step = 200 (1.119 sec)
INFO:tensorflow:loss = 0.6783775, step = 200 (1.119 sec)
INFO:tensorflow:global_step/sec: 88.7376
INFO:tensorflow:global_step/sec: 88.7376
INFO:tensorflow:loss = 0.65822446, step = 300 (1.127 sec)
INFO:tensorflow:loss = 0.65822446, step = 300 (1.127 sec)
INFO:tensorflow:global_step/sec: 84.5886
INFO:tensorflow:global_step/sec: 84.5886
INFO:tensorflow:loss = 0.6663666, step = 400 (1.182 sec)
INFO:tensorflow:loss = 0.6663666, step = 400 (1.182 sec)
INFO:tensorflow:global_step/sec: 92.3105
INFO:tensorflow:global_step/sec: 92.3105
INFO:tensorflow:loss = 0.6259121, step = 500 (1.083 sec)
INFO:tensorflow:loss = 0.6259121, step = 500 (1.083 sec)
INFO:tensorflow:global_step/sec: 89.1299
INFO:tensorflow:global_step/sec: 89.1299
INFO:tensorflow:loss = 0.62898624, step = 600 (1.122 sec)
INFO:tensorflow:loss = 0.62898624, step = 600 (1.122 sec)
INFO:tensorflow:global_step/sec: 83.588
INFO:tensorflow:global_step/sec: 83.588
INFO:tensorflow:loss = 0.6309737, step = 700 (1.196 sec)
INFO:tensorflow:loss = 0.6309737, step = 700 (1.196 sec)
INFO:tensorflow:global_step/sec: 89.2609
INFO:tensorflow:global_step/sec: 89.2609
INFO:tensorflow:loss = 0.6064919, step = 800 (1.120 sec)
INFO:tensorflow:loss = 0.6064919, step = 800 (1.120 sec)
INFO:tensorflow:global_step/sec: 92.123
INFO:tensorflow:global_step/sec: 92.123
INFO:tensorflow:loss = 0.62309647, step = 900 (1.086 sec)
INFO:tensorflow:loss = 0.62309647, step = 900 (1.086 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1000...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1000...
INFO:tensorflow:Saving checkpoints for 1000 into /tmpfs/tmp/tmpl8_1sju2/model.ckpt.
INFO:tensorflow:Saving checkpoints for 1000 into /tmpfs/tmp/tmpl8_1sju2/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000.meta
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000.meta
INFO:tensorflow:200
INFO:tensorflow:200
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000.index
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000.index
INFO:tensorflow:200
INFO:tensorflow:200
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000.data-00000-of-00001
INFO:tensorflow:/tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000.data-00000-of-00001
INFO:tensorflow:499700
INFO:tensorflow:499700
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1000...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1000...
INFO:tensorflow:Loss for final step: 0.5894411.
INFO:tensorflow:Loss for final step: 0.5894411.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:39:52.804157: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:54
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:54
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 3.96815s
INFO:tensorflow:Inference Time : 3.96815s
INFO:tensorflow:Finished evaluation at 2022-06-03-17:39:58
INFO:tensorflow:Finished evaluation at 2022-06-03-17:39:58
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.72776, accuracy_baseline = 0.5, auc = 0.8026487, auc_precision_recall = 0.80292857, average_loss = 0.59141713, global_step = 1000, label/mean = 0.5, loss = 0.59178364, precision = 0.7367765, prediction/mean = 0.49760938, recall = 0.70872
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.72776, accuracy_baseline = 0.5, auc = 0.8026487, auc_precision_recall = 0.80292857, average_loss = 0.59141713, global_step = 1000, label/mean = 0.5, loss = 0.59178364, precision = 0.7367765, prediction/mean = 0.49760938, recall = 0.70872
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:39:58.253469: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:59
INFO:tensorflow:Starting evaluation at 2022-06-03T17:39:59
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 3.73931s
INFO:tensorflow:Inference Time : 3.73931s
INFO:tensorflow:Finished evaluation at 2022-06-03-17:40:02
INFO:tensorflow:Finished evaluation at 2022-06-03-17:40:02
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.72088, accuracy_baseline = 0.5, auc = 0.79300195, auc_precision_recall = 0.792099, average_loss = 0.5959825, global_step = 1000, label/mean = 0.5, loss = 0.59570915, precision = 0.73362666, prediction/mean = 0.4958548, recall = 0.6936
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.72088, accuracy_baseline = 0.5, auc = 0.79300195, auc_precision_recall = 0.792099, average_loss = 0.5959825, global_step = 1000, label/mean = 0.5, loss = 0.59570915, precision = 0.73362666, prediction/mean = 0.4958548, recall = 0.6936
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmpl8_1sju2/model.ckpt-1000
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp0xdt5vok
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp0xdt5vok
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp0xdt5vok', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp0xdt5vok', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:40:07.338940: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp0xdt5vok/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp0xdt5vok/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-0.data-00000-of-00001
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-0.data-00000-of-00001
INFO:tensorflow:499500
INFO:tensorflow:499500
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-0.index
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-0.index
INFO:tensorflow:499500
INFO:tensorflow:499500
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-0.meta
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-0.meta
INFO:tensorflow:499700
INFO:tensorflow:499700
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.7442793, step = 0
INFO:tensorflow:loss = 0.7442793, step = 0
INFO:tensorflow:global_step/sec: 80.1827
INFO:tensorflow:global_step/sec: 80.1827
INFO:tensorflow:loss = 0.6537382, step = 100 (1.249 sec)
INFO:tensorflow:loss = 0.6537382, step = 100 (1.249 sec)
INFO:tensorflow:global_step/sec: 88.687
INFO:tensorflow:global_step/sec: 88.687
INFO:tensorflow:loss = 0.66613746, step = 200 (1.128 sec)
INFO:tensorflow:loss = 0.66613746, step = 200 (1.128 sec)
INFO:tensorflow:global_step/sec: 89.2614
INFO:tensorflow:global_step/sec: 89.2614
INFO:tensorflow:loss = 0.60222983, step = 300 (1.120 sec)
INFO:tensorflow:loss = 0.60222983, step = 300 (1.120 sec)
INFO:tensorflow:global_step/sec: 92.0801
INFO:tensorflow:global_step/sec: 92.0801
INFO:tensorflow:loss = 0.59789896, step = 400 (1.086 sec)
INFO:tensorflow:loss = 0.59789896, step = 400 (1.086 sec)
INFO:tensorflow:global_step/sec: 90.762
INFO:tensorflow:global_step/sec: 90.762
INFO:tensorflow:loss = 0.6843546, step = 500 (1.102 sec)
INFO:tensorflow:loss = 0.6843546, step = 500 (1.102 sec)
INFO:tensorflow:global_step/sec: 83.4971
INFO:tensorflow:global_step/sec: 83.4971
INFO:tensorflow:loss = 0.6567893, step = 600 (1.198 sec)
INFO:tensorflow:loss = 0.6567893, step = 600 (1.198 sec)
INFO:tensorflow:global_step/sec: 90.5604
INFO:tensorflow:global_step/sec: 90.5604
INFO:tensorflow:loss = 0.6075423, step = 700 (1.105 sec)
INFO:tensorflow:loss = 0.6075423, step = 700 (1.105 sec)
INFO:tensorflow:global_step/sec: 82.2708
INFO:tensorflow:global_step/sec: 82.2708
INFO:tensorflow:loss = 0.58856726, step = 800 (1.215 sec)
INFO:tensorflow:loss = 0.58856726, step = 800 (1.215 sec)
INFO:tensorflow:global_step/sec: 90.2744
INFO:tensorflow:global_step/sec: 90.2744
INFO:tensorflow:loss = 0.61389387, step = 900 (1.108 sec)
INFO:tensorflow:loss = 0.61389387, step = 900 (1.108 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1000...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1000...
INFO:tensorflow:Saving checkpoints for 1000 into /tmpfs/tmp/tmp0xdt5vok/model.ckpt.
INFO:tensorflow:Saving checkpoints for 1000 into /tmpfs/tmp/tmp0xdt5vok/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000.meta
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000.meta
INFO:tensorflow:200
INFO:tensorflow:200
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000.index
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000.index
INFO:tensorflow:200
INFO:tensorflow:200
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000.data-00000-of-00001
INFO:tensorflow:/tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000.data-00000-of-00001
INFO:tensorflow:499700
INFO:tensorflow:499700
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1000...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1000...
INFO:tensorflow:Loss for final step: 0.6218455.
INFO:tensorflow:Loss for final step: 0.6218455.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:40:22.210983: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T17:40:23
INFO:tensorflow:Starting evaluation at 2022-06-03T17:40:23
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 3.81693s
INFO:tensorflow:Inference Time : 3.81693s
INFO:tensorflow:Finished evaluation at 2022-06-03-17:40:27
INFO:tensorflow:Finished evaluation at 2022-06-03-17:40:27
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.6796, accuracy_baseline = 0.5, auc = 0.7449059, auc_precision_recall = 0.73526365, average_loss = 0.596061, global_step = 1000, label/mean = 0.5, loss = 0.596665, precision = 0.67440957, prediction/mean = 0.50572324, recall = 0.69448
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.6796, accuracy_baseline = 0.5, auc = 0.7449059, auc_precision_recall = 0.73526365, average_loss = 0.596061, global_step = 1000, label/mean = 0.5, loss = 0.596665, precision = 0.67440957, prediction/mean = 0.50572324, recall = 0.69448
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:40:27.220114: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T17:40:28
INFO:tensorflow:Starting evaluation at 2022-06-03T17:40:28
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 3.74787s
INFO:tensorflow:Inference Time : 3.74787s
INFO:tensorflow:Finished evaluation at 2022-06-03-17:40:31
INFO:tensorflow:Finished evaluation at 2022-06-03-17:40:31
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.66836, accuracy_baseline = 0.5, auc = 0.7271519, auc_precision_recall = 0.7177882, average_loss = 0.6094725, global_step = 1000, label/mean = 0.5, loss = 0.60918844, precision = 0.66327876, prediction/mean = 0.50516915, recall = 0.68392
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.66836, accuracy_baseline = 0.5, auc = 0.7271519, auc_precision_recall = 0.7177882, average_loss = 0.6094725, global_step = 1000, label/mean = 0.5, loss = 0.60918844, precision = 0.66327876, prediction/mean = 0.50516915, recall = 0.68392
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp0xdt5vok/model.ckpt-1000
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp1_klqxsi
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp1_klqxsi
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp1_klqxsi', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp1_klqxsi', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:40:32.007811: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp1_klqxsi/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp1_klqxsi/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-0.data-00000-of-00001
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-0.data-00000-of-00001
INFO:tensorflow:499500
INFO:tensorflow:499500
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-0.index
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-0.index
INFO:tensorflow:499500
INFO:tensorflow:499500
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-0.meta
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-0.meta
INFO:tensorflow:499700
INFO:tensorflow:499700
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.8407335, step = 0
INFO:tensorflow:loss = 0.8407335, step = 0
INFO:tensorflow:global_step/sec: 84.2802
INFO:tensorflow:global_step/sec: 84.2802
INFO:tensorflow:loss = 0.6509659, step = 100 (1.189 sec)
INFO:tensorflow:loss = 0.6509659, step = 100 (1.189 sec)
INFO:tensorflow:global_step/sec: 89.1328
INFO:tensorflow:global_step/sec: 89.1328
INFO:tensorflow:loss = 0.6030776, step = 200 (1.122 sec)
INFO:tensorflow:loss = 0.6030776, step = 200 (1.122 sec)
INFO:tensorflow:global_step/sec: 84.1412
INFO:tensorflow:global_step/sec: 84.1412
INFO:tensorflow:loss = 0.63329923, step = 300 (1.188 sec)
INFO:tensorflow:loss = 0.63329923, step = 300 (1.188 sec)
INFO:tensorflow:global_step/sec: 91.2338
INFO:tensorflow:global_step/sec: 91.2338
INFO:tensorflow:loss = 0.644066, step = 400 (1.096 sec)
INFO:tensorflow:loss = 0.644066, step = 400 (1.096 sec)
INFO:tensorflow:global_step/sec: 90.1763
INFO:tensorflow:global_step/sec: 90.1763
INFO:tensorflow:loss = 0.63977265, step = 500 (1.109 sec)
INFO:tensorflow:loss = 0.63977265, step = 500 (1.109 sec)
INFO:tensorflow:global_step/sec: 82.6851
INFO:tensorflow:global_step/sec: 82.6851
INFO:tensorflow:loss = 0.6360935, step = 600 (1.209 sec)
INFO:tensorflow:loss = 0.6360935, step = 600 (1.209 sec)
INFO:tensorflow:global_step/sec: 84.829
INFO:tensorflow:global_step/sec: 84.829
INFO:tensorflow:loss = 0.5727933, step = 700 (1.179 sec)
INFO:tensorflow:loss = 0.5727933, step = 700 (1.179 sec)
INFO:tensorflow:global_step/sec: 89.8701
INFO:tensorflow:global_step/sec: 89.8701
INFO:tensorflow:loss = 0.6245197, step = 800 (1.113 sec)
INFO:tensorflow:loss = 0.6245197, step = 800 (1.113 sec)
INFO:tensorflow:global_step/sec: 90.9964
INFO:tensorflow:global_step/sec: 90.9964
INFO:tensorflow:loss = 0.5876771, step = 900 (1.099 sec)
INFO:tensorflow:loss = 0.5876771, step = 900 (1.099 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1000...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1000...
INFO:tensorflow:Saving checkpoints for 1000 into /tmpfs/tmp/tmp1_klqxsi/model.ckpt.
INFO:tensorflow:Saving checkpoints for 1000 into /tmpfs/tmp/tmp1_klqxsi/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000.meta
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000.meta
INFO:tensorflow:200
INFO:tensorflow:200
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000.index
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000.index
INFO:tensorflow:200
INFO:tensorflow:200
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000.data-00000-of-00001
INFO:tensorflow:/tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000.data-00000-of-00001
INFO:tensorflow:499700
INFO:tensorflow:499700
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1000...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1000...
INFO:tensorflow:Loss for final step: 0.62289476.
INFO:tensorflow:Loss for final step: 0.62289476.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:40:47.349089: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T17:40:48
INFO:tensorflow:Starting evaluation at 2022-06-03T17:40:48
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 3.80422s
INFO:tensorflow:Inference Time : 3.80422s
INFO:tensorflow:Finished evaluation at 2022-06-03-17:40:52
INFO:tensorflow:Finished evaluation at 2022-06-03-17:40:52
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.67696, accuracy_baseline = 0.5, auc = 0.7428666, auc_precision_recall = 0.734137, average_loss = 0.5975216, global_step = 1000, label/mean = 0.5, loss = 0.598098, precision = 0.6784159, prediction/mean = 0.49787614, recall = 0.67288
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.67696, accuracy_baseline = 0.5, auc = 0.7428666, auc_precision_recall = 0.734137, average_loss = 0.5975216, global_step = 1000, label/mean = 0.5, loss = 0.598098, precision = 0.6784159, prediction/mean = 0.49787614, recall = 0.67288
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:40:52.328737: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T17:40:53
INFO:tensorflow:Starting evaluation at 2022-06-03T17:40:53
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 3.77201s
INFO:tensorflow:Inference Time : 3.77201s
INFO:tensorflow:Finished evaluation at 2022-06-03-17:40:57
INFO:tensorflow:Finished evaluation at 2022-06-03-17:40:57
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.66792, accuracy_baseline = 0.5, auc = 0.7287145, auc_precision_recall = 0.718277, average_loss = 0.608946, global_step = 1000, label/mean = 0.5, loss = 0.60857475, precision = 0.66859436, prediction/mean = 0.49823847, recall = 0.66592
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.66792, accuracy_baseline = 0.5, auc = 0.7287145, auc_precision_recall = 0.718277, average_loss = 0.608946, global_step = 1000, label/mean = 0.5, loss = 0.60857475, precision = 0.66859436, prediction/mean = 0.49823847, recall = 0.66592
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: /tmpfs/tmp/tmp1_klqxsi/model.ckpt-1000

我们来看看结果。

pd.DataFrame.from_dict(results, orient="index")

我们已经看到了一些模式,但首先我们应当建立测试集的基线准确率 - 通过仅输出最具代表性的类的标签可以实现的下限:

estimator.evaluate(input_fn=predict_test_input_fn)["accuracy_baseline"]
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
2022-06-03 17:40:57.143723: W tensorflow/core/common_runtime/graph_constructor.cc:1526] Importing a graph with a lower producer version 26 into an existing graph with producer version 1087. Shape inference will have run different parts of the graph with different producer versions.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T17:40:58
INFO:tensorflow:Starting evaluation at 2022-06-03T17:40:58
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 3.75038s
INFO:tensorflow:Inference Time : 3.75038s
INFO:tensorflow:Finished evaluation at 2022-06-03-17:41:01
INFO:tensorflow:Finished evaluation at 2022-06-03-17:41:01
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.78612, accuracy_baseline = 0.5, auc = 0.86889553, auc_precision_recall = 0.8715557, average_loss = 0.45473337, global_step = 5000, label/mean = 0.5, loss = 0.45439148, precision = 0.80508405, prediction/mean = 0.4827163, recall = 0.75504
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.78612, accuracy_baseline = 0.5, auc = 0.86889553, auc_precision_recall = 0.8715557, average_loss = 0.45473337, global_step = 5000, label/mean = 0.5, loss = 0.45439148, precision = 0.80508405, prediction/mean = 0.4827163, recall = 0.75504
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpudlj2nw3/model.ckpt-5000
0.5

分配最具代表性的类将使我们的准确率达到 50%。这里有几件事需要注意:

  1. 也许会令人惊讶,但此时仍然可以在固定的随机嵌入向量上学习模型。原因是,即使字典中的每个单词都映射到随机向量,Estimator 也可以仅仅使用其全连接层来分隔空间。
  2. 与仅训练分类器相比,允许使用随机嵌入向量训练模块有助于提高训练和测试的准确率。
  3. 使用预训练的嵌入向量训练模块也有助于提高这两个准确率。不过,此时需要注意在训练集上的过拟合。即使进行正则化,训练一个预训练的模块也可能存在风险,因为嵌入向量权重不再表示基于各种数据训练的语言模型,相反,它们会收敛到新数据集的理想表示。