ลิขสิทธิ์ 2021 The TF-Agents Authors.
ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูแหล่งที่มาบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
บทนำ
tf_agents.utils.common.Checkpointer
เป็นโปรแกรมที่จะบันทึก / โหลดรัฐการฝึกอบรมของรัฐนโยบายและรัฐ replay_buffer / จากการจัดเก็บในท้องถิ่น
tf_agents.policies.policy_saver.PolicySaver
เป็นเครื่องมือในการบันทึก / โหลดเพียงนโยบายและมีน้ำหนักเบากว่า Checkpointer
คุณสามารถใช้ PolicySaver
ในการปรับรูปแบบได้เป็นอย่างดีโดยปราศจากความรู้ของรหัสที่สร้างนโยบายใด ๆ
ในการกวดวิชานี้เราจะใช้ในการฝึกอบรม DQN รูปแบบแล้วใช้ Checkpointer
และ PolicySaver
แสดงให้เห็นว่าเราสามารถจัดเก็บและโหลดรัฐและในรูปแบบวิธีการโต้ตอบ โปรดทราบว่าเราจะใช้ TF2.0 ของเครื่องมือ saved_model ใหม่และรูปแบบสำหรับ PolicySaver
ติดตั้ง
หากคุณไม่ได้ติดตั้งการพึ่งพาต่อไปนี้ ให้เรียกใช้:
sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython
try:
from google.colab import files
except ImportError:
files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()
ตัวแทน DQN
เราจะตั้งค่าตัวแทน DQN เช่นเดียวกับใน colab ก่อนหน้า รายละเอียดจะถูกซ่อนไว้โดยค่าเริ่มต้น เนื่องจากไม่ใช่ส่วนหลักของ colab นี้ แต่คุณสามารถคลิกที่ 'แสดงรหัส' เพื่อดูรายละเอียดได้
ไฮเปอร์พารามิเตอร์
env_name = "CartPole-v1"
collect_steps_per_iteration = 100
replay_buffer_capacity = 100000
fc_layer_params = (100,)
batch_size = 64
learning_rate = 1e-3
log_interval = 5
num_eval_episodes = 10
eval_interval = 1000
สิ่งแวดล้อม
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
ตัวแทน
q_net = q_network.QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
global_step = tf.compat.v1.train.get_or_create_global_step()
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=global_step)
agent.initialize()
การเก็บรวบรวมข้อมูล
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=replay_buffer_capacity)
collect_driver = dynamic_step_driver.DynamicStepDriver(
train_env,
agent.collect_policy,
observers=[replay_buffer.add_batch],
num_steps=collect_steps_per_iteration)
# Initial data collection
collect_driver.run()
# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
num_parallel_calls=3, sample_batch_size=batch_size,
num_steps=2).prefetch(3)
iterator = iter(dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version. Instructions for updating: Use `as_dataset(..., single_deterministic_pass=False) instead.
อบรมตัวแทน
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)
def train_one_iteration():
# Collect a few steps using collect_policy and save to the replay buffer.
collect_driver.run()
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
train_loss = agent.train(experience)
iteration = agent.train_step_counter.numpy()
print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))
การสร้างวิดีโอ
def embed_gif(gif_buffer):
"""Embeds a gif file in the notebook."""
tag = '<img src="data:image/gif;base64,{0}"/>'.format(base64.b64encode(gif_buffer).decode())
return IPython.display.HTML(tag)
def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):
num_episodes = 3
frames = []
for _ in range(num_episodes):
time_step = eval_tf_env.reset()
frames.append(eval_py_env.render())
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = eval_tf_env.step(action_step.action)
frames.append(eval_py_env.render())
gif_file = io.BytesIO()
imageio.mimsave(gif_file, frames, format='gif', fps=60)
IPython.display.display(embed_gif(gif_file.getvalue()))
สร้างวิดีโอ
ตรวจสอบประสิทธิภาพของนโยบายโดยการสร้างวิดีโอ
print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step: <tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>
ตั้งค่า Checkpointer และ PolicySaver
ตอนนี้เราพร้อมที่จะใช้ Checkpointer และ PolicySaver แล้ว
ด่านตรวจ
checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
ckpt_dir=checkpoint_dir,
max_to_keep=1,
agent=agent,
policy=agent.policy,
replay_buffer=replay_buffer,
global_step=global_step
)
ตัวรักษานโยบาย
policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
ฝึกวนซ้ำ
print('Training one iteration....')
train_one_iteration()
Training one iteration.... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.foldr(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems)) WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.foldr(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems)) iteration: 1 loss: 1.0214563608169556
บันทึกลงด่าน
train_checkpointer.save(global_step)
คืนค่าด่าน
เพื่อให้ใช้งานได้ ควรสร้างออบเจ็กต์ทั้งชุดในลักษณะเดียวกับเมื่อสร้างจุดตรวจ
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
บันทึกนโยบายและส่งออกไปยังตำแหน่งด้วย
tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel. WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/policy/assets /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) INFO:tensorflow:Assets written to: /tmp/policy/assets
สามารถโหลดนโยบายได้โดยไม่ต้องรู้ว่ามีการใช้ตัวแทนหรือเครือข่ายใดในการสร้าง ทำให้การปรับใช้นโยบายง่ายขึ้นมาก
โหลดนโยบายที่บันทึกไว้และตรวจสอบว่ามันทำงานอย่างไร
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
ส่งออกและนำเข้า
ส่วนที่เหลือของ colab จะช่วยคุณส่งออก / นำเข้าจุดตรวจสอบและไดเร็กทอรีนโยบายเพื่อให้คุณสามารถดำเนินการฝึกอบรมต่อในภายหลังและปรับใช้โมเดลโดยไม่ต้องฝึกอีกครั้ง
ตอนนี้คุณสามารถกลับไปที่ 'ฝึกซ้ำหนึ่งซ้ำ' และฝึกอีกสองสามครั้งเพื่อให้คุณสามารถเข้าใจความแตกต่างได้ในภายหลัง เมื่อคุณเริ่มเห็นผลลัพธ์ที่ดีขึ้นเล็กน้อย ให้ดำเนินการต่อด้านล่าง
สร้างไฟล์ zip และอัปโหลดไฟล์ zip (ดับเบิลคลิกเพื่อดูรหัส)
def create_zip_file(dirname, base_filename):
return shutil.make_archive(base_filename, 'zip', dirname)
def upload_and_unzip_file_to(dirname):
if files is None:
return
uploaded = files.upload()
for fn in uploaded.keys():
print('User uploaded file "{name}" with length {length} bytes'.format(
name=fn, length=len(uploaded[fn])))
shutil.rmtree(dirname)
zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')
zip_files.extractall(dirname)
zip_files.close()
สร้างไฟล์ซิปจากไดเร็กทอรีจุดตรวจสอบ
train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))
ดาวน์โหลดไฟล์ซิป
if files is not None:
files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
หลังจากการฝึกมาระยะหนึ่ง (10-15 ครั้ง) ให้ดาวน์โหลดไฟล์ zip ของจุดตรวจ และไปที่ "รันไทม์ > รีสตาร์ทและเรียกใช้ทั้งหมด" เพื่อรีเซ็ตการฝึก แล้วกลับมาที่เซลล์นี้ ตอนนี้คุณสามารถอัปโหลดไฟล์ zip ที่ดาวน์โหลดและดำเนินการฝึกอบรมต่อได้
upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
เมื่อคุณอัปโหลดไดเรกทอรีจุดตรวจสอบแล้ว ให้กลับไปที่ 'ฝึกการวนซ้ำหนึ่งครั้ง' เพื่อดำเนินการฝึกอบรมต่อ หรือกลับไปที่ 'สร้างวิดีโอ' เพื่อตรวจสอบประสิทธิภาพของนโยบายที่โหลดไว้
หรือคุณสามารถบันทึกนโยบาย (รุ่น) และกู้คืนได้ คุณไม่สามารถดำเนินการฝึกอบรมต่อได้ ซึ่งแตกต่างจากจุดตรวจสอบ แต่คุณยังสามารถปรับใช้โมเดลได้ โปรดทราบว่าไฟล์ที่ดาวน์โหลดนั้นเล็กกว่าไฟล์ของจุดตรวจสอบมาก
tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel. WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/policy/assets /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
อัปโหลดไดเร็กทอรีนโยบายที่ดาวน์โหลด (exported_policy.zip) และตรวจสอบว่านโยบายที่บันทึกไว้ทำงานอย่างไร
upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
ที่บันทึกไว้รุ่นPyTFEagerPolicy
หากคุณไม่ต้องการที่จะใช้นโยบาย TF แล้วคุณยังสามารถใช้ saved_model โดยตรงกับ env งูใหญ่ผ่านการใช้ py_tf_eager_policy.SavedModelPyTFEagerPolicy
โปรดทราบว่าใช้งานได้เฉพาะเมื่อเปิดใช้งานโหมดกระตือรือร้น
eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())
# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)
แปลงนโยบายเป็น TFLite
ดู แปลง TensorFlow Lite สำหรับรายละเอียดเพิ่มเติม
converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
f.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges. WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
เรียกใช้การอนุมานบนโมเดล TFLite
ดู TensorFlow Lite อนุมาน สำหรับรายละเอียดเพิ่มเติม
import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))
policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
'0/discount':tf.constant(0.0),
'0/observation':tf.zeros([1,4]),
'0/reward':tf.constant(0.0),
'0/step_type':tf.constant(0)})
{'action': array([0])}