Google I/O एक लपेट है! TensorFlow सत्रों पर पकड़ बनाएं सत्र देखें

अपनी खुद की फेडरेटेड लर्निंग एल्गोरिथम का निर्माण

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें

हमारे शुरू करने से पहले

शुरू करने से पहले, कृपया यह सुनिश्चित करने के लिए निम्नलिखित चलाएँ कि आपका परिवेश सही ढंग से सेटअप है। आप एक ग्रीटिंग दिखाई नहीं देता है, का संदर्भ लें स्थापना निर्देश के लिए गाइड।

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import tensorflow as tf
import tensorflow_federated as tff

में छवि वर्गीकरण और पाठ पीढ़ी ट्यूटोरियल, हम कैसे संघीय लर्निंग (FL) के लिए मॉडल और डेटा पाइपलाइन स्थापित करने के लिए सीखा है, और के माध्यम से फ़ेडरेटेड प्रशिक्षण प्रदर्शन किया tff.learning TFF के एपीआई परत।

जब FL अनुसंधान की बात आती है तो यह केवल हिमशैल का सिरा होता है। इस ट्यूटोरियल में, हम कैसे टाल बिना फ़ेडरेटेड सीखने वाले एल्गोरिदम लागू करने के लिए विचार-विमर्श tff.learning एपीआई। हम निम्नलिखित को पूरा करने का लक्ष्य रखते हैं:

लक्ष्य:

  • फ़ेडरेटेड लर्निंग एल्गोरिदम की सामान्य संरचना को समझें।
  • TFF संघीय कोर का अन्वेषण करें।
  • फ़ेडरेटेड एवरेजिंग को सीधे लागू करने के लिए फ़ेडरेटेड कोर का उपयोग करें।

जबकि इस ट्यूटोरियल आत्म निहित है, हम पहले पढ़ने की सलाह छवि वर्गीकरण और पाठ पीढ़ी ट्यूटोरियल।

इनपुट डेटा तैयार करना

हम पहले TFF में शामिल EMNIST डेटासेट को लोड और प्रीप्रोसेस करते हैं। अधिक जानकारी के लिए, छवि वर्गीकरण ट्यूटोरियल।

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

ताकि हमारे मॉडल में डाटासेट को खिलाने के लिए में, हम डेटा समतल, और फार्म की एक टपल में प्रत्येक उदाहरण कन्वर्ट (flattened_image_vector, label)

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

अब हम ग्राहकों की एक छोटी संख्या का चयन करते हैं, और उपरोक्त प्रीप्रोसेसिंग को उनके डेटासेट पर लागू करते हैं।

client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

मॉडल तैयार करना

हम में के रूप में ही मॉडल का उपयोग छवि वर्गीकरण ट्यूटोरियल। यह मॉडल (के माध्यम से कार्यान्वित किया tf.keras ) एक एकल छिपा परत, एक softmax परत द्वारा पीछा किया है।

def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

आदेश TFF में इस मॉडल का उपयोग करने में, हम एक के रूप में Keras मॉडल लपेट tff.learning.Model । यह हमारे मॉडल के प्रदर्शन करने के लिए अनुमति देता है फॉरवर्ड पास TFF के भीतर, और निकालने मॉडल आउटपुट । अधिक जानकारी के लिए, यह भी देखें छवि वर्गीकरण ट्यूटोरियल।

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

हम इस्तेमाल किया जबकि tf.keras एक बनाने के लिए tff.learning.Model , TFF और अधिक सामान्य मॉडल का समर्थन करता है। इन मॉडलों में मॉडल भार को कैप्चर करने वाली निम्नलिखित प्रासंगिक विशेषताएं हैं:

  • trainable_variables : tensors trainable परतों के लिए इसी का एक iterable।
  • non_trainable_variables : tensors गैर trainable परतों के लिए इसी का एक iterable।

हमारे प्रयोजनों के लिए, हम केवल का उपयोग करेगा trainable_variables । (जैसा कि हमारे मॉडल में केवल वही है!)

