চেকপয়েন্টার এবং পলিসি সেভার

TensorFlow.org এ দেখুন Google Colab-এ চালান GitHub-এ উৎস দেখুন নোটবুক ডাউনলোড করুন

ভূমিকা

tf_agents.utils.common.Checkpointer সংরক্ষণ / একটি স্থানীয় সংগ্রহস্থল থেকে প্রশিক্ষণ রাষ্ট্র, নীতি রাষ্ট্র, এবং / replay_buffer রাষ্ট্র লোড করতে একটি ইউটিলিটি।

tf_agents.policies.policy_saver.PolicySaver সংরক্ষণ করতে / লোড শুধুমাত্র নীতিটি টুল, এবং চেয়ে অনেক লঘুতর Checkpointer । আপনি ব্যবহার করতে পারেন PolicySaver কোডটি নীতি তৈরি কোন অজ্ঞাতসারে পাশাপাশি মডেল স্থাপন করা।

এই টিউটোরিয়াল, আমরা DQN ব্যবহার একটি মডেল প্রশিক্ষণ, তারপর ব্যবহার করা হবে Checkpointer এবং PolicySaver প্রদর্শনী আমরা কিভাবে সংরক্ষণ এবং একটি ইন্টারেক্টিভ ভাবে রাজ্য এবং মডেল লোড পারেন। মনে রাখবেন আমরা TF2.0 এর নতুন saved_model সাধনী দ্বারা প্রয়োগকরণ এবং জন্য ফর্ম্যাট ব্যবহার করবে PolicySaver

সেটআপ

আপনি যদি নিম্নলিখিত নির্ভরতাগুলি ইনস্টল না করে থাকেন তবে চালান:

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython

try:
  from google.colab import files
except ImportError:
  files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()

DQN এজেন্ট

আমরা আগের কোলাবের মতই DQN এজেন্ট সেট আপ করতে যাচ্ছি। বিশদগুলি ডিফল্টরূপে লুকানো থাকে কারণ সেগুলি এই কোল্যাবের মূল অংশ নয়, তবে আপনি বিস্তারিত দেখতে 'কোড দেখান' এ ক্লিক করতে পারেন৷

হাইপারপ্যারামিটার

env_name = "CartPole-v1"

collect_steps_per_iteration = 100
replay_buffer_capacity = 100000

fc_layer_params = (100,)

batch_size = 64
learning_rate = 1e-3
log_interval = 5

num_eval_episodes = 10
eval_interval = 1000

পরিবেশ

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

প্রতিনিধি

তথ্য সংগ্রহ

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

এজেন্টকে প্রশিক্ষণ দিন

ভিডিও জেনারেশন

একটি ভিডিও তৈরি করুন

একটি ভিডিও তৈরি করে নীতির কার্যকারিতা পরীক্ষা করুন৷

print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step:
<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>

gif

চেকপয়েন্টার এবং পলিসিসেভার সেটআপ করুন

এখন আমরা Checkpointer এবং PolicySaver ব্যবহার করার জন্য প্রস্তুত।

চেকপয়েন্টার

checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
    global_step=global_step
)

পলিসি সেভার

policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.

এক পুনরাবৃত্তি ট্রেন

print('Training one iteration....')
train_one_iteration()
Training one iteration....
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
iteration: 1 loss: 1.0214563608169556

চেকপয়েন্টে সংরক্ষণ করুন

train_checkpointer.save(global_step)

চেকপয়েন্ট পুনরুদ্ধার করুন

এটি কাজ করার জন্য, চেকপয়েন্ট তৈরি করার সময় বস্তুর পুরো সেটটিকে একইভাবে পুনরায় তৈরি করা উচিত।

train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

এছাড়াও নীতি সংরক্ষণ করুন এবং একটি অবস্থানে রপ্তানি করুন

tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets

পলিসিটি তৈরি করার জন্য কোন এজেন্ট বা নেটওয়ার্ক ব্যবহার করা হয়েছে সে সম্পর্কে কোনো জ্ঞান ছাড়াই লোড করা যেতে পারে। এটি নীতির স্থাপনাকে অনেক সহজ করে তোলে।

