Menyesuaikan MinDiffModel

pengantar

Dalam kebanyakan kasus, menggunakan MinDiffModel langsung seperti yang dijelaskan dalam "Mengintegrasikan MinDiff dengan MinDiffModel" panduan cukup. Namun, ada kemungkinan Anda memerlukan perilaku yang disesuaikan. Dua alasan utama untuk ini adalah:

  • The keras.Model yang Anda gunakan memiliki perilaku kustom yang Anda ingin mempertahankan.
  • Anda ingin MinDiffModel untuk berperilaku berbeda dari default.

Dalam kedua kasus, Anda akan perlu subclass MinDiffModel untuk mencapai hasil yang diinginkan.

Mempersiapkan

pip install -q --upgrade tensorflow-model-remediation
import tensorflow as tf
tf.get_logger().setLevel('ERROR')  # Avoid TF warnings.
from tensorflow_model_remediation import min_diff
from tensorflow_model_remediation.tools.tutorials_utils import uci as tutorials_utils

Pertama, unduh datanya. Untuk kekompakan, logika persiapan input telah diperhitungkan keluar ke fungsi pembantu seperti yang dijelaskan dalam panduan persiapan masukan . Anda dapat membaca panduan lengkap untuk detail tentang proses ini.

# Original Dataset for training, sampled at 0.3 for reduced runtimes.
train_df = tutorials_utils.get_uci_data(split='train', sample=0.3)
train_ds = tutorials_utils.df_to_dataset(train_df, batch_size=128)

# Dataset needed to train with MinDiff.
train_with_min_diff_ds = (
    tutorials_utils.get_uci_with_min_diff_dataset(split='train', sample=0.3))

Mempertahankan Kustomisasi Model Asli

tf.keras.Model dirancang untuk mudah dikustomisasi melalui subclassing seperti yang dijelaskan di sini . Jika model Anda telah disesuaikan implementasi yang ingin Anda melestarikan ketika menerapkan MinDiff, Anda akan perlu subclass MinDiffModel .

Model Kustom Asli

Untuk melihat bagaimana Anda dapat melestarikan kustomisasi, membuat model kustom yang menetapkan atribut untuk True ketika kustom train_step disebut. Ini bukan penyesuaian yang berguna tetapi akan berfungsi untuk menggambarkan perilaku.

class CustomModel(tf.keras.Model):

  # Customized train_step
  def train_step(self, *args, **kwargs):
    self.used_custom_train_step = True  # Marker that we can check for.
    return super(CustomModel, self).train_step(*args, **kwargs)

Pelatihan model seperti itu akan terlihat sama seperti normal Sequential Model.

model = tutorials_utils.get_uci_model(model_class=CustomModel)  # Use CustomModel.

model.compile(optimizer='adam', loss='binary_crossentropy')

_ = model.fit(train_ds.take(1), epochs=1, verbose=0)

# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step'))  # True
Model used the custom train_step:
True

Subclassing MinDiffModel

Jika Anda adalah untuk mencoba dan menggunakan MinDiffModel langsung, model tidak akan menggunakan kustom train_step .

model = tutorials_utils.get_uci_model(model_class=CustomModel)
model = min_diff.keras.MinDiffModel(model, min_diff.losses.MMDLoss())

model.compile(optimizer='adam', loss='binary_crossentropy')

_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)

# Model has not used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step'))  # False
Model used the custom train_step:
False

Dalam rangka untuk menggunakan yang benar train_step metode, Anda perlu kelas kustom yang subclass baik MinDiffModel dan CustomModel .

class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):
  pass  # No need for any further implementation.

Pelatihan model ini akan menggunakan train_step dari CustomModel .

model = tutorials_utils.get_uci_model(model_class=CustomModel)

model = CustomMinDiffModel(model, min_diff.losses.MMDLoss())

model.compile(optimizer='adam', loss='binary_crossentropy')

_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)

# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step'))  # True
Model used the custom train_step:
True

Menyesuaikan perilaku default MinDiffModel

Dalam kasus lain, Anda mungkin ingin mengubah perilaku standar tertentu MinDiffModel . Kasus penggunaan yang paling umum dari ini adalah mengubah default membongkar perilaku untuk benar menangani data Anda jika Anda tidak menggunakan pack_min_diff_data .

Saat mengemas data ke dalam format khusus, ini mungkin muncul sebagai berikut.

