השתמש בדגם מיומן מראש

במדריך זה תחקור יישום אינטרנט לדוגמה המדגים למידת העברה באמצעות TensorFlow.js Layers API. הדוגמה טוענת מודל שהוכשר מראש ולאחר מכן מכשירה מחדש את המודל בדפדפן.

המודל עבר הכשרה מראש ב-Python על הספרות 0-4 של מערך הנתונים של סיווג הספרות של MNIST . ההכשרה מחדש (או העברה למידה) בדפדפן משתמשת בספרות 5-9. הדוגמה מראה שניתן להשתמש במספר השכבות הראשונות של מודל שהוכשר מראש כדי לחלץ תכונות מנתונים חדשים במהלך למידת העברה, ובכך לאפשר אימון מהיר יותר על הנתונים החדשים.

היישום לדוגמה עבור מדריך זה זמין באינטרנט , כך שאינך צריך להוריד שום קוד או להגדיר סביבת פיתוח. אם תרצה להפעיל את הקוד באופן מקומי, השלם את השלבים האופציונליים ב- הפעל את הדוגמה באופן מקומי . אם אינך רוצה להגדיר סביבת פיתוח, תוכל לדלג אל חקור את הדוגמה .

הקוד לדוגמה זמין ב- GitHub .

(אופציונלי) הפעל את הדוגמה באופן מקומי

דרישות מוקדמות

כדי להפעיל את האפליקציה לדוגמה באופן מקומי, עליך להתקין את הדברים הבאים בסביבת הפיתוח שלך:

התקן והפעל את האפליקציה לדוגמה

  1. שכפל או הורד את מאגר tfjs-examples .
  2. עבור לספריית mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. תלות בהתקנה:

    yarn
    
  4. הפעל את שרת הפיתוח:

    yarn run watch
    

חקור את הדוגמה

פתח את האפליקציה לדוגמה . (או, אם אתה מפעיל את הדוגמה באופן מקומי, עבור אל http://localhost:1234 בדפדפן שלך.)

אתה אמור לראות עמוד שכותרתו MNIST CNN Transfer Learning . עקוב אחר ההוראות כדי לנסות את האפליקציה.

הנה כמה דברים שכדאי לנסות:

  • התנסו עם מצבי האימון השונים והשוו אובדן ודיוק.
  • בחר דוגמאות שונות של מפת סיביות ובדוק את הסתברויות הסיווג. שים לב שהמספרים בכל דוגמה של מפת סיביות הם ערכי מספרים שלמים בגווני אפור המייצגים פיקסלים מתמונה.
  • ערוך את ערכי המספרים השלמים של מפת הסיביות וראה כיצד השינויים משפיעים על הסתברויות הסיווג.

חקור את הקוד

אפליקציית האינטרנט לדוגמה טוענת מודל שעבר הכשרה מראש על תת-קבוצה של מערך הנתונים של MNIST. האימון המקדים מוגדר בתוכנית Python: mnist_transfer_cnn.py . תוכנית Python היא מחוץ לתחום עבור הדרכה זו, אך כדאי להסתכל עליה אם תרצה לראות דוגמה להמרת מודל .

הקובץ index.js מכיל את רוב קוד ההדרכה של ההדגמה. כאשר index.js פועל בדפדפן, פונקציית הגדרה, setupMnistTransferCNN , מפעילה ומאתחלת את MnistTransferCNNPredictor , המקופלת את שגרות האימון מחדש והניבוי.

שיטת האתחול, MnistTransferCNNPredictor.init , טוענת מודל, טוענת נתוני אימון מחדש ויוצרת נתוני בדיקה. הנה השורה שמטעינה את הדגם:

this.model = await loader.loadHostedPretrainedModel(urls.model);

אם תסתכל על ההגדרה של loader.loadHostedPretrainedModel , תראה שהיא מחזירה את התוצאה של קריאה ל- tf.loadLayersModel . זהו ה-API של TensorFlow.js לטעינת מודל המורכב מאובייקטי Layer.

היגיון האימון מחדש מוגדר ב- MnistTransferCNNPredictor.retrainModel . אם המשתמש בחר בשכבות תכונה הקפאה כמצב האימון, 7 השכבות הראשונות של מודל הבסיס מוקפאות, ורק 5 השכבות האחרונות מאומנות על נתונים חדשים. אם המשתמש בחר אתחול מחדש , כל המשקולות מאופסות, והאפליקציה מאמנת את הדגם מאפס.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

לאחר מכן המודל נערך , ולאחר מכן הוא מאומן על נתוני הבדיקה באמצעות model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

למידע נוסף על הפרמטרים model.fit() , עיין בתיעוד ה-API .

לאחר הכשרה על מערך הנתונים החדש (ספרות 5-9), ניתן להשתמש במודל לביצוע תחזיות. שיטת MnistTransferCNNPredictor.predict עושה זאת באמצעות model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

שימו לב לשימוש ב- tf.tidy , שעוזר במניעת דליפות זיכרון.

למד עוד

מדריך זה חקר אפליקציה לדוגמה המבצעת למידת העברה בדפדפן באמצעות TensorFlow.js. עיין במשאבים למטה כדי ללמוד עוד על מודלים שהוכשרו מראש ולמידה בהעברה.

TensorFlow.js

TensorFlow Core