TFF의 랜덤 노이즈 생성

이 자습서에서는 TFF에서 무작위 노이즈 생성에 대한 권장 모범 사례에 대해 설명합니다. 무작위 잡음 생성은 연합 학습 알고리즘(예: 차등 개인 정보 보호)에서 많은 개인 정보 보호 기술의 중요한 구성 요소입니다.

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

시작하기 전에

먼저 노트북이 관련 구성 요소가 컴파일된 백엔드에 연결되어 있는지 확인합니다.

!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio

import nest_asyncio
nest_asyncio.apply()
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

다음 "Hello World" 예제를 실행하여 TFF 환경이 올바르게 설정되었는지 확인하십시오. 그것이 작동하지 않는 경우를 참조하십시오 설치 지침은 가이드.

@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

클라이언트의 무작위 소음

클라이언트에 대한 노이즈의 필요성은 일반적으로 동일한 노이즈와 iid 노이즈의 두 가지 경우로 나뉩니다.

  • 동일한 소음의 경우, 권장 패턴은 클라이언트에 방송 서버의 씨앗을 유지하고 사용하는 것입니다 tf.random.stateless 소음 생성하는 기능.
  • iid 노이즈의 경우 tf.random.<distribution> 기능을 피하라는 TF의 권장 사항에 따라 from_non_deterministic_state로 클라이언트에서 초기화된 tf.random.Generator를 사용합니다.

클라이언트 동작은 서버와 다릅니다(나중에 설명할 함정이 없음). 각 클라이언트가 자체 계산 그래프를 만들고 자체 기본 시드를 초기화하기 때문입니다.

클라이언트에 동일한 소음

# Set to use 10 clients.
tff.backends.native.set_local_python_execution_context(num_clients=10)

@tff.tf_computation
def noise_from_seed(seed):
  return tf.random.stateless_normal((), seed=seed)

seed_type_at_server = tff.type_at_server(tff.to_type((tf.int64, [2])))

@tff.federated_computation(seed_type_at_server)
def get_random_min_and_max_deterministic(seed):
  # Broadcast seed to all clients.
  seed_on_clients = tff.federated_broadcast(seed)

  # Clients generate noise from seed deterministicly.
  noise_on_clients = tff.federated_map(noise_from_seed, seed_on_clients)

  # Aggregate and return the min and max of the values generated on clients.
  min = tff.aggregators.federated_min(noise_on_clients)
  max = tff.aggregators.federated_max(noise_on_clients)
  return min, max

seed = tf.constant([1, 1], dtype=tf.int64)
min, max = get_random_min_and_max_deterministic(seed)
assert min == max
print(f'Seed: {seed.numpy()}. All clients sampled value {min:8.3f}.')

seed += 1
min, max = get_random_min_and_max_deterministic(seed)
assert min == max
print(f'Seed: {seed.numpy()}. All clients sampled value {min:8.3f}.')
Seed: [1 1]. All clients sampled value    1.665.
Seed: [2 2]. All clients sampled value   -0.219.

클라이언트에 대한 독립적인 소음

@tff.tf_computation
def nondeterministic_noise():
  gen = tf.random.Generator.from_non_deterministic_state()
  return gen.normal(())

@tff.federated_computation(seed_type_at_server)
def get_random_min_and_max_nondeterministic(seed):
  noise_on_clients = tff.federated_eval(nondeterministic_noise, tff.CLIENTS)
  min = tff.aggregators.federated_min(noise_on_clients)
  max = tff.aggregators.federated_max(noise_on_clients)
  return min, max

min, max = get_random_min_and_max_nondeterministic(seed)
assert min != max
print(f'Values differ across clients. {min:8.3f},{max:8.3f}.')

new_min, new_max = get_random_min_and_max_nondeterministic(seed)
assert new_min != new_max
assert new_min != min and new_max != max
print(f'Values differ across rounds.  {new_min:8.3f},{new_max:8.3f}.')
Values differ across clients.   -1.810,   1.079.
Values differ across rounds.    -1.205,   0.851.

