Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Importowanie modelu Keras do TensorFlow.js

Modele Keras (zazwyczaj tworzone przez Python API) można zapisać w jednym z kilku formatów . Format „całego modelu” można przekonwertować na format warstw TensorFlow.js, który można załadować bezpośrednio do TensorFlow.js w celu wnioskowania lub dalszego szkolenia.

Docelowy format warstw TensorFlow.js to katalog zawierający plik model.json i zestaw podzielonych na fragmenty plików wagi w formacie binarnym. Plik model.json zawiera zarówno topologię modelu (inaczej „architekturę” lub „graf”: opis warstw i sposobu ich połączenia), jak i manifest plików wagi.

Wymagania

Procedura konwersji wymaga środowiska Python; możesz zachować izolację za pomocą pipenv lub virtualenv . Aby zainstalować konwerter, użyj pip install tensorflowjs .

Importowanie modelu Keras do TensorFlow.js to proces dwuetapowy. Najpierw przekonwertuj istniejący model Keras do formatu TF.js Layers, a następnie załaduj go do TensorFlow.js.

Krok 1. Przekonwertuj istniejący model Keras na format TF.js Layers

Modele Keras są zwykle zapisywane za pośrednictwem pliku model.save (ścieżka pliku model.save(filepath) , który tworzy pojedynczy plik HDF5 (.h5) zawierający zarówno topologię modelu, jak i wagi. Aby przekonwertować taki plik na format TF.js Layers, uruchom następującą komendę, gdzie path/to/my_model.h5 to źródłowy plik Keras .h5, a path/to/tfjs_target_dir to docelowy katalog wyjściowy dla plików TF.js:

# bash

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

Alternatywnie: użyj interfejsu API języka Python, aby wyeksportować bezpośrednio do formatu warstw TF.js

Jeśli masz model Keras w Pythonie, możesz wyeksportować go bezpośrednio do formatu TensorFlow.js Layers w następujący sposób:

# Python

import tensorflowjs as tfjs

def train(...):
    model = keras.models.Sequential()   # for example
    ...
    model.compile(...)
    model.fit(...)
    tfjs.converters.save_keras_model(model, tfjs_target_dir)

Krok 2: Załaduj model do TensorFlow.js

Użyj serwera internetowego do obsługi przekonwertowanych plików modelu wygenerowanych w kroku 1. Pamiętaj, że może być konieczne skonfigurowanie serwera tak, aby zezwalał na współdzielenie zasobów między źródłami (CORS) , aby umożliwić pobieranie plików w JavaScript.

Następnie załaduj model do TensorFlow.js, podając adres URL do pliku model.json:

// JavaScript

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');

Teraz model jest gotowy do wnioskowania, oceny lub ponownego szkolenia. Na przykład wczytany model można od razu wykorzystać do wykonania prognozy:

// JavaScript

const example = tf.fromPixels(webcamElement);  // for example
const prediction = model.predict(example);

W wielu przykładach TensorFlow.js zastosowano to podejście, używając wstępnie wytrenowanych modeli, które zostały przekonwertowane i umieszczone w Google Cloud Storage.

Zwróć uwagę, że odwołujesz się do całego modelu, używając nazwy pliku model.json . loadModel(...) pobiera model.json , a następnie model.json dodatkowe żądania HTTP (S) w celu uzyskania podzielonych plików wagi, do których model.json manifest wagi model.json . Takie podejście umożliwia buforowanie wszystkich tych plików przez przeglądarkę (i być może przez dodatkowe serwery pamięci podręcznej w Internecie), ponieważ model.json i fragmenty wagi są mniejsze niż typowy limit rozmiaru pliku pamięci podręcznej. W ten sposób model będzie prawdopodobnie ładował się szybciej przy kolejnych okazjach.

Obsługiwane funkcje

Warstwy TensorFlow.js obecnie obsługują tylko modele Keras używające standardowych konstrukcji Keras. Modele wykorzystujące nieobsługiwane operacje lub warstwy - np. Warstwy niestandardowe, warstwy Lambda, niestandardowe straty lub niestandardowe metryki - nie mogą być automatycznie importowane, ponieważ zależą od kodu Pythona, którego nie można wiarygodnie przetłumaczyć na JavaScript.