Bản quyền 2021 Các tác giả TF-Agents.
Xem trên TensorFlow.org | Chạy trong Google Colab | Xem nguồn trên GitHub | Tải xuống sổ ghi chép |
Giới thiệu
Ví dụ này cho thấy cách để đào tạo một củng cố đại lý đối với môi trường Cartpole sử dụng thư viện TF-Đại lý, tương tự như hướng dẫn DQN .
Chúng tôi sẽ hướng dẫn bạn qua tất cả các thành phần trong đường dẫn Học tập củng cố (RL) để đào tạo, đánh giá và thu thập dữ liệu.
Thành lập
Nếu bạn chưa cài đặt các phần phụ thuộc sau, hãy chạy:
sudo apt-get update
sudo apt-get install -y xvfb ffmpeg freeglut3-dev
pip install 'imageio==2.4.0'
pip install pyvirtualdisplay
pip install tf-agents[reverb]
pip install pyglet xvfbwrapper
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import imageio
import IPython
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay
import reverb
import tensorflow as tf
from tf_agents.agents.reinforce import reinforce_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import py_tf_eager_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
Siêu tham số
env_name = "CartPole-v0" # @param {type:"string"}
num_iterations = 250 # @param {type:"integer"}
collect_episodes_per_iteration = 2 # @param {type:"integer"}
replay_buffer_capacity = 2000 # @param {type:"integer"}
fc_layer_params = (100,)
learning_rate = 1e-3 # @param {type:"number"}
log_interval = 25 # @param {type:"integer"}
num_eval_episodes = 10 # @param {type:"integer"}
eval_interval = 50 # @param {type:"integer"}
Môi trường
Các môi trường trong RL đại diện cho nhiệm vụ hoặc vấn đề mà chúng tôi đang cố gắng giải quyết. Môi trường tiêu chuẩn có thể dễ dàng tạo ra trong TF-Đại lý sử dụng suites
. Chúng tôi có khác nhau suites
cho tải môi trường từ các nguồn như OpenAI phòng tập thể dục, Atari, DM kiểm soát, vv cho một tên môi trường chuỗi.
Bây giờ chúng ta hãy tải môi trường CartPole từ bộ OpenAI Gym.
env = suite_gym.load(env_name)
Chúng ta có thể kết xuất môi trường này để xem nó trông như thế nào. Một cây sào đung đưa tự do được gắn vào một xe đẩy. Mục đích là di chuyển xe sang phải hoặc trái để giữ cho cột hướng lên trên.
env.reset()
PIL.Image.fromarray(env.render())
Các time_step = environment.step(action)
tuyên bố sẽ đưa action
trong môi trường. Các TimeStep
tuple trở chứa quan sát tiếp theo của môi trường và phần thưởng cho hành động đó. Các time_step_spec()
và action_spec()
phương pháp trong môi trường trả lại thông số kỹ thuật (chủng loại, hình dạng, giới hạn) của time_step
và action
tương ứng.
print('Observation Spec:')
print(env.time_step_spec().observation)
print('Action Spec:')
print(env.action_spec())
Observation Spec: BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]) Action Spec: BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)
Vì vậy, chúng ta thấy rằng quan sát là một mảng gồm 4 vật nổi: vị trí và vận tốc của xe đẩy, vị trí góc và vận tốc của cực. Vì chỉ có hai hành động có thể xảy ra (di chuyển sang trái hoặc di chuyển bên phải), các action_spec
là một vô hướng trong đó 0 có nghĩa là "di chuyển sang trái" và 1 có nghĩa là "bước đi đúng đắn."
time_step = env.reset()
print('Time step:')
print(time_step)
action = np.array(1, dtype=np.int32)
next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step: TimeStep( {'discount': array(1., dtype=float32), 'observation': array([ 0.02284177, -0.04785635, 0.04171623, 0.04942273], dtype=float32), 'reward': array(0., dtype=float32), 'step_type': array(0, dtype=int32)}) Next time step: TimeStep( {'discount': array(1., dtype=float32), 'observation': array([ 0.02188464, 0.14664337, 0.04270469, -0.22981201], dtype=float32), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)})
Thông thường chúng tôi tạo ra hai môi trường: một để đào tạo và một để đánh giá. Hầu hết các môi trường được viết bằng python tinh khiết, nhưng họ có thể dễ dàng chuyển đổi sang sử dụng TensorFlow TFPyEnvironment
wrapper. API môi trường ban đầu của sử dụng mảng NumPy, các TFPyEnvironment
chuyển đổi những đến / từ Tensors
để bạn có thể dễ dàng hơn tương tác với các chính sách và các đại lý TensorFlow.
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
Đại lý
Các thuật toán mà chúng tôi sử dụng để giải quyết một vấn đề RL được biểu diễn dưới dạng một Agent
. Ngoài các đại lý củng cố, TF-Đại lý cung cấp triển khai tiêu chuẩn của một loạt các Agents
như DQN , DDPG , TD3 , PPO và SAC .
Để tạo một củng cố Agent, đầu tiên chúng ta cần một Actor Network
đó có thể học để dự đoán hành động đưa ra một quan sát từ môi trường.
Chúng ta có thể dễ dàng tạo một Actor Network
sử dụng các thông số kỹ thuật của các quan sát và hành động. Chúng tôi có thể xác định các lớp trong mạng đó, trong ví dụ này, là fc_layer_params
bộ tranh luận với một tuple của ints
hiện các kích thước của mỗi lớp ẩn (xem phần siêu tham số trên).
actor_net = actor_distribution_network.ActorDistributionNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
Chúng ta cũng cần một optimizer
để đào tạo mạng, chúng tôi vừa tạo ra, và một train_step_counter
biến để theo dõi bao nhiêu lần mạng đã được cập nhật.
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)
tf_agent = reinforce_agent.ReinforceAgent(
train_env.time_step_spec(),
train_env.action_spec(),
actor_network=actor_net,
optimizer=optimizer,
normalize_returns=True,
train_step_counter=train_step_counter)
tf_agent.initialize()
Chính sách
Trong TF-Đại lý, chính sách đại diện cho quan điểm tiêu chuẩn của chính sách trong RL: cho một time_step
tạo ra một hành động hoặc một bản phân phối qua hành động. Phương pháp chính là policy_step = policy.action(time_step)
nơi policy_step
là một tên tuple PolicyStep(action, state, info)
. Các policy_step.action
là action
được áp dụng đối với môi trường, state
đại diện cho nhà nước cho stateful chính sách và (RNN) info
có thể chứa thông tin phụ trợ như xác suất log của hành động.
Đại lý chứa hai chính sách: chính sách chính được sử dụng để đánh giá / triển khai (agent.policy) và một chính sách khác được sử dụng để thu thập dữ liệu (agent.collect_policy).
eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy
Số liệu và Đánh giá
Số liệu phổ biến nhất được sử dụng để đánh giá một chính sách là lợi tức trung bình. Lợi nhuận là tổng số phần thưởng nhận được khi chạy chính sách trong môi trường cho một tập và chúng tôi thường tính trung bình con số này qua một vài tập. Chúng tôi có thể tính toán số liệu lợi nhuận trung bình như sau.
def compute_avg_return(environment, policy, num_episodes=10):
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
return avg_return.numpy()[0]
# Please also see the metrics module for standard implementations of different
# metrics.
Replay Buffer
Để theo dõi các dữ liệu thu thập từ môi trường, chúng tôi sẽ sử dụng Reverb , một hệ thống phát lại hiệu quả, mở rộng, và dễ dàng sử dụng bởi Deepmind. Nó lưu trữ dữ liệu kinh nghiệm khi chúng tôi thu thập quỹ đạo và được tiêu thụ trong quá trình đào tạo.
Đệm phát lại này được xây dựng bằng kỹ thuật mô tả tensors mà phải được lưu trữ, có thể được lấy từ các đại lý sử dụng tf_agent.collect_data_spec
.
table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
tf_agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(
replay_buffer_signature)
table = reverb.Table(
table_name,
max_size=replay_buffer_capacity,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1),
signature=replay_buffer_signature)
reverb_server = reverb.Server([table])
replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
tf_agent.collect_data_spec,
table_name=table_name,
sequence_length=None,
local_server=reverb_server)
rb_observer = reverb_utils.ReverbAddEpisodeObserver(
replay_buffer.py_client,
table_name,
replay_buffer_capacity
)
[reverb/cc/platform/tfrecord_checkpointer.cc:150] Initializing TFRecordCheckpointer in /tmp/tmpem6la471. [reverb/cc/platform/tfrecord_checkpointer.cc:385] Loading latest checkpoint from /tmp/tmpem6la471 [reverb/cc/platform/default/server.cc:71] Started replay server on port 19822
Đối với hầu hết các đại lý, các collect_data_spec
là một Trajectory
tên là tuple chứa các quan sát, hành động, khen thưởng, vv
Thu thập dữ liệu
Khi REINFORCE học từ toàn bộ các tập, chúng tôi xác định một hàm để thu thập một tập bằng cách sử dụng chính sách thu thập dữ liệu đã cho và lưu dữ liệu (quan sát, hành động, phần thưởng, v.v.) dưới dạng quỹ đạo trong bộ đệm phát lại. Ở đây chúng tôi đang sử dụng 'PyDriver' để chạy vòng lặp thu thập kinh nghiệm. Bạn có thể tìm hiểu thêm về lái xe TF Đại lý tại chúng tôi lái xe hướng dẫn .
def collect_episode(environment, policy, num_episodes):
driver = py_driver.PyDriver(
environment,
py_tf_eager_policy.PyTFEagerPolicy(
policy, use_tf_function=True),
[rb_observer],
max_episodes=num_episodes)
initial_time_step = environment.reset()
driver.run(initial_time_step)
Đào tạo đại lý
Vòng huấn luyện bao gồm cả việc thu thập dữ liệu từ môi trường và tối ưu hóa mạng của tác nhân. Trên đường đi, chúng tôi sẽ thỉnh thoảng đánh giá chính sách của đại lý để xem chúng tôi đang hoạt động như thế nào.
Phần sau sẽ mất ~ 3 phút để chạy.
try:
%%time
except:
pass
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
tf_agent.train = common.function(tf_agent.train)
# Reset the train step
tf_agent.train_step_counter.assign(0)
# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]
for _ in range(num_iterations):
# Collect a few episodes using collect_policy and save to the replay buffer.
collect_episode(
train_py_env, tf_agent.collect_policy, collect_episodes_per_iteration)
# Use data from the buffer and update the agent's network.
iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
trajectories, _ = next(iterator)
train_loss = tf_agent.train(experience=trajectories)
replay_buffer.clear()
step = tf_agent.train_step_counter.numpy()
if step % log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss.loss))
if step % eval_interval == 0:
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
print('step = {0}: Average Return = {1}'.format(step, avg_return))
returns.append(avg_return)
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. step = 25: loss = 0.8549901247024536 [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. step = 50: loss = 1.0025296211242676 step = 50: Average Return = 23.200000762939453 [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. step = 75: loss = 1.1377763748168945 step = 100: loss = 1.318871021270752 step = 100: Average Return = 159.89999389648438 step = 125: loss = 1.5053682327270508 [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. step = 150: loss = 0.8051948547363281 step = 150: Average Return = 184.89999389648438 step = 175: loss = 0.6872963905334473 step = 200: loss = 2.7238712310791016 step = 200: Average Return = 186.8000030517578 step = 225: loss = 0.7495002746582031 step = 250: loss = -0.3333401679992676 step = 250: Average Return = 200.0
Hình dung
Lô đất
Chúng tôi có thể vẽ biểu đồ trở lại so với các bước toàn cầu để xem hiệu suất của đại lý của chúng tôi. Trong Cartpole-v0
, môi trường cung cấp cho một phần thưởng của 1 cho mỗi bước thời gian ở lại cực lên, và vì số lượng tối đa các bước là 200, có thể tối đa lợi nhuận cũng là 200.
steps = range(0, num_iterations + 1, eval_interval)
plt.plot(steps, returns)
plt.ylabel('Average Return')
plt.xlabel('Step')
plt.ylim(top=250)
(-0.2349997997283939, 250.0)
Video
Sẽ rất hữu ích nếu bạn hình dung hiệu suất của một tác nhân bằng cách hiển thị môi trường ở mỗi bước. Trước khi làm điều đó, trước tiên chúng ta hãy tạo một chức năng để nhúng video vào chuyên mục này.
def embed_mp4(filename):
"""Embeds an mp4 file in the notebook."""
video = open(filename,'rb').read()
b64 = base64.b64encode(video)
tag = '''
<video width="640" height="480" controls>
<source src="data:video/mp4;base64,{0}" type="video/mp4">
Your browser does not support the video tag.
</video>'''.format(b64.decode())
return IPython.display.HTML(tag)
Đoạn mã sau hiển thị chính sách của đại lý trong một vài tập:
num_episodes = 3
video_filename = 'imageio.mp4'
with imageio.get_writer(video_filename, fps=60) as video:
for _ in range(num_episodes):
time_step = eval_env.reset()
video.append_data(eval_py_env.render())
while not time_step.is_last():
action_step = tf_agent.policy.action(time_step)
time_step = eval_env.step(action_step.action)
video.append_data(eval_py_env.render())
embed_mp4(video_filename)
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned. [swscaler @ 0x5604d224f3c0] Warning: data is not aligned! This can lead to a speed loss