Checkpointer e PolicySaver

introduzione

tf_agents.utils.common.Checkpointer è un programma di utilità per salvare / caricare lo stato di formazione, la politica dello stato, e lo stato replay_buffer a / da un archivio locale.

tf_agents.policies.policy_saver.PolicySaver è uno strumento per salvare / caricare solo la politica, ed è più leggero di Checkpointer . È possibile utilizzare PolicySaver per implementare il modello anche senza alcuna conoscenza del codice che ha creato la politica.

In questo tutorial, useremo DQN per addestrare un modello, quindi utilizzare Checkpointer e PolicySaver per mostrare come è possibile memorizzare e caricare gli stati e il modello in modo interattivo. Nota che useremo nuove attrezzature saved_model di TF2.0 e formato per PolicySaver .

Impostare

Se non hai installato le seguenti dipendenze, esegui:

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython

try:
 
from google.colab import files
except ImportError:
  files
= None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tempdir
= os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper
.Xvfb(1400, 900, 24).start()

Agente DQN

Stiamo per configurare l'agente DQN, proprio come nella colab precedente. I dettagli sono nascosti per impostazione predefinita in quanto non sono parte fondamentale di questa collaborazione, ma puoi fare clic su "MOSTRA CODICE" per vedere i dettagli.

Iperparametri

env_name = "CartPole-v1"

collect_steps_per_iteration
= 100
replay_buffer_capacity
= 100000

fc_layer_params
= (100,)

batch_size
= 64
learning_rate
= 1e-3
log_interval
= 5

num_eval_episodes
= 10
eval_interval
= 1000

Ambiente

train_py_env = suite_gym.load(env_name)
eval_py_env
= suite_gym.load(env_name)

train_env
= tf_py_environment.TFPyEnvironment(train_py_env)
eval_env
= tf_py_environment.TFPyEnvironment(eval_py_env)

Agente

q_net = q_network.QNetwork(
    train_env
.observation_spec(),
    train_env
.action_spec(),
    fc_layer_params
=fc_layer_params)

optimizer
= tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

global_step
= tf.compat.v1.train.get_or_create_global_step()

agent
= dqn_agent.DqnAgent(
    train_env
.time_step_spec(),
    train_env
.action_spec(),
    q_network
=q_net,
    optimizer
=optimizer,
    td_errors_loss_fn
=common.element_wise_squared_loss,
    train_step_counter
=global_step)
agent
.initialize()

Raccolta dati

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec
=agent.collect_data_spec,
    batch_size
=train_env.batch_size,
    max_length
=replay_buffer_capacity)

collect_driver
= dynamic_step_driver.DynamicStepDriver(
    train_env
,
    agent
.collect_policy,
    observers
=[replay_buffer.add_batch],
    num_steps
=collect_steps_per_iteration)

# Initial data collection
collect_driver
.run()

# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset
= replay_buffer.as_dataset(
    num_parallel_calls
=3, sample_batch_size=batch_size,
    num_steps
=2).prefetch(3)

iterator
= iter(dataset)

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

Formare l'agente

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent
.train = common.function(agent.train)

def train_one_iteration():

 
# Collect a few steps using collect_policy and save to the replay buffer.
  collect_driver
.run()

 
# Sample a batch of data from the buffer and update the agent's network.
  experience
, unused_info = next(iterator)
  train_loss
= agent.train(experience)

  iteration
= agent.train_step_counter.numpy()
 
print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))

Generazione video

def embed_gif(gif_buffer):
 
"""Embeds a gif file in the notebook."""
  tag
= '<img src="data:image/gif;base64,{0}"/>'.format(base64.b64encode(gif_buffer).decode())
 
return IPython.display.HTML(tag)

def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):
  num_episodes
= 3
  frames
= []
 
for _ in range(num_episodes):
    time_step
= eval_tf_env.reset()
    frames
.append(eval_py_env.render())
   
while not time_step.is_last():
      action_step
= policy.action(time_step)
      time_step
= eval_tf_env.step(action_step.action)
      frames
.append(eval_py_env.render())
  gif_file
= io.BytesIO()
  imageio
.mimsave(gif_file, frames, format='gif', fps=60)
 
IPython.display.display(embed_gif(gif_file.getvalue()))

Genera un video

Verifica le prestazioni della policy generando un video.

print ('global_step:')
print (global_step)
run_episodes_and_create_video
(agent.policy, eval_env, eval_py_env)
global_step:
<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>

gif

Setup Checkpointer e PolicySaver

Ora siamo pronti per utilizzare Checkpointer e PolicySaver.

Checkpointer

checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer
= common.Checkpointer(
    ckpt_dir
=checkpoint_dir,
    max_to_keep
=1,
    agent
=agent,
    policy
=agent.policy,
    replay_buffer
=replay_buffer,
    global_step
=global_step
)

Politica di risparmio

policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver
= policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.

Allena un'iterazione

print('Training one iteration....')
train_one_iteration
()
Training one iteration....
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
iteration: 1 loss: 1.0214563608169556

Salva al checkpoint

train_checkpointer.save(global_step)

Ripristina punto di controllo

Affinché ciò funzioni, l'intero set di oggetti dovrebbe essere ricreato allo stesso modo di quando è stato creato il checkpoint.

train_checkpointer.initialize_or_restore()
global_step
= tf.compat.v1.train.get_global_step()

Salva anche la politica ed esporta in una posizione

tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets

La policy può essere caricata senza conoscere quale agente o rete è stata utilizzata per crearla. Ciò semplifica notevolmente l'implementazione della policy.

Carica la politica salvata e controlla come si comporta

saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video
(saved_policy, eval_env, eval_py_env)

gif

Esporta e importa

Il resto della collaborazione ti aiuterà a esportare / importare checkpointer e directory di criteri in modo da poter continuare l'addestramento in un secondo momento e distribuire il modello senza dover eseguire nuovamente l'addestramento.

Ora puoi tornare a "Train one iterazione" e allenarti ancora un paio di volte in modo da poter capire la differenza in seguito. Una volta che inizi a vedere risultati leggermente migliori, continua di seguito.

Crea file zip e carica file zip (fare doppio clic per vedere il codice)

def create_zip_file(dirname, base_filename):
 
return shutil.make_archive(base_filename, 'zip', dirname)

def upload_and_unzip_file_to(dirname):
 
if files is None:
   
return
  uploaded
= files.upload()
 
for fn in uploaded.keys():
   
print('User uploaded file "{name}" with length {length} bytes'.format(
        name
=fn, length=len(uploaded[fn])))
    shutil
.rmtree(dirname)
    zip_files
= zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')
    zip_files
.extractall(dirname)
    zip_files
.close()

Crea un file zippato dalla directory del checkpoint.

train_checkpointer.save(global_step)
checkpoint_zip_filename
= create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))

Scarica il file zip.

if files is not None:
  files
.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

Dopo l'allenamento per un po' di tempo (10-15 volte), scarica il file zip del checkpoint e vai su "Runtime > Riavvia ed esegui tutto" per ripristinare l'allenamento e tornare a questa cella. Ora puoi caricare il file zip scaricato e continuare la formazione.

upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer
.initialize_or_restore()
global_step
= tf.compat.v1.train.get_global_step()

Una volta caricata la directory del checkpoint, torna a "Train one iterazione" per continuare la formazione o torna a "Genera un video" per verificare le prestazioni della politica caricata.

In alternativa, è possibile salvare la policy (modello) e ripristinarla. A differenza di checkpointer, non è possibile continuare con l'addestramento, ma è comunque possibile distribuire il modello. Nota che il file scaricato è molto più piccolo di quello del checkpointer.

tf_policy_saver.save(policy_dir)
policy_zip_filename
= create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
  files
.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

Carica la directory della policy scaricata (exported_policy.zip) e controlla come si comporta la policy salvata.

upload_and_unzip_file_to(policy_dir)
saved_policy
= tf.saved_model.load(policy_dir)
run_episodes_and_create_video
(saved_policy, eval_env, eval_py_env)

gif

SavedModelPyTFEagerPolicy

Se non si desidera utilizzare la politica TF, allora si può anche utilizzare il saved_model direttamente con l'ENV Python attraverso l'utilizzo di py_tf_eager_policy.SavedModelPyTFEagerPolicy .

Nota che funziona solo quando è abilitata la modalità desideroso.

eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    policy_dir
, eval_py_env.time_step_spec(), eval_py_env.action_spec())

# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video
(eager_py_policy, eval_py_env, eval_py_env)

gif

Converti la politica in TFLite

Vedere convertitore tensorflow Lite per maggiori dettagli.

converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy
= converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
  f
.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

Esegui l'inferenza sul modello TFLite

Vedere tensorflow Lite Inference per maggiori dettagli.

import numpy as np
interpreter
= tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))

policy_runner
= interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
   
'0/discount':tf.constant(0.0),
   
'0/observation':tf.zeros([1,4]),
   
'0/reward':tf.constant(0.0),
   
'0/step_type':tf.constant(0)})
{'action': array([0])}