अपना स्वयं का फ़ेडरेटेड लर्निंग एल्गोरिथम बनाना

जबकि tff.learning एपीआई एक संघीय औसत का के कई वेरिएंट बनाने की अनुमति देता है, वहाँ अन्य फ़ेडरेटेड एल्गोरिदम कि इस सांचे में बड़े करीने से फिट नहीं है कर रहे हैं। उदाहरण के लिए, आप इस तरह के रूप नियमितीकरण, कतरन, या अधिक जटिल एल्गोरिदम जोड़ सकते हैं फ़ेडरेटेड GAN प्रशिक्षण । तुम भी बजाय में रुचि हो किया जा सकता है फ़ेडरेटेड एनालिटिक्स

इन अधिक उन्नत एल्गोरिदम के लिए, हमें TFF का उपयोग करके अपना स्वयं का कस्टम एल्गोरिथम लिखना होगा। कई मामलों में, फ़ेडरेटेड एल्गोरिदम में 4 मुख्य घटक होते हैं:

  1. एक सर्वर-से-क्लाइंट प्रसारण चरण।
  2. एक स्थानीय क्लाइंट अद्यतन चरण।
  3. क्लाइंट-टू-सर्वर अपलोड चरण।
  4. एक सर्वर अद्यतन चरण।

TFF में, हम आम तौर पर एक के रूप में फ़ेडरेटेड एल्गोरिदम का प्रतिनिधित्व tff.templates.IterativeProcess (जो हम सिर्फ एक के रूप में उल्लेख IterativeProcess भर)। यह एक वर्ग है कि होता है initialize और next कार्य करता है। इधर, initialize सर्वर प्रारंभ करने में प्रयोग किया जाता है, और next फ़ेडरेटेड एल्गोरिथ्म के एक संचार दौर प्रदर्शन करेंगे। आइए एक रूपरेखा लिखें कि FedAvg के लिए हमारी पुनरावृत्ति प्रक्रिया कैसी दिखनी चाहिए।

सबसे पहले, हम एक इनिशियलाइज़ समारोह है कि बस एक बनाता है tff.learning.Model , और इसके trainable वजन देता है।

def initialize_fn():
  model = model_fn()
  return model.trainable_variables

यह फ़ंक्शन अच्छा दिखता है, लेकिन जैसा कि हम बाद में देखेंगे, हमें इसे "TFF गणना" बनाने के लिए एक छोटा संशोधन करने की आवश्यकता होगी।

हम यह भी स्केच करना चाहते next_fn

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

हम इन चार घटकों को अलग-अलग लागू करने पर ध्यान देंगे। हम पहले उन हिस्सों पर ध्यान केंद्रित करते हैं जिन्हें शुद्ध TensorFlow में लागू किया जा सकता है, अर्थात् क्लाइंट और सर्वर अपडेट चरण।

TensorFlow Blocks

ग्राहक अद्यतन

हम अपने प्रयोग करेंगे tff.learning.Model मूलतः एक ही तरीका है कि आप एक TensorFlow मॉडल को प्रशिक्षित करेंगे में ग्राहक प्रशिक्षण करना है। विशेष रूप से, हम का उपयोग करेगा tf.GradientTape डेटा के बैच पर ढाल की गणना करने के लिए, तो एक का उपयोग कर इन ढाल लागू client_optimizer । हम केवल प्रशिक्षित वजन पर ध्यान केंद्रित करते हैं।

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

सर्वर अपडेट

FedAvg के लिए सर्वर अपडेट क्लाइंट अपडेट की तुलना में आसान है। हम "वेनिला" फ़ेडरेटेड एवरेज को लागू करेंगे, जिसमें हम सर्वर मॉडल वेट को क्लाइंट मॉडल वेट के औसत से बदल देते हैं। फिर से, हम केवल प्रशिक्षित वजन पर ध्यान केंद्रित करते हैं।

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

टुकड़ा बस वापस लौट कर सरल किया जा सकता है mean_client_weights । हालांकि, संघीय औसत का उपयोग की अधिक उन्नत कार्यान्वयन mean_client_weights इस तरह के गति या adaptivity के रूप में और अधिक परिष्कृत तकनीक, के साथ।

