Este documento proporciona una descripción general sobre la poda de modelos para ayudarle a determinar cómo se adapta a su caso de uso.
- Para profundizar en un ejemplo de un extremo a otro, consulte el ejemplo de Poda con Keras .
- Para encontrar rápidamente las API que necesita para su caso de uso, consulte la guía completa de poda .
- Para explorar la aplicación de la poda para la inferencia en el dispositivo, consulte Poda para la inferencia en el dispositivo con XNNPACK .
- Para ver un ejemplo de poda estructural, ejecute el tutorial Poda estructural con dispersión de 2 por 4 .
Descripción general
La poda de peso basada en la magnitud reduce gradualmente a cero los pesos del modelo durante el proceso de entrenamiento para lograr la escasez del modelo. Los modelos dispersos son más fáciles de comprimir y podemos omitir los ceros durante la inferencia para mejorar la latencia.
Esta técnica aporta mejoras mediante la compresión del modelo. En el futuro, el soporte del marco para esta técnica proporcionará mejoras en la latencia. Hemos visto mejoras de hasta 6 veces en la compresión del modelo con una pérdida mínima de precisión.
La técnica se está evaluando en varias aplicaciones de voz, como el reconocimiento de voz y la conversión de texto a voz, y se ha experimentado con varios modelos de visión y traducción.
Matriz de compatibilidad de API
Los usuarios pueden aplicar la poda con las siguientes API:
- Construcción de modelos:
keras
solo con modelos Secuenciales y Funcionales - Versiones de TensorFlow: TF 1.x para las versiones 1.14+ y 2.x.
- No se admiten
tf.compat.v1
con un paquete TF 2.X ytf.compat.v2
con un paquete TF 1.X.
- No se admiten
- Modo de ejecución de TensorFlow: gráfico y ansioso
- Entrenamiento distribuido:
tf.distribute
con solo ejecución de gráficos
Está en nuestra hoja de ruta agregar soporte en las siguientes áreas:
Resultados
Clasificación de imágenes
Modelo | Precisión Top-1 no escasa | Precisión escasa aleatoria | Escasez aleatoria | Precisión escasa estructurada | Escasez estructurada |
---|---|---|---|---|---|
InicioV3 | 78,1% | 78,0% | 50% | 75,8% | 2 por 4 |
76,1% | 75% | ||||
74,6% | 87,5% | ||||
MóvilnetV1 224 | 71,04% | 70,84% | 50% | 67,35% | 2 por 4 |
MóvilnetV2 224 | 71,77% | 69,64% | 50% | 66,75% | 2 por 4 |
Los modelos fueron probados en Imagenet.
Traducción
Modelo | BLEU no disperso | BLEU escaso | Escasez |
---|---|---|---|
GNMT EN-DE | 26,77 | 26,86 | 80% |
26,52 | 85% | ||
26.19 | 90% | ||
GNMT DE-ES | 29,47 | 29,50 | 80% |
29.24 | 85% | ||
28,81 | 90% |
Los modelos utilizan el conjunto de datos WMT16 en alemán e inglés con news-test2013 como conjunto de desarrollo y news-test2015 como conjunto de prueba.
Modelo de detección de palabras clave
DS-CNN-L es un modelo de detección de palabras clave creado para dispositivos perimetrales. Se puede encontrar en el repositorio de ejemplos del software ARM.
Modelo | Precisión no escasa | Precisión dispersa estructurada (patrón de 2 por 4) | Precisión dispersa aleatoria (escasez objetivo 50%) |
---|---|---|---|
DS-CNN-L | 95,23 | 94,33 | 94,84 |
Ejemplos
Además del tutorial Podar con Keras , consulte los siguientes ejemplos:
- Entrene un modelo CNN en la tarea de clasificación de dígitos escritos a mano MNIST con poda: código
- Entrene un LSTM en la tarea de clasificación de sentimientos de IMDB con poda: código
Para obtener información general, consulte Podar o no podar: exploración de la eficacia de la poda para la compresión de modelos [ artículo ].