不平衡数据的分类

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

本教程演示了如何对高度不平衡的数据集进行分类,在此类数据集中,一类中的样本数量远多于另一类中的样本数量。您将使用 Kaggle 上托管的 Credit Card Fraud Detection 数据集,目的是从总共 284,807 笔交易中检测出仅有的 492 笔欺诈交易。您将使用 Keras 来定义模型和类权重,从而帮助模型从不平衡数据中学习。

本教程包含下列操作的完整代码:

  • 使用 Pandas 加载 CSV 文件。
  • 创建训练、验证和测试集。
  • 使用 Keras 定义并训练模型(包括设置类权重)。
  • 使用各种指标(包括精确率和召回率)评估模型。
  • 尝试使用常见技术来处理不平衡数据,例如:
    • 类加权
    • 过采样

设置

import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
2023-11-08 01:16:05.391090: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-08 01:16:05.391146: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-08 01:16:05.392950: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

数据处理与浏览

下载 Kaggle Credit Card Fraud 数据集

Pandas 是一个 Python 库,其中包含许多有用的实用工具,用于加载和使用结构化数据,并可用于将 CSV 下载到数据帧中。

注:Worldline 和 ULB(布鲁塞尔自由大学)机器学习小组在大数据挖掘和欺诈检测的合作研究期间,已对此数据集进行了收集和分析。与相关主题当前和过去项目有关的详细信息,请访问这里DefeatFraud 项目页面。

file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()

检查类标签的不平衡

让我们看一下数据集的不平衡情况:

neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)

这表明正样本的比例很小。

清理、拆分和归一化数据

原始数据有一些问题。首先,TimeAmount 列变化太大,无法直接使用。删除 Time 列(因为不清楚其含义),并获取 Amount 列的日志以缩小其范围。

cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps = 0.001 # 0 => 0.1¢
cleaned_df['Log Amount'] = np.log(cleaned_df.pop('Amount')+eps)

将数据集拆分为训练、验证和测试集。验证集在模型拟合期间使用,用于评估损失和任何指标,判断模型与数据的拟合程度。测试集在训练阶段完全不使用,仅在最后用于评估模型泛化到新数据的能力。这对于不平衡的数据集尤为重要,因为过拟合是缺乏训练数据造成的一个重大问题。

# Use a utility from sklearn to split and shuffle your dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)

使用 sklearn StandardScaler 将输入特征归一化。这会将平均值设置为 0,标准偏差设置为 1。

注:StandardScaler 只能使用 train_features 进行拟合,以确保模型不会窥视验证集或测试集。

scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

小心:如果要部署模型,保留预处理计算至关重要。这是将它们实现为层并在导出之前将它们附加到模型最简单的方法。

查看数据分布

接下来通过一些特征比较一下正样本和负样本的分布。此时,建议您问自己如下问题:

  • 这些分布是否有意义?
    • 是。您已对输入进行了归一化处理,而它们大多集中在 +/- 2 范围内。
  • 您是否能看出分布之间的差异?
    • 是。正样本包含极值的比率高得多 。
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train_df.columns)

