Hướng dẫn về Băng cướp nhiều cánh tay với các tính năng trên mỗi cánh tay

Bắt đầu

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Hướng dẫn này là hướng dẫn từng bước về cách sử dụng thư viện TF-Agents cho các vấn đề về kẻ cướp theo ngữ cảnh trong đó các hành động (cánh tay) có các tính năng riêng, chẳng hạn như danh sách các phim được thể hiện theo các tính năng (thể loại, năm phát hành, ...).

Điều kiện tiên quyết

Người ta cho rằng người đọc có phần quen thuộc với các thư viện Bandit của TF-Đại lý, đặc biệt, đã làm việc thông qua các hướng dẫn cho Bandits vào TF-Đại lý trước khi đọc hướng dẫn này.

Băng cướp nhiều vũ trang với các tính năng của cánh tay

Trong cài đặt Kẻ cướp đa vũ trang theo ngữ cảnh "cổ điển", tác nhân nhận được vectơ ngữ cảnh (còn gọi là quan sát) tại mỗi bước thời gian và phải chọn từ một tập hợp hữu hạn các hành động được đánh số (cánh tay) để tối đa hóa phần thưởng tích lũy của nó.

Bây giờ hãy xem xét tình huống mà nhân viên giới thiệu cho người dùng bộ phim tiếp theo để xem. Mỗi khi phải đưa ra quyết định, agent sẽ nhận được một số thông tin về người dùng theo ngữ cảnh (lịch sử xem, sở thích thể loại, v.v.), cũng như danh sách các bộ phim để lựa chọn.

Chúng ta có thể cố gắng để xây dựng vấn đề này bằng việc có thông tin người dùng như bối cảnh và các loại vũ sẽ movie_1, movie_2, ..., movie_K , nhưng phương pháp này có nhiều thiếu sót:

  • Số lượng hành động sẽ phải là tất cả các phim trong hệ thống và việc thêm một phim mới sẽ rất phức tạp.
  • Đặc vụ phải học một mô hình cho mỗi bộ phim.
  • Sự tương đồng giữa các bộ phim không được tính đến.

Thay vì đánh số thứ tự các bộ phim, chúng ta có thể làm điều gì đó trực quan hơn: chúng ta có thể trình bày các bộ phim bằng một tập hợp các tính năng bao gồm thể loại, thời lượng, dàn diễn viên, xếp hạng, năm, v.v. Ưu điểm của cách tiếp cận này là rất đa dạng:

  • Tổng quát trên các bộ phim.
  • Tác nhân chỉ học một chức năng phần thưởng mà mô hình phần thưởng với các tính năng của người dùng và phim.
  • Dễ dàng xóa hoặc giới thiệu phim mới vào hệ thống.

Trong cài đặt mới này, số lượng hành động thậm chí không phải giống nhau trong mỗi bước thời gian.

Băng cướp Per-Arm trong TF-Agent

Bộ TF-Agents Bandit được phát triển để người ta có thể sử dụng nó cho cả hộp đựng trên tay. Có các môi trường trên mỗi nhánh và hầu hết các chính sách và tác nhân có thể hoạt động ở chế độ mỗi nhánh.

Trước khi đi sâu vào viết mã một ví dụ, chúng ta cần nhập những thứ cần thiết.

Cài đặt

pip install tf-agents

Nhập khẩu

import functools
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tf_agents.bandits.agents import lin_ucb_agent
from tf_agents.bandits.environments import stationary_stochastic_per_arm_py_environment as p_a_env
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import tf_py_environment
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts

nest = tf.nest

Tham số - Thoải mái chơi xung quanh

# The dimension of the global features.
GLOBAL_DIM = 40 
# The elements of the global feature will be integers in [-GLOBAL_BOUND, GLOBAL_BOUND).
GLOBAL_BOUND = 10 
# The dimension of the per-arm features.
PER_ARM_DIM = 50 
# The elements of the PER-ARM feature will be integers in [-PER_ARM_BOUND, PER_ARM_BOUND).
PER_ARM_BOUND = 6 
# The variance of the Gaussian distribution that generates the rewards.
VARIANCE = 100.0 
# The elements of the linear reward parameter will be integers in [-PARAM_BOUND, PARAM_BOUND).
PARAM_BOUND = 10 

