![]() |
![]() |
![]() |
![]() |
Overview
Graph regularization is a specific technique under the broader paradigm of Neural Graph Learning (Bui et al., 2018). The core idea is to train neural network models with a graph-regularized objective, harnessing both labeled and unlabeled data.
In this tutorial, we will explore the use of graph regularization to classify documents that form a natural (organic) graph.
The general recipe for creating a graph-regularized model using the Neural Structured Learning (NSL) framework is as follows:
- Generate training data from the input graph and sample features. Nodes in the graph correspond to samples and edges in the graph correspond to similarity between pairs of samples. The resulting training data will contain neighbor features in addition to the original node features.
- Create a neural network as a base model using the
Keras
sequential, functional, or subclass API. - Wrap the base model with the
GraphRegularization
wrapper class, which is provided by the NSL framework, to create a new graphKeras
model. This new model will include a graph regularization loss as the regularization term in its training objective. - Train and evaluate the graph
Keras
model.
Setup
Install the Neural Structured Learning package.
pip install --quiet neural-structured-learning
Dependencies and imports
import neural_structured_learning as nsl
import tensorflow as tf
# Resets notebook state
tf.keras.backend.clear_session()
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
"GPU is",
"available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
2022-12-14 12:30:10.701022: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:30:10.701136: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:30:10.701148: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. Version: 2.11.0 Eager mode: True GPU is NOT AVAILABLE 2022-12-14 12:30:11.873055: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Cora dataset
The Cora dataset is a citation graph where nodes represent machine learning papers and edges represent citations between pairs of papers. The task involved is document classification where the goal is to categorize each paper into one of 7 categories. In other words, this is a multi-class classification problem with 7 classes.
Graph
The original graph is directed. However, for the purpose of this example, we consider the undirected version of this graph. So, if paper A cites paper B, we also consider paper B to have cited A. Although this is not necessarily true, in this example, we consider citations as a proxy for similarity, which is usually a commutative property.
Features
Each paper in the input effectively contains 2 features:
Words: A dense, multi-hot bag-of-words representation of the text in the paper. The vocabulary for the Cora dataset contains 1433 unique words. So, the length of this feature is 1433, and the value at position 'i' is 0/1 indicating whether word 'i' in the vocabulary exists in the given paper or not.
Label: A single integer representing the class ID (category) of the paper.
Download the Cora dataset
wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/ cora/README cora/cora.cites cora/cora.content
Convert the Cora data to the NSL format
In order to preprocess the Cora dataset and convert it to the format required by Neural Structured Learning, we will run the 'preprocess_cora_dataset.py' script, which is included in the NSL github repository. This script does the following:
- Generate neighbor features using the original node features and the graph.
- Generate train and test data splits containing
tf.train.Example
instances. - Persist the resulting train and test data in the
TFRecord
format.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2022-12-14 12:30:12-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’ preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0s 2022-12-14 12:30:13 (51.0 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640] 2022-12-14 12:30:14.426059: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:30:14.426163: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:30:14.426185: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. 2022-12-14 12:30:15.551270: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.11 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.36 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.04 minutes.
Global variables
The file paths to the train and test data are based on the command line flag values used to invoke the 'preprocess_cora_dataset.py' script above.
### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'
### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
Hyperparameters
We will use an instance of HParams
to include various hyperparameters and
constants used for training and evaluation. We briefly describe each of them
below:
num_classes: There are a total 7 different classes
max_seq_length: This is the size of the vocabulary and all instances in the input have a dense multi-hot, bag-of-words representation. In other words, a value of 1 for a word indicates that the word is present in the input and a value of 0 indicates that it is not.
distance_type: This is the distance metric used to regularize the sample with its neighbors.
graph_regularization_multiplier: This controls the relative weight of the graph regularization term in the overall loss function.
num_neighbors: The number of neighbors used for graph regularization. This value has to be less than or equal to the
max_nbrs
command-line argument used above when runningpreprocess_cora_dataset.py
.num_fc_units: The number of fully connected layers in our neural network.
train_epochs: The number of training epochs.
batch_size: Batch size used for training and evaluation.
dropout_rate: Controls the rate of dropout following each fully connected layer
eval_steps: The number of batches to process before deeming evaluation is complete. If set to
None
, all instances in the test set are evaluated.
class HParams(object):
"""Hyperparameters used for training."""
def __init__(self):
### dataset parameters
self.num_classes = 7
self.max_seq_length = 1433
### neural graph learning parameters
self.distance_type = nsl.configs.DistanceType.L2
self.graph_regularization_multiplier = 0.1
self.num_neighbors = 1
### model architecture
self.num_fc_units = [50, 50]
### training parameters
self.train_epochs = 100
self.batch_size = 128
self.dropout_rate = 0.5
### eval parameters
self.eval_steps = None # All instances in the test set are evaluated.
HPARAMS = HParams()
Load train and test data
As described earlier in this notebook, the input training and test data have
been created by the 'preprocess_cora_dataset.py'. We will load them into two
tf.data.Dataset
objects -- one for train and one for test.
In the input layer of our model, we will extract not just the 'words' and the
'label' features from each sample, but also corresponding neighbor features
based on the hparams.num_neighbors
value. Instances with fewer neighbors than
hparams.num_neighbors
will be assigned dummy values for those non-existent
neighbor features.
def make_dataset(file_path, training=False):
"""Creates a `tf.data.TFRecordDataset`.
Args:
file_path: Name of the file in the `.tfrecord` format containing
`tf.train.Example` objects.
training: Boolean indicating if we are in training mode.
Returns:
An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
objects.
"""
def parse_example(example_proto):
"""Extracts relevant fields from the `example_proto`.
Args:
example_proto: An instance of `tf.train.Example`.
Returns:
A pair whose first value is a dictionary containing relevant features
and whose second value contains the ground truth label.
"""
# The 'words' feature is a multi-hot, bag-of-words representation of the
# original raw text. A default value is required for examples that don't
# have the feature.
feature_spec = {
'words':
tf.io.FixedLenFeature([HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0,
dtype=tf.int64,
shape=[HPARAMS.max_seq_length])),
'label':
tf.io.FixedLenFeature((), tf.int64, default_value=-1),
}
# We also extract corresponding neighbor features in a similar manner to
# the features above during training.
if training:
for i in range(HPARAMS.num_neighbors):
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
NBR_WEIGHT_SUFFIX)
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
[HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))
# We assign a default value of 0.0 for the neighbor weight so that
# graph regularization is done on samples based on their exact number
# of neighbors. In other words, non-existent neighbors are discounted.
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
[1], tf.float32, default_value=tf.constant([0.0]))
features = tf.io.parse_single_example(example_proto, feature_spec)
label = features.pop('label')
return features, label
dataset = tf.data.TFRecordDataset([file_path])
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.map(parse_example)
dataset = dataset.batch(HPARAMS.batch_size)
return dataset
train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
Let's peek into the train dataset to look at its contents.
for feature_batch, label_batch in train_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
print('Batch of neighbor weights:',
tf.reshape(feature_batch[nbr_weight_key], [-1]))
print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [3 2 3 6 6 1 3 0 0 4 3 3 1 2 2 2 3 1 5 3 0 0 3 2 0 4 3 2 1 2 3 2 4 5 1 3 6 5 2 4 1 2 0 6 2 3 2 3 2 4 4 1 2 2 5 2 3 3 1 2 2 3 3 6 3 3 1 2 5 2 0 1 0 2 1 0 6 6 2 0 2 1 0 2 5 1 2 2 1 2 6 1 0 5 2 0 2 2 6 0 3 2 0 2 2 2 1 2 3 2 4 5 2 3 1 0 0 0 6 5 1 4 5 0 2 0 1 1], shape=(128,), dtype=int64)
Let's peek into the test dataset to look at its contents.
for feature_batch, label_batch in test_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
print('Batch of labels:', label_batch)
Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)
Model definition
In order to demonstrate the use of graph regularization, we build a base model
for this problem first. We will use a simple feed-forward neural network with 2
hidden layers and dropout in between. We illustrate the creation of the base
model using all model types supported by the tf.Keras
framework -- sequential,
functional, and subclass.
Sequential base model
def make_mlp_sequential_model(hparams):
"""Creates a sequential multi-layer perceptron model."""
model = tf.keras.Sequential()
model.add(
tf.keras.layers.InputLayer(
input_shape=(hparams.max_seq_length,), name='words'))
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
model.add(
tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
for num_units in hparams.num_fc_units:
model.add(tf.keras.layers.Dense(num_units, activation='relu'))
# For sequential models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
model.add(tf.keras.layers.Dense(hparams.num_classes))
return model
Functional base model
def make_mlp_functional_model(hparams):
"""Creates a functional API-based multi-layer perceptron model."""
inputs = tf.keras.Input(
shape=(hparams.max_seq_length,), dtype='int64', name='words')
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
cur_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))(
inputs)
for num_units in hparams.num_fc_units:
cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
# For functional models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)
outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)
model = tf.keras.Model(inputs, outputs=outputs)
return model
Subclass base model
def make_mlp_subclass_model(hparams):
"""Creates a multi-layer perceptron subclass model in Keras."""
class MLP(tf.keras.Model):
"""Subclass model defining a multi-layer perceptron."""
def __init__(self):
super(MLP, self).__init__()
# Input is already one-hot encoded in the integer format. We create a
# layer to cast it to floating point format here.
self.cast_to_float_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))
self.dense_layers = [
tf.keras.layers.Dense(num_units, activation='relu')
for num_units in hparams.num_fc_units
]
self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
self.output_layer = tf.keras.layers.Dense(hparams.num_classes)
def call(self, inputs, training=False):
cur_layer = self.cast_to_float_layer(inputs['words'])
for dense_layer in self.dense_layers:
cur_layer = dense_layer(cur_layer)
cur_layer = self.dropout_layer(cur_layer, training=training)
outputs = self.output_layer(cur_layer)
return outputs
return MLP()
Create base model(s)
# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= words (InputLayer) [(None, 1433)] 0 lambda (Lambda) (None, 1433) 0 dense (Dense) (None, 50) 71700 dropout (Dropout) (None, 50) 0 dense_1 (Dense) (None, 50) 2550 dropout_1 (Dropout) (None, 50) 0 dense_2 (Dense) (None, 7) 357 ================================================================= Total params: 74,607 Trainable params: 74,607 Non-trainable params: 0 _________________________________________________________________
Train base MLP model
# Compile and train the base MLP model
base_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/engine/functional.py:638: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model. inputs = self._flatten_to_reference_inputs(inputs) 17/17 [==============================] - 1s 22ms/step - loss: 1.9376 - accuracy: 0.1824 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8481 - accuracy: 0.2826 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7504 - accuracy: 0.3244 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6410 - accuracy: 0.3754 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5148 - accuracy: 0.4292 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3573 - accuracy: 0.5090 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2228 - accuracy: 0.5754 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0686 - accuracy: 0.6371 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9709 - accuracy: 0.6770 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8576 - accuracy: 0.7281 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7724 - accuracy: 0.7545 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7295 - accuracy: 0.7684 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6399 - accuracy: 0.7949 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6030 - accuracy: 0.8023 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5421 - accuracy: 0.8325 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5235 - accuracy: 0.8339 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4918 - accuracy: 0.8418 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4502 - accuracy: 0.8524 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3927 - accuracy: 0.8896 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3635 - accuracy: 0.8910 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3564 - accuracy: 0.8984 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3232 - accuracy: 0.9035 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2942 - accuracy: 0.9100 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2992 - accuracy: 0.9104 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2717 - accuracy: 0.9225 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2442 - accuracy: 0.9267 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2541 - accuracy: 0.9165 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2316 - accuracy: 0.9364 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2498 - accuracy: 0.9234 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2160 - accuracy: 0.9406 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2019 - accuracy: 0.9392 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1956 - accuracy: 0.9420 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1879 - accuracy: 0.9448 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1802 - accuracy: 0.9476 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1547 - accuracy: 0.9615 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1632 - accuracy: 0.9592 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1493 - accuracy: 0.9606 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1491 - accuracy: 0.9582 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1442 - accuracy: 0.9592 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1285 - accuracy: 0.9671 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1330 - accuracy: 0.9592 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1246 - accuracy: 0.9638 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1228 - accuracy: 0.9712 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1299 - accuracy: 0.9610 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1174 - accuracy: 0.9684 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1289 - accuracy: 0.9638 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1132 - accuracy: 0.9703 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1078 - accuracy: 0.9703 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1136 - accuracy: 0.9684 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1027 - accuracy: 0.9740 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1043 - accuracy: 0.9694 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0963 - accuracy: 0.9698 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0959 - accuracy: 0.9754 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0987 - accuracy: 0.9745 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0955 - accuracy: 0.9735 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0935 - accuracy: 0.9712 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0916 - accuracy: 0.9745 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0771 - accuracy: 0.9777 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0887 - accuracy: 0.9745 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0761 - accuracy: 0.9810 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0790 - accuracy: 0.9768 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0688 - accuracy: 0.9819 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0807 - accuracy: 0.9796 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0651 - accuracy: 0.9828 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0647 - accuracy: 0.9838 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0744 - accuracy: 0.9810 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0663 - accuracy: 0.9833 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0610 - accuracy: 0.9852 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0642 - accuracy: 0.9847 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0722 - accuracy: 0.9800 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0715 - accuracy: 0.9819 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0680 - accuracy: 0.9819 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0599 - accuracy: 0.9856 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0621 - accuracy: 0.9824 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0533 - accuracy: 0.9861 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0664 - accuracy: 0.9819 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0630 - accuracy: 0.9824 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0608 - accuracy: 0.9838 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0550 - accuracy: 0.9852 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0561 - accuracy: 0.9870 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0556 - accuracy: 0.9828 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0498 - accuracy: 0.9842 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0614 - accuracy: 0.9842 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0510 - accuracy: 0.9828 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0495 - accuracy: 0.9870 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0528 - accuracy: 0.9847 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0500 - accuracy: 0.9861 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0435 - accuracy: 0.9875 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0427 - accuracy: 0.9898 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0572 - accuracy: 0.9828 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0467 - accuracy: 0.9861 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0584 - accuracy: 0.9824 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0458 - accuracy: 0.9889 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0444 - accuracy: 0.9879 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0467 - accuracy: 0.9852 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0513 - accuracy: 0.9847 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0464 - accuracy: 0.9870 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0514 - accuracy: 0.9865 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0353 - accuracy: 0.9903 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0499 - accuracy: 0.9861 <keras.callbacks.History at 0x7fc17e2c9400>
Evaluate base MLP model
# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
"""Prints evaluation metrics.
Args:
model_desc: A description of the model.
eval_metrics: A dictionary mapping metric names to corresponding values. It
must contain the loss and accuracy metrics.
"""
print('\n')
print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
if 'graph_loss' in eval_metrics:
print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
zip(base_model.metrics_names,
base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.3725 - accuracy: 0.7866 Eval accuracy for Base MLP model : 0.7866184711456299 Eval loss for Base MLP model : 1.372455358505249
Train MLP model with graph regularization
Incorporating graph regularization into the loss term of an existing
tf.Keras.Model
requires just a few lines of code. The base model is wrapped to
create a new tf.Keras
subclass model, whose loss includes graph
regularization.
To assess the incremental benefit of graph regularization, we will create a new
base model instance. This is because base_model
has already been trained for a
few iterations, and reusing this trained model to create a graph-regularized
model will not be a fair comparison for base_model
.
# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
max_neighbors=HPARAMS.num_neighbors,
multiplier=HPARAMS.graph_regularization_multiplier,
distance_type=HPARAMS.distance_type,
sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
graph_reg_config)
graph_reg_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 17/17 [==============================] - 2s 5ms/step - loss: 1.9608 - accuracy: 0.2111 - scaled_graph_loss: 0.0321 Epoch 2/100 17/17 [==============================] - 0s 4ms/step - loss: 1.8863 - accuracy: 0.3012 - scaled_graph_loss: 0.0249 Epoch 3/100 17/17 [==============================] - 0s 4ms/step - loss: 1.8332 - accuracy: 0.3155 - scaled_graph_loss: 0.0383 Epoch 4/100 17/17 [==============================] - 0s 4ms/step - loss: 1.7775 - accuracy: 0.3309 - scaled_graph_loss: 0.0590 Epoch 5/100 17/17 [==============================] - 0s 4ms/step - loss: 1.7284 - accuracy: 0.3411 - scaled_graph_loss: 0.0774 Epoch 6/100 17/17 [==============================] - 0s 4ms/step - loss: 1.6866 - accuracy: 0.3800 - scaled_graph_loss: 0.0888 Epoch 7/100 17/17 [==============================] - 0s 4ms/step - loss: 1.6393 - accuracy: 0.4306 - scaled_graph_loss: 0.1097 Epoch 8/100 17/17 [==============================] - 0s 4ms/step - loss: 1.5881 - accuracy: 0.4780 - scaled_graph_loss: 0.1173 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5383 - accuracy: 0.5258 - scaled_graph_loss: 0.1445 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5083 - accuracy: 0.5787 - scaled_graph_loss: 0.1613 Epoch 11/100 17/17 [==============================] - 0s 4ms/step - loss: 1.4778 - accuracy: 0.6019 - scaled_graph_loss: 0.1713 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4453 - accuracy: 0.6260 - scaled_graph_loss: 0.1875 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4017 - accuracy: 0.6608 - scaled_graph_loss: 0.2000 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3734 - accuracy: 0.6831 - scaled_graph_loss: 0.2033 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3410 - accuracy: 0.7174 - scaled_graph_loss: 0.2236 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3265 - accuracy: 0.7285 - scaled_graph_loss: 0.2144 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3034 - accuracy: 0.7411 - scaled_graph_loss: 0.2351 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2642 - accuracy: 0.7689 - scaled_graph_loss: 0.2354 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2685 - accuracy: 0.7787 - scaled_graph_loss: 0.2526 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2325 - accuracy: 0.7879 - scaled_graph_loss: 0.2427 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2216 - accuracy: 0.7907 - scaled_graph_loss: 0.2515 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2019 - accuracy: 0.8074 - scaled_graph_loss: 0.2558 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1884 - accuracy: 0.8218 - scaled_graph_loss: 0.2637 Epoch 24/100 17/17 [==============================] - 0s 4ms/step - loss: 1.1718 - accuracy: 0.8200 - scaled_graph_loss: 0.2513 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1739 - accuracy: 0.8190 - scaled_graph_loss: 0.2590 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1340 - accuracy: 0.8362 - scaled_graph_loss: 0.2662 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1293 - accuracy: 0.8450 - scaled_graph_loss: 0.2716 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1333 - accuracy: 0.8445 - scaled_graph_loss: 0.2764 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1139 - accuracy: 0.8450 - scaled_graph_loss: 0.2658 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1122 - accuracy: 0.8580 - scaled_graph_loss: 0.2771 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1114 - accuracy: 0.8589 - scaled_graph_loss: 0.2808 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0905 - accuracy: 0.8640 - scaled_graph_loss: 0.2837 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0762 - accuracy: 0.8705 - scaled_graph_loss: 0.2728 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0876 - accuracy: 0.8696 - scaled_graph_loss: 0.2884 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0570 - accuracy: 0.8710 - scaled_graph_loss: 0.2723 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0586 - accuracy: 0.8756 - scaled_graph_loss: 0.2857 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0606 - accuracy: 0.8687 - scaled_graph_loss: 0.2821 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0497 - accuracy: 0.8770 - scaled_graph_loss: 0.2800 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0423 - accuracy: 0.8886 - scaled_graph_loss: 0.2964 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0338 - accuracy: 0.8951 - scaled_graph_loss: 0.2987 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0386 - accuracy: 0.8826 - scaled_graph_loss: 0.2750 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0372 - accuracy: 0.8831 - scaled_graph_loss: 0.2909 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0110 - accuracy: 0.8914 - scaled_graph_loss: 0.2820 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0177 - accuracy: 0.8914 - scaled_graph_loss: 0.2923 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0204 - accuracy: 0.8914 - scaled_graph_loss: 0.2891 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0200 - accuracy: 0.8872 - scaled_graph_loss: 0.2865 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0147 - accuracy: 0.8933 - scaled_graph_loss: 0.2898 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0103 - accuracy: 0.8951 - scaled_graph_loss: 0.2895 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9986 - accuracy: 0.9002 - scaled_graph_loss: 0.2933 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0030 - accuracy: 0.9030 - scaled_graph_loss: 0.2980 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9813 - accuracy: 0.9039 - scaled_graph_loss: 0.2924 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9876 - accuracy: 0.9016 - scaled_graph_loss: 0.2918 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9796 - accuracy: 0.9053 - scaled_graph_loss: 0.2932 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9857 - accuracy: 0.9039 - scaled_graph_loss: 0.2885 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9690 - accuracy: 0.9142 - scaled_graph_loss: 0.2985 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9721 - accuracy: 0.9035 - scaled_graph_loss: 0.2956 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9712 - accuracy: 0.9063 - scaled_graph_loss: 0.2958 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9761 - accuracy: 0.9012 - scaled_graph_loss: 0.2907 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9596 - accuracy: 0.9160 - scaled_graph_loss: 0.2956 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9549 - accuracy: 0.9086 - scaled_graph_loss: 0.2940 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9604 - accuracy: 0.9123 - scaled_graph_loss: 0.2942 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9624 - accuracy: 0.9021 - scaled_graph_loss: 0.3018 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9660 - accuracy: 0.9095 - scaled_graph_loss: 0.2978 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9580 - accuracy: 0.9072 - scaled_graph_loss: 0.2891 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9557 - accuracy: 0.9151 - scaled_graph_loss: 0.3005 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9629 - accuracy: 0.9104 - scaled_graph_loss: 0.3093 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9447 - accuracy: 0.9104 - scaled_graph_loss: 0.2759 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9493 - accuracy: 0.9179 - scaled_graph_loss: 0.3004 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9545 - accuracy: 0.9155 - scaled_graph_loss: 0.3110 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9402 - accuracy: 0.9155 - scaled_graph_loss: 0.2959 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9370 - accuracy: 0.9179 - scaled_graph_loss: 0.2873 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9261 - accuracy: 0.9155 - scaled_graph_loss: 0.3011 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9311 - accuracy: 0.9142 - scaled_graph_loss: 0.2978 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9255 - accuracy: 0.9118 - scaled_graph_loss: 0.3007 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9346 - accuracy: 0.9118 - scaled_graph_loss: 0.2841 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9231 - accuracy: 0.9197 - scaled_graph_loss: 0.3038 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9362 - accuracy: 0.9174 - scaled_graph_loss: 0.3007 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9274 - accuracy: 0.9109 - scaled_graph_loss: 0.2920 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8998 - accuracy: 0.9299 - scaled_graph_loss: 0.2786 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9260 - accuracy: 0.9234 - scaled_graph_loss: 0.3090 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9118 - accuracy: 0.9230 - scaled_graph_loss: 0.2891 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9036 - accuracy: 0.9304 - scaled_graph_loss: 0.2999 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8991 - accuracy: 0.9202 - scaled_graph_loss: 0.2929 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8940 - accuracy: 0.9383 - scaled_graph_loss: 0.2983 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9079 - accuracy: 0.9253 - scaled_graph_loss: 0.2989 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9004 - accuracy: 0.9262 - scaled_graph_loss: 0.3064 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9124 - accuracy: 0.9244 - scaled_graph_loss: 0.2909 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8887 - accuracy: 0.9271 - scaled_graph_loss: 0.2817 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8870 - accuracy: 0.9304 - scaled_graph_loss: 0.3015 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9211 - accuracy: 0.9160 - scaled_graph_loss: 0.3102 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8937 - accuracy: 0.9309 - scaled_graph_loss: 0.2878 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8912 - accuracy: 0.9230 - scaled_graph_loss: 0.3018 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8865 - accuracy: 0.9336 - scaled_graph_loss: 0.3034 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9044 - accuracy: 0.9197 - scaled_graph_loss: 0.2976 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8919 - accuracy: 0.9230 - scaled_graph_loss: 0.2929 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9032 - accuracy: 0.9230 - scaled_graph_loss: 0.2964 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8892 - accuracy: 0.9244 - scaled_graph_loss: 0.2995 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8852 - accuracy: 0.9281 - scaled_graph_loss: 0.3017 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8813 - accuracy: 0.9216 - scaled_graph_loss: 0.2953 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8911 - accuracy: 0.9341 - scaled_graph_loss: 0.3030 <keras.callbacks.History at 0x7fc0d039a0a0>
Evaluate MLP model with graph regularization
eval_results = dict(
zip(graph_reg_model.metrics_names,
graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8984 - accuracy: 0.7957 Eval accuracy for MLP + graph regularization : 0.7956600189208984 Eval loss for MLP + graph regularization : 0.8983543515205383
The graph-regularized model's accuracy is about 2-3% higher than that of the
base model (base_model
).
Conclusion
We have demonstrated the use of graph regularization for document classification on a natural citation graph (Cora) using the Neural Structured Learning (NSL) framework. Our advanced tutorial involves synthesizing graphs based on sample embeddings before training a neural network with graph regularization. This approach is useful if the input does not contain an explicit graph.
We encourage users to experiment further by varying the amount of supervision as well as trying different neural architectures for graph regularization.