sns.jointplot(x=pos_df['V5'], y=pos_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(x=neg_df['V5'], y=neg_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")

png

png

定义模型和指标

定义一个函数,该函数会创建一个简单的神经网络,其中包含一个密集连接的隐藏层、一个用于减少过拟合的随机失活层,以及一个返回欺诈交易概率的输出 Sigmoid 层:

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
      keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]

def make_model(metrics=METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(learning_rate=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model

了解有用的指标

请注意,上面定义的一些指标可以由模型计算得出,这对评估性能很有帮助。

  • 负例和正例是被错误分类的样本
  • 负例和正例是被正确分类的样本
  • 准确率是被正确分类的样本的百分比

\(\frac{\text{true samples} }{\text{total samples} }\)

  • 精确率是被正确分类的预测正例的百分比

\(\frac{\text{true positives} }{\text{true positives + false positives} }\)

  • 召回率是被正确分类的实际正例的百分比

\(\frac{\text{true positives} }{\text{true positives + false negatives} }\)

  • AUC 是指接收器操作特征曲线中的曲线下方面积 (ROC-AUC)。此指标等于分类器对随机正样本的排序高于随机负样本的概率。
  • AUPRC 是指精确率-召回率曲线下方面积。该指标计算不同概率阈值的精度率-召回率对。

注:准确率在此任务中不是一个有用的指标。只要始终预测“False”,您就可以在此任务中达到 99.8%+ 的准确率。

延伸阅读:

基线模型

构建模型

现在,使用先前定义的函数创建并训练模型。请注意,该模型使用大于默认的批次大小 (2048) 来进行拟合,这一点很重要,有助于确保每个批次都有一定机会包含少量正样本。如果批次过小,它们可能会没有可供学习的欺诈交易。

注:此模型无法很好地处理类不平衡问题。我们将在本教程的后面部分对此进行改进。

EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 16)                480       
                                                                 
 dropout (Dropout)           (None, 16)                0         
                                                                 
 dense_1 (Dense)             (None, 1)                 17        
                                                                 
=================================================================
Total params: 497 (1.94 KB)
Trainable params: 497 (1.94 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

试运行模型:

model.predict(train_features[:10])
1/1 [==============================] - 0s 377ms/step
array([[0.12993479],
       [0.0831413 ],
       [0.24497728],
       [0.3017688 ],
       [0.33697775],
       [0.0949156 ],
       [0.1872398 ],
       [0.27218476],
       [0.29461417],
       [0.6366021 ]], dtype=float32)

可选:设置正确的初始偏差。

模型最初的猜测不太理想。您知道数据集不平衡,因此需要设置输出层的偏差以反映这种不平衡(请参阅:训练神经网络的秘诀:“好好初始化”)。这样做有助于初始收敛。

使用默认偏差初始化时,损失应约为 math.log(2) = 0.69314

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.3724

可以用以下代码推导出要设置的正确偏差:

\( p_0 = pos/(pos + neg) = 1/(1+e^{-b_0}) \) \( b_0 = -log_e(1/p_0 - 1) \) \[ b_0 = log_e(pos/neg)\]

initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])

将其设置为初始偏差,模型将给出合理得多的初始猜测。

结果应该接近:pos/total = 0.0018

model = make_model(output_bias = initial_bias)
model.predict(train_features[:10])
1/1 [==============================] - 0s 81ms/step
array([[0.00022771],
       [0.00144273],
       [0.00385517],
       [0.00052363],
       [0.00052137],
       [0.00077223],
       [0.00100834],
       [0.00055533],
       [0.00064206],
       [0.00016787]], dtype=float32)

使用此初始化,初始损失应约为:

\[-p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317\]

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0174

此初始损失大约是使用朴素初始化时损失的 50 倍。

这样,模型就不需要花费前几个周期去仅仅了解不可能有正样本。这也使得在训练过程中更容易读取损失图。

为初始权重设置检查点

为了使各种训练运行更具可比性,请将这个初始模型的权重保存在检查点文件中,并在训练前将它们加载到每个模型中:

initial_weights = os.path.join(tempfile.mkdtemp(),'initial_weights')
model.save_weights(initial_weights)

确认偏差修正有帮助

在继续之前,迅速确认这一细致偏差初始化是否确实起了作用。

在使用和不使用此细致初始化的情况下,将模型训练 20 个周期,并比较损失:

model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699406183.892374  943342 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
def plot_loss(history, label, n):
  # Use a log scale to show the wide range of values.
  plt.semilogy(history.epoch,  history.history['loss'],
               color=colors[n], label='Train '+label)
  plt.semilogy(history.epoch,  history.history['val_loss'],
          color=colors[n], label='Val '+label,
          linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')

  plt.legend()
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)

png

上图清楚表明:就验证损失而言,在这个问题上,此细致初始化具有明显优势。

训练模型

model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels))
Epoch 1/100
90/90 [==============================] - 3s 12ms/step - loss: 0.0135 - tp: 101.0000 - fp: 26.0000 - tn: 227434.0000 - fn: 284.0000 - accuracy: 0.9986 - precision: 0.7953 - recall: 0.2623 - auc: 0.7220 - prc: 0.3198 - val_loss: 0.0076 - val_tp: 3.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 66.0000 - val_accuracy: 0.9984 - val_precision: 0.3750 - val_recall: 0.0435 - val_auc: 0.8258 - val_prc: 0.4747
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0094 - tp: 98.0000 - fp: 16.0000 - tn: 181944.0000 - fn: 218.0000 - accuracy: 0.9987 - precision: 0.8596 - recall: 0.3101 - auc: 0.7451 - prc: 0.3475 - val_loss: 0.0050 - val_tp: 18.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 51.0000 - val_accuracy: 0.9987 - val_precision: 0.7500 - val_recall: 0.2609 - val_auc: 0.9127 - val_prc: 0.6222
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0079 - tp: 118.0000 - fp: 25.0000 - tn: 181935.0000 - fn: 198.0000 - accuracy: 0.9988 - precision: 0.8252 - recall: 0.3734 - auc: 0.8315 - prc: 0.4644 - val_loss: 0.0039 - val_tp: 38.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 31.0000 - val_accuracy: 0.9992 - val_precision: 0.8636 - val_recall: 0.5507 - val_auc: 0.9199 - val_prc: 0.6663
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0069 - tp: 132.0000 - fp: 25.0000 - tn: 181935.0000 - fn: 184.0000 - accuracy: 0.9989 - precision: 0.8408 - recall: 0.4177 - auc: 0.8565 - prc: 0.5263 - val_loss: 0.0036 - val_tp: 43.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 26.0000 - val_accuracy: 0.9993 - val_precision: 0.8776 - val_recall: 0.6232 - val_auc: 0.9272 - val_prc: 0.6980
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0065 - tp: 138.0000 - fp: 28.0000 - tn: 181932.0000 - fn: 178.0000 - accuracy: 0.9989 - precision: 0.8313 - recall: 0.4367 - auc: 0.8619 - prc: 0.5629 - val_loss: 0.0034 - val_tp: 44.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8800 - val_recall: 0.6377 - val_auc: 0.9345 - val_prc: 0.7071
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0061 - tp: 149.0000 - fp: 29.0000 - tn: 181931.0000 - fn: 167.0000 - accuracy: 0.9989 - precision: 0.8371 - recall: 0.4715 - auc: 0.8767 - prc: 0.5928 - val_loss: 0.0032 - val_tp: 44.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8800 - val_recall: 0.6377 - val_auc: 0.9417 - val_prc: 0.7401
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0061 - tp: 150.0000 - fp: 29.0000 - tn: 181931.0000 - fn: 166.0000 - accuracy: 0.9989 - precision: 0.8380 - recall: 0.4747 - auc: 0.8655 - prc: 0.5760 - val_loss: 0.0031 - val_tp: 44.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8800 - val_recall: 0.6377 - val_auc: 0.9418 - val_prc: 0.7652
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0061 - tp: 144.0000 - fp: 22.0000 - tn: 181938.0000 - fn: 172.0000 - accuracy: 0.9989 - precision: 0.8675 - recall: 0.4557 - auc: 0.8658 - prc: 0.5889 - val_loss: 0.0029 - val_tp: 45.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8824 - val_recall: 0.6522 - val_auc: 0.9418 - val_prc: 0.7871
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0059 - tp: 144.0000 - fp: 26.0000 - tn: 181934.0000 - fn: 172.0000 - accuracy: 0.9989 - precision: 0.8471 - recall: 0.4557 - auc: 0.8610 - prc: 0.5845 - val_loss: 0.0028 - val_tp: 46.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.8846 - val_recall: 0.6667 - val_auc: 0.9490 - val_prc: 0.8028
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0058 - tp: 144.0000 - fp: 28.0000 - tn: 181932.0000 - fn: 172.0000 - accuracy: 0.9989 - precision: 0.8372 - recall: 0.4557 - auc: 0.8690 - prc: 0.5951 - val_loss: 0.0027 - val_tp: 45.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8824 - val_recall: 0.6522 - val_auc: 0.9490 - val_prc: 0.8107
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0058 - tp: 150.0000 - fp: 24.0000 - tn: 181936.0000 - fn: 166.0000 - accuracy: 0.9990 - precision: 0.8621 - recall: 0.4747 - auc: 0.8626 - prc: 0.5828 - val_loss: 0.0027 - val_tp: 45.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8824 - val_recall: 0.6522 - val_auc: 0.9491 - val_prc: 0.8156
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0053 - tp: 154.0000 - fp: 26.0000 - tn: 181934.0000 - fn: 162.0000 - accuracy: 0.9990 - precision: 0.8556 - recall: 0.4873 - auc: 0.8739 - prc: 0.6229 - val_loss: 0.0026 - val_tp: 48.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8889 - val_recall: 0.6957 - val_auc: 0.9563 - val_prc: 0.8279
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0056 - tp: 142.0000 - fp: 28.0000 - tn: 181932.0000 - fn: 174.0000 - accuracy: 0.9989 - precision: 0.8353 - recall: 0.4494 - auc: 0.8627 - prc: 0.5919 - val_loss: 0.0025 - val_tp: 49.0000 - val_fp: 7.0000 - val_tn: 45493.0000 - val_fn: 20.0000 - val_accuracy: 0.9994 - val_precision: 0.8750 - val_recall: 0.7101 - val_auc: 0.9563 - val_prc: 0.8343
Epoch 14/100
90/90 [==============================] - 0s 6ms/step - loss: 0.0054 - tp: 155.0000 - fp: 25.0000 - tn: 181935.0000 - fn: 161.0000 - accuracy: 0.9990 - precision: 0.8611 - recall: 0.4905 - auc: 0.8752 - prc: 0.6058 - val_loss: 0.0024 - val_tp: 51.0000 - val_fp: 8.0000 - val_tn: 45492.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8644 - val_recall: 0.7391 - val_auc: 0.9563 - val_prc: 0.8385
Epoch 15/100
90/90 [==============================] - 0s 6ms/step - loss: 0.0054 - tp: 160.0000 - fp: 30.0000 - tn: 181930.0000 - fn: 156.0000 - accuracy: 0.9990 - precision: 0.8421 - recall: 0.5063 - auc: 0.8722 - prc: 0.6190 - val_loss: 0.0024 - val_tp: 50.0000 - val_fp: 7.0000 - val_tn: 45493.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8772 - val_recall: 0.7246 - val_auc: 0.9563 - val_prc: 0.8455
Epoch 16/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0053 - tp: 155.0000 - fp: 29.0000 - tn: 181931.0000 - fn: 161.0000 - accuracy: 0.9990 - precision: 0.8424 - recall: 0.4905 - auc: 0.8787 - prc: 0.6159 - val_loss: 0.0023 - val_tp: 50.0000 - val_fp: 7.0000 - val_tn: 45493.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8772 - val_recall: 0.7246 - val_auc: 0.9563 - val_prc: 0.8522
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0052 - tp: 172.0000 - fp: 35.0000 - tn: 181925.0000 - fn: 144.0000 - accuracy: 0.9990 - precision: 0.8309 - recall: 0.5443 - auc: 0.8802 - prc: 0.6151 - val_loss: 0.0023 - val_tp: 47.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.8868 - val_recall: 0.6812 - val_auc: 0.9563 - val_prc: 0.8593
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0055 - tp: 138.0000 - fp: 26.0000 - tn: 181934.0000 - fn: 178.0000 - accuracy: 0.9989 - precision: 0.8415 - recall: 0.4367 - auc: 0.8657 - prc: 0.5980 - val_loss: 0.0022 - val_tp: 50.0000 - val_fp: 7.0000 - val_tn: 45493.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8772 - val_recall: 0.7246 - val_auc: 0.9563 - val_prc: 0.8602
Epoch 19/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0055 - tp: 149.0000 - fp: 26.0000 - tn: 181934.0000 - fn: 167.0000 - accuracy: 0.9989 - precision: 0.8514 - recall: 0.4715 - auc: 0.8546 - prc: 0.5791 - val_loss: 0.0022 - val_tp: 50.0000 - val_fp: 7.0000 - val_tn: 45493.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8772 - val_recall: 0.7246 - val_auc: 0.9563 - val_prc: 0.8629
Epoch 20/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0053 - tp: 147.0000 - fp: 30.0000 - tn: 181930.0000 - fn: 169.0000 - accuracy: 0.9989 - precision: 0.8305 - recall: 0.4652 - auc: 0.8752 - prc: 0.6042 - val_loss: 0.0021 - val_tp: 50.0000 - val_fp: 7.0000 - val_tn: 45493.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8772 - val_recall: 0.7246 - val_auc: 0.9563 - val_prc: 0.8668
Epoch 21/100
90/90 [==============================] - 0s 6ms/step - loss: 0.0049 - tp: 164.0000 - fp: 21.0000 - tn: 181939.0000 - fn: 152.0000 - accuracy: 0.9991 - precision: 0.8865 - recall: 0.5190 - auc: 0.8737 - prc: 0.6472 - val_loss: 0.0021 - val_tp: 52.0000 - val_fp: 7.0000 - val_tn: 45493.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.8814 - val_recall: 0.7536 - val_auc: 0.9563 - val_prc: 0.8674
Epoch 22/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 152.0000 - fp: 31.0000 - tn: 181929.0000 - fn: 164.0000 - accuracy: 0.9989 - precision: 0.8306 - recall: 0.4810 - auc: 0.8848 - prc: 0.6228 - val_loss: 0.0021 - val_tp: 53.0000 - val_fp: 8.0000 - val_tn: 45492.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8689 - val_recall: 0.7681 - val_auc: 0.9563 - val_prc: 0.8662
Epoch 23/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0051 - tp: 158.0000 - fp: 31.0000 - tn: 181929.0000 - fn: 158.0000 - accuracy: 0.9990 - precision: 0.8360 - recall: 0.5000 - auc: 0.8833 - prc: 0.6207 - val_loss: 0.0020 - val_tp: 51.0000 - val_fp: 4.0000 - val_tn: 45496.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9273 - val_recall: 0.7391 - val_auc: 0.9563 - val_prc: 0.8708
Epoch 24/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0049 - tp: 159.0000 - fp: 25.0000 - tn: 181935.0000 - fn: 157.0000 - accuracy: 0.9990 - precision: 0.8641 - recall: 0.5032 - auc: 0.8801 - prc: 0.6329 - val_loss: 0.0020 - val_tp: 52.0000 - val_fp: 4.0000 - val_tn: 45496.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9286 - val_recall: 0.7536 - val_auc: 0.9563 - val_prc: 0.8745
Epoch 25/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0049 - tp: 157.0000 - fp: 25.0000 - tn: 181935.0000 - fn: 159.0000 - accuracy: 0.9990 - precision: 0.8626 - recall: 0.4968 - auc: 0.8864 - prc: 0.6369 - val_loss: 0.0020 - val_tp: 54.0000 - val_fp: 4.0000 - val_tn: 45496.0000 - val_fn: 15.0000 - val_accuracy: 0.9996 - val_precision: 0.9310 - val_recall: 0.7826 - val_auc: 0.9563 - val_prc: 0.8758
Epoch 26/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0049 - tp: 154.0000 - fp: 21.0000 - tn: 181939.0000 - fn: 162.0000 - accuracy: 0.9990 - precision: 0.8800 - recall: 0.4873 - auc: 0.8706 - prc: 0.6314 - val_loss: 0.0020 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 14.0000 - val_accuracy: 0.9996 - val_precision: 0.9016 - val_recall: 0.7971 - val_auc: 0.9563 - val_prc: 0.8750
Epoch 27/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0051 - tp: 146.0000 - fp: 33.0000 - tn: 181927.0000 - fn: 170.0000 - accuracy: 0.9989 - precision: 0.8156 - recall: 0.4620 - auc: 0.8736 - prc: 0.6080 - val_loss: 0.0019 - val_tp: 55.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 14.0000 - val_accuracy: 0.9996 - val_precision: 0.9167 - val_recall: 0.7971 - val_auc: 0.9563 - val_prc: 0.8780
Epoch 28/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0050 - tp: 161.0000 - fp: 24.0000 - tn: 181936.0000 - fn: 155.0000 - accuracy: 0.9990 - precision: 0.8703 - recall: 0.5095 - auc: 0.8705 - prc: 0.6133 - val_loss: 0.0019 - val_tp: 55.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 14.0000 - val_accuracy: 0.9996 - val_precision: 0.9167 - val_recall: 0.7971 - val_auc: 0.9563 - val_prc: 0.8793
Epoch 29/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 171.0000 - fp: 24.0000 - tn: 181936.0000 - fn: 145.0000 - accuracy: 0.9991 - precision: 0.8769 - recall: 0.5411 - auc: 0.8912 - prc: 0.6686 - val_loss: 0.0019 - val_tp: 56.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 13.0000 - val_accuracy: 0.9996 - val_precision: 0.9180 - val_recall: 0.8116 - val_auc: 0.9563 - val_prc: 0.8791
Epoch 30/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 173.0000 - fp: 18.0000 - tn: 181942.0000 - fn: 143.0000 - accuracy: 0.9991 - precision: 0.9058 - recall: 0.5475 - auc: 0.8928 - prc: 0.6809 - val_loss: 0.0019 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9563 - val_prc: 0.8784
Epoch 31/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0047 - tp: 165.0000 - fp: 30.0000 - tn: 181930.0000 - fn: 151.0000 - accuracy: 0.9990 - precision: 0.8462 - recall: 0.5222 - auc: 0.8849 - prc: 0.6529 - val_loss: 0.0019 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9563 - val_prc: 0.8799
Epoch 32/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0048 - tp: 173.0000 - fp: 32.0000 - tn: 181928.0000 - fn: 143.0000 - accuracy: 0.9990 - precision: 0.8439 - recall: 0.5475 - auc: 0.8846 - prc: 0.6361 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9062 - val_recall: 0.8406 - val_auc: 0.9563 - val_prc: 0.8813
Epoch 33/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 166.0000 - fp: 28.0000 - tn: 181932.0000 - fn: 150.0000 - accuracy: 0.9990 - precision: 0.8557 - recall: 0.5253 - auc: 0.8799 - prc: 0.6417 - val_loss: 0.0019 - val_tp: 59.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 10.0000 - val_accuracy: 0.9996 - val_precision: 0.9077 - val_recall: 0.8551 - val_auc: 0.9563 - val_prc: 0.8787
Epoch 34/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 170.0000 - fp: 24.0000 - tn: 181936.0000 - fn: 146.0000 - accuracy: 0.9991 - precision: 0.8763 - recall: 0.5380 - auc: 0.8847 - prc: 0.6515 - val_loss: 0.0019 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9062 - val_recall: 0.8406 - val_auc: 0.9563 - val_prc: 0.8822
Epoch 35/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0048 - tp: 166.0000 - fp: 30.0000 - tn: 181930.0000 - fn: 150.0000 - accuracy: 0.9990 - precision: 0.8469 - recall: 0.5253 - auc: 0.8816 - prc: 0.6244 - val_loss: 0.0018 - val_tp: 56.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 13.0000 - val_accuracy: 0.9996 - val_precision: 0.9180 - val_recall: 0.8116 - val_auc: 0.9563 - val_prc: 0.8823
Epoch 36/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 159.0000 - fp: 27.0000 - tn: 181933.0000 - fn: 157.0000 - accuracy: 0.9990 - precision: 0.8548 - recall: 0.5032 - auc: 0.8976 - prc: 0.6680 - val_loss: 0.0018 - val_tp: 59.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 10.0000 - val_accuracy: 0.9996 - val_precision: 0.9077 - val_recall: 0.8551 - val_auc: 0.9564 - val_prc: 0.8842
Epoch 37/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 177.0000 - fp: 25.0000 - tn: 181935.0000 - fn: 139.0000 - accuracy: 0.9991 - precision: 0.8762 - recall: 0.5601 - auc: 0.8896 - prc: 0.6850 - val_loss: 0.0018 - val_tp: 59.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 10.0000 - val_accuracy: 0.9996 - val_precision: 0.9077 - val_recall: 0.8551 - val_auc: 0.9563 - val_prc: 0.8839
Epoch 38/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 168.0000 - fp: 24.0000 - tn: 181936.0000 - fn: 148.0000 - accuracy: 0.9991 - precision: 0.8750 - recall: 0.5316 - auc: 0.8832 - prc: 0.6569 - val_loss: 0.0018 - val_tp: 59.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 10.0000 - val_accuracy: 0.9996 - val_precision: 0.9077 - val_recall: 0.8551 - val_auc: 0.9563 - val_prc: 0.8821
Epoch 39/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0047 - tp: 165.0000 - fp: 29.0000 - tn: 181931.0000 - fn: 151.0000 - accuracy: 0.9990 - precision: 0.8505 - recall: 0.5222 - auc: 0.8833 - prc: 0.6458 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9564 - val_prc: 0.8834
Epoch 40/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0047 - tp: 161.0000 - fp: 27.0000 - tn: 181933.0000 - fn: 155.0000 - accuracy: 0.9990 - precision: 0.8564 - recall: 0.5095 - auc: 0.8801 - prc: 0.6468 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9564 - val_prc: 0.8841
Epoch 41/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 174.0000 - fp: 30.0000 - tn: 181930.0000 - fn: 142.0000 - accuracy: 0.9991 - precision: 0.8529 - recall: 0.5506 - auc: 0.8881 - prc: 0.6824 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9564 - val_prc: 0.8830
Epoch 42/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 161.0000 - fp: 27.0000 - tn: 181933.0000 - fn: 155.0000 - accuracy: 0.9990 - precision: 0.8564 - recall: 0.5095 - auc: 0.8707 - prc: 0.6479 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9563 - val_prc: 0.8825
Epoch 43/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0048 - tp: 172.0000 - fp: 24.0000 - tn: 181936.0000 - fn: 144.0000 - accuracy: 0.9991 - precision: 0.8776 - recall: 0.5443 - auc: 0.8753 - prc: 0.6130 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9636 - val_prc: 0.8914
Epoch 44/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0045 - tp: 173.0000 - fp: 30.0000 - tn: 181930.0000 - fn: 143.0000 - accuracy: 0.9991 - precision: 0.8522 - recall: 0.5475 - auc: 0.8833 - prc: 0.6598 - val_loss: 0.0018 - val_tp: 57.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 12.0000 - val_accuracy: 0.9996 - val_precision: 0.9194 - val_recall: 0.8261 - val_auc: 0.9636 - val_prc: 0.8937
Epoch 45/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 164.0000 - fp: 24.0000 - tn: 181936.0000 - fn: 152.0000 - accuracy: 0.9990 - precision: 0.8723 - recall: 0.5190 - auc: 0.8865 - prc: 0.6724 - val_loss: 0.0018 - val_tp: 57.0000 - val_fp: 4.0000 - val_tn: 45496.0000 - val_fn: 12.0000 - val_accuracy: 0.9996 - val_precision: 0.9344 - val_recall: 0.8261 - val_auc: 0.9636 - val_prc: 0.8939
Epoch 46/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 165.0000 - fp: 23.0000 - tn: 181937.0000 - fn: 151.0000 - accuracy: 0.9990 - precision: 0.8777 - recall: 0.5222 - auc: 0.8960 - prc: 0.6687 - val_loss: 0.0018 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 45496.0000 - val_fn: 10.0000 - val_accuracy: 0.9997 - val_precision: 0.9365 - val_recall: 0.8551 - val_auc: 0.9636 - val_prc: 0.8948
Epoch 47/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 166.0000 - fp: 25.0000 - tn: 181935.0000 - fn: 150.0000 - accuracy: 0.9990 - precision: 0.8691 - recall: 0.5253 - auc: 0.8801 - prc: 0.6463 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 4.0000 - val_tn: 45496.0000 - val_fn: 11.0000 - val_accuracy: 0.9997 - val_precision: 0.9355 - val_recall: 0.8406 - val_auc: 0.9636 - val_prc: 0.8953
Epoch 48/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 169.0000 - fp: 21.0000 - tn: 181939.0000 - fn: 147.0000 - accuracy: 0.9991 - precision: 0.8895 - recall: 0.5348 - auc: 0.8738 - prc: 0.6581 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9636 - val_prc: 0.8945
Epoch 49/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0047 - tp: 173.0000 - fp: 30.0000 - tn: 181930.0000 - fn: 143.0000 - accuracy: 0.9991 - precision: 0.8522 - recall: 0.5475 - auc: 0.8786 - prc: 0.6424 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 4.0000 - val_tn: 45496.0000 - val_fn: 11.0000 - val_accuracy: 0.9997 - val_precision: 0.9355 - val_recall: 0.8406 - val_auc: 0.9636 - val_prc: 0.8973
Epoch 50/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 165.0000 - fp: 22.0000 - tn: 181938.0000 - fn: 151.0000 - accuracy: 0.9991 - precision: 0.8824 - recall: 0.5222 - auc: 0.8914 - prc: 0.6806 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9636 - val_prc: 0.8958
Epoch 51/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 186.0000 - fp: 35.0000 - tn: 181925.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8416 - recall: 0.5886 - auc: 0.8864 - prc: 0.6577 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 3.0000 - val_tn: 45497.0000 - val_fn: 11.0000 - val_accuracy: 0.9997 - val_precision: 0.9508 - val_recall: 0.8406 - val_auc: 0.9636 - val_prc: 0.8976
Epoch 52/100
90/90 [==============================] - 0s 6ms/step - loss: 0.0048 - tp: 155.0000 - fp: 27.0000 - tn: 181933.0000 - fn: 161.0000 - accuracy: 0.9990 - precision: 0.8516 - recall: 0.4905 - auc: 0.8769 - prc: 0.6301 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 4.0000 - val_tn: 45496.0000 - val_fn: 11.0000 - val_accuracy: 0.9997 - val_precision: 0.9355 - val_recall: 0.8406 - val_auc: 0.9636 - val_prc: 0.8972
Epoch 53/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0043 - tp: 178.0000 - fp: 26.0000 - tn: 181934.0000 - fn: 138.0000 - accuracy: 0.9991 - precision: 0.8725 - recall: 0.5633 - auc: 0.8818 - prc: 0.6617 - val_loss: 0.0017 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9636 - val_prc: 0.8960
Epoch 54/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 167.0000 - fp: 21.0000 - tn: 181939.0000 - fn: 149.0000 - accuracy: 0.9991 - precision: 0.8883 - recall: 0.5285 - auc: 0.8897 - prc: 0.6597 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9636 - val_prc: 0.8952
Epoch 55/100
90/90 [==============================] - 1s 6ms/step - loss: 0.0046 - tp: 159.0000 - fp: 26.0000 - tn: 181934.0000 - fn: 157.0000 - accuracy: 0.9990 - precision: 0.8595 - recall: 0.5032 - auc: 0.8754 - prc: 0.6306 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9062 - val_recall: 0.8406 - val_auc: 0.9636 - val_prc: 0.8949
Epoch 56/100
90/90 [==============================] - ETA: 0s - loss: 0.0044 - tp: 182.0000 - fp: 27.0000 - tn: 181933.0000 - fn: 134.0000 - accuracy: 0.9991 - precision: 0.8708 - recall: 0.5759 - auc: 0.8976 - prc: 0.6778Restoring model weights from the end of the best epoch: 46.
90/90 [==============================] - 1s 6ms/step - loss: 0.0044 - tp: 182.0000 - fp: 27.0000 - tn: 181933.0000 - fn: 134.0000 - accuracy: 0.9991 - precision: 0.8708 - recall: 0.5759 - auc: 0.8976 - prc: 0.6778 - val_loss: 0.0018 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45495.0000 - val_fn: 11.0000 - val_accuracy: 0.9996 - val_precision: 0.9206 - val_recall: 0.8406 - val_auc: 0.9636 - val_prc: 0.8943
Epoch 56: early stopping