चैलेंज: का एक संस्करण को लागू server_update कि सर्वर भार को अद्यतन करता model_weights और mean_client_weights के मध्य किया जाना है। (नोट: "मध्य" दृष्टिकोण इस तरह की पर हाल ही में काम के अनुरूप है अग्रावलोकन अनुकूलक !)।

अब तक, हमने केवल शुद्ध TensorFlow कोड लिखा है। यह डिज़ाइन द्वारा है, क्योंकि TFF आपको उस TensorFlow कोड का अधिक उपयोग करने की अनुमति देता है जिससे आप पहले से परिचित हैं। हालांकि, अब हम आर्केस्ट्रा तर्क यह है कि, तर्क यह है कि तय कर ग्राहक के लिए क्या सर्वर प्रसारण, और क्या ग्राहक अपलोड सर्वर को निर्दिष्ट किया जाना है।

यह TFF संघीय कोर की आवश्यकता होगी।

फ़ेडरेटेड कोर का परिचय

संघीय कोर (एफसी) निचले स्तर इंटरफ़ेस के लिए नींव के रूप में सेवा का एक सेट है tff.learning एपीआई। हालाँकि, ये इंटरफेस सीखने तक सीमित नहीं हैं। वास्तव में, उनका उपयोग विश्लेषिकी और वितरित डेटा पर कई अन्य संगणनाओं के लिए किया जा सकता है।

उच्च स्तर पर, फ़ेडरेटेड कोर एक विकास वातावरण है जो वितरित संचार ऑपरेटरों (जैसे वितरित रकम और प्रसारण) के साथ TensorFlow कोड को संयोजित करने के लिए कॉम्पैक्ट रूप से व्यक्त प्रोग्राम तर्क को सक्षम बनाता है। लक्ष्य शोधकर्ताओं और चिकित्सकों को उनके सिस्टम में वितरित संचार पर स्पष्ट नियंत्रण देना है, बिना सिस्टम कार्यान्वयन विवरण (जैसे पॉइंट-टू-पॉइंट नेटवर्क संदेश एक्सचेंजों को निर्दिष्ट करना) की आवश्यकता के बिना।

एक महत्वपूर्ण बिंदु यह है कि TFF को गोपनीयता-संरक्षण के लिए डिज़ाइन किया गया है। इसलिए, यह केंद्रीकृत सर्वर स्थान पर डेटा के अवांछित संचय को रोकने के लिए, जहां डेटा रहता है, उस पर स्पष्ट नियंत्रण की अनुमति देता है।

फ़ेडरेटेड डेटा

TFF में एक प्रमुख अवधारणा "संघीय डेटा" है, जो एक वितरित सिस्टम (जैसे क्लाइंट डेटासेट, या सर्वर मॉडल वज़न) में उपकरणों के एक समूह में होस्ट किए गए डेटा आइटम के संग्रह को संदर्भित करता है। हम एक भी फ़ेडरेटेड मूल्य के रूप में सभी उपकरणों के डेटा आइटम का पूरा संग्रह मॉडल।

उदाहरण के लिए, मान लें कि हमारे पास क्लाइंट डिवाइस हैं जिनमें से प्रत्येक में सेंसर के तापमान का प्रतिनिधित्व करने वाला एक फ्लोट होता है। हम द्वारा एक फ़ेडरेटेड नाव के रूप में यह प्रतिनिधित्व कर सकता है

federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)

संघीय प्रकार एक प्रकार से निर्दिष्ट कर रहे हैं T अपने सदस्य घटकों में से (उदाहरण के लिए। tf.float32 ) और एक समूह G उपकरणों की। हम ऐसे मामलों में जहां पर ध्यान दिया जाएगा G या तो है tff.CLIENTS या tff.SERVER । इस तरह की एक फ़ेडरेटेड प्रकार के रूप में प्रस्तुत किया जाता है {T}@G नीचे दिखाए गए।

str(federated_float_on_clients)
'{float32}@CLIENTS'