সংরক্ষিত নীতি লোড করুন এবং এটি কীভাবে কাজ করে তা পরীক্ষা করুন

saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

রপ্তানি এবং আমদানি

বাকি কোল্যাব আপনাকে চেকপয়েন্টার এবং নীতি নির্দেশিকাগুলি রপ্তানি/আমদানি করতে সাহায্য করবে যাতে আপনি পরবর্তী সময়ে প্রশিক্ষণ চালিয়ে যেতে পারেন এবং পুনরায় প্রশিক্ষণ না নিয়েই মডেলটি স্থাপন করতে পারেন।

এখন আপনি 'ট্রেন ওয়ান আইটারেশন'-এ ফিরে যেতে পারেন এবং আরও কয়েকবার প্রশিক্ষণ দিতে পারেন যাতে আপনি পরে পার্থক্য বুঝতে পারেন। একবার আপনি একটু ভালো ফলাফল দেখতে শুরু করলে, নিচে চালিয়ে যান।

জিপ ফাইল তৈরি করুন এবং জিপ ফাইল আপলোড করুন (কোডটি দেখতে ডাবল ক্লিক করুন)

চেকপয়েন্ট ডিরেক্টরি থেকে একটি জিপ করা ফাইল তৈরি করুন।

train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))

জিপ ফাইলটি ডাউনলোড করুন।

if files is not None:
  files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

কিছু সময়ের জন্য প্রশিক্ষণের পর (10-15 বার), চেকপয়েন্ট জিপ ফাইলটি ডাউনলোড করুন এবং ট্রেনিং রিসেট করতে "Runtime > Restart and run all" এ যান এবং এই সেলে ফিরে আসুন। এখন আপনি ডাউনলোড করা জিপ ফাইল আপলোড করতে পারেন, এবং প্রশিক্ষণ চালিয়ে যেতে পারেন।

upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

একবার আপনি চেকপয়েন্ট ডিরেক্টরি আপলোড করলে, প্রশিক্ষণ চালিয়ে যেতে 'একটি পুনরাবৃত্তির ট্রেন'-এ ফিরে যান বা লোড করা নীতির কার্যকারিতা পরীক্ষা করতে 'একটি ভিডিও তৈরি করুন'-এ ফিরে যান।

বিকল্পভাবে, আপনি নীতি (মডেল) সংরক্ষণ করতে পারেন এবং এটি পুনরুদ্ধার করতে পারেন। চেকপয়েন্টারের বিপরীতে, আপনি প্রশিক্ষণ চালিয়ে যেতে পারবেন না, তবে আপনি এখনও মডেলটি স্থাপন করতে পারেন। উল্লেখ্য যে ডাউনলোড করা ফাইলটি চেকপয়েন্টারের তুলনায় অনেক ছোট।

tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
  files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

ডাউনলোড করা পলিসি ডাইরেক্টরি (exported_policy.zip) আপলোড করুন এবং সেভ করা পলিসি কীভাবে কাজ করে তা দেখুন।

upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

SavedModelPyTFEagerPolicy

আপনি মেমরি নীতি ব্যবহার করতে না চান, তাহলে আপনার কাছে saved_model সরাসরি পাইথন env সঙ্গে ব্যবহারের মাধ্যমে ব্যবহার করতে পারেন py_tf_eager_policy.SavedModelPyTFEagerPolicy

মনে রাখবেন যে এটি শুধুমাত্র তখনই কাজ করে যখন ইজার মোড সক্ষম থাকে৷

eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())

# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)

gif

নীতিকে TFLite-এ রূপান্তর করুন

দেখুন TensorFlow লাইট রূপান্তরকারী আরো বিস্তারিত জানার জন্য।

converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
  f.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

TFLite মডেলে অনুমান চালান

দেখুন TensorFlow লাইট ইনফিরেনস আরো বিস্তারিত জানার জন্য।

import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))

policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
    '0/discount':tf.constant(0.0),
    '0/observation':tf.zeros([1,4]),
    '0/reward':tf.constant(0.0),
    '0/step_type':tf.constant(0)})
{'action': array([0])}