High-performance Simulation with Kubernetes

This tutorial will describe how to set up high-performance simulation using a TFF runtime running on Kubernetes. The model is the same as in the previous tutorial, High-performance simulations with TFF. The only difference is that here we use a worker pool instead of a local executor.

This tutorial refers to Google Cloud's GKE to create the Kubernetes cluster, but all the steps after the cluster is created can be used with any Kubernetes installation.

TFF 환경 설정

!pip install --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio

import nest_asyncio

훈련할 모델 정의하기

import collections
import time

import tensorflow as tf
import tensorflow_federated as tff

source, _ = tff.simulation.datasets.emnist.load_data()

def map_fn(example):
  return collections.OrderedDict(
      x=tf.reshape(example['pixels'], [-1, 784]), y=example['label'])

def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds.repeat(10).batch(20).map(map_fn)

train_data = [client_data(n) for n in range(10)]
input_spec = train_data[0].element_spec

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
  return tff.learning.from_keras_model(

trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02))

def evaluate(num_rounds=10):
  state = trainer.initialize()
  for round in range(num_rounds):
    t1 = time.time()
    state, metrics = trainer.next(state, train_data)
    t2 = time.time()
    print('Round {}: loss {}, round time {}'.format(round, metrics.loss, t2 - t1))

원격 실행기 설정하기

기본적으로 TFF는 모든 계산을 로컬에서 실행합니다. 이 단계에서는 위에서 설정한 Kubernetes 서비스에 연결하도록 TFF에 지시합니다. 서비스의 IP 주소를 복사해야 합니다.

import grpc

ip_address = '' 
port = 80 

channels = [grpc.insecure_channel(f'{ip_address}:{port}') for _ in range(10)]

tff.backends.native.set_remote_execution_context(channels, rpc_mode='STREAMING')

훈련 실행하기

Round 0: loss 4.370407581329346, round time 4.201097726821899
Round 1: loss 4.1407670974731445, round time 3.3283166885375977
Round 2: loss 3.865147590637207, round time 3.098310947418213
Round 3: loss 3.534019708633423, round time 3.1565616130828857
Round 4: loss 3.272688388824463, round time 3.175067663192749
Round 5: loss 2.935391664505005, round time 3.008434534072876
Round 6: loss 2.7399251461029053, round time 3.31435227394104
Round 7: loss 2.5054931640625, round time 3.4411356449127197
Round 8: loss 2.290508985519409, round time 3.158798933029175
Round 9: loss 2.1194536685943604, round time 3.1348156929016113