हम प्लेसमेंट की इतनी परवाह क्यों करते हैं? TFF का एक प्रमुख लक्ष्य लेखन कोड को सक्षम करना है जिसे वास्तविक वितरित सिस्टम पर तैनात किया जा सकता है। इसका मतलब यह है कि यह तर्क करना महत्वपूर्ण है कि उपकरणों के कौन से सबसेट किस कोड को निष्पादित करते हैं, और जहां डेटा के विभिन्न टुकड़े रहते हैं।

TFF तीन बातों पर केंद्रित है: डेटा, जहां डाटा रखा गया है, और डेटा कैसे तब्दील किया जा रहा है। पहले दो, फ़ेडरेटेड प्रकार में समाहित हैं, जबकि पिछले फ़ेडरेटेड संगणना में समझाया गया है।

संघीय संगणना

TFF एक जोरदार टाइप कार्यात्मक प्रोग्रामिंग वातावरण जिसका बुनियादी इकाइयों फ़ेडरेटेड संगणना कर रहे हैं। ये तर्क के टुकड़े हैं जो फ़ेडरेटेड मानों को इनपुट के रूप में स्वीकार करते हैं, और फ़ेडरेटेड मानों को आउटपुट के रूप में वापस करते हैं।

उदाहरण के लिए, मान लीजिए कि हम अपने क्लाइंट सेंसर पर तापमान को औसत करना चाहते हैं। हम निम्नलिखित को परिभाषित कर सकते हैं (हमारे फ़ेडरेटेड फ्लोट का उपयोग करके):

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

आप पूछ सकते हैं, कैसे से अलग है tf.function TensorFlow में डेकोरेटर? कुंजी जवाब यह है कि द्वारा बनाया गया कोड है tff.federated_computation न TensorFlow है और न ही अजगर कोड है, यह एक आंतरिक मंच स्वतंत्र गोंद भाषा में एक वितरित प्रणाली के एक विनिर्देश है।

हालांकि यह जटिल लग सकता है, आप टीएफएफ गणनाओं को अच्छी तरह से परिभाषित प्रकार के हस्ताक्षर वाले कार्यों के रूप में सोच सकते हैं। इस प्रकार के हस्ताक्षर सीधे पूछे जा सकते हैं।

str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'

यह tff.federated_computation फ़ेडरेटेड प्रकार के तर्कों को स्वीकार करता है {float32}@CLIENTS , और फ़ेडरेटेड प्रकार के रिटर्न मान {float32}@SERVER । फ़ेडरेटेड कंप्यूटेशंस सर्वर से क्लाइंट, क्लाइंट से क्लाइंट या सर्वर से सर्वर तक भी जा सकते हैं। फ़ेडरेटेड कंप्यूटेशंस को सामान्य कार्यों की तरह भी बनाया जा सकता है, जब तक कि उनके प्रकार के हस्ताक्षर मेल खाते हैं।

विकास का समर्थन करने के लिए, आप एक TFF आह्वान करने के लिए अनुमति देता है tff.federated_computation एक अजगर समारोह के रूप में। उदाहरण के लिए, हम कॉल कर सकते हैं

get_average_temperature([68.5, 70.3, 69.8])
69.53334

गैर-उत्सुक संगणना और TensorFlow

जागरूक होने के लिए दो प्रमुख प्रतिबंध हैं। सबसे पहले, जब अजगर दुभाषिया एक का सामना करना पड़ता tff.federated_computation डेकोरेटर, समारोह एक बार पता लगाया और भविष्य में उपयोग के लिए धारावाहिक है। फ़ेडरेटेड लर्निंग की विकेंद्रीकृत प्रकृति के कारण, यह भविष्य का उपयोग कहीं और हो सकता है, जैसे कि दूरस्थ निष्पादन वातावरण। इसलिए, TFF संगणना मौलिक गैर उत्सुक हैं। यह व्यवहार कुछ हद तक की है कि के अनुरूप है tf.function TensorFlow में डेकोरेटर।

