Checkpointer dan PolicySaver

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

pengantar

tf_agents.utils.common.Checkpointer adalah utilitas untuk menyimpan / memuat negara pelatihan, kebijakan negara, dan negara replay_buffer ke / dari penyimpanan lokal.

tf_agents.policies.policy_saver.PolicySaver adalah alat untuk menyimpan / beban hanya kebijakan, dan lebih ringan dari Checkpointer . Anda dapat menggunakan PolicySaver untuk menyebarkan model serta tanpa pengetahuan tentang kode yang dibuat kebijakan.

Dalam tutorial ini, kita akan menggunakan DQN untuk melatih model, kemudian gunakan Checkpointer dan PolicySaver untuk menunjukkan bagaimana kita dapat menyimpan dan memuat negara dan model dengan cara yang interaktif. Perhatikan bahwa kita akan menggunakan perkakas saved_model baru TF2.0 dan format untuk PolicySaver .

Mempersiapkan

Jika Anda belum menginstal dependensi berikut, jalankan:

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython

try:
  from google.colab import files
except ImportError:
  files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()

agen DQN

Kami akan menyiapkan agen DQN, seperti di colab sebelumnya. Detailnya disembunyikan secara default karena bukan bagian inti dari colab ini, tetapi Anda dapat mengklik 'TAMPILKAN KODE' untuk melihat detailnya.

Hyperparameter

env_name = "CartPole-v1"

collect_steps_per_iteration = 100
replay_buffer_capacity = 100000

fc_layer_params = (100,)

batch_size = 64
learning_rate = 1e-3
log_interval = 5

num_eval_episodes = 10
eval_interval = 1000

Lingkungan

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

Agen

Pengumpulan data

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

Latih agennya

Pembuatan Video

Buat video

Periksa kinerja kebijakan dengan membuat video.

print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step:
<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>

gif

Siapkan Checkpointer dan PolicySaver

Sekarang kita siap untuk menggunakan Checkpointer dan PolicySaver.

pemeriksa

checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
    global_step=global_step
)

Penghemat Kebijakan

policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.

Latih satu iterasi

print('Training one iteration....')
train_one_iteration()
Training one iteration....
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
iteration: 1 loss: 1.0214563608169556

Simpan ke pos pemeriksaan

train_checkpointer.save(global_step)

Pulihkan pos pemeriksaan

Agar ini berfungsi, seluruh rangkaian objek harus dibuat ulang dengan cara yang sama seperti saat pos pemeriksaan dibuat.

train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

Juga simpan kebijakan dan ekspor ke suatu lokasi

tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets

Kebijakan dapat dimuat tanpa mengetahui agen atau jaringan apa yang digunakan untuk membuatnya. Hal ini membuat penerapan kebijakan jauh lebih mudah.

Muat kebijakan yang disimpan dan periksa kinerjanya

saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

Ekspor dan impor

Kolab lainnya akan membantu Anda mengekspor / mengimpor checkpointer dan direktori kebijakan sehingga Anda dapat melanjutkan pelatihan di lain waktu dan menerapkan model tanpa harus melatih lagi.

Sekarang Anda dapat kembali ke 'Latih satu iterasi' dan berlatih beberapa kali lagi sehingga Anda dapat memahami perbedaannya nanti. Setelah Anda mulai melihat hasil yang sedikit lebih baik, lanjutkan di bawah.

Buat file zip dan unggah file zip (klik dua kali untuk melihat kodenya)

Buat file zip dari direktori pos pemeriksaan.

train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))

Unduh file zipnya.

if files is not None:
  files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

Setelah pelatihan selama beberapa waktu (10-15 kali), unduh file zip pos pemeriksaan, dan buka "Waktu Proses > Mulai ulang dan jalankan semua" untuk mengatur ulang pelatihan, dan kembali ke sel ini. Sekarang Anda dapat mengunggah file zip yang diunduh, dan melanjutkan pelatihan.

upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

Setelah Anda mengunggah direktori pos pemeriksaan, kembali ke 'Latih satu iterasi' untuk melanjutkan pelatihan atau kembali ke 'Buat video' untuk memeriksa kinerja kebijakan yang dimuat.

Atau, Anda dapat menyimpan kebijakan (model) dan memulihkannya. Tidak seperti checkpointer, Anda tidak dapat melanjutkan pelatihan, tetapi Anda masih dapat menerapkan model. Perhatikan bahwa file yang diunduh jauh lebih kecil daripada file checkpointer.

tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
  files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

Unggah direktori kebijakan yang diunduh (exported_policy.zip) dan periksa bagaimana kinerja kebijakan yang disimpan.

upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

SavedModelPyTFEAgerPolicy

Jika Anda tidak ingin menggunakan kebijakan TF, maka Anda juga dapat menggunakan saved_model langsung dengan env Python melalui penggunaan py_tf_eager_policy.SavedModelPyTFEagerPolicy .

Perhatikan bahwa ini hanya berfungsi ketika mode bersemangat diaktifkan.

eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())

# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)

gif

Ubah kebijakan menjadi TFLite

Lihat TensorFlow Lite converter untuk lebih jelasnya.

converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
  f.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

Jalankan inferensi pada model TFLite

Lihat TensorFlow Lite Inference untuk lebih jelasnya.

import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))

policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
    '0/discount':tf.constant(0.0),
    '0/observation':tf.zeros([1,4]),
    '0/reward':tf.constant(0.0),
    '0/step_type':tf.constant(0)})
{'action': array([0])}