NUM_ACTIONS = 70 
BATCH_SIZE = 20 

# Parameter for linear reward function acting on the
# concatenation of global and per-arm features.
reward_param = list(np.random.randint(
      -PARAM_BOUND, PARAM_BOUND, [GLOBAL_DIM + PER_ARM_DIM]))

Một môi trường đơn giản cho mỗi cánh tay

Môi trường ngẫu nhiên tĩnh, giải thích trong khác hướng dẫn , có một người đồng mỗi cánh tay.

Để khởi tạo môi trường per-arm, người ta phải xác định các hàm tạo ra

  • toàn cầu và tính năng mỗi cánh tay: Các chức năng này không có thông số đầu vào và tạo ra một đơn (toàn cầu hoặc mỗi cánh tay) vector đặc trưng khi gọi.
  • phần thưởng: chức năng này có tham số được nối của một toàn cầu và một vector đặc trưng cho mỗi cánh tay, và tạo ra một phần thưởng. Về cơ bản đây là chức năng mà agent sẽ phải "đoán". Điều đáng chú ý ở đây là trong trường hợp mỗi nhánh, chức năng phần thưởng giống hệt nhau cho mỗi nhánh. Đây là điểm khác biệt cơ bản so với trường hợp cướp cổ điển, trong đó đặc vụ phải ước tính các chức năng phần thưởng cho từng nhánh một cách độc lập.
def global_context_sampling_fn():
  """This function generates a single global observation vector."""
  return np.random.randint(
      -GLOBAL_BOUND, GLOBAL_BOUND, [GLOBAL_DIM]).astype(np.float32)

def per_arm_context_sampling_fn():
  """"This function generates a single per-arm observation vector."""
  return np.random.randint(
      -PER_ARM_BOUND, PER_ARM_BOUND, [PER_ARM_DIM]).astype(np.float32)

def linear_normal_reward_fn(x):
  """This function generates a reward from the concatenated global and per-arm observations."""
  mu = np.dot(x, reward_param)
  return np.random.normal(mu, VARIANCE)

Bây giờ chúng ta đã được trang bị để khởi tạo môi trường của mình.

per_arm_py_env = p_a_env.StationaryStochasticPerArmPyEnvironment(
    global_context_sampling_fn,
    per_arm_context_sampling_fn,
    NUM_ACTIONS,
    linear_normal_reward_fn,
    batch_size=BATCH_SIZE
)
per_arm_tf_env = tf_py_environment.TFPyEnvironment(per_arm_py_env)

Dưới đây chúng ta có thể kiểm tra những gì môi trường này tạo ra.

print('observation spec: ', per_arm_tf_env.observation_spec())
print('\nAn observation: ', per_arm_tf_env.reset().observation)

action = tf.zeros(BATCH_SIZE, dtype=tf.int32)
time_step = per_arm_tf_env.step(action)
print('\nRewards after taking an action: ', time_step.reward)
observation spec:  {'global': TensorSpec(shape=(40,), dtype=tf.float32, name=None), 'per_arm': TensorSpec(shape=(70, 50), dtype=tf.float32, name=None)}