दूसरा, एक फ़ेडरेटेड गणना केवल (जैसे फ़ेडरेटेड ऑपरेटरों हो सकते हैं tff.federated_mean ), वे TensorFlow संचालन नहीं हो सकते। TensorFlow कोड के साथ सजाया ब्लॉक तक सीमित रखना चाहिए tff.tf_computation । अधिकांश साधारण TensorFlow कोड सीधे इस तरह के निम्नलिखित समारोह है कि एक नंबर लेता है और शामिल करेगा, सजाया जा सकता 0.5 यह करने के लिए।

@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)

ये भी प्रकार हस्ताक्षर है, लेकिन प्लेसमेंट के बिना। उदाहरण के लिए, हम कॉल कर सकते हैं

str(add_half.type_signature)
'(float32 -> float32)'

यहाँ हम बीच एक महत्वपूर्ण अंतर देख tff.federated_computation और tff.tf_computation । पूर्व में स्पष्ट प्लेसमेंट हैं, जबकि बाद वाले में नहीं है।

हम उपयोग कर सकते हैं tff.tf_computation प्लेसमेंट को निर्दिष्ट करके फ़ेडरेटेड संगणना में ब्लॉक। आइए एक फ़ंक्शन बनाएं जो आधा जोड़ता है, लेकिन केवल क्लाइंट पर फ़ेडरेटेड फ़्लोट्स के लिए। हम का उपयोग करके ऐसा कर सकते हैं tff.federated_map , किसी दिए गए लागू होता है tff.tf_computation , जबकि नियुक्ति संरक्षण।

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

इस समारोह के लगभग समान है add_half , सिवाय इसके कि यह केवल पर नियुक्ति के साथ मान स्वीकार tff.CLIENTS उसी नियुक्ति के साथ रिटर्न मूल्यों, और। हम इसे इसके टाइप सिग्नेचर में देख सकते हैं:

str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'

सारांश:

  • TFF फ़ेडरेटेड मूल्यों पर कार्य करता है।
  • प्रत्येक फ़ेडरेटेड मूल्य एक प्रकार (उदाहरण के लिए। के साथ एक फ़ेडरेटेड प्रकार होता है, tf.float32 ) और एक प्लेसमेंट (जैसे। tff.CLIENTS )।
  • संघीय मूल्यों फ़ेडरेटेड संगणना, जो के साथ सजाया जाना चाहिए का उपयोग कर तब्दील किया जा सकता tff.federated_computation और एक फ़ेडरेटेड प्रकार हस्ताक्षर।
  • TensorFlow कोड के साथ ब्लॉक में समाहित किया जाना चाहिए tff.tf_computation सज्जाकार।
  • फिर इन ब्लॉकों को फ़ेडरेटेड कंप्यूटेशंस में शामिल किया जा सकता है।

अपना स्वयं का फ़ेडरेटेड लर्निंग एल्गोरिथम बनाना, फिर से देखना

अब जब हमें फ़ेडरेटेड कोर की एक झलक मिल गई है, तो हम अपना फ़ेडरेटेड लर्निंग एल्गोरिथम बना सकते हैं। याद रखें कि इसके बाद के संस्करण, हम एक परिभाषित initialize_fn और next_fn हमारे एल्गोरिथ्म के लिए। next_fn का उपयोग करेगा client_update और server_update हम शुद्ध TensorFlow कोड का उपयोग कर परिभाषित किया।

तथापि, हमारी एल्गोरिथ्म एक फ़ेडरेटेड गणना करने के लिए, हम दोनों की आवश्यकता होगी next_fn और initialize_fn करने के लिए प्रत्येक एक हो tff.federated_computation

TensorFlow फ़ेडरेटेड ब्लॉक

आरंभीकरण गणना बनाना

इनिशियलाइज़ समारोह काफी सरल हो जाएगा: हम एक मॉडल का उपयोग कर पैदा करेगा model_fn । हालांकि, याद है कि हम का उपयोग कर हमारे TensorFlow कोड को अलग करना होगा tff.tf_computation

@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

हम तो का उपयोग कर एक फ़ेडरेटेड गणना में सीधे इस पारित कर सकते हैं tff.federated_value

