Ver no TensorFlow.org | Executar no Google Colab | Ver fonte no GitHub | Baixar caderno |
A noção de um conjunto de dados codificado por clientes (por exemplo, usuários) é essencial para a computação federada, conforme modelado no TFF. TFF fornece a interface tff.simulation.datasets.ClientData
para abstrato sobre este conceito, e os conjuntos de dados que abriga TFF ( stackoverflow , shakespeare , emnist , cifar100 e gldv2 ) tudo implementar essa interface.
Se você estiver trabalhando na aprendizagem federado com seu próprio conjunto de dados, TFF encoraja fortemente que você implementar tanto o ClientData
de interface ou usar uma das funções auxiliares de TFF para gerar um ClientData
que representa seus dados no disco, por exemplo, tff.simulation.datasets.ClientData.from_clients_and_fn
.
Como a maioria de exemplos end-to-end da TFF começar com ClientData
objetos, implementando o ClientData
de interface com o seu conjunto de dados personalizado irá torná-lo mais fácil de spelunk através de código existente escrito com TFF. Além disso, os tf.data.Datasets
que ClientData
construções podem ser iterado para se obter directamente as estruturas de numpy
matrizes, de modo que ClientData
objectos podem ser utilizados com qualquer quadro ML baseado em Python antes de se mudar para TFF.
Existem vários padrões com os quais você pode tornar sua vida mais fácil se você pretende dimensionar suas simulações para muitas máquinas ou implantá-las. Abaixo vamos percorrer algumas das maneiras que podemos usar ClientData
e TFF para tornar a nossa pequena escala iteração-to-larga escala experimentação-to produção experiência de implantação o mais suave possível.
Qual padrão devo usar para passar ClientData para TFF?
Vamos discutir dois usos da TFF ClientData
em profundidade; se você se encaixa em uma das duas categorias abaixo, você claramente preferirá uma em vez da outra. Caso contrário, você pode precisar de uma compreensão mais detalhada dos prós e contras de cada um para fazer uma escolha com mais nuances.
Quero iterar o mais rápido possível em uma máquina local; Não preciso tirar proveito facilmente do tempo de execução distribuído da TFF.
- Você quer passar
tf.data.Datasets
para TFF diretamente. - Isso permite que você programar imperativamente com
tf.data.Dataset
objetos, e processá-los arbitrariamente. - Ele oferece mais flexibilidade do que a opção abaixo; enviar a lógica aos clientes requer que essa lógica seja serializável.
- Você quer passar
Desejo executar minha computação federada no tempo de execução remoto da TFF ou pretendo fazê-lo em breve.
- Nesse caso, você deseja mapear a construção e o pré-processamento do conjunto de dados para os clientes.
- Isso resulta em você passando simplesmente uma lista de
client_ids
diretamente para o seu cálculo federado. - Empurrar a construção e o pré-processamento do conjunto de dados para os clientes evita gargalos na serialização e aumenta significativamente o desempenho com centenas de milhares de clientes.
Configurar ambiente de código aberto
# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly
!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio
import nest_asyncio
nest_asyncio.apply()
Pacotes de importação
import collections
import time
import tensorflow as tf
import tensorflow_federated as tff
Manipulando um objeto ClientData
Vamos começar por carga e explorar da TFF EMNIST ClientData
:
client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s] 2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Inspecionando o primeiro conjunto de dados pode nos dizer que tipo de exemplos estão no ClientData
.
first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
Note-se que os rendimentos do conjunto de dados collections.OrderedDict
objetos que têm pixels
e label
chaves, onde pixels é um tensor com forma [28, 28]
. Suponha que queremos achatar nossos entradas para forma [784]
. Uma maneira possível nós podemos fazer isso seria aplicar uma função pré-processamento para o nosso ClientData
objeto.
def preprocess_dataset(dataset):
"""Create batches of 5 examples, and limit to 3 batches."""
def map_fn(input):
return collections.OrderedDict(
x=tf.reshape(input['pixels'], shape=(-1, 784)),
y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
)
return dataset.batch(5).map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
preprocessed_client_data = client_data.preprocess(preprocess_dataset)
# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])
Além disso, podemos desejar realizar um pré-processamento mais complexo (e possivelmente com estado), por exemplo, embaralhamento.
def preprocess_and_shuffle(dataset):
"""Applies `preprocess_dataset` above and shuffles the result."""
preprocessed = preprocess_dataset(dataset)
return preprocessed.shuffle(buffer_size=5)
preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)
# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])
Interface com um tff.Computation
Agora que podemos executar algumas manipulações básicas com ClientData
objetos, estamos prontos para dados de alimentação para um tff.Computation
. Nós definimos um tff.templates.IterativeProcess
que implementa Federated Média , e explorar diferentes métodos de passá-lo dados.
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer='zeros'),
])
return tff.learning.from_keras_model(
model,
# Note: input spec is the _batched_ shape, and includes the
# label tensor which will be passed to the loss function. This model is
# therefore configured to accept data _after_ it has been preprocessed.
input_spec=collections.OrderedDict(
x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))
Antes de começar a trabalhar com este IterativeProcess
, um comentário sobre a semântica de ClientData
está em ordem. A ClientData
objeto representa a totalidade da população disponível para o treinamento federado, que em geral é não disponível para o ambiente de execução de um sistema de FL produção e é específico para simulação. ClientData
fato dá ao usuário a capacidade de desvio de computação federada totalmente e simplesmente treinar um modelo do lado do servidor, como de costume via ClientData.create_tf_dataset_from_all_clients
.
O ambiente de simulação da TFF coloca o pesquisador em controle total do loop externo. Em particular, isso implica que as considerações sobre a disponibilidade do cliente, o abandono do cliente, etc., devem ser tratadas pelo usuário ou pelo script do driver Python. Poderíamos, por exemplo, modelo de desistência dos clientes, ajustando a distribuição de amostragem sobre seus ClientData's
client_ids
tais que os usuários com mais dados (e, correspondentemente, a longo executar cálculos locais) seriam selecionados com menor probabilidade.
Em um sistema federado real, entretanto, os clientes não podem ser selecionados explicitamente pelo treinador do modelo; a seleção de clientes é delegada ao sistema que está executando a computação federada.
Passando tf.data.Datasets
diretamente para TFF
Uma opção que temos para fazer a interface entre um ClientData
e um IterativeProcess
é a de construir tf.data.Datasets
em Python, e passando esses conjuntos de dados para TFF.
Observe que, se usarmos os nossos pré-processadas ClientData
os conjuntos de dados que produzem são do tipo apropriado esperado pelo nosso modelo definido acima.
selected_client_ids = preprocessed_and_shuffled.client_ids[:10]
preprocessed_data_for_clients = [
preprocessed_and_shuffled.create_tf_dataset_for_client(
selected_client_ids[i]) for i in range(10)
]
state = trainer.initialize()
for _ in range(5):
t1 = time.time()
state, metrics = trainer.next(state, preprocessed_data_for_clients)
t2 = time.time()
print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.graph_util.extract_sub_graph` WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.graph_util.extract_sub_graph` loss 2.9005744457244873, round time 4.576513767242432 loss 3.113278388977051, round time 0.49641919136047363 loss 2.7581865787506104, round time 0.4904160499572754 loss 2.87259578704834, round time 0.48976993560791016 loss 3.1202380657196045, round time 0.6724586486816406
Se tomarmos este caminho, no entanto, não será capaz de mover-se trivialmente a simulação multimáquina. Os conjuntos de dados que construímos no tempo de execução TensorFlow local pode capturar o estado do ambiente python circundante, e falhar em serialização ou desserialização quando tentam estado de referência que não está mais disponível para eles é. Isto pode se manifestar por exemplo, o erro inescrutáveis de TensorFlow tensor_util.cc
:
Check failed: DT_VARIANT == input.dtype() (21 vs. 20)
Mapeamento de construção e pré-processamento sobre os clientes
Para evitar esse problema, TFF recomenda aos seus utilizadores a considerar dataset instanciação e pré-processamento como algo que acontece localmente em cada cliente, e usar ajudantes de TFF ou federated_map
para executar explicitamente este código pré-processamento em cada cliente.
Conceitualmente, a razão para preferir isso é clara: no tempo de execução local da TFF, os clientes apenas "acidentalmente" têm acesso ao ambiente Python global devido ao fato de que toda a orquestração federada está acontecendo em uma única máquina. Vale a pena observar neste ponto que pensamento semelhante dá origem à filosofia funcional de plataforma cruzada, sempre serializável e da TFF.
TFF faz essa simples mudança através ClientData's
atributo dataset_computation
, um tff.Computation
que leva um client_id
e retorna o associado tf.data.Dataset
.
Note-se que preprocess
simplesmente trabalha com dataset_computation
; o dataset_computation
atributo do pré-processados ClientData
incorpora toda a pipeline de pré-processamento que acabamos de definir:
print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing: (string -> <label=int32,pixels=float32[28,28]>*) dataset computation with preprocessing: (string -> <x=float32[?,784],y=int64[?,1]>*)
Poderíamos invocar dataset_computation
e receber um conjunto de dados ansioso no tempo de execução Python, mas o verdadeiro poder desta abordagem é exercido quando compor com um processo iterativo ou de outra computação para evitar a materialização destes conjuntos de dados no tempo de execução ansioso mundial em tudo. TFF fornece uma função auxiliar tff.simulation.compose_dataset_computation_with_iterative_process
que pode ser usado para fazer exatamente isso.
trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
preprocessed_and_shuffled.dataset_computation, trainer)
Tanto este tff.templates.IterativeProcesses
e aquele acima executado da mesma maneira; mas primeiro aceita conjuntos de dados de clientes pré-processadas, eo último aceita strings representando IDs de cliente, manipulação tanto a construção do conjunto de dados e pré-processamento em seu corpo - na verdade state
podem ser passados entre os dois.
for _ in range(5):
t1 = time.time()
state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
t2 = time.time()
print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023 loss 2.7670371532440186, round time 0.5207102298736572 loss 2.665048122406006, round time 0.5302855968475342 loss 2.7213189601898193, round time 0.5313887596130371 loss 2.580148935317993, round time 0.5283482074737549
Escalonando para um grande número de clientes
trainer_accepting_ids
pode imediatamente ser usado em tempo de execução multimáquina da TFF e evita materializando tf.data.Datasets
e o controlador (e, portanto, a serialização-los e enviá-los para os trabalhadores).
Isso acelera significativamente as simulações distribuídas, especialmente com um grande número de clientes, e permite a agregação intermediária para evitar sobrecarga de serialização / desserialização semelhante.
Deepdive opcional: compondo manualmente a lógica de pré-processamento no TFF
TFF é projetado para composicionalidade desde o início; o tipo de composição executado pelo ajudante da TFF está totalmente sob nosso controle como usuários. Poderíamos ter manualmente compor o cálculo de pré-processamento que acabamos de definir com o treinador próprio next
simplesmente:
selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)
@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
return trainer.next(server_state, preprocessed_data)
manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)
Na verdade, isso é efetivamente o que o auxiliar que usamos está fazendo nos bastidores (além de executar a verificação de tipo e manipulação apropriadas). Poderíamos até ter expressado a mesma lógica ligeiramente diferente, por serialização preprocess_and_shuffle
em um tff.Computation
, e decompondo o federated_map
em um passo que constrói conjuntos de dados pré-processados-un e outro que corre preprocess_and_shuffle
em cada cliente.
Podemos verificar que este caminho mais manual resulta em cálculos com a mesma assinatura de tipo que o auxiliar de TFF (nomes de parâmetros de módulo):
print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>) (<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)