def _reformat_input(inputs, original_labels):
  min_diff_data = min_diff.keras.utils.unpack_min_diff_data(inputs)
  original_inputs = min_diff.keras.utils.unpack_original_inputs(inputs)

  return ({
      'min_diff_data': min_diff_data,
      'original_inputs': original_inputs}, original_labels)

customized_train_with_min_diff_ds = train_with_min_diff_ds.map(_reformat_input)

The customized_train_with_min_diff_ds dataset kembali batch terdiri dari tupel (x, y) di mana x adalah dict mengandung min_diff_data dan original_inputs dan y adalah original_labels .

for x, _ in customized_train_with_min_diff_ds.take(1):
  print('Type of x:', type(x))  # dict
  print('Keys of x:', x.keys())  # 'min_diff_data', 'original_inputs'
Type of x: <class 'dict'>
Keys of x: dict_keys(['min_diff_data', 'original_inputs'])

Format data ini bukan apa MinDiffModel mengharapkan secara default dan melewati customized_train_with_min_diff_ds untuk itu akan menghasilkan perilaku yang tidak diharapkan. Untuk memperbaikinya, Anda perlu membuat subkelas Anda sendiri.

class CustomUnpackingMinDiffModel(min_diff.keras.MinDiffModel):

  def unpack_min_diff_data(self, inputs):
    return inputs['min_diff_data']

  def unpack_original_inputs(self, inputs):
    return inputs['original_inputs']

Dengan subclass ini, Anda dapat berlatih seperti contoh lainnya.

model = tutorials_utils.get_uci_model()
model = CustomUnpackingMinDiffModel(model, min_diff.losses.MMDLoss())

model.compile(optimizer='adam', loss='binary_crossentropy')

_ = model.fit(customized_train_with_min_diff_ds, epochs=1)
77/77 [==============================] - 4s 30ms/step - loss: 0.6690 - min_diff_loss: 0.0395

Keterbatasan dari Customized MinDiffModel

Membuat kustom MinDiffModel menyediakan sejumlah besar fleksibilitas untuk kasus penggunaan yang lebih kompleks. Namun, masih ada beberapa kasus tepi yang tidak didukungnya.

Preprocessing atau Validasi input sebelum call

Keterbatasan terbesar untuk subclass dari MinDiffModel adalah bahwa hal itu memerlukan x komponen input data (yaitu elemen pertama atau hanya dalam batch dikembalikan oleh tf.data.Dataset ) untuk melewati tanpa preprocessing atau validasi untuk call .

Hal ini hanya karena min_diff_data dikemas ke dalam x komponen dari input data. Setiap preprocessing atau validasi tidak akan mengharapkan struktur tambahan yang mengandung min_diff_data dan kemungkinan akan pecah.

Jika pra-pemrosesan atau validasi mudah disesuaikan (misalnya, diperhitungkan dalam metodenya sendiri) maka hal ini mudah diatasi dengan menimpanya untuk memastikannya menangani struktur tambahan dengan benar.

Contoh dengan validasi mungkin terlihat seperti ini:

class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):

  # Override so that it correctly handles additional `min_diff_data`.
  def validate_inputs(self, inputs):
    original_inputs = self.unpack_original_inputs(inputs)
    ...  # Optionally also validate min_diff_data
    # Call original validate method with correct inputs
    return super(CustomMinDiffModel, self).validate(original_inputs)

Jika preprocessing atau validasi tidak mudah disesuaikan, kemudian menggunakan MinDiffModel mungkin tidak bekerja untuk Anda dan Anda akan perlu untuk mengintegrasikan MinDiff tanpa itu seperti yang dijelaskan dalam panduan ini .

Tabrakan nama metode

Ada kemungkinan bahwa model Anda memiliki metode yang namanya berbenturan dengan yang diterapkan di MinDiffModel (lihat daftar lengkap metode umum dalam dokumentasi API ).

Ini hanya bermasalah jika ini akan dipanggil pada instance model (bukan secara internal dalam beberapa metode lain). Sementara sangat tidak mungkin, jika Anda berada dalam situasi ini Anda akan harus baik override dan mengubah nama beberapa metode atau, jika tidak memungkinkan, Anda mungkin perlu mempertimbangkan mengintegrasikan MinDiff tanpa MinDiffModel seperti yang dijelaskan dalam panduan ini pada subjek .

Sumber daya tambahan