An observation:  {'global': <tf.Tensor: shape=(20, 40), dtype=float32, numpy=
array([[ -9.,  -4.,  -3.,   3.,   5.,  -9.,   6.,  -5.,   4.,  -8.,  -6.,
         -1.,  -7.,  -5.,   7.,   8.,   2.,   5.,  -8.,   0.,  -4.,   4.,
         -1.,  -1.,  -4.,   6.,   8.,   6.,   9.,  -5.,  -1.,  -1.,   2.,
          5.,  -1.,  -8.,   1.,   0.,   0.,   5.],
       [  5.,   7.,   0.,   3.,  -8.,  -7.,  -5.,  -2.,  -8.,  -7.,  -7.,
         -8.,   5.,  -3.,   5.,   4.,  -5.,   2.,  -6., -10.,  -4.,  -2.,
          2.,  -1.,  -1.,   8.,  -7.,   7.,   2.,  -3., -10.,  -1.,  -4.,
         -7.,   3.,   4.,   8.,  -2.,   9.,   5.],
       [ -6.,  -2.,  -1.,  -1.,   6.,  -3.,   4.,   9.,   2.,  -2.,   3.,
          1.,   0.,  -7.,   5.,   5.,  -8.,  -4.,   5.,   7., -10.,  -4.,
          5.,   6.,   8., -10.,   7.,  -1.,  -8.,  -8.,  -6.,  -6.,   4.,
        -10.,  -8.,   3.,   8.,  -9.,  -5.,   8.],
       [ -1.,   8.,  -8.,  -7.,   9.,   2.,  -6.,   8.,   4.,  -2.,   1.,
          8.,  -4.,   3.,   1.,  -6.,  -9.,   3.,  -5.,   7.,  -9.,   6.,
          6.,  -3.,   1.,   2.,  -1.,   3.,   7.,   4.,  -8.,   1.,   1.,
          3.,  -4.,   1.,  -4.,  -5.,  -9.,   4.],
       [ -9.,   8.,   9.,  -8.,   2.,   9.,  -1.,  -9.,  -1.,   9.,  -8.,
         -4.,   1.,   1.,   9.,   6.,  -6., -10.,  -6.,   2.,   6.,  -4.,
         -7.,  -2.,  -7.,  -8.,  -4.,   5.,  -6.,  -1.,   8.,  -3.,  -7.,
          4.,  -9.,  -9.,   6.,  -9.,   6.,  -2.],
       [  0.,  -6.,  -5.,  -8., -10.,   2.,  -4.,   9.,   9.,  -1.,   5.,
         -7.,  -1.,  -3., -10., -10.,   3.,  -2.,  -7.,  -9.,  -4.,  -8.,
         -4.,  -1.,   7.,  -2.,  -4.,  -4.,   9.,   2.,  -2.,  -8.,   6.,
          5.,  -4.,   7.,   0.,   6.,  -3.,   2.],
       [  8.,   5.,   3.,   5.,   9.,   4., -10.,  -5.,  -4.,  -4.,  -5.,
          3.,   5.,  -4.,   9.,  -2.,  -7.,  -6.,  -2.,  -8.,  -7., -10.,
          0.,  -2.,   3.,   1., -10.,  -8.,   3.,   9.,  -5.,  -6.,   1.,
         -7.,  -1.,   3.,  -7.,  -2.,   1.,  -1.],
       [  3.,   9.,   8.,   6.,  -2.,   9.,   9.,   7.,   0.,   5.,  -5.,
          6.,   9.,   3.,   2.,   9.,   4.,  -1.,  -3.,   3.,  -1.,  -4.,
         -9.,  -1.,  -3.,   8.,   0.,   4.,  -1.,   4.,  -3.,   4.,  -5.,
         -3.,  -6.,  -4.,   7.,  -9.,  -7.,  -1.],
       [  5.,  -1.,   9.,  -5.,   8.,   7.,  -7.,  -5.,   0.,  -4.,  -5.,
          6.,  -3.,  -1.,   7.,   3.,  -7.,  -9.,   6.,   4.,   9.,   6.,
         -3.,   3.,  -2.,  -6.,  -4.,  -7.,  -5.,  -6.,  -2.,  -1.,  -9.,
         -4.,  -9.,  -2.,  -7.,  -6.,  -3.,   6.],
       [ -7.,   1.,  -8.,   1.,  -8.,  -9.,   5.,   1.,  -4.,  -2.,  -5.,
          3.,  -1.,  -4.,  -4.,   5.,   0., -10.,  -4.,  -1.,  -5.,   3.,
          8.,  -5.,  -4., -10.,  -8.,  -6., -10.,  -1.,  -6.,   1.,   7.,
          8.,   6.,  -2.,  -4.,  -9.,   7.,  -1.],
       [ -2.,   3.,   8.,  -5.,   0.,   5.,   8.,  -5.,   6.,  -8.,   5.,
          8.,  -5.,  -5.,  -5., -10.,   4.,   8.,  -4.,  -7.,   4.,  -6.,
         -9.,  -8.,  -5.,   4.,  -1.,  -2.,  -7., -10.,  -6.,  -8.,  -6.,
          3.,   1.,   6.,   9.,   6.,  -8.,  -3.],
       [  9.,  -6.,  -2., -10.,   2.,  -8.,   8.,  -7.,  -5.,   8., -10.,
          4.,  -5.,   9.,   7.,   9.,  -2.,  -9.,  -5.,  -2.,   9.,   0.,
         -6.,   2.,   4.,   6.,  -7.,  -4.,  -5.,  -7.,  -8.,  -8.,   8.,
         -7.,  -1.,  -5.,   0.,  -7.,   7.,  -6.],
       [ -1.,  -3.,   1.,   8.,   4.,   7.,  -1.,  -8.,  -4.,   6.,   9.,
          5., -10.,   4.,  -4.,   5.,  -2.,   0.,   3.,   4.,   3.,  -5.,
         -2.,   7.,   4.,  -4.,  -9.,   9.,  -6.,  -5.,  -8.,   4., -10.,
         -6.,   3.,   0.,   6., -10.,   4.,   3.],
       [  8.,   8.,  -5.,   0.,  -7.,   5.,  -6.,  -8.,   2.,  -3.,  -5.,
          5.,   0.,   6., -10.,   3.,  -4.,   1.,  -8.,  -9.,   6.,  -5.,
          5., -10.,   1.,   0.,   3.,   5.,   2.,  -9.,  -6.,   9.,   7.,
          9., -10.,   4.,  -4., -10.,  -5.,   1.],
       [  8.,   3.,  -5.,  -2.,  -8.,  -6.,   6.,  -7.,   8.,   1.,  -8.,
          0.,  -2.,   3.,  -6.,   0., -10.,   6.,  -8.,  -2.,  -5.,   4.,
         -1.,  -9.,  -7.,   3.,  -1.,  -4.,  -1., -10.,  -3.,  -7.,  -3.,
          4.,  -7.,  -6.,  -1.,   9.,  -3.,   2.],
       [  8.,   7.,   6.,  -5.,  -3.,   0.,   1.,  -2.,   0.,  -3.,   9.,
         -8.,   5.,   1.,   1.,   1.,  -5.,   4.,  -4.,   0.,  -4.,  -3.,
          7., -10.,   3.,   6.,   4.,   5.,   2.,  -7.,   0.,  -3.,  -5.,
          2.,  -6.,   4.,   5.,   8.,  -1.,  -3.],
       [  8.,  -9.,  -4.,   8.,  -2.,   9.,   5.,   5.,  -3.,  -4.,   0.,
         -5.,   5.,  -2., -10.,  -4.,  -3.,   5.,   8.,   6.,  -2.,  -2.,
         -1.,  -8.,  -5.,  -9.,   1.,  -1.,   5.,   6.,   4.,   9.,  -5.,
          6.,  -2.,   7.,  -7.,  -9.,   4.,   2.],
       [  2.,   4.,   6.,   2.,   6.,  -6.,  -2.,   5.,   8.,   1.,   3.,
          8.,   6.,   9.,  -3.,  -1.,   4.,   7.,  -5.,   7.,   0., -10.,
          9.,  -6.,  -4.,  -7.,   1.,  -2.,  -2.,   3.,  -1.,   2.,   5.,
          8.,   4.,  -9.,   1.,  -4.,   9.,   6.],
       [ -8.,  -5.,   9.,   3.,   9., -10.,  -8.,   3.,  -8.,   0.,  -4.,
         -8.,  -3.,  -4.,  -3.,   0.,   8.,   3., -10.,   7.,   7.,  -3.,
          8.,   4.,  -3.,   9.,   3.,   7.,   2.,   7.,  -8.,  -3.,  -4.,
         -7.,   3.,  -9., -10.,   2.,   5.,   7.],
       [  5.,  -7.,  -8.,   6.,  -8.,   1.,  -8.,   4.,   2.,   6.,  -6.,
         -5.,   4.,  -1.,   3.,  -8.,  -3.,   6.,   5.,  -5.,   1.,  -7.,
          8., -10.,   8.,   1.,   3.,   7.,   2.,   2.,  -1.,   1.,  -3.,
          7.,   1.,   6.,  -6.,   0.,  -9.,   6.]], dtype=float32)>, 'per_arm': <tf.Tensor: shape=(20, 70, 50), dtype=float32, numpy=
array([[[ 5., -6.,  4., ..., -3.,  3.,  4.],
        [-5., -6., -4., ...,  3.,  4., -4.],
        [ 1., -1.,  5., ..., -1., -3.,  1.],
        ...,
        [ 3.,  3., -5., ...,  4.,  4.,  0.],
        [ 5.,  1., -3., ..., -2., -2., -3.],
        [-6.,  4.,  2., ...,  4.,  5., -5.]],

       [[-5., -3.,  1., ..., -2., -1.,  1.],
        [ 1.,  4., -1., ..., -1., -4., -4.],
        [ 4., -6.,  5., ...,  2., -2.,  4.],
        ...,
        [ 0.,  4., -4., ..., -1., -3.,  1.],
        [ 3.,  4.,  5., ..., -5., -2., -2.],
        [ 0.,  4., -3., ...,  5.,  1.,  3.]],

       [[-2., -6., -6., ..., -6.,  1., -5.],
        [ 4.,  5.,  5., ...,  1.,  4., -4.],
        [ 0.,  0., -3., ..., -5.,  0., -2.],
        ...,
        [-3., -1.,  4., ...,  5., -2.,  5.],
        [-3., -6., -2., ...,  3.,  1., -5.],
        [ 5., -3., -5., ..., -4.,  4., -5.]],

       ...,

       [[ 4.,  3.,  0., ...,  1., -6.,  4.],
        [-5., -3.,  5., ...,  0., -1., -5.],
        [ 0.,  4.,  3., ..., -2.,  1., -3.],
        ...,
        [-5., -2., -5., ..., -5., -5., -2.],
        [-2.,  5.,  4., ..., -2., -2.,  2.],
        [-1., -4.,  4., ..., -5.,  2., -3.]],

       [[-1., -4.,  4., ..., -3., -5.,  4.],
        [-4., -6., -2., ..., -1., -6.,  0.],
        [ 0.,  0.,  5., ...,  4., -4.,  0.],
        ...,
        [ 2.,  3.,  5., ..., -6., -5.,  5.],
        [-5., -5.,  2., ...,  0.,  4., -2.],
        [ 4., -5., -4., ..., -5., -5., -1.]],

       [[ 3.,  0.,  2., ...,  2.,  1., -3.],
        [-5., -4.,  3., ..., -6.,  0., -2.],
        [-4., -5.,  3., ..., -6., -3.,  0.],
        ...,
        [-6., -6.,  4., ..., -1., -5., -2.],
        [-4.,  3., -1., ...,  1.,  4.,  4.],
        [ 5.,  2.,  2., ..., -3.,  1., -4.]]], dtype=float32)>}

