![]() | ![]() | ![]() | ![]() |
Oprócz bycia częścią ekosystemu TensorFlow, TFF ma na celu umożliwienie interoperacyjności z innymi frontendowymi i backendowymi frameworkami ML. W tej chwili wsparcie dla innych frameworków ML jest jeszcze w fazie inkubacji, a API i obsługiwana funkcjonalność mogą ulec zmianie (głównie w funkcji zapotrzebowania ze strony użytkowników TFF). W tym samouczku opisano, jak używać TFF z JAX jako alternatywnego frontendu ML oraz kompilatora XLA jako alternatywnego backendu. Przedstawione tutaj przykłady są oparte na całkowicie natywnym stosie JAX/XLA, kompleksowo. Możliwość mieszania kodu w różnych frameworkach (np. JAX z TensorFlow) zostanie omówiona w jednym z przyszłych tutoriali.
Jak zawsze, czekamy na Twój wkład. Jeśli wsparcie dla JAX/XLA lub możliwość współdziałania z innymi frameworkami ML jest dla Ciebie ważne, rozważ pomoc nam w rozwijaniu tych możliwości w kierunku zgodności z resztą TFF.
Zanim zaczniemy
Zapoznaj się z główną treścią dokumentacji TFF, aby dowiedzieć się, jak skonfigurować swoje środowisko. W zależności od tego, gdzie prowadzisz ten samouczek, możesz chcieć odkomentować i uruchomić część lub całość poniższego kodu.
# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()
Ten samouczek zakłada również, że zapoznałeś się z podstawowymi samouczkami TFF TensorFlow i znasz podstawowe koncepcje TFF. Jeśli jeszcze tego nie zrobiłeś, rozważ przejrzenie przynajmniej jednego z nich.
Obliczenia JAX
Obsługa JAX w TFF została zaprojektowana tak, aby była symetryczna do sposobu, w jaki TFF współdziała z TensorFlow, zaczynając od importu:
import jax
import numpy as np
import tensorflow_federated as tff
Podobnie jak w przypadku TensorFlow, podstawą wyrażenia dowolnego kodu TFF jest logika działająca lokalnie. Można wyrazić tę logikę w JAX, jak pokazano poniżej, używając @tff.experimental.jax_computation
opakowanie. Zachowuje się podobnie jak @tff.tf_computation
że teraz Twoje są znane. Zacznijmy od czegoś prostego, np. obliczenia, które dodaje dwie liczby całkowite:
@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
return jax.numpy.add(x, y)
Możesz użyć obliczeń JAX zdefiniowanych powyżej, tak jak normalnie używasz obliczeń TFF. Na przykład możesz sprawdzić jego sygnaturę typu w następujący sposób:
str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'
Zauważ, że użyliśmy np.int32
aby określić typ argumentów. TFF nie rozróżnia typy NumPy (takich jak np.int32
) i typ TensorFlow (np tf.int32
). Z perspektywy TFF to tylko sposoby na odniesienie się do tego samego.
Teraz pamiętaj, że TFF to nie Python (a jeśli to nie przeszkadza, zapoznaj się z niektórymi z naszych wcześniejszych samouczków, np. o algorytmach niestandardowych). Można użyć @tff.experimental.jax_computation
otoki z dowolnym JAX kod, który można prześledzić i odcinkach, czyli z kodem, który normalnie annotate z @jax.jit
oczekuje się być kompilowane do XLA (ale nie muszą faktycznie użyć @jax.jit
adnotacji umieścić swój kod JAX w TFF).
Rzeczywiście, pod maską, TFF natychmiast kompiluje obliczenia JAX do XLA. Można to sprawdzić na własne oczy ręcznie wyodrębnianie i drukowanie odcinkach kod XLA z add_numbers
, co następuje:
comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7 ENTRY xla_computation_add_numbers.7 { constant.4 = pred[] constant(false) parameter.1 = (s32[], s32[]) parameter(0) get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0 get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1 add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3) ROOT tuple.6 = (s32[]) tuple(add.5) }
Myśleć reprezentacji obliczeń JAX jak kod XLA jako funkcjonalny równoważnik tf.GraphDef
obliczeń wyrażone w TensorFlow. Jest przenośny i wykonywalny w różnych środowiskach, które obsługują XLA, podobnie jak tf.GraphDef
mogą być wykonywane w dowolnym czasie wykonywania TensorFlow.
TFF zapewnia stos runtime oparty na kompilatorze XLA jako backend. Możesz go aktywować w następujący sposób:
tff.backends.xla.set_local_python_execution_context()
Teraz możesz wykonać obliczenia, które zdefiniowaliśmy powyżej:
add_numbers(2, 3)
5
Wystarczająco łatwe. Pójdźmy za ciosem i zróbmy coś bardziej skomplikowanego, na przykład MNIST.
Przykład szkolenia MNIST z interfejsem API w puszce
Jak zwykle zaczynamy od zdefiniowania kilku typów TFF dla partii danych i dla modelu (pamiętaj, że TFF jest frameworkiem o silnym typie).
import collections
BATCH_TYPE = collections.OrderedDict([
('pixels', tff.TensorType(np.float32, (50, 784))),
('labels', tff.TensorType(np.int32, (50,)))
])
MODEL_TYPE = collections.OrderedDict([
('weights', tff.TensorType(np.float32, (784, 10))),
('bias', tff.TensorType(np.float32, (10,)))
])
Teraz zdefiniujmy funkcję straty dla modelu w JAX, biorąc model i pojedynczą partię danych jako parametr:
def loss(model, batch):
y = jax.nn.softmax(
jax.numpy.add(
jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))
Teraz jednym sposobem jest użycie gotowego interfejsu API. Oto przykład, w jaki sposób możesz użyć naszego API do stworzenia procesu szkoleniowego w oparciu o właśnie zdefiniowaną funkcję straty.
STEP_SIZE = 0.001
trainer = tff.experimental.learning.build_jax_federated_averaging_process(
BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)
Można użyć wyżej tak jak byłoby użyć kompilacji trener z tf.Keras
modelu w TensorFlow. Na przykład, oto jak możesz stworzyć początkowy model do szkolenia:
initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])
Aby przeprowadzić faktyczne szkolenie, potrzebujemy pewnych danych. Zróbmy losowe dane, aby było to proste. Ponieważ dane są losowe, będziemy oceniać na danych uczących, ponieważ w przeciwnym razie, przy losowych danych eval, trudno byłoby oczekiwać, że model będzie działał. Ponadto w przypadku tego demo na małą skalę nie będziemy się martwić losowym próbkowaniem klientów (pozostawiamy użytkownikowi ćwiczenie tego typu zmian, postępując zgodnie z szablonami z innych samouczków):
def random_batch():
pixels = np.random.uniform(
low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
return collections.OrderedDict([('pixels', pixels), ('labels', labels)])
NUM_CLIENTS = 2
NUM_BATCHES = 10
train_data = [
[random_batch() for _ in range(NUM_BATCHES)]
for _ in range(NUM_CLIENTS)]
Dzięki temu możemy wykonać jeden krok szkolenia, jak następuje:
trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05, 2.54597180e-05, ..., 5.61640409e-05, -5.32875274e-05, -4.62881755e-04], [ 7.30908650e-05, 4.67643113e-05, 2.03352147e-06, ..., 3.77510623e-05, 3.52839161e-05, -4.59865667e-04], [ 8.14835730e-05, 3.03147244e-05, -1.89143739e-05, ..., 1.12527239e-04, 4.09212225e-06, -4.59960109e-04], ..., [ 9.23552434e-05, 2.44302555e-06, -2.20817346e-05, ..., 7.61375341e-05, 1.76906979e-05, -4.43495519e-04], [ 1.17451040e-04, 2.47748958e-05, 1.04728279e-05, ..., 5.26388249e-07, 7.21131510e-05, -4.67137404e-04], [ 3.75041491e-05, 6.58061981e-05, 1.14522081e-05, ..., 2.52584141e-05, 3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04, 2.6502126e-05, -1.9462314e-05, 8.1269856e-05, 2.1832302e-04, 1.6636557e-04, 1.2815947e-04, 9.0642272e-05, 7.7109929e-05, -9.1987278e-04], dtype=float32))])
Oceńmy wynik etapu szkolenia. Aby było to łatwe, możemy ocenić to w sposób scentralizowany:
import itertools
eval_data = list(itertools.chain.from_iterable(train_data))
def average_loss(model, data):
return np.mean([loss(model, batch) for batch in data])
print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854 2.282762
Strata maleje. Świetny! Przeanalizujmy teraz kilka rund:
NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
trained_model = trainer.next(trained_model, train_data)
print(average_loss(trained_model, eval_data))
2.2685437 2.257856 2.2495182 2.2428129 2.2372835 2.2326245 2.2286277 2.2251441 2.2220676 2.219318 2.2168345 2.2145717 2.2124937 2.2105706 2.2087805 2.2071042 2.2055268 2.2040353 2.2026198 2.2012706
Jak widać, używanie JAX z TFF nie różni się aż tak bardzo, chociaż eksperymentalne API nie są jeszcze na równi z funkcjonalnością API TensorFlow.
Pod maską
Jeśli wolisz nie używać naszego standardowego interfejsu API, możesz zaimplementować własne niestandardowe obliczenia, podobnie jak w samouczkach dotyczących niestandardowych algorytmów dla TensorFlow, z wyjątkiem tego, że użyjesz mechanizmu JAX do opadania gradientu. Na przykład poniżej znajduje się sposób zdefiniowania obliczenia JAX, które aktualizuje model w pojedynczej minipartii:
@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
grads = jax.grad(loss)(model, batch)
return collections.OrderedDict([
(k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
])
Oto jak możesz sprawdzić, czy to działa:
sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854 2.2977567
Jedno zastrzeżenie pracy z JAX jest to, że nie oferuje równowartość tf.data.Dataset
. Tak więc, aby iterować po zbiorach danych, będziesz musiał użyć deklaratywnych konstrukcji TFF dla operacji na sekwencjach, takich jak ta pokazana poniżej:
@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
return tff.sequence_reduce(batches, model, train_on_one_batch)
Zobaczmy, że to działa:
sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854 2.2284968
Obliczenie, które wykonuje pojedynczą rundę szkolenia, wygląda dokładnie tak samo, jak w samouczkach TensorFlow:
@tff.federated_computation(
tff.FederatedType(MODEL_TYPE, tff.SERVER),
tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
locally_trained_models = tff.federated_map(
train_on_one_client,
collections.OrderedDict([
('model', tff.federated_broadcast(model)),
('batches', federated_data)]))
return tff.federated_mean(locally_trained_models)
Zobaczmy, że to działa:
trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854 2.282762
Jak widać, używanie JAX w TFF, czy to za pośrednictwem gotowych interfejsów API, czy bezpośrednio przy użyciu niskopoziomowych konstrukcji TFF, jest podobne do używania TFF z TensorFlow. Bądź na bieżąco z przyszłymi aktualizacjami, a jeśli chcesz zobaczyć lepszą obsługę interoperacyjności w ramach ML, wyślij nam żądanie ściągnięcia!