@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

बनाना next_fn

अब हम वास्तविक एल्गोरिथम लिखने के लिए अपने क्लाइंट और सर्वर अपडेट कोड का उपयोग करते हैं। हम पहले हमारे बदल जाएगी client_update एक में tff.tf_computation है कि एक ग्राहक डेटासेट और सर्वर भार स्वीकार करता है, और एक अद्यतन ग्राहक वजन टेन्सर आउटपुट।

हमें अपने फ़ंक्शन को ठीक से सजाने के लिए संबंधित प्रकारों की आवश्यकता होगी। सौभाग्य से, सर्वर भार के प्रकार को सीधे हमारे मॉडल से निकाला जा सकता है।

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

आइए डेटासेट प्रकार के हस्ताक्षर को देखें। याद रखें कि हमने 28 गुणा 28 छवियां लीं (पूर्णांक लेबल के साथ) और उन्हें समतल किया।

str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'

हम अपने का उपयोग करके मॉडल वेट प्रकार निकाल सकते हैं server_init ऊपर कार्य करते हैं।

model_weights_type = server_init.type_signature.result

टाइप सिग्नेचर की जांच करते हुए, हम अपने मॉडल के आर्किटेक्चर को देख पाएंगे!

str(model_weights_type)
'<float32[784,10],float32[10]>'

अब हम अपने बना सकते हैं tff.tf_computation ग्राहक अद्यतन के लिए।

@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

tff.tf_computation सर्वर अद्यतन के संस्करण एक समान तरीके से परिभाषित किया जा सकता, प्रकार हम पहले से ही निकाला है का उपयोग कर।

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

पिछले है, लेकिन कम से कम नहीं है, हम बनाने की जरूरत tff.federated_computation कि यह सब एक साथ लाता है। यह समारोह दो फ़ेडरेटेड मूल्यों, एक सर्वर वजन करने के लिए इसी (नियुक्ति के साथ स्वीकार करेंगे tff.SERVER ), और अन्य ग्राहक डेटासेट के लिए इसी (नियुक्ति के साथ tff.CLIENTS )।

ध्यान दें कि इन दोनों प्रकारों को ऊपर परिभाषित किया गया था! हम बस उन्हें उचित स्थान का उपयोग कर देने की आवश्यकता tff.FederatedType

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

FL एल्गोरिथम के 4 तत्व याद रखें?

  1. एक सर्वर-से-क्लाइंट प्रसारण चरण।
  2. एक स्थानीय क्लाइंट अद्यतन चरण।
  3. क्लाइंट-टू-सर्वर अपलोड चरण।
  4. एक सर्वर अद्यतन चरण।

अब जब हमने उपरोक्त का निर्माण कर लिया है, तो प्रत्येक भाग को TFF कोड की एकल पंक्ति के रूप में संक्षिप्त रूप से दर्शाया जा सकता है। यह सरलता इसलिए है कि हमें फ़ेडरेटेड प्रकार जैसी चीज़ों को निर्दिष्ट करने के लिए अतिरिक्त सावधानी बरतनी पड़ी!

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

अब हम एक है tff.federated_computation दोनों एल्गोरिथ्म प्रारंभ, और एल्गोरिथ्म के एक कदम को चलाने के लिए के लिए। हमारे एल्गोरिथ्म समाप्त करने के लिए, हम में इन पारित tff.templates.IterativeProcess

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

के प्रकार के हस्ताक्षर पर आइए नज़र initialize और next हमारे सतत प्रक्रिया का कार्य करता है।

str(federated_algorithm.initialize.type_signature)
'( -> <float32[784,10],float32[10]>@SERVER)'

यह तथ्य यह है कि दर्शाता है federated_algorithm.initialize नो आर्ग समारोह है कि रिटर्न एक एकल परत मॉडल (एक 784-दर-10 वजन मैट्रिक्स के साथ, और 10 पूर्वाग्रह इकाइयों)।

str(federated_algorithm.next.type_signature)
'(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