查看训练历史记录

在本部分,您将针对训练集和验证集生成模型的准确率和损失绘图。这些对于检查过拟合十分有用,您可以在此教程中了解更多信息。

此外,您还可以为您在上面创建的任何指标生成上述绘图。假负例包含在以下示例中。

def plot_metrics(history):
  metrics = ['loss', 'prc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()
plot_metrics(baseline_history)

png

注:验证曲线通常比训练曲线表现更好。这主要是由于在评估模型时,随机失活层处于非活动状态。

评估指标

您可以使用混淆矩阵来汇总实际标签与预测标签,其中 X 轴是预测标签,Y 轴是实际标签:

train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step
28/28 [==============================] - 0s 2ms/step
def plot_cm(labels, predictions, p=0.5):
  cm = confusion_matrix(labels, predictions > p)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(p))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))

在测试数据集上评估您的模型并显示您在上面创建的指标的结果:

baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
loss :  0.0030497945845127106
tp :  84.0
fp :  12.0
tn :  56843.0
fn :  23.0
accuracy :  0.9993855357170105
precision :  0.875
recall :  0.7850467562675476
auc :  0.9389954805374146
prc :  0.8279591798782349

Legitimate Transactions Detected (True Negatives):  56843
Legitimate Transactions Incorrectly Detected (False Positives):  12
Fraudulent Transactions Missed (False Negatives):  23
Fraudulent Transactions Detected (True Positives):  84
Total Fraudulent Transactions:  107

