حقوق الطبع والنشر 2021 The TF-Agents Authors.
عرض على TensorFlow.org | تشغيل في Google Colab | عرض المصدر على جيثب | تحميل دفتر |
مقدمة
tf_agents.utils.common.Checkpointer
هو أداة لحفظ / تحميل دولة والتدريب، ودولة السياسة، والدولة replay_buffer إلى / من التخزين المحلي.
tf_agents.policies.policy_saver.PolicySaver
هو أداة لحفظ / تحميل فقط السياسة، وأخف من Checkpointer
. يمكنك استخدام PolicySaver
لنشر نموذج، وكذلك دون أي معرفة من التعليمات البرمجية التي تم إنشاؤها السياسة.
في هذا البرنامج التعليمي، سوف نستخدم DQN لتدريب نموذج، ثم استخدم Checkpointer
و PolicySaver
لإظهار كيف يمكننا تخزين وتحميل الدول ونموذج بطريقة تفاعلية. علما بأن سوف نستخدم الأدوات saved_model جديدة TF2.0 وتنسيق PolicySaver
.
يثبت
إذا لم تكن قد قمت بتثبيت التبعيات التالية ، فقم بتشغيل:
sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython
try:
from google.colab import files
except ImportError:
files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()
وكيل DQN
سنقوم بإعداد وكيل DQN ، تمامًا كما في colab السابق. يتم إخفاء التفاصيل افتراضيًا لأنها ليست جزءًا أساسيًا من هذا الكولاب ، ولكن يمكنك النقر فوق "إظهار الكود" للاطلاع على التفاصيل.
Hyperparameters
env_name = "CartPole-v1"
collect_steps_per_iteration = 100
replay_buffer_capacity = 100000
fc_layer_params = (100,)
batch_size = 64
learning_rate = 1e-3
log_interval = 5
num_eval_episodes = 10
eval_interval = 1000
بيئة
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
وكيلات
q_net = q_network.QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
global_step = tf.compat.v1.train.get_or_create_global_step()
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=global_step)
agent.initialize()
جمع البيانات
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=replay_buffer_capacity)
collect_driver = dynamic_step_driver.DynamicStepDriver(
train_env,
agent.collect_policy,
observers=[replay_buffer.add_batch],
num_steps=collect_steps_per_iteration)
# Initial data collection
collect_driver.run()
# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
num_parallel_calls=3, sample_batch_size=batch_size,
num_steps=2).prefetch(3)
iterator = iter(dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version. Instructions for updating: Use `as_dataset(..., single_deterministic_pass=False) instead.
تدريب الوكيل
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)
def train_one_iteration():
# Collect a few steps using collect_policy and save to the replay buffer.
collect_driver.run()
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
train_loss = agent.train(experience)
iteration = agent.train_step_counter.numpy()
print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))
توليد الفيديو
def embed_gif(gif_buffer):
"""Embeds a gif file in the notebook."""
tag = '<img src="data:image/gif;base64,{0}"/>'.format(base64.b64encode(gif_buffer).decode())
return IPython.display.HTML(tag)
def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):
num_episodes = 3
frames = []
for _ in range(num_episodes):
time_step = eval_tf_env.reset()
frames.append(eval_py_env.render())
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = eval_tf_env.step(action_step.action)
frames.append(eval_py_env.render())
gif_file = io.BytesIO()
imageio.mimsave(gif_file, frames, format='gif', fps=60)
IPython.display.display(embed_gif(gif_file.getvalue()))
قم بإنشاء مقطع فيديو
تحقق من أداء السياسة عن طريق إنشاء مقطع فيديو.
print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step: <tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>
إعداد Checkpointer و PolicySaver
الآن نحن جاهزون لاستخدام Checkpointer و PolicySaver.
Checkpointer
checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
ckpt_dir=checkpoint_dir,
max_to_keep=1,
agent=agent,
policy=agent.policy,
replay_buffer=replay_buffer,
global_step=global_step
)
حافظ السياسة
policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
تدريب تكرار واحد
print('Training one iteration....')
train_one_iteration()
Training one iteration.... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.foldr(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems)) WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.foldr(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems)) iteration: 1 loss: 1.0214563608169556
حفظ في نقطة التفتيش
train_checkpointer.save(global_step)
استعادة نقطة التفتيش
لكي يعمل هذا ، يجب إعادة إنشاء مجموعة الكائنات بالكامل بنفس الطريقة التي تم بها إنشاء نقطة التحقق.
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
أيضا حفظ السياسة والتصدير إلى الموقع
tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel. WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/policy/assets /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) INFO:tensorflow:Assets written to: /tmp/policy/assets
يمكن تحميل السياسة دون معرفة أي وكيل أو شبكة تم استخدامها لإنشائها. هذا يجعل نشر السياسة أسهل بكثير.
قم بتحميل السياسة المحفوظة وتحقق من كيفية أدائها
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
التصدير والاستيراد
سيساعدك باقي colab على تصدير / استيراد دليل الفحص وأدلة السياسة بحيث يمكنك متابعة التدريب في وقت لاحق ونشر النموذج دون الحاجة إلى التدريب مرة أخرى.
يمكنك الآن العودة إلى "تدريب تكرار واحد" والتدريب بضع مرات أخرى حتى تتمكن من فهم الفرق لاحقًا. بمجرد أن تبدأ في رؤية نتائج أفضل قليلاً ، تابع أدناه.
إنشاء ملف مضغوط وتحميل ملف مضغوط (انقر نقرا مزدوجا لرؤية الرمز)
def create_zip_file(dirname, base_filename):
return shutil.make_archive(base_filename, 'zip', dirname)
def upload_and_unzip_file_to(dirname):
if files is None:
return
uploaded = files.upload()
for fn in uploaded.keys():
print('User uploaded file "{name}" with length {length} bytes'.format(
name=fn, length=len(uploaded[fn])))
shutil.rmtree(dirname)
zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')
zip_files.extractall(dirname)
zip_files.close()
قم بإنشاء ملف مضغوط من دليل نقاط التفتيش.
train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))
قم بتنزيل الملف المضغوط.
if files is not None:
files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
بعد التدريب لبعض الوقت (10-15 مرة) ، قم بتنزيل ملف checkpoint zip ، وانتقل إلى "Runtime> Restart and Run all" لإعادة تعيين التدريب ، والعودة إلى هذه الخلية. يمكنك الآن تحميل الملف المضغوط الذي تم تنزيله ومتابعة التدريب.
upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
بمجرد تحميل دليل نقاط التفتيش ، ارجع إلى "تدريب واحد على التكرار" لمواصلة التدريب أو ارجع إلى "إنشاء فيديو" للتحقق من أداء السياسة المحملة.
بدلاً من ذلك ، يمكنك حفظ السياسة (النموذج) واستعادتها. على عكس checkpointer ، لا يمكنك متابعة التدريب ، ولكن لا يزال بإمكانك نشر النموذج. لاحظ أن الملف الذي تم تنزيله أصغر بكثير من ملف checkpointer.
tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel. WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/policy/assets /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
قم بتحميل دليل السياسة الذي تم تنزيله (exported_policy.zip) وتحقق من كيفية أداء السياسة المحفوظة.
upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
SavedModelPyTFEagerPolicy
إذا كنت لا تريد استخدام سياسة TF، ثم يمكنك أيضا استخدام saved_model مباشرة مع الحياة الفطرية بيثون من خلال استخدام py_tf_eager_policy.SavedModelPyTFEagerPolicy
.
لاحظ أن هذا لا يعمل إلا عند تمكين الوضع الدائم.
eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())
# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)
تحويل السياسة إلى TFLite
انظر TensorFlow لايت تحويل لمزيد من التفاصيل.
converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
f.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges. WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
تشغيل الاستدلال على نموذج TFLite
انظر TensorFlow لايت الاستدلال لمزيد من التفاصيل.
import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))
policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
'0/discount':tf.constant(0.0),
'0/observation':tf.zeros([1,4]),
'0/reward':tf.constant(0.0),
'0/step_type':tf.constant(0)})
{'action': array([0])}