यहाँ, हम देखते हैं कि federated_algorithm.next एक सर्वर मॉडल और ग्राहक डेटा, और रिटर्न एक अद्यतन सर्वर मॉडल स्वीकार करता है।

एल्गोरिथ्म का मूल्यांकन

आइए कुछ दौर चलाएं, और देखें कि नुकसान कैसे बदलता है। सबसे पहले, हम एक मूल्यांकन समारोह केंद्रीकृत दृष्टिकोण दूसरे ट्यूटोरियल में चर्चा का उपयोग कर परिभाषित करेगा।

हम पहले एक केंद्रीकृत मूल्यांकन डेटासेट बनाते हैं, और फिर उसी प्रीप्रोसेसिंग को लागू करते हैं जिसका उपयोग हमने प्रशिक्षण डेटा के लिए किया था।

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

अगला, हम एक फ़ंक्शन लिखते हैं जो सर्वर स्थिति को स्वीकार करता है, और परीक्षण डेटासेट पर मूल्यांकन करने के लिए केरस का उपयोग करता है। आप के साथ परिचित हैं, तो tf.Keras , इस होगा सब नज़र परिचित है, हालांकि टिप्पणी के उपयोग set_weights !

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

अब, आइए हमारे एल्गोरिदम को प्रारंभ करें और परीक्षण सेट पर मूल्यांकन करें।

server_state = federated_algorithm.initialize()
evaluate(server_state)
2042/2042 [==============================] - 2s 767us/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027

आइए कुछ राउंड के लिए प्रशिक्षण लें और देखें कि क्या कुछ बदलता है।

for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
2042/2042 [==============================] - 2s 738us/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980

हम नुकसान समारोह में थोड़ी कमी देखते हैं। जबकि छलांग छोटी है, हमने केवल 15 प्रशिक्षण राउंड किए हैं, और ग्राहकों के एक छोटे से सबसेट पर। बेहतर परिणाम देखने के लिए हमें हजारों नहीं तो सैकड़ों चक्कर लगाने पड़ सकते हैं।

हमारे एल्गोरिथ्म को संशोधित करना

इस बिंदु पर, आइए रुकें और सोचें कि हमने क्या हासिल किया है। हमने TFF के फ़ेडरेटेड कोर से फ़ेडरेटेड कंप्यूटेशंस के साथ शुद्ध TensorFlow कोड (क्लाइंट और सर्वर अपडेट के लिए) को मिलाकर सीधे फ़ेडरेटेड एवरेजिंग को लागू किया है।

अधिक परिष्कृत सीखने के लिए, हम जो ऊपर है उसे आसानी से बदल सकते हैं। विशेष रूप से, ऊपर दिए गए शुद्ध TF कोड को संपादित करके, हम यह बदल सकते हैं कि क्लाइंट प्रशिक्षण कैसे करता है, या सर्वर अपने मॉडल को कैसे अपडेट करता है।

चैलेंज: जोड़े ढाल कतरन को client_update कार्य करते हैं।

अगर हम बड़े बदलाव करना चाहते हैं, तो हम सर्वर स्टोर भी कर सकते हैं और अधिक डेटा प्रसारित कर सकते हैं। उदाहरण के लिए, सर्वर क्लाइंट सीखने की दर को भी स्टोर कर सकता है, और इसे समय के साथ क्षय कर सकता है! ध्यान दें कि यह में इस्तेमाल किया प्रकार हस्ताक्षर करने के लिए परिवर्तन की आवश्यकता होगी tff.tf_computation ऊपर कहता है।

कठिन चुनौती: ग्राहकों पर दर क्षय सीखने के साथ लागू संघीय औसत।

इस बिंदु पर, आप महसूस करना शुरू कर सकते हैं कि इस ढांचे में आप जो लागू कर सकते हैं उसमें कितना लचीलापन है। विचारों (ऊपर कठिन चुनौती का जवाब सहित) के लिए आप के लिए स्रोत-कोड देख सकते हैं tff.learning.build_federated_averaging_process , या विभिन्न की जाँच अनुसंधान परियोजनाओं TFF का उपयोग कर।