png

如果模型完美地预测了所有内容,则这是一个对角矩阵,其中偏离主对角线的值(表示不正确的预测)将为零。在这种情况下,矩阵会显示您的假正例相对较少,这意味着被错误标记的合法交易相对较少。但是,您可能希望得到更少的假负例,即使这会增加假正例的数量。这种权衡可能更加可取,因为假负例允许进行欺诈交易,而假正例可能导致向客户发送电子邮件,要求他们验证自己的信用卡活动。

绘制 ROC

现在绘制 ROC。此绘图非常有用,因为它一目了然地显示了模型只需通过调整输出阈值就能达到的性能范围。

def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right');

png

绘制 AUPRC

现在绘制 AUPRC。内插精确率-召回率曲线的下方面积,通过为分类阈值的不同值绘制(召回率、精确率)点获得。根据计算方式,PR AUC 可能相当于模型的平均精确率。

def plot_prc(name, labels, predictions, **kwargs):
    precision, recall, _ = sklearn.metrics.precision_recall_curve(labels, predictions)

    plt.plot(precision, recall, label=name, linewidth=2, **kwargs)
    plt.xlabel('Precision')
    plt.ylabel('Recall')
    plt.grid(True)
    ax = plt.gca()
    ax.set_aspect('equal')
plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right');

png

看起来精确率相对较高,但是召回率和 ROC 曲线下方面积 (AUC) 可能并没有您期望的那么高。当试图同时最大限度地提高精确率和召回率时,分类器通常会面临挑战,在处理不平衡数据集时尤其如此。请务必根据您所关心的问题来考虑不同类型错误的代价。在此示例中,假负例(漏掉欺诈交易)可能造成财务损失,而假正例(将交易错误地标记为欺诈)则可能降低用户满意度。

类权重

计算类权重

我们的目标是识别欺诈交易,但您没有很多可以使用的此类正样本,因此您希望分类器提高可用的少数样本的权重。为此,您可以使用参数将 Keras 权重传递给每个类。这些权重将使模型“更加关注”来自代表不足的类的样本。

# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50
Weight for class 1: 289.44

使用类权重训练模型

现在,尝试使用类权重对模型进行重新训练和评估,以了解其对预测的影响。

注:使用 class_weights 会改变损失范围。这可能会影响训练的稳定性,具体取决于优化器。步长取决于梯度大小的优化器(如 optimizers.SGD)可能会失效。此处使用的优化器(optimizers.Adam)不受缩放更改的影响。还要注意,由于加权,两个模型之间的总损失不具可比性。

weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight)
Epoch 1/100
90/90 [==============================] - 3s 12ms/step - loss: 3.4540 - tp: 127.0000 - fp: 28.0000 - tn: 238787.0000 - fn: 296.0000 - accuracy: 0.9986 - precision: 0.8194 - recall: 0.3002 - auc: 0.7327 - prc: 0.3224 - val_loss: 0.0070 - val_tp: 9.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 60.0000 - val_accuracy: 0.9986 - val_precision: 0.6000 - val_recall: 0.1304 - val_auc: 0.8538 - val_prc: 0.4746
Epoch 2/100
90/90 [==============================] - 1s 6ms/step - loss: 1.8967 - tp: 106.0000 - fp: 48.0000 - tn: 181912.0000 - fn: 210.0000 - accuracy: 0.9986 - precision: 0.6883 - recall: 0.3354 - auc: 0.7954 - prc: 0.3507 - val_loss: 0.0049 - val_tp: 48.0000 - val_fp: 6.0000 - val_tn: 45494.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8889 - val_recall: 0.6957 - val_auc: 0.9396 - val_prc: 0.6825
Epoch 3/100
90/90 [==============================] - 1s 6ms/step - loss: 1.2148 - tp: 165.0000 - fp: 131.0000 - tn: 181829.0000 - fn: 151.0000 - accuracy: 0.9985 - precision: 0.5574 - recall: 0.5222 - auc: 0.8611 - prc: 0.4673 - val_loss: 0.0056 - val_tp: 53.0000 - val_fp: 11.0000 - val_tn: 45489.0000 - val_fn: 16.0000 - val_accuracy: 0.9994 - val_precision: 0.8281 - val_recall: 0.7681 - val_auc: 0.9592 - val_prc: 0.7340
Epoch 4/100
90/90 [==============================] - 1s 6ms/step - loss: 0.9954 - tp: 184.0000 - fp: 354.0000 - tn: 181606.0000 - fn: 132.0000 - accuracy: 0.9973 - precision: 0.3420 - recall: 0.5823 - auc: 0.8707 - prc: 0.4738 - val_loss: 0.0071 - val_tp: 58.0000 - val_fp: 15.0000 - val_tn: 45485.0000 - val_fn: 11.0000 - val_accuracy: 0.9994 - val_precision: 0.7945 - val_recall: 0.8406 - val_auc: 0.9857 - val_prc: 0.7529
Epoch 5/100
90/90 [==============================] - 1s 6ms/step - loss: 0.8191 - tp: 212.0000 - fp: 650.0000 - tn: 181310.0000 - fn: 104.0000 - accuracy: 0.9959 - precision: 0.2459 - recall: 0.6709 - auc: 0.8850 - prc: 0.4946 - val_loss: 0.0097 - val_tp: 59.0000 - val_fp: 29.0000 - val_tn: 45471.0000 - val_fn: 10.0000 - val_accuracy: 0.9991 - val_precision: 0.6705 - val_recall: 0.8551 - val_auc: 0.9918 - val_prc: 0.7607
Epoch 6/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6886 - tp: 225.0000 - fp: 1145.0000 - tn: 180815.0000 - fn: 91.0000 - accuracy: 0.9932 - precision: 0.1642 - recall: 0.7120 - auc: 0.8986 - prc: 0.4293 - val_loss: 0.0142 - val_tp: 59.0000 - val_fp: 89.0000 - val_tn: 45411.0000 - val_fn: 10.0000 - val_accuracy: 0.9978 - val_precision: 0.3986 - val_recall: 0.8551 - val_auc: 0.9930 - val_prc: 0.7548
Epoch 7/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6061 - tp: 240.0000 - fp: 1686.0000 - tn: 180274.0000 - fn: 76.0000 - accuracy: 0.9903 - precision: 0.1246 - recall: 0.7595 - auc: 0.9086 - prc: 0.3771 - val_loss: 0.0198 - val_tp: 60.0000 - val_fp: 182.0000 - val_tn: 45318.0000 - val_fn: 9.0000 - val_accuracy: 0.9958 - val_precision: 0.2479 - val_recall: 0.8696 - val_auc: 0.9946 - val_prc: 0.7211
Epoch 8/100
90/90 [==============================] - 1s 6ms/step - loss: 0.6127 - tp: 233.0000 - fp: 2287.0000 - tn: 179673.0000 - fn: 83.0000 - accuracy: 0.9870 - precision: 0.0925 - recall: 0.7373 - auc: 0.9071 - prc: 0.3235 - val_loss: 0.0276 - val_tp: 64.0000 - val_fp: 340.0000 - val_tn: 45160.0000 - val_fn: 5.0000 - val_accuracy: 0.9924 - val_precision: 0.1584 - val_recall: 0.9275 - val_auc: 0.9953 - val_prc: 0.6965
Epoch 9/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4657 - tp: 244.0000 - fp: 3136.0000 - tn: 178824.0000 - fn: 72.0000 - accuracy: 0.9824 - precision: 0.0722 - recall: 0.7722 - auc: 0.9334 - prc: 0.2770 - val_loss: 0.0385 - val_tp: 64.0000 - val_fp: 514.0000 - val_tn: 44986.0000 - val_fn: 5.0000 - val_accuracy: 0.9886 - val_precision: 0.1107 - val_recall: 0.9275 - val_auc: 0.9953 - val_prc: 0.6873
Epoch 10/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4505 - tp: 254.0000 - fp: 4030.0000 - tn: 177930.0000 - fn: 62.0000 - accuracy: 0.9776 - precision: 0.0593 - recall: 0.8038 - auc: 0.9320 - prc: 0.2187 - val_loss: 0.0484 - val_tp: 64.0000 - val_fp: 642.0000 - val_tn: 44858.0000 - val_fn: 5.0000 - val_accuracy: 0.9858 - val_precision: 0.0907 - val_recall: 0.9275 - val_auc: 0.9953 - val_prc: 0.6514
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4498 - tp: 252.0000 - fp: 4605.0000 - tn: 177355.0000 - fn: 64.0000 - accuracy: 0.9744 - precision: 0.0519 - recall: 0.7975 - auc: 0.9242 - prc: 0.2122 - val_loss: 0.0548 - val_tp: 64.0000 - val_fp: 694.0000 - val_tn: 44806.0000 - val_fn: 5.0000 - val_accuracy: 0.9847 - val_precision: 0.0844 - val_recall: 0.9275 - val_auc: 0.9953 - val_prc: 0.6065
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.4053 - tp: 260.0000 - fp: 5182.0000 - tn: 176778.0000 - fn: 56.0000 - accuracy: 0.9713 - precision: 0.0478 - recall: 0.8228 - auc: 0.9284 - prc: 0.1952 - val_loss: 0.0608 - val_tp: 64.0000 - val_fp: 767.0000 - val_tn: 44733.0000 - val_fn: 5.0000 - val_accuracy: 0.9831 - val_precision: 0.0770 - val_recall: 0.9275 - val_auc: 0.9953 - val_prc: 0.6025
Epoch 13/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3426 - tp: 263.0000 - fp: 5964.0000 - tn: 175996.0000 - fn: 53.0000 - accuracy: 0.9670 - precision: 0.0422 - recall: 0.8323 - auc: 0.9503 - prc: 0.1911 - val_loss: 0.0683 - val_tp: 64.0000 - val_fp: 849.0000 - val_tn: 44651.0000 - val_fn: 5.0000 - val_accuracy: 0.9813 - val_precision: 0.0701 - val_recall: 0.9275 - val_auc: 0.9950 - val_prc: 0.5961
Epoch 14/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3627 - tp: 268.0000 - fp: 6356.0000 - tn: 175604.0000 - fn: 48.0000 - accuracy: 0.9649 - precision: 0.0405 - recall: 0.8481 - auc: 0.9365 - prc: 0.1903 - val_loss: 0.0714 - val_tp: 64.0000 - val_fp: 880.0000 - val_tn: 44620.0000 - val_fn: 5.0000 - val_accuracy: 0.9806 - val_precision: 0.0678 - val_recall: 0.9275 - val_auc: 0.9951 - val_prc: 0.5914
Epoch 15/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3306 - tp: 271.0000 - fp: 6756.0000 - tn: 175204.0000 - fn: 45.0000 - accuracy: 0.9627 - precision: 0.0386 - recall: 0.8576 - auc: 0.9458 - prc: 0.2015 - val_loss: 0.0741 - val_tp: 64.0000 - val_fp: 914.0000 - val_tn: 44586.0000 - val_fn: 5.0000 - val_accuracy: 0.9798 - val_precision: 0.0654 - val_recall: 0.9275 - val_auc: 0.9948 - val_prc: 0.5921
Epoch 16/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3222 - tp: 272.0000 - fp: 6837.0000 - tn: 175123.0000 - fn: 44.0000 - accuracy: 0.9622 - precision: 0.0383 - recall: 0.8608 - auc: 0.9446 - prc: 0.1837 - val_loss: 0.0752 - val_tp: 65.0000 - val_fp: 923.0000 - val_tn: 44577.0000 - val_fn: 4.0000 - val_accuracy: 0.9797 - val_precision: 0.0658 - val_recall: 0.9420 - val_auc: 0.9947 - val_prc: 0.5894
Epoch 17/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3474 - tp: 267.0000 - fp: 7262.0000 - tn: 174698.0000 - fn: 49.0000 - accuracy: 0.9599 - precision: 0.0355 - recall: 0.8449 - auc: 0.9397 - prc: 0.1785 - val_loss: 0.0808 - val_tp: 65.0000 - val_fp: 979.0000 - val_tn: 44521.0000 - val_fn: 4.0000 - val_accuracy: 0.9784 - val_precision: 0.0623 - val_recall: 0.9420 - val_auc: 0.9947 - val_prc: 0.5926
Epoch 18/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3525 - tp: 267.0000 - fp: 7350.0000 - tn: 174610.0000 - fn: 49.0000 - accuracy: 0.9594 - precision: 0.0351 - recall: 0.8449 - auc: 0.9373 - prc: 0.1762 - val_loss: 0.0806 - val_tp: 65.0000 - val_fp: 977.0000 - val_tn: 44523.0000 - val_fn: 4.0000 - val_accuracy: 0.9785 - val_precision: 0.0624 - val_recall: 0.9420 - val_auc: 0.9947 - val_prc: 0.5927
Epoch 19/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3757 - tp: 269.0000 - fp: 7315.0000 - tn: 174645.0000 - fn: 47.0000 - accuracy: 0.9596 - precision: 0.0355 - recall: 0.8513 - auc: 0.9277 - prc: 0.1830 - val_loss: 0.0779 - val_tp: 65.0000 - val_fp: 949.0000 - val_tn: 44551.0000 - val_fn: 4.0000 - val_accuracy: 0.9791 - val_precision: 0.0641 - val_recall: 0.9420 - val_auc: 0.9948 - val_prc: 0.5936
Epoch 20/100
90/90 [==============================] - 1s 6ms/step - loss: 0.3146 - tp: 271.0000 - fp: 7173.0000 - tn: 174787.0000 - fn: 45.0000 - accuracy: 0.9604 - precision: 0.0364 - recall: 0.8576 - auc: 0.9483 - prc: 0.1861 - val_loss: 0.0790 - val_tp: 65.0000 - val_fp: 944.0000 - val_tn: 44556.0000 - val_fn: 4.0000 - val_accuracy: 0.9792 - val_precision: 0.0644 - val_recall: 0.9420 - val_auc: 0.9948 - val_prc: 0.5963
Epoch 21/100
87/90 [============================>.] - ETA: 0s - loss: 0.3226 - tp: 268.0000 - fp: 7036.0000 - tn: 170831.0000 - fn: 41.0000 - accuracy: 0.9603 - precision: 0.0367 - recall: 0.8673 - auc: 0.9444 - prc: 0.1947Restoring model weights from the end of the best epoch: 11.
90/90 [==============================] - 1s 6ms/step - loss: 0.3172 - tp: 275.0000 - fp: 7203.0000 - tn: 174757.0000 - fn: 41.0000 - accuracy: 0.9603 - precision: 0.0368 - recall: 0.8703 - auc: 0.9455 - prc: 0.1949 - val_loss: 0.0773 - val_tp: 65.0000 - val_fp: 908.0000 - val_tn: 44592.0000 - val_fn: 4.0000 - val_accuracy: 0.9800 - val_precision: 0.0668 - val_recall: 0.9420 - val_auc: 0.9950 - val_prc: 0.6021
Epoch 21: early stopping

查看训练历史记录

plot_metrics(weighted_history)

png

评估指标

train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step
28/28 [==============================] - 0s 2ms/step
weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
loss :  0.054440323263406754
tp :  94.0
fp :  832.0
tn :  56023.0
fn :  13.0
accuracy :  0.985165536403656
precision :  0.1015118807554245
recall :  0.8785046935081482
auc :  0.9787468910217285
prc :  0.6335744857788086

Legitimate Transactions Detected (True Negatives):  56023
Legitimate Transactions Incorrectly Detected (False Positives):  832
Fraudulent Transactions Missed (False Negatives):  13
Fraudulent Transactions Detected (True Positives):  94
Total Fraudulent Transactions:  107

png

在这里,您可以看到,使用类权重时,由于存在更多假正例,准确率和精确率较低,但是相反,由于模型也找到了更多真正例,召回率和 AUC 较高。尽管准确率较低,但是此模型具有较高的召回率(且识别出了更多欺诈交易)。当然,两种类型的错误都有代价(您也不希望因将过多合法交易标记为欺诈来打扰客户)。请在应用时认真权衡这些不同类型的错误。

绘制 ROC

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right');

png

绘制 AUPRC

plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right');

png

过采样

对占少数的类进行过采样

一种相关方法是通过对占少数的类进行过采样来对数据集进行重新采样。

pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]

使用 NumPy

您可以通过从正样本中选择正确数量的随机索引来手动平衡数据集:

ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
(181960, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
(363920, 29)

使用 tf.data

如果您使用的是 tf.data,则生成平衡样本最简单的方法是从 positivenegative 数据集开始,然后将它们合并。有关更多示例,请参阅 tf.data 指南

BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)

每个数据集都会提供 (feature, label) 对:

for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
Features:
 [ 0.9559428   0.80097054 -1.53757364  2.98558412  0.98626604 -0.64296745
  0.58623809 -0.1389485  -1.36773827  0.23749502  1.42062045 -0.28331678
 -0.7594046  -2.88674106 -1.86392462  1.81905823  2.28344955  1.42037765
 -1.88417858 -0.29778025 -0.2085476  -0.49742697  0.06843843 -0.3974017
  0.32761487 -0.00762779 -0.06715681  0.01644712 -0.48884179]