Rewards after taking an action:  tf.Tensor(
[-130.17787    344.98013    371.39893     75.433975   396.35742
 -176.46881     56.62174   -158.03278    491.3239    -156.10696
   -1.0527252 -264.42285     22.356699  -395.89832    125.951546
  142.99467   -322.3012     -24.547596  -159.47539    -44.123775 ], shape=(20,), dtype=float32)

Chúng tôi thấy rằng thông số kỹ thuật quan sát là một từ điển có hai yếu tố:

  • Một với phím 'global' : đây là phần bối cảnh toàn cầu, với hình dạng phù hợp với các thông số GLOBAL_DIM .
  • Một với phím 'per_arm' : đây là bối cảnh mỗi cánh tay, và hình dạng của nó là [NUM_ACTIONS, PER_ARM_DIM] . Phần này là trình giữ chỗ cho các tính năng của nhánh cho mọi nhánh trong một bước thời gian.

Đại lý LinUCB

Tác nhân LinUCB thực hiện thuật toán Bandit có tên giống hệt, thuật toán này ước tính tham số của hàm phần thưởng tuyến tính trong khi cũng duy trì một ellipsoid tin cậy xung quanh ước tính. Tác nhân chọn nhánh có phần thưởng dự kiến ​​ước tính cao nhất, giả sử rằng tham số nằm trong ellipsoid tin cậy.

