Ao treinar um modelo de aprendizado de máquina, é comum ter um loop onde os dados de treinamento são ingeridos (ou gerados), os lotes são executados em um modelo, os gradientes são obtidos e o modelo é atualizado por meio de um otimizador. Embora você possa escrever seu próprio loop de treinamento para cada aplicativo de treinamento, o Swift para TensorFlow fornece uma abstração de loop de treinamento experimental que pode simplificar esse processo.
O módulo TrainingLoop
no repositório de modelos contém a versão atual deste loop de treinamento experimental generalizado. Ele é estruturado de forma a se integrar a wrappers de conjuntos de dados em conformidade com a API Epochs para facilitar a ingestão de dados e a automatizar a interação de modelos, conjuntos de dados e otimizadores com back-ends de aceleradores para obter desempenho ideal. A personalização pesada do processo de treinamento pode ser alcançada por meio do uso de retornos de chamada.
A maioria dos exemplos baseados em imagens no repositório do modelo foram convertidos para usar esta abstração do loop de treinamento, bem como os exemplos de treinamento do modelo de texto supervisionado. No entanto, o ciclo de formação pode não ser apropriado na sua concepção actual para todos os modelos de aprendizagem automática.
A implementação do loop de treinamento generalizado do Swift para TensorFlow é fortemente influenciada pelo Learner do fastai . Para obter mais informações sobre seu design, consulte "fastai: Uma API em camadas para aprendizado profundo" e a apresentação de Sylvain Gugger "Fast.ai - Um ciclo de treinamento infinitamente personalizável" .
Uso
O exemplo ResNet-CIFAR10 fornece uma boa demonstração de como usar esse ciclo de treinamento na prática. Primeiro, importe o módulo:
import TrainingLoop
em seguida, escolha um back-end do acelerador configurando um Device
. Nesse caso, selecionaremos o backend X10 baseado em XLA e usaremos o primeiro acelerador disponível:
let device = Device.defaultXLA
A próxima etapa é configurar o conjunto de dados, o modelo e o otimizador para usar com seu loop de treinamento:
let dataset = CIFAR10(batchSize: 10, on: device)
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
var optimizer = SGD(for: model, learningRate: 0.001)
e, em seguida, configure o ciclo de treinamento:
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
metrics: [.accuracy])
O loop de treinamento pressupõe que o conjunto de dados que você está usando está em conformidade com a API Epochs e permite especificar quais divisões do conjunto de dados usar para treinamento e validação. Qualquer função de perda pode ser usada uma vez colocada em um wrapper compatível, como softmaxCrossEntropy
está aqui .
As métricas atuais que podem ser capturadas incluem:
-
loss
-
accuracy
-
top5Accuracy
-
matthewsCorrelationCoefficient
-
perplexity
Por fim, para realizar o treinamento, você chama o seguinte:
try! trainingLoop.fit(&model, epochs: 10, on: device)
Isso treinará o modelo por 10 épocas usando o back-end do acelerador que especificamos. As estatísticas serão exibidas durante o treinamento no console usando um prompt animado.
Retornos de chamada
A customização desse loop de treinamento generalizado ocorre por meio do uso de retornos de chamada. Esses retornos de chamada podem ser conectados a vários pontos do loop.
Vários retornos de chamada integrados fornecem funcionalidades que podem ser adicionadas a qualquer loop de treinamento. Estes incluem:
- Registrando estatísticas em arquivos de valores separados por vírgula (CSV)
- Ajustando a taxa de aprendizagem de acordo com uma programação personalizada
- Monitoramento e gráficos do progresso do treinamento via TensorBoard
Além disso, você pode criar seus próprios retornos de chamada personalizados para adicionar uma variedade de funcionalidades adicionais a um loop de treinamento padrão.
Registro CSV
A classe CSVLogger
encapsula um retorno de chamada que gravará estatísticas de treinamento em um formato de valores separados por vírgula em um arquivo de sua escolha. Este arquivo começará com colunas denominadas epoch
, batch
e quaisquer métricas que você tenha habilitado em seu loop de treinamento. Uma linha será então escrita para cada lote, com os valores atuais dessas colunas.
Para adicionar log CSV ao seu loop de treinamento, adicione algo como o seguinte a uma matriz de retornos de chamada fornecidos ao parâmetro callbacks:
para seu TrainingLoop
:
try! CSVLogger(path: "file.csv").log
Por exemplo, a amostra LeNet-MNIST
usa isso em seu loop de treinamento.
Programações de taxas de aprendizagem
É comum, ao treinar um modelo, alterar a taxa de aprendizado fornecida a um otimizador durante o processo de treinamento. Isto pode ser tão simples como uma diminuição linear ao longo do tempo, ou tão complexo como ciclos de aquecimento e declínio descritos por funções complicadas.
O retorno de chamada learningRateScheduler
fornece os meios de descrever programações de taxas de aprendizagem compostas de diferentes segmentos, cada um com seu próprio formato distinto. Isso é feito definindo um LearningRateSchedule
composto de ScheduleSegment
s, cada um com um Shape
definido por uma função, uma taxa de aprendizado inicial e uma taxa de aprendizado final.
Por exemplo, a amostra BERT-CoLA utiliza um aumento linear na taxa de aprendizagem durante um período de aquecimento e uma diminuição linear depois disso. Para fazer isso, o retorno de chamada do cronograma de taxa de aprendizagem é definido da seguinte forma:
learningRateScheduler(
schedule: makeSchedule(
[
ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
ScheduleSegment(shape: linear, endRate: 0)
]
)
)
Os dois ScheduleSegment
s definem uma taxa de aprendizado que começa em 0 e aumenta linearmente até peakLearningRate
em uma série de 10 etapas discretas, depois começa na taxa de aprendizado final da etapa anterior e diminui linearmente até 0 no final do processo de treinamento.
Integração TensorBoard
TensorBoard é uma ferramenta de visualização poderosa para monitorar o treinamento do modelo, analisar o treinamento quando concluído ou comparar execuções de treinamento. Swift para TensorFlow oferece suporte à visualização do TensorBoard por meio do uso do módulo TensorBoard
no repositório de modelos, que fornece retornos de chamada que registram métricas de treinamento.
O exemplo GPT2-WikiText2 ilustra como adicionar o log do TensorBoard ao treinamento do seu modelo. Primeiro, importe o módulo TensorBoard
. Então é tão simples quanto adicionar tensorBoardStatisticsLogger()
aos retornos de chamada do seu TrainingLoop
callbacks:
array.
Por padrão, isso registrará cada execução de treinamento em um diretório run/tensorboard/stats
. Para visualizar isso no Tensorboard, execute
tensorboard --logdir ./run/tensorboard/stats
e o TensorBoard deve iniciar um servidor local onde você possa visualizar suas métricas de treinamento. Os resultados do treinamento e da validação devem ser mostrados separadamente, e cada execução possui um carimbo de data/hora exclusivo para permitir fácil comparação entre várias execuções do mesmo modelo.
O design da integração Swift para TensorFlow TensorBoard foi inspirado em tensorboardX . Os retornos de chamada do TensorBoard criam diretamente os eventos apropriados e os buffers de protocolo de resumo e os gravam em um arquivo de log durante o treinamento.
Retornos de chamada personalizados
Além dos retornos de chamada integrados descritos acima, você tem a capacidade de personalizar a função dos loops de treinamento criando seus próprios retornos de chamada. Esses retornos de chamada são funções que possuem uma assinatura semelhante à seguinte:
func customCallback<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
{
if event == .updateStart {
...
}
}
O loop de treinamento e o estado associado são passados como o primeiro parâmetro. A parte atual do loop à qual o retorno de chamada está respondendo é fornecida por meio de event
. O evento do loop de treinamento possui um dos seguintes estados, cada um correspondendo a um ponto diferente no ciclo de vida do loop:
-
fitStart
-
fitEnd
-
epochStart
-
epochEnd
-
trainingStart
-
trainingEnd
-
validationStart
-
validationEnd
-
batchStart
-
batchEnd
-
updateStart
-
inferencePredictionEnd
Sua função de retorno de chamada pode optar por ativar sua lógica em qualquer combinação dos estados acima, o que permite extrair dados ou controlar o loop de treinamento de várias maneiras.