Label:  1

使用 experimental.sample_from_datasets 将二者合并起来:

resampled_ds = tf.data.Dataset.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
0.49072265625

要使用此数据集,您需要每个周期的步骤数。

在这种情况下,“周期”的定义就不那么明确了。假设它是遍历一次所有负样本所需的批次数量:

resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0

在过采样数据上进行训练

现在尝试使用重新采样后的数据集(而非使用类权重)来训练模型,对比一下这两种方法有何区别。

注:因为数据平衡是通过复制正样本实现的,所以数据集的总大小变大了,且每个周期运行的训练步骤也增加了。

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks = [early_stopping],
    validation_data=val_ds)
Epoch 1/100
278/278 [==============================] - 9s 26ms/step - loss: 0.6244 - tp: 196970.0000 - fp: 41577.0000 - tn: 299585.0000 - fn: 88174.0000 - accuracy: 0.7928 - precision: 0.8257 - recall: 0.6908 - auc: 0.8304 - prc: 0.8662 - val_loss: 0.1924 - val_tp: 63.0000 - val_fp: 819.0000 - val_tn: 44681.0000 - val_fn: 6.0000 - val_accuracy: 0.9819 - val_precision: 0.0714 - val_recall: 0.9130 - val_auc: 0.9944 - val_prc: 0.7565
Epoch 2/100
278/278 [==============================] - 7s 24ms/step - loss: 0.2485 - tp: 223798.0000 - fp: 14518.0000 - tn: 270446.0000 - fn: 60582.0000 - accuracy: 0.8681 - precision: 0.9391 - recall: 0.7870 - auc: 0.9596 - prc: 0.9661 - val_loss: 0.1043 - val_tp: 66.0000 - val_fp: 746.0000 - val_tn: 44754.0000 - val_fn: 3.0000 - val_accuracy: 0.9836 - val_precision: 0.0813 - val_recall: 0.9565 - val_auc: 0.9961 - val_prc: 0.7801
Epoch 3/100
278/278 [==============================] - 7s 23ms/step - loss: 0.1958 - tp: 251598.0000 - fp: 13396.0000 - tn: 271898.0000 - fn: 32452.0000 - accuracy: 0.9195 - precision: 0.9494 - recall: 0.8858 - auc: 0.9782 - prc: 0.9796 - val_loss: 0.0813 - val_tp: 66.0000 - val_fp: 799.0000 - val_tn: 44701.0000 - val_fn: 3.0000 - val_accuracy: 0.9824 - val_precision: 0.0763 - val_recall: 0.9565 - val_auc: 0.9957 - val_prc: 0.7847
Epoch 4/100
278/278 [==============================] - 6s 23ms/step - loss: 0.1683 - tp: 263076.0000 - fp: 13950.0000 - tn: 270190.0000 - fn: 22128.0000 - accuracy: 0.9366 - precision: 0.9496 - recall: 0.9224 - auc: 0.9847 - prc: 0.9851 - val_loss: 0.0677 - val_tp: 66.0000 - val_fp: 760.0000 - val_tn: 44740.0000 - val_fn: 3.0000 - val_accuracy: 0.9833 - val_precision: 0.0799 - val_recall: 0.9565 - val_auc: 0.9954 - val_prc: 0.7691
Epoch 5/100
278/278 [==============================] - 6s 23ms/step - loss: 0.1513 - tp: 267889.0000 - fp: 14968.0000 - tn: 269576.0000 - fn: 16911.0000 - accuracy: 0.9440 - precision: 0.9471 - recall: 0.9406 - auc: 0.9878 - prc: 0.9878 - val_loss: 0.0582 - val_tp: 66.0000 - val_fp: 747.0000 - val_tn: 44753.0000 - val_fn: 3.0000 - val_accuracy: 0.9835 - val_precision: 0.0812 - val_recall: 0.9565 - val_auc: 0.9957 - val_prc: 0.7716
Epoch 6/100
278/278 [==============================] - 6s 23ms/step - loss: 0.1383 - tp: 272263.0000 - fp: 15458.0000 - tn: 269001.0000 - fn: 12622.0000 - accuracy: 0.9507 - precision: 0.9463 - recall: 0.9557 - auc: 0.9897 - prc: 0.9895 - val_loss: 0.0514 - val_tp: 66.0000 - val_fp: 716.0000 - val_tn: 44784.0000 - val_fn: 3.0000 - val_accuracy: 0.9842 - val_precision: 0.0844 - val_recall: 0.9565 - val_auc: 0.9961 - val_prc: 0.7749
Epoch 7/100
278/278 [==============================] - 6s 23ms/step - loss: 0.1279 - tp: 275144.0000 - fp: 15791.0000 - tn: 268338.0000 - fn: 10071.0000 - accuracy: 0.9546 - precision: 0.9457 - recall: 0.9647 - auc: 0.9912 - prc: 0.9909 - val_loss: 0.0474 - val_tp: 65.0000 - val_fp: 720.0000 - val_tn: 44780.0000 - val_fn: 4.0000 - val_accuracy: 0.9841 - val_precision: 0.0828 - val_recall: 0.9420 - val_auc: 0.9955 - val_prc: 0.7776
Epoch 8/100
278/278 [==============================] - 6s 23ms/step - loss: 0.1200 - tp: 276934.0000 - fp: 16125.0000 - tn: 267921.0000 - fn: 8364.0000 - accuracy: 0.9570 - precision: 0.9450 - recall: 0.9707 - auc: 0.9921 - prc: 0.9918 - val_loss: 0.0443 - val_tp: 65.0000 - val_fp: 716.0000 - val_tn: 44784.0000 - val_fn: 4.0000 - val_accuracy: 0.9842 - val_precision: 0.0832 - val_recall: 0.9420 - val_auc: 0.9948 - val_prc: 0.7683
Epoch 9/100
278/278 [==============================] - 7s 24ms/step - loss: 0.1126 - tp: 278043.0000 - fp: 16472.0000 - tn: 267971.0000 - fn: 6858.0000 - accuracy: 0.9590 - precision: 0.9441 - recall: 0.9759 - auc: 0.9928 - prc: 0.9925 - val_loss: 0.0407 - val_tp: 65.0000 - val_fp: 664.0000 - val_tn: 44836.0000 - val_fn: 4.0000 - val_accuracy: 0.9853 - val_precision: 0.0892 - val_recall: 0.9420 - val_auc: 0.9941 - val_prc: 0.7696
Epoch 10/100
278/278 [==============================] - 7s 24ms/step - loss: 0.1061 - tp: 278163.0000 - fp: 16128.0000 - tn: 269479.0000 - fn: 5574.0000 - accuracy: 0.9619 - precision: 0.9452 - recall: 0.9804 - auc: 0.9936 - prc: 0.9931 - val_loss: 0.0370 - val_tp: 65.0000 - val_fp: 620.0000 - val_tn: 44880.0000 - val_fn: 4.0000 - val_accuracy: 0.9863 - val_precision: 0.0949 - val_recall: 0.9420 - val_auc: 0.9935 - val_prc: 0.7715
Epoch 11/100
278/278 [==============================] - 6s 23ms/step - loss: 0.0997 - tp: 278905.0000 - fp: 15891.0000 - tn: 269709.0000 - fn: 4839.0000 - accuracy: 0.9636 - precision: 0.9461 - recall: 0.9829 - auc: 0.9942 - prc: 0.9937 - val_loss: 0.0342 - val_tp: 65.0000 - val_fp: 585.0000 - val_tn: 44915.0000 - val_fn: 4.0000 - val_accuracy: 0.9871 - val_precision: 0.1000 - val_recall: 0.9420 - val_auc: 0.9931 - val_prc: 0.7760
Epoch 12/100
276/278 [============================>.] - ETA: 0s - loss: 0.0944 - tp: 278484.0000 - fp: 15374.0000 - tn: 267452.0000 - fn: 3938.0000 - accuracy: 0.9658 - precision: 0.9477 - recall: 0.9861 - auc: 0.9946 - prc: 0.9942Restoring model weights from the end of the best epoch: 2.
278/278 [==============================] - 6s 23ms/step - loss: 0.0944 - tp: 280510.0000 - fp: 15490.0000 - tn: 269381.0000 - fn: 3963.0000 - accuracy: 0.9658 - precision: 0.9477 - recall: 0.9861 - auc: 0.9946 - prc: 0.9942 - val_loss: 0.0319 - val_tp: 65.0000 - val_fp: 576.0000 - val_tn: 44924.0000 - val_fn: 4.0000 - val_accuracy: 0.9873 - val_precision: 0.1014 - val_recall: 0.9420 - val_auc: 0.9869 - val_prc: 0.7762
Epoch 12: early stopping

如果训练过程在每次梯度更新时都考虑整个数据集,那么这种过采样将与类加权基本相同。

但是,当按批次训练模型时(如您在上面所做的那样),过采样的数据将提供更加平滑的梯度信号:不在一个权重较大的批次中显示每个正样本,而是在许多具有较小权重的不同批次中分别显示。

这种更平滑的梯度信号使训练模型变得更加容易。

查看训练历史记录

请注意,此处的指标分布将有所不同,因为训练数据与验证和测试数据的分布完全不同。

plot_metrics(resampled_history )

png

重新训练

由于在平衡数据上训练更加容易,上面的训练过程可能很快就会过拟合。

因此,请打破周期,使 callbacks.EarlyStopping 能够更好地控制停止训练的时间。

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch = 20,
    epochs=10*EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_ds))
