Personalizando MinDiffModel

Introdução

Na maioria dos casos, utilizando MinDiffModel directamente tal como descrito no "Integrando MinDiff com MinDiffModel" guia é suficiente. No entanto, é possível que você precise de um comportamento personalizado. As duas principais razões para isso são:

  • O keras.Model você está usando tem comportamento personalizado que você deseja preservar.
  • Você quer que o MinDiffModel a se comportar de forma diferente do padrão.

Em ambos os casos, você precisará subclasse MinDiffModel para alcançar os resultados desejados.

Configurar

pip install -q --upgrade tensorflow-model-remediation
import tensorflow as tf
tf.get_logger().setLevel('ERROR')  # Avoid TF warnings.
from tensorflow_model_remediation import min_diff
from tensorflow_model_remediation.tools.tutorials_utils import uci as tutorials_utils

Primeiro, baixe os dados. Para concisão, a lógica de entrada preparação foi fatoramos em funções auxiliares como descrito no guia de preparação de entrada . Você pode ler o guia completo para obter detalhes sobre esse processo.

# Original Dataset for training, sampled at 0.3 for reduced runtimes.
train_df = tutorials_utils.get_uci_data(split='train', sample=0.3)
train_ds = tutorials_utils.df_to_dataset(train_df, batch_size=128)

# Dataset needed to train with MinDiff.
train_with_min_diff_ds = (
    tutorials_utils.get_uci_with_min_diff_dataset(split='train', sample=0.3))

Preservando as personalizações do modelo original

tf.keras.Model é projetado para ser facilmente customizado via subclasse como descrito aqui . Se o seu modelo tiver personalizado implementações que deseja preservar ao aplicar MinDiff, você precisará subclasse MinDiffModel .

Modelo personalizado original

Para ver como você pode preservar personalizações, criar um modelo personalizado que define um atributo para True quando o seu costume train_step é chamado. Esta não é uma personalização útil, mas servirá para ilustrar o comportamento.

class CustomModel(tf.keras.Model):

  # Customized train_step
  def train_step(self, *args, **kwargs):
    self.used_custom_train_step = True  # Marker that we can check for.
    return super(CustomModel, self).train_step(*args, **kwargs)

Treinando tal modelo um ficaria o mesmo que um normal, Sequential modelo.

model = tutorials_utils.get_uci_model(model_class=CustomModel)  # Use CustomModel.

model.compile(optimizer='adam', loss='binary_crossentropy')

_ = model.fit(train_ds.take(1), epochs=1, verbose=0)

# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step'))  # True
Model used the custom train_step:
True

Subclassificando MinDiffModel

Se você estava a tentar usar MinDiffModel diretamente, o modelo não usaria o costume train_step .

model = tutorials_utils.get_uci_model(model_class=CustomModel)
model = min_diff.keras.MinDiffModel(model, min_diff.losses.MMDLoss())

model.compile(optimizer='adam', loss='binary_crossentropy')

_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)

# Model has not used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step'))  # False
Model used the custom train_step:
False

Para utilizar o correto train_step método, você precisa de uma classe personalizada que subclasses tanto MinDiffModel e CustomModel .

class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):
  pass  # No need for any further implementation.

Treinando este modelo usará o train_step de CustomModel .

model = tutorials_utils.get_uci_model(model_class=CustomModel)

model = CustomMinDiffModel(model, min_diff.losses.MMDLoss())

model.compile(optimizer='adam', loss='binary_crossentropy')

_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)

# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step'))  # True
Model used the custom train_step:
True

Personalização de comportamentos padrão de MinDiffModel

Em outros casos, você pode querer mudar comportamentos padrão específicas de MinDiffModel . O caso de uso comum a maior parte desta está mudando o comportamento descompactação padrão para tratar adequadamente seus dados se você não usar pack_min_diff_data .

Ao empacotar os dados em um formato personalizado, isso pode aparecer da seguinte maneira.

