This tutorial will describe how to set up high-performance simulation using a TFF runtime running on Kubernetes.
This tutorial refers to Google Cloud's GKE to create the Kubernetes cluster, but all the steps after the cluster is created can be used with any Kubernetes installation.
![]() |
![]() |
![]() |
![]() |
Launch the TFF Workers on GKE
Create a Kubernetes Cluster
The following step only needs to be done once. The cluster can be re-used for future workloads.
Follow the GKE instructions to create a container cluster. The rest of this tutorial assumes that the cluster is named tff-cluster
, but the actual name isn't important.
Stop following the instructions when you get to "Step 5: Deploy your application".
Deploy the TFF Worker Application
The commands to interact with GCP can be run locally or in the Google Cloud Shell. We recommend the Google Cloud Shell since it doesn't require additional setup.
- Run the following command to launch the Kubernetes application.
kubectl create deployment tff-workers --image=gcr.io/tensorflow-federated/remote-executor-service:latest
- Add a load balancer for the application.
kubectl expose deployment tff-workers --type=LoadBalancer --port 80 --target-port 8000
Look up the IP address of the loadbalancer on the Google Cloud Console. You'll need it later to connect the training loop to the worker app.
(Alternately) Launch the Docker Container Locally
docker run --rm -p 8000:8000 gcr.io/tensorflow-federated/remote-executor-service:latest
Set Up TFF Environment
!pip install --quiet --upgrade tensorflow-federated
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()
Define the Model to Train
import collections
import time
import tensorflow as tf
import tensorflow_federated as tff
source, _ = tff.simulation.datasets.emnist.load_data()
def map_fn(example):
return collections.OrderedDict(
x=tf.reshape(example['pixels'], [-1, 784]), y=example['label'])
def client_data(n):
ds = source.create_tf_dataset_for_client(source.client_ids[n])
return ds.repeat(10).batch(20).map(map_fn)
train_data = [client_data(n) for n in range(10)]
input_spec = train_data[0].element_spec
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
tf.keras.layers.Softmax(),
])
return tff.learning.from_keras_model(
model,
input_spec=input_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.algorithms.build_weighted_fed_avg(
model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02))
def evaluate(num_rounds=10):
state = trainer.initialize()
for round in range(num_rounds):
t1 = time.time()
result = trainer.next(state, train_data)
state = result.state
train_metrics = result.metrics['client_work']['train']
t2 = time.time()
print('Round {}: loss {}, round time {}'.format(
round, train_metrics['loss'], t2 - t1))
Set Up the Remote Executors
By default, TFF executes all computations locally. In this step we tell TFF to connect to the Kubernetes services we set up above. Be sure to copy the IP address of your service here.
import grpc
ip_address = '0.0.0.0'
port = 80
channels = [grpc.insecure_channel(f'{ip_address}:{port}') for _ in range(10)]
tff.backends.native.set_remote_python_execution_context(channels)
Run Training
evaluate()
Round 0: loss 4.370407581329346, round time 4.201097726821899 Round 1: loss 4.1407670974731445, round time 3.3283166885375977 Round 2: loss 3.865147590637207, round time 3.098310947418213 Round 3: loss 3.534019708633423, round time 3.1565616130828857 Round 4: loss 3.272688388824463, round time 3.175067663192749 Round 5: loss 2.935391664505005, round time 3.008434534072876 Round 6: loss 2.7399251461029053, round time 3.31435227394104 Round 7: loss 2.5054931640625, round time 3.4411356449127197 Round 8: loss 2.290508985519409, round time 3.158798933029175 Round 9: loss 2.1194536685943604, round time 3.1348156929016113