Tạo một tác nhân yêu cầu kiến ​​thức về quan sát và đặc tả hành động. Khi xác định các đại lý, chúng tôi thiết lập các tham số boolean accepts_per_arm_features thiết lập để True .

observation_spec = per_arm_tf_env.observation_spec()
time_step_spec = ts.time_step_spec(observation_spec)
action_spec = tensor_spec.BoundedTensorSpec(
    dtype=tf.int32, shape=(), minimum=0, maximum=NUM_ACTIONS - 1)

agent = lin_ucb_agent.LinearUCBAgent(time_step_spec=time_step_spec,
                                     action_spec=action_spec,
                                     accepts_per_arm_features=True)

Luồng dữ liệu đào tạo

Phần này giới thiệu sơ lược về cơ chế hoạt động của các tính năng trên mỗi cánh tay từ chính sách đến đào tạo. Vui lòng chuyển sang phần tiếp theo (Xác định chỉ số tiếc nuối) và quay lại đây sau nếu quan tâm.

Đầu tiên, chúng ta hãy xem đặc tả dữ liệu trong tác nhân. Các training_data_spec thuộc tính của quy định cụ thể những gì đại diện các yếu tố và cấu trúc dữ liệu huấn luyện nên có.

print('training data spec: ', agent.training_data_spec)
training data spec:  Trajectory(
{'action': BoundedTensorSpec(shape=(), dtype=tf.int32, name=None, minimum=array(0, dtype=int32), maximum=array(69, dtype=int32)),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': {'global': TensorSpec(shape=(40,), dtype=tf.float32, name=None)},
 'policy_info': PerArmPolicyInfo(log_probability=(), predicted_rewards_mean=(), multiobjective_scalarized_predicted_rewards_mean=(), predicted_rewards_optimistic=(), predicted_rewards_sampled=(), bandit_policy_type=(), chosen_arm_features=TensorSpec(shape=(50,), dtype=tf.float32, name=None)),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type')})

Nếu chúng ta có một cái nhìn gần gũi hơn với các observation một phần của spec, chúng ta thấy rằng nó không có các tính năng mỗi cánh tay!

print('observation spec in training: ', agent.training_data_spec.observation)
observation spec in training:  {'global': TensorSpec(shape=(40,), dtype=tf.float32, name=None)}

Điều gì đã xảy ra với các tính năng trên mỗi nhánh? Để trả lời câu hỏi này, trước hết chúng ta lưu ý rằng khi xe lửa đại lý LinUCB, nó không cần các tính năng mỗi cánh tay của tất cả các vũ khí, nó chỉ cần những cánh tay chọn. Do đó, nó làm cho tinh thần để thả các tensor hình dạng [BATCH_SIZE, NUM_ACTIONS, PER_ARM_DIM] , vì nó là rất lãng phí, đặc biệt là nếu số lượng của các hành động là lớn.

Nhưng vẫn còn, các tính năng trên mỗi cánh tay của cánh tay đã chọn phải ở đâu đó! Để kết thúc này, chúng tôi đảm bảo rằng các cửa hàng chính sách LinUCB các tính năng của cánh tay chọn trong policy_info lĩnh vực dữ liệu huấn luyện:

print('chosen arm features: ', agent.training_data_spec.policy_info.chosen_arm_features)
chosen arm features:  TensorSpec(shape=(50,), dtype=tf.float32, name=None)

Chúng ta thấy từ hình dạng rằng chosen_arm_features lĩnh vực chỉ có các vector đặc trưng của một cánh tay, và đó sẽ là cánh tay chọn. Lưu ý rằng policy_info , và cùng với nó là chosen_arm_features , là một phần của dữ liệu huấn luyện, như chúng ta thấy từ kiểm tra việc đào tạo dữ liệu spec, và do đó nó có sẵn tại thời gian đào tạo.

Xác định chỉ số hối tiếc

Trước khi bắt đầu vòng lặp đào tạo, chúng tôi xác định một số hàm tiện ích giúp tính toán sự hối tiếc của tác nhân của chúng tôi. Các chức năng này giúp xác định phần thưởng dự kiến ​​tối ưu dựa trên tập hợp các hành động (được cung cấp bởi các tính năng nhánh của chúng) và tham số tuyến tính bị ẩn khỏi tác nhân.

def _all_rewards(observation, hidden_param):
  """Outputs rewards for all actions, given an observation."""
  hidden_param = tf.cast(hidden_param, dtype=tf.float32)
  global_obs = observation['global']
  per_arm_obs = observation['per_arm']
  num_actions = tf.shape(per_arm_obs)[1]
  tiled_global = tf.tile(
      tf.expand_dims(global_obs, axis=1), [1, num_actions, 1])
  concatenated = tf.concat([tiled_global, per_arm_obs], axis=-1)
  rewards = tf.linalg.matvec(concatenated, hidden_param)
  return rewards

def optimal_reward(observation):
  """Outputs the maximum expected reward for every element in the batch."""
  return tf.reduce_max(_all_rewards(observation, reward_param), axis=1)

regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward)

Bây giờ tất cả chúng ta đã sẵn sàng để bắt đầu vòng đào tạo tên cướp của mình. Trình điều khiển bên dưới sẽ đảm nhận việc chọn các hành động bằng cách sử dụng chính sách, lưu trữ phần thưởng của các hành động đã chọn trong bộ đệm phát lại, tính toán chỉ số hối tiếc được xác định trước và thực hiện bước đào tạo của tác nhân.

num_iterations = 20 # @param
steps_per_loop = 1 # @param

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.policy.trajectory_spec,
    batch_size=BATCH_SIZE,
    max_length=steps_per_loop)

