pengantar
Dalam kebanyakan kasus, menggunakan MinDiffModel
langsung seperti yang dijelaskan dalam "Mengintegrasikan MinDiff dengan MinDiffModel" panduan cukup. Namun, ada kemungkinan Anda memerlukan perilaku yang disesuaikan. Dua alasan utama untuk ini adalah:
- The
keras.Model
yang Anda gunakan memiliki perilaku kustom yang Anda ingin mempertahankan. - Anda ingin
MinDiffModel
untuk berperilaku berbeda dari default.
Dalam kedua kasus, Anda akan perlu subclass MinDiffModel
untuk mencapai hasil yang diinginkan.
Mempersiapkan
pip install -q --upgrade tensorflow-model-remediation
import tensorflow as tf
tf.get_logger().setLevel('ERROR') # Avoid TF warnings.
from tensorflow_model_remediation import min_diff
from tensorflow_model_remediation.tools.tutorials_utils import uci as tutorials_utils
Pertama, unduh datanya. Untuk kekompakan, logika persiapan input telah diperhitungkan keluar ke fungsi pembantu seperti yang dijelaskan dalam panduan persiapan masukan . Anda dapat membaca panduan lengkap untuk detail tentang proses ini.
# Original Dataset for training, sampled at 0.3 for reduced runtimes.
train_df = tutorials_utils.get_uci_data(split='train', sample=0.3)
train_ds = tutorials_utils.df_to_dataset(train_df, batch_size=128)
# Dataset needed to train with MinDiff.
train_with_min_diff_ds = (
tutorials_utils.get_uci_with_min_diff_dataset(split='train', sample=0.3))
Mempertahankan Kustomisasi Model Asli
tf.keras.Model
dirancang untuk mudah dikustomisasi melalui subclassing seperti yang dijelaskan di sini . Jika model Anda telah disesuaikan implementasi yang ingin Anda melestarikan ketika menerapkan MinDiff, Anda akan perlu subclass MinDiffModel
.
Model Kustom Asli
Untuk melihat bagaimana Anda dapat melestarikan kustomisasi, membuat model kustom yang menetapkan atribut untuk True
ketika kustom train_step
disebut. Ini bukan penyesuaian yang berguna tetapi akan berfungsi untuk menggambarkan perilaku.
class CustomModel(tf.keras.Model):
# Customized train_step
def train_step(self, *args, **kwargs):
self.used_custom_train_step = True # Marker that we can check for.
return super(CustomModel, self).train_step(*args, **kwargs)
Pelatihan model seperti itu akan terlihat sama seperti normal Sequential
Model.
model = tutorials_utils.get_uci_model(model_class=CustomModel) # Use CustomModel.
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_ds.take(1), epochs=1, verbose=0)
# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step')) # True
Model used the custom train_step: True
Subclassing MinDiffModel
Jika Anda adalah untuk mencoba dan menggunakan MinDiffModel
langsung, model tidak akan menggunakan kustom train_step
.
model = tutorials_utils.get_uci_model(model_class=CustomModel)
model = min_diff.keras.MinDiffModel(model, min_diff.losses.MMDLoss())
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)
# Model has not used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step')) # False
Model used the custom train_step: False
Dalam rangka untuk menggunakan yang benar train_step
metode, Anda perlu kelas kustom yang subclass baik MinDiffModel
dan CustomModel
.
class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):
pass # No need for any further implementation.
Pelatihan model ini akan menggunakan train_step
dari CustomModel
.
model = tutorials_utils.get_uci_model(model_class=CustomModel)
model = CustomMinDiffModel(model, min_diff.losses.MMDLoss())
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)
# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step')) # True
Model used the custom train_step: True
Menyesuaikan perilaku default MinDiffModel
Dalam kasus lain, Anda mungkin ingin mengubah perilaku standar tertentu MinDiffModel
. Kasus penggunaan yang paling umum dari ini adalah mengubah default membongkar perilaku untuk benar menangani data Anda jika Anda tidak menggunakan pack_min_diff_data
.
Saat mengemas data ke dalam format khusus, ini mungkin muncul sebagai berikut.
def _reformat_input(inputs, original_labels):
min_diff_data = min_diff.keras.utils.unpack_min_diff_data(inputs)
original_inputs = min_diff.keras.utils.unpack_original_inputs(inputs)
return ({
'min_diff_data': min_diff_data,
'original_inputs': original_inputs}, original_labels)
customized_train_with_min_diff_ds = train_with_min_diff_ds.map(_reformat_input)
The customized_train_with_min_diff_ds
dataset kembali batch terdiri dari tupel (x, y)
di mana x
adalah dict mengandung min_diff_data
dan original_inputs
dan y
adalah original_labels
.
for x, _ in customized_train_with_min_diff_ds.take(1):
print('Type of x:', type(x)) # dict
print('Keys of x:', x.keys()) # 'min_diff_data', 'original_inputs'
Type of x: <class 'dict'> Keys of x: dict_keys(['min_diff_data', 'original_inputs'])
Format data ini bukan apa MinDiffModel
mengharapkan secara default dan melewati customized_train_with_min_diff_ds
untuk itu akan menghasilkan perilaku yang tidak diharapkan. Untuk memperbaikinya, Anda perlu membuat subkelas Anda sendiri.
class CustomUnpackingMinDiffModel(min_diff.keras.MinDiffModel):
def unpack_min_diff_data(self, inputs):
return inputs['min_diff_data']
def unpack_original_inputs(self, inputs):
return inputs['original_inputs']
Dengan subclass ini, Anda dapat berlatih seperti contoh lainnya.
model = tutorials_utils.get_uci_model()
model = CustomUnpackingMinDiffModel(model, min_diff.losses.MMDLoss())
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(customized_train_with_min_diff_ds, epochs=1)
77/77 [==============================] - 4s 30ms/step - loss: 0.6690 - min_diff_loss: 0.0395
Keterbatasan dari Customized MinDiffModel
Membuat kustom MinDiffModel
menyediakan sejumlah besar fleksibilitas untuk kasus penggunaan yang lebih kompleks. Namun, masih ada beberapa kasus tepi yang tidak didukungnya.
Preprocessing atau Validasi input sebelum call
Keterbatasan terbesar untuk subclass dari MinDiffModel
adalah bahwa hal itu memerlukan x
komponen input data (yaitu elemen pertama atau hanya dalam batch dikembalikan oleh tf.data.Dataset
) untuk melewati tanpa preprocessing atau validasi untuk call
.
Hal ini hanya karena min_diff_data
dikemas ke dalam x
komponen dari input data. Setiap preprocessing atau validasi tidak akan mengharapkan struktur tambahan yang mengandung min_diff_data
dan kemungkinan akan pecah.
Jika pra-pemrosesan atau validasi mudah disesuaikan (misalnya, diperhitungkan dalam metodenya sendiri) maka hal ini mudah diatasi dengan menimpanya untuk memastikannya menangani struktur tambahan dengan benar.
Contoh dengan validasi mungkin terlihat seperti ini:
class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):
# Override so that it correctly handles additional `min_diff_data`.
def validate_inputs(self, inputs):
original_inputs = self.unpack_original_inputs(inputs)
... # Optionally also validate min_diff_data
# Call original validate method with correct inputs
return super(CustomMinDiffModel, self).validate(original_inputs)
Jika preprocessing atau validasi tidak mudah disesuaikan, kemudian menggunakan MinDiffModel
mungkin tidak bekerja untuk Anda dan Anda akan perlu untuk mengintegrasikan MinDiff tanpa itu seperti yang dijelaskan dalam panduan ini .
Tabrakan nama metode
Ada kemungkinan bahwa model Anda memiliki metode yang namanya berbenturan dengan yang diterapkan di MinDiffModel
(lihat daftar lengkap metode umum dalam dokumentasi API ).
Ini hanya bermasalah jika ini akan dipanggil pada instance model (bukan secara internal dalam beberapa metode lain). Sementara sangat tidak mungkin, jika Anda berada dalam situasi ini Anda akan harus baik override dan mengubah nama beberapa metode atau, jika tidak memungkinkan, Anda mungkin perlu mempertimbangkan mengintegrasikan MinDiff tanpa MinDiffModel
seperti yang dijelaskan dalam panduan ini pada subjek .
Sumber daya tambahan
- Untuk dalam diskusi mendalam mengenai evaluasi kewajaran melihat bimbingan Keadilan Indikator
- Untuk informasi umum tentang Remediasi dan MinDiff, melihat remediasi gambaran .
- Untuk rincian tentang persyaratan sekitarnya MinDiff melihat panduan ini .
- Untuk melihat tutorial end-to-end menggunakan MinDiff di Keras, lihat tutorial ini .