Bản quyền 2021 Các tác giả TF-Agents.
Xem trên TensorFlow.org | Chạy trong Google Colab | Xem nguồn trên GitHub | Tải xuống sổ ghi chép |
Giới thiệu
tf_agents.utils.common.Checkpointer
là một tiện ích để lưu / tải tình trạng đào tạo, tình trạng chính sách, và tiểu bang replay_buffer đến / từ một lưu trữ địa phương.
tf_agents.policies.policy_saver.PolicySaver
là một công cụ để lưu / tải chỉ chủ trương, và nhẹ hơn Checkpointer
. Bạn có thể sử dụng PolicySaver
để triển khai các mô hình như cũng không có bất kỳ kiến thức về mã mà tạo ra chính sách.
Trong hướng dẫn này, chúng tôi sẽ sử dụng DQN để đào tạo người mẫu, sau đó sử dụng Checkpointer
và PolicySaver
để hiển thị như thế nào chúng ta có thể lưu trữ và tải các tiểu bang và mô hình một cách tương tác. Lưu ý rằng chúng tôi sẽ sử dụng dụng cụ saved_model mới TF2.0 và định dạng cho PolicySaver
.
Thành lập
Nếu bạn chưa cài đặt các phần phụ thuộc sau, hãy chạy:
sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython
try:
from google.colab import files
except ImportError:
files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()
Đại lý DQN
Chúng tôi sẽ thiết lập đại lý DQN, giống như trong chuyên mục trước. Các chi tiết được ẩn theo mặc định vì chúng không phải là phần cốt lõi của chuyên mục này, nhưng bạn có thể nhấp vào 'HIỂN THỊ MÃ' để xem chi tiết.
Siêu tham số
env_name = "CartPole-v1"
collect_steps_per_iteration = 100
replay_buffer_capacity = 100000
fc_layer_params = (100,)
batch_size = 64
learning_rate = 1e-3
log_interval = 5
num_eval_episodes = 10
eval_interval = 1000
Môi trường
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
Đại lý
q_net = q_network.QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
global_step = tf.compat.v1.train.get_or_create_global_step()
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=global_step)
agent.initialize()
Thu thập dữ liệu
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=replay_buffer_capacity)
collect_driver = dynamic_step_driver.DynamicStepDriver(
train_env,
agent.collect_policy,
observers=[replay_buffer.add_batch],
num_steps=collect_steps_per_iteration)
# Initial data collection
collect_driver.run()
# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
num_parallel_calls=3, sample_batch_size=batch_size,
num_steps=2).prefetch(3)
iterator = iter(dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version. Instructions for updating: Use `as_dataset(..., single_deterministic_pass=False) instead.
Đào tạo đại lý
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)
def train_one_iteration():
# Collect a few steps using collect_policy and save to the replay buffer.
collect_driver.run()
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
train_loss = agent.train(experience)
iteration = agent.train_step_counter.numpy()
print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))
Tạo video
def embed_gif(gif_buffer):
"""Embeds a gif file in the notebook."""
tag = '<img src="data:image/gif;base64,{0}"/>'.format(base64.b64encode(gif_buffer).decode())
return IPython.display.HTML(tag)
def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):
num_episodes = 3
frames = []
for _ in range(num_episodes):
time_step = eval_tf_env.reset()
frames.append(eval_py_env.render())
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = eval_tf_env.step(action_step.action)
frames.append(eval_py_env.render())
gif_file = io.BytesIO()
imageio.mimsave(gif_file, frames, format='gif', fps=60)
IPython.display.display(embed_gif(gif_file.getvalue()))
Tạo video
Kiểm tra hiệu suất của chính sách bằng cách tạo video.
print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step: <tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>
Thiết lập Checkpointer và PolicySaver
Bây giờ chúng tôi đã sẵn sàng để sử dụng Checkpointer và PolicySaver.
Người kiểm tra
checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
ckpt_dir=checkpoint_dir,
max_to_keep=1,
agent=agent,
policy=agent.policy,
replay_buffer=replay_buffer,
global_step=global_step
)
Trình tiết kiệm chính sách
policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
Đào tạo một lần lặp lại
print('Training one iteration....')
train_one_iteration()
Training one iteration.... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.foldr(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems)) WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.foldr(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems)) iteration: 1 loss: 1.0214563608169556
Lưu vào trạm kiểm soát
train_checkpointer.save(global_step)
Khôi phục trạm kiểm soát
Để điều này hoạt động, toàn bộ tập hợp các đối tượng phải được tạo lại theo cách giống như khi điểm kiểm tra được tạo.
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
Đồng thời lưu chính sách và xuất sang một vị trí
tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel. WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/policy/assets /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) INFO:tensorflow:Assets written to: /tmp/policy/assets
Có thể tải chính sách mà không cần biết tác nhân hoặc mạng nào đã được sử dụng để tạo chính sách. Điều này làm cho việc triển khai chính sách dễ dàng hơn nhiều.
Tải chính sách đã lưu và kiểm tra cách nó hoạt động
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
Xuất khẩu và nhập khẩu
Phần còn lại của chuyên mục sẽ giúp bạn xuất / nhập bộ kiểm tra và thư mục chính sách để bạn có thể tiếp tục đào tạo sau này và triển khai mô hình mà không cần phải đào tạo lại.
Bây giờ bạn có thể quay lại 'Huấn luyện một lần lặp' và huấn luyện thêm một vài lần nữa để sau này bạn có thể hiểu được sự khác biệt. Khi bạn bắt đầu thấy kết quả tốt hơn một chút, hãy tiếp tục bên dưới.
Tạo tệp zip và tải tệp zip lên (nhấp đúp để xem mã)
def create_zip_file(dirname, base_filename):
return shutil.make_archive(base_filename, 'zip', dirname)
def upload_and_unzip_file_to(dirname):
if files is None:
return
uploaded = files.upload()
for fn in uploaded.keys():
print('User uploaded file "{name}" with length {length} bytes'.format(
name=fn, length=len(uploaded[fn])))
shutil.rmtree(dirname)
zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')
zip_files.extractall(dirname)
zip_files.close()
Tạo một tệp nén từ thư mục trạm kiểm soát.
train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))
Tải xuống tệp zip.
if files is not None:
files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
Sau khi huấn luyện một thời gian (10-15 lần), hãy tải xuống tệp zip điểm kiểm tra và đi tới "Runtime> Restart and run all" để đặt lại quá trình huấn luyện và quay lại ô này. Bây giờ bạn có thể tải lên tệp zip đã tải xuống và tiếp tục đào tạo.
upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
Sau khi bạn đã tải lên thư mục điểm kiểm tra, hãy quay lại 'Đào tạo một lần lặp' để tiếp tục đào tạo hoặc quay lại 'Tạo video' để kiểm tra hiệu suất của chính sách đã tải.
Ngoài ra, bạn có thể lưu chính sách (mô hình) và khôi phục nó. Không giống như người kiểm tra, bạn không thể tiếp tục đào tạo, nhưng bạn vẫn có thể triển khai mô hình. Lưu ý rằng tệp đã tải xuống nhỏ hơn nhiều so với tệp của điểm kiểm tra.
tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel. WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/policy/assets /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
Tải lên thư mục chính sách đã tải xuống (export_policy.zip) và kiểm tra cách chính sách đã lưu hoạt động.
upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
SavedModelPyTFEagerPolicy
Nếu bạn không muốn sử dụng chính sách TF, sau đó bạn cũng có thể sử dụng saved_model trực tiếp với env Python thông qua việc sử dụng các py_tf_eager_policy.SavedModelPyTFEagerPolicy
.
Lưu ý rằng điều này chỉ hoạt động khi chế độ háo hức được bật.
eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())
# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)
Chuyển đổi chính sách thành TFLite
Xem TensorFlow Lite Chuyển đổi để biết thêm chi tiết.
converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
f.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges. WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
Chạy suy luận trên mô hình TFLite
Xem TensorFlow Lite suy luận để biết thêm chi tiết.
import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))
policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
'0/discount':tf.constant(0.0),
'0/observation':tf.zeros([1,4]),
'0/reward':tf.constant(0.0),
'0/step_type':tf.constant(0)})
{'action': array([0])}