observers = [replay_buffer.add_batch, regret_metric]

driver = dynamic_step_driver.DynamicStepDriver(
    env=per_arm_tf_env,
    policy=agent.collect_policy,
    num_steps=steps_per_loop * BATCH_SIZE,
    observers=observers)

regret_values = []

for _ in range(num_iterations):
  driver.run()
  loss_info = agent.train(replay_buffer.gather_all())
  replay_buffer.clear()
  regret_values.append(regret_metric.result())
WARNING:tensorflow:From /tmp/ipykernel_12052/1190294793.py:21: ReplayBuffer.gather_all (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=True)` instead.

Bây giờ chúng ta hãy xem kết quả. Nếu chúng tôi đã làm đúng mọi thứ, tác nhân có thể ước tính tốt hàm phần thưởng tuyến tính và do đó, chính sách có thể chọn các hành động có phần thưởng dự kiến ​​gần với phần thưởng tối ưu. Điều này được chỉ ra bởi chỉ số hối tiếc đã xác định ở trên của chúng tôi, chỉ số này đi xuống và tiến gần đến con số không.

plt.plot(regret_values)
plt.title('Regret of LinUCB on the Linear per-arm environment')
plt.xlabel('Number of Iterations')
_ = plt.ylabel('Average Regret')

png

Cái gì tiếp theo?

Ví dụ trên được thực hiện trong codebase chúng tôi, nơi bạn có thể chọn từ các đại lý khác là tốt, bao gồm cả các đại lý epsilon-tham lam thần kinh .