서버의 랜덤 노이즈

낙담 사용 : 직접 사용 tf.random.normal

TF1.x API를 같은 tf.random.normal 랜덤 노이즈 생성을위한 강력에 따라 TF2에 낙담 TF에서 랜덤 노이즈 발생 튜토리얼 . 이러한 API가 함께 사용될 때 놀라운 동작이 발생할 수 있습니다 tf.functiontf.random.set_seed . 예를 들어 다음 코드는 각 호출에서 동일한 값을 생성합니다. 이 놀라운 동작은 TF 예상되며, 설명은에서 찾을 수 의 문서 tf.random.set_seed .

tf.random.set_seed(1)

@tf.function
def return_one_noise(_):
  return tf.random.normal([])

n1=return_one_noise(1)
n2=return_one_noise(2) 
assert n1 == n2
print(n1.numpy(), n2.numpy())
0.3052047 0.3052047

TFF에서는 상황이 약간 다릅니다. 우리가 소음 발생을 포장하는 경우 tff.tf_computation 대신 tf.function , 비 결정적 랜덤 노이즈가 생성됩니다. 우리가이 코드 조각을 여러 번 실행하는 경우, 다른 세트 (n1, n2) 때마다 생성됩니다. TFF에 대한 전역 무작위 시드를 설정하는 쉬운 방법은 없습니다.

tf.random.set_seed(1)

@tff.tf_computation
def return_one_noise(_):
  return tf.random.normal([])

n1=return_one_noise(1)
n2=return_one_noise(2) 
assert n1 != n2
print(n1, n2)
1.3283143 0.45740178

또한 명시적으로 시드를 설정하지 않고도 TFF에서 결정적 노이즈가 생성될 수 있습니다. 함수 return_two_noise 다음 코드는 반환에게 두 개의 동일한 노이즈 값을 니펫을. 이는 TFF가 실행 전에 미리 계산 그래프를 작성하기 때문에 예상되는 동작입니다. 그러나이 사용자가의 사용에 지불 관심을 가지고 제안 tf.random.normal TFF있다.

@tff.tf_computation
def tff_return_one_noise():
  return tf.random.normal([])

@tff.federated_computation
def return_two_noise():
  return (tff_return_one_noise(), tff_return_one_noise())

n1, n2=return_two_noise() 
assert n1 == n2
print(n1, n2)
-0.15665223 -0.15665223

주의 사용법 : tf.random.Generator

우리는 사용할 수 tf.random.Generator 에 제안 TF 튜토리얼 .

@tff.tf_computation
def tff_return_one_noise(i):
  g=tf.random.Generator.from_seed(i)
  @tf.function
  def tf_return_one_noise():
    return g.normal([])
  return tf_return_one_noise()

@tff.federated_computation
def return_two_noise():
  return (tff_return_one_noise(1), tff_return_one_noise(2))

n1, n2 = return_two_noise() 
assert n1 != n2
print(n1, n2)
0.3052047 -0.38260338

다만, 이용자는 사용상의 주의가 필요할 수 있습니다.

일반적으로, TFF는 기능 작업을 선호하고 우리의 사용 선보일 예정 tf.random.stateless_* 다음 섹션에서 기능을.

연합 학습을 위한 TFF에서 우리는 종종 스칼라 대신 중첩 구조로 작업하며 이전 코드 스니펫은 자연스럽게 중첩 구조로 확장될 수 있습니다.

@tff.tf_computation
def tff_return_one_noise(i):
  g=tf.random.Generator.from_seed(i)
  weights = [
         tf.ones([2, 2], dtype=tf.float32),
         tf.constant([2], dtype=tf.float32)
     ]
  @tf.function
  def tf_return_one_noise():
    return tf.nest.map_structure(lambda x: g.normal(tf.shape(x)), weights)
  return tf_return_one_noise()

@tff.federated_computation
def return_two_noise():
  return (tff_return_one_noise(1), tff_return_one_noise(2))

