Trabalho prático MO809 - Split Federated Learning com gRPC e Ray
Neste trabalho, vou explorar a implementação de um sistema de Split Federated Learning utilizando duas tecnologias principais: gRPC e Ray. A abordagem de Split Federated Learning permite distribuir o treinamento de modelos de aprendizado de máquina entre dispositivos, dividindo a carga de processamento e mantendo os dados sensíveis locais, contribuindo para a preservação da privacidade.
O uso do gRPC facilita a comunicação eficiente entre diferentes partes do sistema distribuído, proporcionando um canal rápido e seguro para a troca de informações entre o servidor central e os dispositivos remotos. Já o Ray será utilizado para a orquestração e escalabilidade do sistema, permitindo distribuir tarefas de forma paralela e otimizada.
Com esta implementação, espero demonstrar como o Split Federated Learning pode ser uma solução eficaz para cenários que exigem privacidade de dados e processamento distribuído, explorando os benefícios de desempenho e segurança proporcionados pelo uso de gRPC e Ray.
!pip install grpcio grpcio-tools ray
import grpc
import tensorflow as tf
import keras as K
from concurrent import futures
from time import sleep, time
import threading
import pandas as pd
import os
import splifed_pb2 as splifed
import splifed_pb2_grpc as splifed_grpc
from proto_messages import * # Certifique-se de que esse módulo está acessível
import time
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import flwr_datasets as fld
import grpc
from flwr.server.strategy.aggregate import aggregate
import ray
from proto_messages import *
import ray
N_CLIENTS = 3
EPOCHS = 50
BATCH_SIZE = 128
USE_AGGREGATION = True
partitioner = fld.partitioner.IidPartitioner(N_CLIENTS)
fd = fld.FederatedDataset(dataset='mnist', partitioners={'train':partitioner, 'test':partitioner})
train_partition = fd.load_partition(0, 'train').with_format("numpy")
test_partition = fd.load_partition(0, 'test').with_format("numpy")
x_train, y_train, x_test, y_test = train_partition['image']/255.0, train_partition['label'], test_partition['image']/255.0, test_partition['label']# X_train, y_train = test_partition["img"], test_partition["label"]
x_train.shape
fld.visualization.plot_label_distributions(
partitioner,
label_name="label",
title="Comparação de diferentes partições do dataset MNIST",
legend=True,
size_unit="percent",
verbose_labels=False
)
Relembrando Split Federated Learning
Servidor de treinamento
import numpy as np
import tensorflow as tf
import splifed_pb2 as splifed
import time
# ------------------------------ CRIAR DE MENSAGENS ------------------------------
def create_shape_from_array(array):
shape = splifed.Shape()
shape.shape.extend(array.shape)
return shape
def create_model_msg(weights):
model = splifed.Model()
np_weights = [np.array(w) for w in weights]
flattened_weights = [w.flatten().tolist() for w in np_weights]
for fw in flattened_weights:
model.weights.extend(fw)
for w in np_weights:
model.s.append(create_shape_from_array(w))
return model
def create_activations(activations_list):
activations = splifed.Activations()
np_activations = np.array(activations_list)
flattened_activations = np_activations.flatten().tolist()
activations.activations.extend(flattened_activations)
activations.s.append(create_shape_from_array(np_activations))
return activations
def create_forward_message(activations, labels, cid, rnd, batch):
forward_msg = splifed.ForwardMessage()
forward_msg.activations.CopyFrom(activations)
labels = [int(label) for label in labels]
forward_msg.labels.extend(labels)
forward_msg.cid = cid
forward_msg.rnd = rnd
forward_msg.batch = batch
forward_msg.time = time.time()
return forward_msg
def create_gradient(gradients):
gradient = splifed.Gradient()
np_gradients = np.array(gradients)
flattened_gradients = np_gradients.flatten().tolist()
gradient.gradients.extend(flattened_gradients)
gradient.s.append(create_shape_from_array(np_gradients))
return gradient
def create_backward_message(gradient, loss):
backward_msg = splifed.BackwardMessage()
backward_msg.gradients.CopyFrom(create_gradient(gradient))
backward_msg.loss = loss
backward_msg.time = time.time()
return backward_msg
def create_aggregate_message(model, size):
aggregate_msg = splifed.AggregateMessage()
aggregate_msg.model.CopyFrom(create_model_msg(model))
aggregate_msg.size = size
return aggregate_msg
def create_finish_training_message(cid, downlink_times):
ft = splifed.FinishTraining()
ft.cid = cid
ft.downlink_times.extend(downlink_times)
return ft
# ------------------------------ LEITURA DE MENSAGENS ------------------------------
def read_aggregate_message(aggregate_msg):
model = read_model(aggregate_msg.model)
size = aggregate_msg.size
return {
"weights": model["weights"],
"shapes": model["shapes"],
"size": size
}
def read_shape(shape_msg):
return list(shape_msg.shape)
def read_model(model_msg):
shapes = [read_shape(s) for s in model_msg.s]
weights = []
start = 0
for shape in shapes:
size = np.prod(shape)
weights.append(np.array(model_msg.weights[start:start + size]).reshape(shape))
start += size
return {
"weights": weights,
"shapes": shapes
}
def read_activations(activations_msg):
shapes = [read_shape(s) for s in activations_msg.s]
activations = np.array(activations_msg.activations).reshape(*shapes[0]) if shapes else activations_msg.activations
return {
"activations": activations,
"shapes": shapes,
}
def read_forward_message(forward_msg):
activations = read_activations(forward_msg.activations)
labels = list(forward_msg.labels)
rnd = forward_msg.rnd
cid = forward_msg.cid
time = forward_msg.time
return {
"activations": activations['activations'],
"shapes": activations['shapes'],
"labels": np.array(labels),
"cid": cid,
'rnd': rnd,
"batch": forward_msg.batch,
'time': time
}
def read_gradient(gradient_msg):
shapes = [read_shape(s) for s in gradient_msg.s]
gradients = np.array(gradient_msg.gradients).reshape(*shapes[0]) if shapes else gradient_msg.gradients
gradients = tf.convert_to_tensor(gradients, dtype=tf.float32)
return {
"gradients": gradients,
"shapes": shapes,
}
def read_backward_message(backward_msg):
gradients = read_gradient(backward_msg.gradients)
loss = backward_msg.loss
time = backward_msg.time
return {
"gradients": gradients['gradients'],
"loss": loss,
"time": time
}
def read_finish_training_message(ft_msg):
cid = ft_msg.cid
downlink_times = list(ft_msg.downlink_times)
return {
"cid": cid,
"downlink_times": downlink_times
}
# Definição do modelo de rede neural
class ServerModel(tf.keras.models.Model):
def __init__(self):
super(ServerModel, self).__init__()
self.dense1 = K.layers.Dense(64, activation='relu')
self.dense2 = K.layers.Dense(128, activation='relu')
self.dense3 = K.layers.Dense(64, activation='relu')
self.dense4 = K.layers.Dense(10, activation='softmax')
def call(self, input):
x = self.dense1(input)
x = self.dense2(x)
x = self.dense3(x)
x = self.dense4(x)
return x
# Funções de criação e treinamento do modelo
def create_server_model():
return ServerModel()
def s_forward(model, X, labels):
with tf.GradientTape(persistent=True) as tape:
tape.watch(X)
y_pred = model(X)
loss = tf.keras.losses.sparse_categorical_crossentropy(labels, y_pred)
loss = tf.reduce_mean(loss)
return tape, loss
def s_backward(model, optimizer, tape, loss):
server_gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(server_gradients, model.trainable_variables))
return model
# Serviço Split Federated Learning
class SplitFedService(splifed_grpc.SplifedServicer):
def __init__(self, n_clients, agg):
self.server_model = create_server_model()
self.optimizer = tf.keras.optimizers.Adam()
self.finished = [False for _ in range(n_clients)]
self.data = []
self.agg = agg
self.downlink_times = pd.DataFrame(columns=['cid', 'downlink'])
def save_df_losses(self):
filename = "losses.csv" if not self.agg else "losses_aggregated.csv"
header = ['rnd', 'batch', 'cid', 'loss', 'ff_package_size', 'bk_package_size', 'uplink']
data_df = pd.DataFrame(self.data, columns=header)
new_entries = pd.merge(data_df, self.downlink_times, on='cid')
if os.path.exists(filename):
df_existing = pd.read_csv(filename)
df_combined = pd.concat([df_existing, new_entries], ignore_index=True)
df_combined.to_csv(filename, index=False)
else:
new_entries.to_csv(filename, index=False)
def Forward(self, request, context):
ff_package_size = len(request.SerializeToString())
response = read_forward_message(request)
labels = response['labels']
rnd = response['rnd']
batch = response['batch']
uplink = abs(time.time() - response['time'])
act = tf.convert_to_tensor(response["activations"], dtype=tf.float32)
print("FORWARD")
tape, loss = s_forward(self.server_model, act, labels)
print("BACKWARD")
s_backward(self.server_model, self.optimizer, tape, loss)
activation_gradients = tape.gradient(loss, act)
backward_msg = create_backward_message(
activation_gradients,
loss.numpy()
)
bk_package_size = len(backward_msg.SerializeToString())
self.data.append([rnd, batch, response['cid'], loss.numpy(), ff_package_size, bk_package_size, uplink])
return backward_msg
def create_downlink_df(self, downlink_times, cid):
return pd.DataFrame(
{
'cid': [
cid for _ in range(len(downlink_times))
],
'downlink': downlink_times
}
)
def Finish(self, request, context):
f_msg = read_finish_training_message(request)
downlink_df = self.create_downlink_df(f_msg['downlink_times'], f_msg['cid'])
self.finished[f_msg['cid']] = True
if not downlink_df.empty:
self.downlink_times = pd.concat([self.downlink_times, downlink_df], ignore_index=True)
else:
self.downlink_times = downlink_df
if all(self.finished):
self.save_df_losses()
return splifed.Empty()
# Função para iniciar o servidor em uma thread separada
def serve(n_clients, agg=False):
MAX_MESSAGE_LENGTH = 20 * 1024 * 1024 * 20
server = grpc.server(futures.ThreadPoolExecutor(max_workers=n_clients), options=[
('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
])
s = SplitFedService(n_clients, agg)
splifed_grpc.add_SplifedServicer_to_server(s, server)
server.add_insecure_port('[::]:50051')
print("Server started on port 50051")
server.start()
# Mantém o servidor ativo até que todos os clientes terminem o processo de treinamento
while not all(s.finished):
sleep(1)
print("Stopping server...")
sleep(2)
server.stop(0)
server_thread = threading.Thread(target=serve, args=(N_CLIENTS, USE_AGGREGATION))
server_thread.start()
Servidor de agregação
# Definição do serviço agregador
class AggregatorService(splifed_grpc.AggregatorServicer):
def __init__(self, n_clients):
self.models_received = 0
self.n_clients = n_clients
self.finished = [False for _ in range(n_clients)]
self.models = []
self.lock = threading.Lock()
self.clients_condition = threading.Condition(self.lock)
def Aggregate(self, request, context):
print("REQUEST RECEBIDO")
self.aggregated_model = None
with self.lock:
print(f"Received model from client. Current count: {self.models_received + 1}/{self.n_clients}")
model_msg = read_model(request.model)
size = request.size
self.models.append((model_msg["weights"], size))
self.models_received += 1
if self.models_received == self.n_clients:
print("All models received. Aggregating...")
self.aggregated_model = aggregate(self.models)
self.models = [] # Reset for the next round
self.models_received = 0
print("Aggregation complete. Returning aggregated model.")
self.clients_condition.notify_all()
else:
print(f"Waiting for more models. Current count: {self.models_received}/{self.n_clients}")
self.clients_condition.wait()
w_msg = create_model_msg(self.aggregated_model)
return w_msg
def Finish(self, request, context):
self.finished[request.cid] = True
return splifed.Empty()
# Função para iniciar o servidor em uma thread separada
def start_aggregator_server(port, n_clients):
MAX_MESSAGE_LENGTH = 20 * 1024 * 1024 * 20
server = grpc.server(futures.ThreadPoolExecutor(max_workers=n_clients), options=[
('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
])
agg_server = AggregatorService(n_clients)
splifed_grpc.add_AggregatorServicer_to_server(agg_server, server)
server.add_insecure_port(f'[::]:{port}')
print(f"Aggregator server started on port {port}")
server.start()
# Mantém o servidor ativo até que todos os clientes concluam
while not all(agg_server.finished):
sleep(1)
print("Stopping server...")
sleep(2)
server.stop(0)
# Iniciar o servidor agregador em uma thread paralela
aggregator_thread = threading.Thread(target=start_aggregator_server, args=(50052, N_CLIENTS))
aggregator_thread.start()
Cliente
# ------------------- Funções para gRPC -------------------
def connect_to_server():
MAX_MESSAGE_LENGTH = 20 * 1024 * 1024 * 20
channel = grpc.insecure_channel('localhost:50051', options=[
('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
])
stub = splifed_grpc.SplifedStub(channel)
return stub
def connect_to_aggregator():
MAX_MESSAGE_LENGTH = 20 * 1024 * 1024 * 20
channel = grpc.insecure_channel('localhost:50052', options=[
('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
])
stub = splifed_grpc.AggregatorStub(channel)
return stub
# ------------------- Funções de Treinamento -------------------
def create_model(input_shape):
input_layer = K.layers.Input(shape=(input_shape[0], input_shape[1], 1))
x = K.layers.Conv2D(32, (3, 3), activation='relu')(input_layer)
x = K.layers.MaxPooling2D((2, 2))(x)
x = K.layers.Flatten()(x)
x = K.layers.Dense(64, activation='relu')(x)
model = K.Model(inputs=input_layer, outputs=x)
optimizer = K.optimizers.Adam()
return model, optimizer
def load_dataset(dataset_name, n_partitions):
partitioner = fld.partitioner.IidPartitioner(n_partitions)
fd = fld.FederatedDataset(dataset=dataset_name, partitioners={'train':partitioner, 'test':partitioner})
return fd
def load_partitions(fd, cid):
train_partition = fd.load_partition(cid, 'train').with_format("numpy")
# test_partition = fd.load_partition(cid, 'test').with_format("numpy")
x_train, y_train = train_partition['image']/255.0, train_partition['label']
# x_test, y_test = test_partition['image']/255.0, test_partition['label']
return x_train, y_train
def forward(model, X):
with tf.GradientTape() as tape:
y_pred = model(X)
return model, y_pred, tape
def backward(model, optimizer, tape, activations, server_grad):
gradient = tape.gradient(activations,model.trainable_variables, output_gradients=server_grad)
optimizer.apply_gradients(zip(gradient, model.trainable_variables))
return model
@ray.remote
def train_model(X_train, y_train, cid, aggragation, epochs):
stub = connect_to_server()
downlink_times = []
client_model, optimizer = create_model((28, 28))
if aggragation:
agg_stub = connect_to_aggregator()
for epoch in range(epochs):
print(f"EPOCH {epoch}")
for i in range(0, len(X_train), BATCH_SIZE):
X_batch = X_train[i:i + BATCH_SIZE]
y_batch = y_train[i:i + BATCH_SIZE]
# ---- Forward ----
client_model, activations, tape = forward(client_model, X_batch)
# ---- Send forward message ----
msg_act = create_activations(activations)
forward_msg = create_forward_message(msg_act, y_batch, cid, epoch, batch=i//BATCH_SIZE + epoch * BATCH_SIZE)
response = stub.Forward(forward_msg)
# ---- Backward ----
server_grad = read_backward_message(response)
downlink = abs(time.time() - server_grad['time'])
downlink_times.append(downlink)
client_model = backward(client_model, optimizer, tape, activations, server_grad['gradients'])
# ---- Aggregation Weights ----
print("Send model to aggregator")
weights = client_model.get_weights()
if aggragation:
w_msg = create_aggregate_message(weights, X_train.shape[0])
updated_w_msg = agg_stub.Aggregate(w_msg)
upd_w = read_model(updated_w_msg)
client_model.set_weights(upd_w['weights'])
ft_msg = create_finish_training_message(cid, downlink_times)
stub.Finish(ft_msg)
if aggragation:
agg_stub.Finish(ft_msg)
fd = load_dataset('mnist', N_CLIENTS)
# train_model(stub, agg_stub, client_model, *load_partitions(fd, args.cid), args.cid)
ray_results = ray.get([
train_model.remote(*load_partitions(fd, cid), cid,USE_AGGREGATION, EPOCHS)
for cid in range(N_CLIENTS)
])
Análise dos dados
agg_losses = pd.read_csv('losses_aggregated.csv')
losses = pd.read_csv('losses.csv')
agg_losses['approach'] = 'With Aggregation'
losses['approach'] = 'Without Aggregation'
df = pd.concat([agg_losses, losses])
sns.lineplot(data=df, x='batch', y='loss', hue='approach')
plt.xlabel('Round')
plt.ylabel('Loss')
plt.title('Loss over Rounds')
plt.show()
t = df.groupby(['approach']).agg({'uplink': 'mean', 'downlink': 'mean'}).reset_index().set_index('approach')
t.plot(kind='bar')
df['throughput'] = (df['ff_package_size'] + df['bk_package_size']) * 8 / 1000 / (df['uplink'] + df['downlink'])
sns.lineplot(data=df, x='batch', y='throughput', hue='approach')
plt.xlabel('Round')
plt.ylabel('Network Throughput (Kbts/s)')
plt.title('Network Throughput over Rounds')
plt.show()
