कॉपीराइट 2021 टीएफ-एजेंट लेखक।
TensorFlow.org पर देखें | Google Colab में चलाएं | GitHub पर स्रोत देखें | नोटबुक डाउनलोड करें |
परिचय
इस उदाहरण से पता चलता है कि कैसे एक प्रशिक्षित करने के लिए सुदृढ़ TF-एजेंटों पुस्तकालय, के लिए इसी तरह का उपयोग कर Cartpole पर्यावरण पर एजेंट DQN ट्यूटोरियल ।
हम आपको प्रशिक्षण, मूल्यांकन और डेटा संग्रह के लिए सुदृढीकरण सीखने (आरएल) पाइपलाइन में सभी घटकों के माध्यम से चलेंगे।
सेट अप
यदि आपने निम्नलिखित निर्भरताएँ स्थापित नहीं की हैं, तो चलाएँ:
sudo apt-get update
sudo apt-get install -y xvfb ffmpeg freeglut3-dev
pip install 'imageio==2.4.0'
pip install pyvirtualdisplay
pip install tf-agents[reverb]
pip install pyglet xvfbwrapper
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import imageio
import IPython
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay
import reverb
import tensorflow as tf
from tf_agents.agents.reinforce import reinforce_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import py_tf_eager_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
हाइपरपैरामीटर
env_name = "CartPole-v0" # @param {type:"string"}
num_iterations = 250 # @param {type:"integer"}
collect_episodes_per_iteration = 2 # @param {type:"integer"}
replay_buffer_capacity = 2000 # @param {type:"integer"}
fc_layer_params = (100,)
learning_rate = 1e-3 # @param {type:"number"}
log_interval = 25 # @param {type:"integer"}
num_eval_episodes = 10 # @param {type:"integer"}
eval_interval = 50 # @param {type:"integer"}
वातावरण
RL में परिवेश उस कार्य या समस्या का प्रतिनिधित्व करते हैं जिसे हम हल करने का प्रयास कर रहे हैं। स्टैंडर्ड वातावरण आसानी से उपयोग कर रहा TF-एजेंटों में बनाया जा सकता suites
। हम अलग-अलग है suites
ऐसे OpenAI जिम, अटारी, डीएम नियंत्रण, आदि जैसे स्रोतों से वातावरण लोड हो रहा है, एक स्ट्रिंग पर्यावरण नाम दिया है।
अब हम OpenAI जिम सूट से CartPole के वातावरण को लोड करते हैं।
env = suite_gym.load(env_name)
हम इस वातावरण को यह देखने के लिए प्रस्तुत कर सकते हैं कि यह कैसा दिखता है। एक फ्री-स्विंगिंग पोल एक गाड़ी से जुड़ा हुआ है। लक्ष्य ध्रुव को ऊपर की ओर रखने के लिए गाड़ी को दाएं या बाएं ले जाना है।
env.reset()
PIL.Image.fromarray(env.render())
time_step = environment.step(action)
बयान लेता action
के माहौल में। TimeStep
टपल लौटे पर्यावरण की अगली अवलोकन और उस कार्य के लिए पुरस्कार शामिल हैं। time_step_spec()
और action_spec()
वातावरण में तरीकों की विशिष्टताओं (प्रकार, आकार, सीमा) लौट time_step
और action
क्रमशः।
print('Observation Spec:')
print(env.time_step_spec().observation)
print('Action Spec:')
print(env.action_spec())
Observation Spec: BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]) Action Spec: BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)
तो, हम देखते हैं कि अवलोकन 4 फ्लोट्स की एक सरणी है: गाड़ी की स्थिति और वेग, और ध्रुव की कोणीय स्थिति और वेग। के बाद से केवल दो कार्यों संभव हो रहे हैं (इस कदम बाएं या दाएं ले जाते हैं), action_spec
एक अदिश जहां 0 का अर्थ "चाल छोड़ दिया" और 1 का अर्थ है "इस कदम सही है।"
time_step = env.reset()
print('Time step:')
print(time_step)
action = np.array(1, dtype=np.int32)
next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step: TimeStep( {'discount': array(1., dtype=float32), 'observation': array([ 0.02284177, -0.04785635, 0.04171623, 0.04942273], dtype=float32), 'reward': array(0., dtype=float32), 'step_type': array(0, dtype=int32)}) Next time step: TimeStep( {'discount': array(1., dtype=float32), 'observation': array([ 0.02188464, 0.14664337, 0.04270469, -0.22981201], dtype=float32), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)})
आमतौर पर हम दो वातावरण बनाते हैं: एक प्रशिक्षण के लिए और दूसरा मूल्यांकन के लिए। अधिकांश वातावरण शुद्ध अजगर में लिखे गए हैं, लेकिन वे आसानी से उपयोग कर रहा TensorFlow में बदला जा सकता TFPyEnvironment
आवरण। मूल पर्यावरण के एपीआई NumPy सरणी का उपयोग करता है, TFPyEnvironment
इन / करने से धर्मान्तरित Tensors
आप के लिए के लिए और अधिक आसानी से बातचीत TensorFlow नीतियों और एजेंटों के साथ।
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
एजेंट
एल्गोरिथ्म है कि हम एक आर एल समस्या को हल करने के लिए उपयोग एक के रूप में प्रस्तुत किया जाता है Agent
। सुदृढ़ एजेंट के अलावा, TF-एजेंटों की एक किस्म के मानक कार्यान्वयन प्रदान करता है Agents
जैसे DQN , DDPG , TD3 , पीपीओ और सैक ।
एक सुदृढ़ एजेंट बनाने के लिए, हम पहले एक जरूरत Actor Network
है कि पर्यावरण से एक अवलोकन दिया कार्रवाई भविष्यवाणी करने के लिए सीख सकते हैं।
हम आसानी से एक बना सकते हैं Actor Network
टिप्पणियों और कार्यों की ऐनक का उपयोग कर। हम नेटवर्क है, जो इस उदाहरण में, है में परतों निर्दिष्ट कर सकते हैं fc_layer_params
की एक टपल करने के लिए तर्क सेट ints
प्रत्येक छिपा परत के आकार (ऊपर Hyperparameters अनुभाग देखें) का प्रतिनिधित्व।
actor_net = actor_distribution_network.ActorDistributionNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
हम यह भी एक जरूरत optimizer
नेटवर्क हम अभी बनाया प्रशिक्षित करने के लिए, और एक train_step_counter
कितनी बार नेटवर्क अद्यतन किया गया था का ट्रैक रखने के चर।
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)
tf_agent = reinforce_agent.ReinforceAgent(
train_env.time_step_spec(),
train_env.action_spec(),
actor_network=actor_net,
optimizer=optimizer,
normalize_returns=True,
train_step_counter=train_step_counter)
tf_agent.initialize()
नीतियों
TF-एजेंटों में, नीतियों आर एल में नीतियों के मानक धारणा का प्रतिनिधित्व करते हैं: किसी दिए गए time_step
एक कार्रवाई या कार्यों पर एक वितरण का उत्पादन। मुख्य विधि है policy_step = policy.action(time_step)
जहां policy_step
एक नामित टपल है PolicyStep(action, state, info)
। policy_step.action
है action
पर्यावरण के लिए लागू किया जाना है, state
स्टेटफुल (RNN) की नीतियों और के लिए राज्य का प्रतिनिधित्व करता info
इस तरह के कार्यों के लॉग संभावनाओं के रूप में सहायक जानकारी हो सकती है।
एजेंटों में दो नीतियां होती हैं: मुख्य नीति जिसका उपयोग मूल्यांकन/तैनाती (एजेंट.नीति) के लिए किया जाता है और दूसरी नीति जो डेटा संग्रह के लिए उपयोग की जाती है (agent.collect_policy)।
eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy
मेट्रिक्स और मूल्यांकन
किसी पॉलिसी का मूल्यांकन करने के लिए उपयोग की जाने वाली सबसे आम मीट्रिक औसत रिटर्न है। वापसी एक एपिसोड के लिए एक वातावरण में पॉलिसी चलाते समय प्राप्त पुरस्कारों का योग है, और हम आमतौर पर इसे कुछ एपिसोड में औसत करते हैं। हम औसत रिटर्न मीट्रिक की गणना निम्नानुसार कर सकते हैं।
def compute_avg_return(environment, policy, num_episodes=10):
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
return avg_return.numpy()[0]
# Please also see the metrics module for standard implementations of different
# metrics.
फिर से खेलना बफर
आदेश वातावरण से एकत्र किए गए आंकड़ों का ट्रैक रखने के लिए, हम का उपयोग करेगा गूंज , Deepmind द्वारा एक कुशल, एक्स्टेंसिबल, और आसान से उपयोग पुनरावृत्ति प्रणाली। जब हम प्रक्षेपवक्र एकत्र करते हैं और प्रशिक्षण के दौरान उपभोग किया जाता है तो यह अनुभव डेटा संग्रहीत करता है।
इस पुनरावृत्ति बफर tensors कि संग्रहीत करने के लिए कर रहे हैं, जो का उपयोग कर एजेंट से प्राप्त किया जा सकता का वर्णन चश्मा का उपयोग कर निर्माण किया है tf_agent.collect_data_spec
।
table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
tf_agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(
replay_buffer_signature)
table = reverb.Table(
table_name,
max_size=replay_buffer_capacity,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1),
signature=replay_buffer_signature)
reverb_server = reverb.Server([table])
replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
tf_agent.collect_data_spec,
table_name=table_name,
sequence_length=None,
local_server=reverb_server)
rb_observer = reverb_utils.ReverbAddEpisodeObserver(
replay_buffer.py_client,
table_name,
replay_buffer_capacity
)
[reverb/cc/platform/tfrecord_checkpointer.cc:150] Initializing TFRecordCheckpointer in /tmp/tmpem6la471. [reverb/cc/platform/tfrecord_checkpointer.cc:385] Loading latest checkpoint from /tmp/tmpem6la471 [reverb/cc/platform/default/server.cc:71] Started replay server on port 19822
सबसे एजेंटों के लिए, collect_data_spec
एक है Trajectory
अवलोकन, कार्रवाई युक्त टपल नाम है, इनाम आदि
आंकड़ा संग्रहण
जैसा कि REINFORCE पूरे एपिसोड से सीखता है, हम दिए गए डेटा संग्रह नीति का उपयोग करके एक एपिसोड एकत्र करने के लिए एक फ़ंक्शन को परिभाषित करते हैं और डेटा (टिप्पणियों, कार्यों, पुरस्कार आदि) को रीप्ले बफर में प्रक्षेपवक्र के रूप में सहेजते हैं। यहां हम अनुभव संग्रह लूप को चलाने के लिए 'पाइड्राइवर' का उपयोग कर रहे हैं। आप हमारे में TF एजेंटों ड्राइवर के बारे में अधिक सीख सकते हैं ड्राइवरों ट्यूटोरियल ।
def collect_episode(environment, policy, num_episodes):
driver = py_driver.PyDriver(
environment,
py_tf_eager_policy.PyTFEagerPolicy(
policy, use_tf_function=True),
[rb_observer],
max_episodes=num_episodes)
initial_time_step = environment.reset()
driver.run(initial_time_step)
एजेंट को प्रशिक्षण
प्रशिक्षण लूप में पर्यावरण से डेटा एकत्र करना और एजेंट के नेटवर्क को अनुकूलित करना दोनों शामिल हैं। रास्ते में, हम कभी-कभी एजेंट की नीति का मूल्यांकन करके देखेंगे कि हम कैसे कर रहे हैं।
निम्नलिखित को चलने में ~3 मिनट का समय लगेगा।
try:
%%time
except:
pass
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
tf_agent.train = common.function(tf_agent.train)
# Reset the train step
tf_agent.train_step_counter.assign(0)
# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]
for _ in range(num_iterations):
# Collect a few episodes using collect_policy and save to the replay buffer.
collect_episode(
train_py_env, tf_agent.collect_policy, collect_episodes_per_iteration)
# Use data from the buffer and update the agent's network.
iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
trajectories, _ = next(iterator)
train_loss = tf_agent.train(experience=trajectories)
replay_buffer.clear()
step = tf_agent.train_step_counter.numpy()
if step % log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss.loss))
if step % eval_interval == 0:
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
print('step = {0}: Average Return = {1}'.format(step, avg_return))
returns.append(avg_return)
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. step = 25: loss = 0.8549901247024536 [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. step = 50: loss = 1.0025296211242676 step = 50: Average Return = 23.200000762939453 [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. step = 75: loss = 1.1377763748168945 step = 100: loss = 1.318871021270752 step = 100: Average Return = 159.89999389648438 step = 125: loss = 1.5053682327270508 [reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC. step = 150: loss = 0.8051948547363281 step = 150: Average Return = 184.89999389648438 step = 175: loss = 0.6872963905334473 step = 200: loss = 2.7238712310791016 step = 200: Average Return = 186.8000030517578 step = 225: loss = 0.7495002746582031 step = 250: loss = -0.3333401679992676 step = 250: Average Return = 200.0
VISUALIZATION
भूखंडों
हम अपने एजेंट के प्रदर्शन को देखने के लिए रिटर्न बनाम ग्लोबल स्टेप्स की साजिश रच सकते हैं। में Cartpole-v0
, पर्यावरण +1 के ईनाम की हर बार कदम के लिए पोल रहता अप देता है, और के बाद से चरणों की अधिकतम संख्या 200 है, अधिकतम संभावित वापसी भी 200 है।
steps = range(0, num_iterations + 1, eval_interval)
plt.plot(steps, returns)
plt.ylabel('Average Return')
plt.xlabel('Step')
plt.ylim(top=250)
(-0.2349997997283939, 250.0)
वीडियो
प्रत्येक चरण पर परिवेश का प्रतिपादन करके किसी एजेंट के प्रदर्शन की कल्पना करना सहायक होता है। ऐसा करने से पहले, आइए पहले इस कोलाब में वीडियो एम्बेड करने के लिए एक फ़ंक्शन बनाएं।
def embed_mp4(filename):
"""Embeds an mp4 file in the notebook."""
video = open(filename,'rb').read()
b64 = base64.b64encode(video)
tag = '''
<video width="640" height="480" controls>
<source src="data:video/mp4;base64,{0}" type="video/mp4">
Your browser does not support the video tag.
</video>'''.format(b64.decode())
return IPython.display.HTML(tag)
निम्नलिखित कोड कुछ एपिसोड के लिए एजेंट की नीति की कल्पना करता है:
num_episodes = 3
video_filename = 'imageio.mp4'
with imageio.get_writer(video_filename, fps=60) as video:
for _ in range(num_episodes):
time_step = eval_env.reset()
video.append_data(eval_py_env.render())
while not time_step.is_last():
action_step = tf_agent.policy.action(time_step)
time_step = eval_env.step(action_step.action)
video.append_data(eval_py_env.render())
embed_mp4(video_filename)
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned. [swscaler @ 0x5604d224f3c0] Warning: data is not aligned! This can lead to a speed loss