Epoch 1/1000
20/20 [==============================] - 3s 55ms/step - loss: 1.8818 - tp: 7147.0000 - fp: 3701.0000 - tn: 62207.0000 - fn: 13474.0000 - accuracy: 0.8015 - precision: 0.6588 - recall: 0.3466 - auc: 0.7204 - prc: 0.5663 - val_loss: 0.3562 - val_tp: 18.0000 - val_fp: 3378.0000 - val_tn: 42122.0000 - val_fn: 51.0000 - val_accuracy: 0.9248 - val_precision: 0.0053 - val_recall: 0.2609 - val_auc: 0.5310 - val_prc: 0.0789
Epoch 2/1000
20/20 [==============================] - 1s 30ms/step - loss: 1.2506 - tp: 9432.0000 - fp: 3742.0000 - tn: 16690.0000 - fn: 11096.0000 - accuracy: 0.6377 - precision: 0.7160 - recall: 0.4595 - auc: 0.5656 - prc: 0.7088 - val_loss: 0.3813 - val_tp: 46.0000 - val_fp: 4274.0000 - val_tn: 41226.0000 - val_fn: 23.0000 - val_accuracy: 0.9057 - val_precision: 0.0106 - val_recall: 0.6667 - val_auc: 0.7818 - val_prc: 0.3564
Epoch 3/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.9031 - tp: 11130.0000 - fp: 4216.0000 - tn: 16301.0000 - fn: 9313.0000 - accuracy: 0.6697 - precision: 0.7253 - recall: 0.5444 - auc: 0.6484 - prc: 0.7620 - val_loss: 0.3923 - val_tp: 57.0000 - val_fp: 4628.0000 - val_tn: 40872.0000 - val_fn: 12.0000 - val_accuracy: 0.8982 - val_precision: 0.0122 - val_recall: 0.8261 - val_auc: 0.9040 - val_prc: 0.5297
Epoch 4/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.6819 - tp: 12966.0000 - fp: 4063.0000 - tn: 16442.0000 - fn: 7489.0000 - accuracy: 0.7180 - precision: 0.7614 - recall: 0.6339 - auc: 0.7355 - prc: 0.8188 - val_loss: 0.3901 - val_tp: 60.0000 - val_fp: 4341.0000 - val_tn: 41159.0000 - val_fn: 9.0000 - val_accuracy: 0.9045 - val_precision: 0.0136 - val_recall: 0.8696 - val_auc: 0.9306 - val_prc: 0.6600
Epoch 5/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.5618 - tp: 14319.0000 - fp: 4049.0000 - tn: 16518.0000 - fn: 6074.0000 - accuracy: 0.7529 - precision: 0.7796 - recall: 0.7022 - auc: 0.7971 - prc: 0.8594 - val_loss: 0.3759 - val_tp: 62.0000 - val_fp: 3600.0000 - val_tn: 41900.0000 - val_fn: 7.0000 - val_accuracy: 0.9208 - val_precision: 0.0169 - val_recall: 0.8986 - val_auc: 0.9432 - val_prc: 0.6929
Epoch 6/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.4880 - tp: 15347.0000 - fp: 3721.0000 - tn: 16696.0000 - fn: 5196.0000 - accuracy: 0.7823 - precision: 0.8049 - recall: 0.7471 - auc: 0.8385 - prc: 0.8877 - val_loss: 0.3542 - val_tp: 63.0000 - val_fp: 2857.0000 - val_tn: 42643.0000 - val_fn: 6.0000 - val_accuracy: 0.9372 - val_precision: 0.0216 - val_recall: 0.9130 - val_auc: 0.9539 - val_prc: 0.7195
Epoch 7/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.4461 - tp: 15816.0000 - fp: 3280.0000 - tn: 17103.0000 - fn: 4761.0000 - accuracy: 0.8037 - precision: 0.8282 - recall: 0.7686 - auc: 0.8619 - prc: 0.9038 - val_loss: 0.3301 - val_tp: 64.0000 - val_fp: 2195.0000 - val_tn: 43305.0000 - val_fn: 5.0000 - val_accuracy: 0.9517 - val_precision: 0.0283 - val_recall: 0.9275 - val_auc: 0.9639 - val_prc: 0.7314
Epoch 8/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.4191 - tp: 15893.0000 - fp: 2891.0000 - tn: 17619.0000 - fn: 4557.0000 - accuracy: 0.8182 - precision: 0.8461 - recall: 0.7772 - auc: 0.8764 - prc: 0.9124 - val_loss: 0.3054 - val_tp: 64.0000 - val_fp: 1742.0000 - val_tn: 43758.0000 - val_fn: 5.0000 - val_accuracy: 0.9617 - val_precision: 0.0354 - val_recall: 0.9275 - val_auc: 0.9734 - val_prc: 0.7366
Epoch 9/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.3895 - tp: 16135.0000 - fp: 2567.0000 - tn: 17914.0000 - fn: 4344.0000 - accuracy: 0.8313 - precision: 0.8627 - recall: 0.7879 - auc: 0.8927 - prc: 0.9221 - val_loss: 0.2812 - val_tp: 63.0000 - val_fp: 1410.0000 - val_tn: 44090.0000 - val_fn: 6.0000 - val_accuracy: 0.9689 - val_precision: 0.0428 - val_recall: 0.9130 - val_auc: 0.9811 - val_prc: 0.7399
Epoch 10/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.3720 - tp: 16397.0000 - fp: 2348.0000 - tn: 18019.0000 - fn: 4196.0000 - accuracy: 0.8402 - precision: 0.8747 - recall: 0.7962 - auc: 0.9013 - prc: 0.9290 - val_loss: 0.2592 - val_tp: 62.0000 - val_fp: 1199.0000 - val_tn: 44301.0000 - val_fn: 7.0000 - val_accuracy: 0.9735 - val_precision: 0.0492 - val_recall: 0.8986 - val_auc: 0.9869 - val_prc: 0.7455
Epoch 11/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.3445 - tp: 16567.0000 - fp: 1952.0000 - tn: 18519.0000 - fn: 3922.0000 - accuracy: 0.8566 - precision: 0.8946 - recall: 0.8086 - auc: 0.9149 - prc: 0.9377 - val_loss: 0.2394 - val_tp: 62.0000 - val_fp: 1050.0000 - val_tn: 44450.0000 - val_fn: 7.0000 - val_accuracy: 0.9768 - val_precision: 0.0558 - val_recall: 0.8986 - val_auc: 0.9902 - val_prc: 0.7474
Epoch 12/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.3299 - tp: 16021.0000 - fp: 1682.0000 - tn: 18721.0000 - fn: 4536.0000 - accuracy: 0.8482 - precision: 0.9050 - recall: 0.7793 - auc: 0.9221 - prc: 0.9422 - val_loss: 0.2220 - val_tp: 62.0000 - val_fp: 946.0000 - val_tn: 44554.0000 - val_fn: 7.0000 - val_accuracy: 0.9791 - val_precision: 0.0615 - val_recall: 0.8986 - val_auc: 0.9922 - val_prc: 0.7511
Epoch 13/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.3143 - tp: 15899.0000 - fp: 1497.0000 - tn: 19005.0000 - fn: 4559.0000 - accuracy: 0.8521 - precision: 0.9139 - recall: 0.7772 - auc: 0.9297 - prc: 0.9467 - val_loss: 0.2055 - val_tp: 62.0000 - val_fp: 882.0000 - val_tn: 44618.0000 - val_fn: 7.0000 - val_accuracy: 0.9805 - val_precision: 0.0657 - val_recall: 0.8986 - val_auc: 0.9934 - val_prc: 0.7536
Epoch 14/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.3082 - tp: 15953.0000 - fp: 1450.0000 - tn: 18984.0000 - fn: 4573.0000 - accuracy: 0.8530 - precision: 0.9167 - recall: 0.7772 - auc: 0.9327 - prc: 0.9487 - val_loss: 0.1903 - val_tp: 63.0000 - val_fp: 806.0000 - val_tn: 44694.0000 - val_fn: 6.0000 - val_accuracy: 0.9822 - val_precision: 0.0725 - val_recall: 0.9130 - val_auc: 0.9942 - val_prc: 0.7563
Epoch 15/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2970 - tp: 15952.0000 - fp: 1287.0000 - tn: 19055.0000 - fn: 4666.0000 - accuracy: 0.8547 - precision: 0.9253 - recall: 0.7737 - auc: 0.9373 - prc: 0.9520 - val_loss: 0.1775 - val_tp: 63.0000 - val_fp: 774.0000 - val_tn: 44726.0000 - val_fn: 6.0000 - val_accuracy: 0.9829 - val_precision: 0.0753 - val_recall: 0.9130 - val_auc: 0.9949 - val_prc: 0.7595
Epoch 16/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2834 - tp: 15633.0000 - fp: 1229.0000 - tn: 19595.0000 - fn: 4503.0000 - accuracy: 0.8601 - precision: 0.9271 - recall: 0.7764 - auc: 0.9431 - prc: 0.9542 - val_loss: 0.1665 - val_tp: 63.0000 - val_fp: 754.0000 - val_tn: 44746.0000 - val_fn: 6.0000 - val_accuracy: 0.9833 - val_precision: 0.0771 - val_recall: 0.9130 - val_auc: 0.9953 - val_prc: 0.7615
Epoch 17/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2778 - tp: 15809.0000 - fp: 1186.0000 - tn: 19384.0000 - fn: 4581.0000 - accuracy: 0.8592 - precision: 0.9302 - recall: 0.7753 - auc: 0.9469 - prc: 0.9576 - val_loss: 0.1567 - val_tp: 63.0000 - val_fp: 738.0000 - val_tn: 44762.0000 - val_fn: 6.0000 - val_accuracy: 0.9837 - val_precision: 0.0787 - val_recall: 0.9130 - val_auc: 0.9957 - val_prc: 0.7651
Epoch 18/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2702 - tp: 15965.0000 - fp: 1113.0000 - tn: 19348.0000 - fn: 4534.0000 - accuracy: 0.8621 - precision: 0.9348 - recall: 0.7788 - auc: 0.9508 - prc: 0.9601 - val_loss: 0.1489 - val_tp: 63.0000 - val_fp: 738.0000 - val_tn: 44762.0000 - val_fn: 6.0000 - val_accuracy: 0.9837 - val_precision: 0.0787 - val_recall: 0.9130 - val_auc: 0.9960 - val_prc: 0.7679
Epoch 19/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2586 - tp: 16031.0000 - fp: 1081.0000 - tn: 19414.0000 - fn: 4434.0000 - accuracy: 0.8654 - precision: 0.9368 - recall: 0.7833 - auc: 0.9553 - prc: 0.9633 - val_loss: 0.1420 - val_tp: 63.0000 - val_fp: 746.0000 - val_tn: 44754.0000 - val_fn: 6.0000 - val_accuracy: 0.9835 - val_precision: 0.0779 - val_recall: 0.9130 - val_auc: 0.9961 - val_prc: 0.7724
Epoch 20/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2562 - tp: 16026.0000 - fp: 1043.0000 - tn: 19414.0000 - fn: 4477.0000 - accuracy: 0.8652 - precision: 0.9389 - recall: 0.7816 - auc: 0.9569 - prc: 0.9641 - val_loss: 0.1360 - val_tp: 64.0000 - val_fp: 744.0000 - val_tn: 44756.0000 - val_fn: 5.0000 - val_accuracy: 0.9836 - val_precision: 0.0792 - val_recall: 0.9275 - val_auc: 0.9962 - val_prc: 0.7750
Epoch 21/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2461 - tp: 16002.0000 - fp: 1012.0000 - tn: 19574.0000 - fn: 4372.0000 - accuracy: 0.8686 - precision: 0.9405 - recall: 0.7854 - auc: 0.9604 - prc: 0.9667 - val_loss: 0.1305 - val_tp: 64.0000 - val_fp: 743.0000 - val_tn: 44757.0000 - val_fn: 5.0000 - val_accuracy: 0.9836 - val_precision: 0.0793 - val_recall: 0.9275 - val_auc: 0.9962 - val_prc: 0.7679
Epoch 22/1000
20/20 [==============================] - 1s 34ms/step - loss: 0.2421 - tp: 16160.0000 - fp: 973.0000 - tn: 19431.0000 - fn: 4396.0000 - accuracy: 0.8689 - precision: 0.9432 - recall: 0.7861 - auc: 0.9627 - prc: 0.9685 - val_loss: 0.1254 - val_tp: 64.0000 - val_fp: 743.0000 - val_tn: 44757.0000 - val_fn: 5.0000 - val_accuracy: 0.9836 - val_precision: 0.0793 - val_recall: 0.9275 - val_auc: 0.9962 - val_prc: 0.7700
Epoch 23/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2362 - tp: 16012.0000 - fp: 998.0000 - tn: 19606.0000 - fn: 4344.0000 - accuracy: 0.8696 - precision: 0.9413 - recall: 0.7866 - auc: 0.9645 - prc: 0.9693 - val_loss: 0.1207 - val_tp: 65.0000 - val_fp: 744.0000 - val_tn: 44756.0000 - val_fn: 4.0000 - val_accuracy: 0.9836 - val_precision: 0.0803 - val_recall: 0.9420 - val_auc: 0.9962 - val_prc: 0.7720
Epoch 24/1000
20/20 [==============================] - 1s 34ms/step - loss: 0.2289 - tp: 16312.0000 - fp: 884.0000 - tn: 19492.0000 - fn: 4272.0000 - accuracy: 0.8741 - precision: 0.9486 - recall: 0.7925 - auc: 0.9675 - prc: 0.9719 - val_loss: 0.1163 - val_tp: 65.0000 - val_fp: 739.0000 - val_tn: 44761.0000 - val_fn: 4.0000 - val_accuracy: 0.9837 - val_precision: 0.0808 - val_recall: 0.9420 - val_auc: 0.9962 - val_prc: 0.7738
Epoch 25/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2256 - tp: 16361.0000 - fp: 889.0000 - tn: 19458.0000 - fn: 4252.0000 - accuracy: 0.8745 - precision: 0.9485 - recall: 0.7937 - auc: 0.9690 - prc: 0.9730 - val_loss: 0.1134 - val_tp: 66.0000 - val_fp: 756.0000 - val_tn: 44744.0000 - val_fn: 3.0000 - val_accuracy: 0.9833 - val_precision: 0.0803 - val_recall: 0.9565 - val_auc: 0.9961 - val_prc: 0.7748
Epoch 26/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2208 - tp: 16267.0000 - fp: 893.0000 - tn: 19573.0000 - fn: 4227.0000 - accuracy: 0.8750 - precision: 0.9480 - recall: 0.7937 - auc: 0.9700 - prc: 0.9737 - val_loss: 0.1106 - val_tp: 66.0000 - val_fp: 762.0000 - val_tn: 44738.0000 - val_fn: 3.0000 - val_accuracy: 0.9832 - val_precision: 0.0797 - val_recall: 0.9565 - val_auc: 0.9961 - val_prc: 0.7758
Epoch 27/1000
20/20 [==============================] - 1s 34ms/step - loss: 0.2170 - tp: 16197.0000 - fp: 864.0000 - tn: 19723.0000 - fn: 4176.0000 - accuracy: 0.8770 - precision: 0.9494 - recall: 0.7950 - auc: 0.9717 - prc: 0.9745 - val_loss: 0.1082 - val_tp: 66.0000 - val_fp: 774.0000 - val_tn: 44726.0000 - val_fn: 3.0000 - val_accuracy: 0.9829 - val_precision: 0.0786 - val_recall: 0.9565 - val_auc: 0.9961 - val_prc: 0.7769
Epoch 28/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2147 - tp: 16390.0000 - fp: 867.0000 - tn: 19544.0000 - fn: 4159.0000 - accuracy: 0.8773 - precision: 0.9498 - recall: 0.7976 - auc: 0.9721 - prc: 0.9754 - val_loss: 0.1051 - val_tp: 66.0000 - val_fp: 774.0000 - val_tn: 44726.0000 - val_fn: 3.0000 - val_accuracy: 0.9829 - val_precision: 0.0786 - val_recall: 0.9565 - val_auc: 0.9961 - val_prc: 0.7771
Epoch 29/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2100 - tp: 16237.0000 - fp: 868.0000 - tn: 19885.0000 - fn: 3970.0000 - accuracy: 0.8819 - precision: 0.9493 - recall: 0.8035 - auc: 0.9737 - prc: 0.9759 - val_loss: 0.1026 - val_tp: 66.0000 - val_fp: 785.0000 - val_tn: 44715.0000 - val_fn: 3.0000 - val_accuracy: 0.9827 - val_precision: 0.0776 - val_recall: 0.9565 - val_auc: 0.9960 - val_prc: 0.7781
Epoch 30/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2117 - tp: 17870.0000 - fp: 1046.0000 - tn: 19512.0000 - fn: 2532.0000 - accuracy: 0.9126 - precision: 0.9447 - recall: 0.8759 - auc: 0.9734 - prc: 0.9757 - val_loss: 0.1003 - val_tp: 66.0000 - val_fp: 790.0000 - val_tn: 44710.0000 - val_fn: 3.0000 - val_accuracy: 0.9826 - val_precision: 0.0771 - val_recall: 0.9565 - val_auc: 0.9960 - val_prc: 0.7791
Epoch 31/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2027 - tp: 18028.0000 - fp: 1001.0000 - tn: 19511.0000 - fn: 2420.0000 - accuracy: 0.9165 - precision: 0.9474 - recall: 0.8817 - auc: 0.9759 - prc: 0.9781 - val_loss: 0.0977 - val_tp: 66.0000 - val_fp: 782.0000 - val_tn: 44718.0000 - val_fn: 3.0000 - val_accuracy: 0.9828 - val_precision: 0.0778 - val_recall: 0.9565 - val_auc: 0.9959 - val_prc: 0.7807
Epoch 32/1000
20/20 [==============================] - 1s 29ms/step - loss: 0.2036 - tp: 18110.0000 - fp: 1003.0000 - tn: 19526.0000 - fn: 2321.0000 - accuracy: 0.9188 - precision: 0.9475 - recall: 0.8864 - auc: 0.9761 - prc: 0.9780 - val_loss: 0.0956 - val_tp: 66.0000 - val_fp: 789.0000 - val_tn: 44711.0000 - val_fn: 3.0000 - val_accuracy: 0.9826 - val_precision: 0.0772 - val_recall: 0.9565 - val_auc: 0.9959 - val_prc: 0.7723
Epoch 33/1000
20/20 [==============================] - ETA: 0s - loss: 0.2003 - tp: 18169.0000 - fp: 1012.0000 - tn: 19520.0000 - fn: 2259.0000 - accuracy: 0.9201 - precision: 0.9472 - recall: 0.8894 - auc: 0.9773 - prc: 0.9786Restoring model weights from the end of the best epoch: 23.
20/20 [==============================] - 1s 29ms/step - loss: 0.2003 - tp: 18169.0000 - fp: 1012.0000 - tn: 19520.0000 - fn: 2259.0000 - accuracy: 0.9201 - precision: 0.9472 - recall: 0.8894 - auc: 0.9773 - prc: 0.9786 - val_loss: 0.0931 - val_tp: 66.0000 - val_fp: 780.0000 - val_tn: 44720.0000 - val_fn: 3.0000 - val_accuracy: 0.9828 - val_precision: 0.0780 - val_recall: 0.9565 - val_auc: 0.9959 - val_prc: 0.7733
Epoch 33: early stopping

重新查看训练历史记录

plot_metrics(resampled_history)

png

评估指标

train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step
28/28 [==============================] - 0s 1ms/step
resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_resampled)
loss :  0.12108473479747772
tp :  95.0
fp :  915.0
tn :  55940.0
fn :  12.0
accuracy :  0.9837259650230408
precision :  0.09405940771102905
recall :  0.8878504633903503
auc :  0.9800193309783936
prc :  0.7522088289260864

Legitimate Transactions Detected (True Negatives):  55940
Legitimate Transactions Incorrectly Detected (False Positives):  915
Fraudulent Transactions Missed (False Negatives):  12
Fraudulent Transactions Detected (True Positives):  95
Total Fraudulent Transactions:  107

png

绘制 ROC

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right');

png

绘制 AUPRC

plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_prc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_prc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right');

png

使用本教程解决您的问题

由于可供学习的样本过少,不平衡数据的分类是固有难题。您应该始终先从数据开始,尽可能多地收集样本,并充分考虑可能相关的特征,以便模型能够充分利用占少数的类。有时您的模型可能难以改善且无法获得想要的结果,因此请务必牢记问题的上下文,并在不同类型的错误之间进行权衡。