n1, n2 = return_two_noise() 
assert n1[1] != n2[1]
print('n1', n1)
print('n2', n2)
n1 [array([[0.3052047 , 0.5671378 ],
       [0.41852272, 0.2326421 ]], dtype=float32), array([1.1675092], dtype=float32)]
n2 [array([[-0.38260338, -0.47804865],
       [-0.5187485 , -1.8471988 ]], dtype=float32), array([-0.77835274], dtype=float32)]

TFF의 일반적인 권장 기능적 사용하는 tf.random.stateless_* 랜덤 잡음을 생성하는 함수. 이러한 기능은 가지고 seed (형상 텐서 [2] 또는 tuple 랜덤 잡음을 생성하는 명시적인 입력 인수 개의 스칼라 텐서의 참조). 먼저 시드를 의사 상태로 유지하기 위해 도우미 클래스를 정의합니다. 도우미 RandomSeedGenerator 상태 -의 상태 - 아웃 방식으로 기능 연산자를 가지고있다. 이는 대한 의사 상태로 카운터를 사용하는 것이 합리적이다 tf.random.stateless_* 이러한 기능으로 출격 통계적으로 상관 상관 씨앗에 의해 생성 된 소음을 위해 사용하기 전에 씨앗을.

def timestamp_seed():
  # tf.timestamp returns microseconds as decimal places, thus scaling by 1e6.
  return tf.math.cast(tf.timestamp() * 1e6, tf.int64)

class RandomSeedGenerator():

  def initialize(self, seed=None):
    if seed is None:
      return tf.stack([timestamp_seed(), 0])
    else:
      return tf.constant(self.seed, dtype=tf.int64, shape=(2,))

  def next(self, state):
    return state + tf.constant([0, 1], tf.int64)

  def structure_next(self, state, nest_structure):
    "Returns seed in nested structure and the next state seed."
    flat_structure = tf.nest.flatten(nest_structure)
    flat_seeds = [state + tf.constant([0, i], tf.int64) for
                  i in range(len(flat_structure))]
    nest_seeds = tf.nest.pack_sequence_as(nest_structure, flat_seeds)
    return nest_seeds, flat_seeds[-1] + tf.constant([0, 1], tf.int64)

이제 우리가 도우미 클래스와 사용할 수 있도록 tf.random.stateless_normal (의 중첩 된 구조) TFF에서 랜덤 노이즈를 생성합니다. 다음 코드는 TFF 반복적 인 과정과 같은 많은 참조 보이는 simple_fedavg TFF 반복적 인 과정으로 연합 학습 알고리즘을 표현의 예로서. 랜덤 노이즈 발생 여기를 의사의 씨앗 상태입니다 tf.Tensor 쉽게 TFF와 TF 기능에 수송 할 수있다.

@tff.tf_computation
def tff_return_one_noise(seed_state):
  g=RandomSeedGenerator()
  weights = [
         tf.ones([2, 2], dtype=tf.float32),
         tf.constant([2], dtype=tf.float32)
     ]
  @tf.function
  def tf_return_one_noise():
    nest_seeds, updated_state = g.structure_next(seed_state, weights)
    nest_noise = tf.nest.map_structure(lambda x,s: tf.random.stateless_normal(
        shape=tf.shape(x), seed=s), weights, nest_seeds)
    return nest_noise, updated_state
  return tf_return_one_noise()

@tff.tf_computation
def tff_init_state():
  g=RandomSeedGenerator()
  return g.initialize()

@tff.federated_computation
def return_two_noise():
  seed_state = tff_init_state()
  n1, seed_state = tff_return_one_noise(seed_state)
  n2, seed_state = tff_return_one_noise(seed_state)
  return (n1, n2)

n1, n2 = return_two_noise() 
assert n1[1] != n2[1]
print('n1', n1)
print('n2', n2)
n1 [array([[-0.21598858, -0.30700883],
       [ 0.7562299 , -0.21218438]], dtype=float32), array([-1.0359321], dtype=float32)]
n2 [array([[ 1.0722181 ,  0.81287116],
       [-0.7140338 ,  0.5896157 ]], dtype=float32), array([0.44190162], dtype=float32)]