از یک مدل از قبل آموزش دیده استفاده کنید

در این آموزش شما یک نمونه برنامه وب را بررسی خواهید کرد که یادگیری انتقال را با استفاده از TensorFlow.js Layers API نشان می دهد. مثال یک مدل از پیش آموزش دیده را بارگیری می کند و سپس مدل را در مرورگر دوباره آموزش می دهد.

این مدل در پایتون روی ارقام 0-4 مجموعه داده طبقه بندی ارقام MNIST از قبل آموزش داده شده است. آموزش مجدد (یا انتقال یادگیری) در مرورگر از ارقام 5-9 استفاده می کند. این مثال نشان می دهد که چندین لایه اول یک مدل از پیش آموزش دیده می تواند برای استخراج ویژگی ها از داده های جدید در طول یادگیری انتقال استفاده شود، بنابراین امکان آموزش سریعتر روی داده های جدید فراهم می شود.

برنامه نمونه برای این آموزش به صورت آنلاین در دسترس است، بنابراین نیازی به دانلود کد یا راه اندازی یک محیط توسعه ندارید. اگر می خواهید کد را به صورت محلی اجرا کنید، مراحل اختیاری را در اجرای مثال به صورت محلی کامل کنید. اگر نمی‌خواهید یک محیط توسعه راه‌اندازی کنید، می‌توانید به کاوش مثال بروید.

کد نمونه در GitHub موجود است.

(اختیاری) مثال را به صورت محلی اجرا کنید

پیش نیازها

برای اجرای برنامه نمونه به صورت محلی، باید موارد زیر را در محیط توسعه خود نصب کنید:

برنامه نمونه را نصب و اجرا کنید

  1. مخزن tfjs-examples کلون یا دانلود کنید.
  2. به دایرکتوری mnist-transfer-cnn تغییر دهید:

    cd tfjs-examples/mnist-transfer-cnn
    
  3. نصب وابستگی ها:

    yarn
    
  4. سرور توسعه را راه اندازی کنید:

    yarn run watch
    

مثال را بررسی کنید

برنامه نمونه را باز کنید . (یا اگر مثال را به صورت محلی اجرا می کنید، به http://localhost:1234 در مرورگر خود بروید.)

شما باید صفحه ای با عنوان MNIST CNN Transfer Learning را ببینید. دستورالعمل ها را دنبال کنید تا برنامه را امتحان کنید.

در اینجا چند چیز برای امتحان وجود دارد:

  • با حالت‌های مختلف تمرین آزمایش کنید و ضرر و دقت را مقایسه کنید.
  • نمونه های مختلف بیت مپ را انتخاب کنید و احتمالات طبقه بندی را بررسی کنید. توجه داشته باشید که اعداد در هر نمونه بیت مپ مقادیر صحیح مقیاس خاکستری هستند که پیکسل های یک تصویر را نشان می دهند.
  • مقادیر عدد صحیح بیت مپ را ویرایش کنید و ببینید که چگونه تغییرات بر احتمالات طبقه بندی تاثیر می گذارد.

کد را کاوش کنید

برنامه وب مثال مدلی را بارگیری می کند که از قبل روی زیرمجموعه ای از مجموعه داده MNIST آموزش داده شده است. پیش‌آموزش در یک برنامه پایتون تعریف شده است: mnist_transfer_cnn.py . برنامه پایتون برای این آموزش خارج از محدوده است، اما اگر می‌خواهید نمونه‌ای از تبدیل مدل را ببینید، ارزش دیدن آن را دارد.

فایل index.js شامل اکثر کدهای آموزشی نسخه ی نمایشی است. هنگامی که index.js در مرورگر اجرا می‌شود، یک تابع راه‌اندازی، setupMnistTransferCNN ، MnistTransferCNNPredictor را نمونه‌سازی و مقداردهی اولیه می‌کند، که روال‌های بازآموزی و پیش‌بینی را محصور می‌کند.

روش اولیه سازی، MnistTransferCNNPredictor.init ، یک مدل را بارگذاری می کند، داده های بازآموزی را بارگیری می کند و داده های آزمایشی را ایجاد می کند. این خطی است که مدل را بارگذاری می کند:

this.model = await loader.loadHostedPretrainedModel(urls.model);

اگر به تعریف loader.loadHostedPretrainedModel نگاه کنید، خواهید دید که نتیجه تماس به tf.loadLayersModel را برمی گرداند. این API TensorFlow.js برای بارگیری مدلی متشکل از اشیاء لایه است.

منطق بازآموزی در MnistTransferCNNPredictor.retrainModel تعریف شده است. اگر کاربر لایه‌های ویژگی Freeze را به عنوان حالت آموزشی انتخاب کرده باشد، 7 لایه اول مدل پایه ثابت می‌شوند و تنها 5 لایه نهایی بر روی داده‌های جدید آموزش داده می‌شوند. اگر کاربر وزن‌های اولیه را انتخاب کرده باشد، همه وزن‌ها بازنشانی می‌شوند و برنامه به طور موثر مدل را از ابتدا آموزش می‌دهد.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

سپس مدل کامپایل می شود و سپس با استفاده از model.fit() بر روی داده های تست آموزش داده می شود:

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

برای کسب اطلاعات بیشتر در مورد پارامترهای model.fit() به مستندات API مراجعه کنید.

پس از آموزش بر روی مجموعه داده جدید (رقم های 5-9)، این مدل می تواند برای پیش بینی استفاده شود. متد MnistTransferCNNPredictor.predict این کار را با استفاده از model.predict() انجام می دهد:

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

به استفاده از tf.tidy توجه کنید که به جلوگیری از نشت حافظه کمک می کند.

بیشتر بدانید

این آموزش یک برنامه نمونه را بررسی کرده است که یادگیری انتقال را در مرورگر با استفاده از TensorFlow.js انجام می دهد. برای کسب اطلاعات بیشتر در مورد مدل های از پیش آموزش دیده و انتقال یادگیری، منابع زیر را بررسی کنید.

TensorFlow.js

هسته TensorFlow