def _reformat_input(inputs, original_labels):
  min_diff_data = min_diff.keras.utils.unpack_min_diff_data(inputs)
  original_inputs = min_diff.keras.utils.unpack_original_inputs(inputs)

  return ({
      'min_diff_data': min_diff_data,
      'original_inputs': original_inputs}, original_labels)

customized_train_with_min_diff_ds = train_with_min_diff_ds.map(_reformat_input)

Os customized_train_with_min_diff_ds conjunto de dados retorna lotes compostos de tuplos (x, y) onde x é um dicionário contendo min_diff_data e original_inputs e y representa os original_labels .

for x, _ in customized_train_with_min_diff_ds.take(1):
  print('Type of x:', type(x))  # dict
  print('Keys of x:', x.keys())  # 'min_diff_data', 'original_inputs'
Type of x: <class 'dict'>
Keys of x: dict_keys(['min_diff_data', 'original_inputs'])

Este formato de dados não é o que MinDiffModel espera por padrão e passando customized_train_with_min_diff_ds para isso resultaria em um comportamento inesperado. Para corrigir isso, você precisará criar sua própria subclasse.

class CustomUnpackingMinDiffModel(min_diff.keras.MinDiffModel):

  def unpack_min_diff_data(self, inputs):
    return inputs['min_diff_data']

  def unpack_original_inputs(self, inputs):
    return inputs['original_inputs']

Com esta subclasse, você pode treinar como nos outros exemplos.

model = tutorials_utils.get_uci_model()
model = CustomUnpackingMinDiffModel(model, min_diff.losses.MMDLoss())

model.compile(optimizer='adam', loss='binary_crossentropy')

_ = model.fit(customized_train_with_min_diff_ds, epochs=1)
77/77 [==============================] - 4s 30ms/step - loss: 0.6690 - min_diff_loss: 0.0395

Limitações de um personalizado MinDiffModel

Criando um costume MinDiffModel fornece uma enorme quantidade de flexibilidade para casos de uso mais complexos. No entanto, ainda existem alguns casos extremos que ele não suportará.

Pré-processamento ou validação de entradas antes de call

A maior limitação para uma subclasse de MinDiffModel é que ele exige a x componente dos dados de entrada (ou seja, o primeiro ou o único elemento no lote devolvido pelo tf.data.Dataset ) para ser passado através de pré-processamento sem validação ou para call .

Isto é simplesmente porque o min_diff_data é embalado para o x componente dos dados de entrada. Qualquer pré-processamento ou validação não vai esperar a estrutura adicional contendo min_diff_data e provavelmente vai quebrar.

Se o pré-processamento ou validação for facilmente personalizável (por exemplo, fatorado em seu próprio método), isso será facilmente resolvido substituindo-o para garantir que ele lide com a estrutura adicional corretamente.

Um exemplo com validação pode ser assim:

class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):

  # Override so that it correctly handles additional `min_diff_data`.
  def validate_inputs(self, inputs):
    original_inputs = self.unpack_original_inputs(inputs)
    ...  # Optionally also validate min_diff_data
    # Call original validate method with correct inputs
    return super(CustomMinDiffModel, self).validate(original_inputs)

Se o pré-processamento ou validação não é facilmente customizável, em seguida, usando MinDiffModel não pode trabalhar para você e você vai precisar para integrar MinDiff sem ele, como descrito no presente guia .

Colisões de nomes de métodos

É possível que o seu modelo possui métodos cujos nomes se chocam com as implementadas em MinDiffModel (ver lista completa de métodos públicos na documentação da API ).

Isso só é problemático se eles forem chamados em uma instância do modelo (em vez de internamente em algum outro método). Embora altamente improvável, se você está nesta situação, você terá que quer substituir e renomear alguns métodos ou, se não for possível, você pode precisar de considerar a integração MinDiff sem MinDiffModel conforme descrito no este guia sobre o assunto .

Recursos adicionais