Skip to content

Trabalho prático MO809 - Split Federated Learning com gRPC e Ray

Neste trabalho, vou explorar a implementação de um sistema de Split Federated Learning utilizando duas tecnologias principais: gRPC e Ray. A abordagem de Split Federated Learning permite distribuir o treinamento de modelos de aprendizado de máquina entre dispositivos, dividindo a carga de processamento e mantendo os dados sensíveis locais, contribuindo para a preservação da privacidade.

O uso do gRPC facilita a comunicação eficiente entre diferentes partes do sistema distribuído, proporcionando um canal rápido e seguro para a troca de informações entre o servidor central e os dispositivos remotos. Já o Ray será utilizado para a orquestração e escalabilidade do sistema, permitindo distribuir tarefas de forma paralela e otimizada.

Com esta implementação, espero demonstrar como o Split Federated Learning pode ser uma solução eficaz para cenários que exigem privacidade de dados e processamento distribuído, explorando os benefícios de desempenho e segurança proporcionados pelo uso de gRPC e Ray.

!pip install grpcio grpcio-tools ray
import grpc
import tensorflow as tf
import keras as K
from concurrent import futures
from time import sleep, time
import threading
import pandas as pd
import os
import splifed_pb2 as splifed
import splifed_pb2_grpc as splifed_grpc
from proto_messages import *  # Certifique-se de que esse módulo está acessível
import time
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import flwr_datasets as fld
import grpc
from flwr.server.strategy.aggregate import aggregate
import ray
from proto_messages import *
import ray
N_CLIENTS = 3
EPOCHS = 50
BATCH_SIZE = 128
USE_AGGREGATION = True
partitioner = fld.partitioner.IidPartitioner(N_CLIENTS)
fd = fld.FederatedDataset(dataset='mnist', partitioners={'train':partitioner, 'test':partitioner})
train_partition = fd.load_partition(0, 'train').with_format("numpy")
test_partition = fd.load_partition(0, 'test').with_format("numpy")
x_train, y_train, x_test, y_test = train_partition['image']/255.0, train_partition['label'], test_partition['image']/255.0, test_partition['label']# X_train, y_train = test_partition["img"], test_partition["label"]
x_train.shape
fld.visualization.plot_label_distributions(
    partitioner,
    label_name="label",
    title="Comparação de diferentes partições do dataset MNIST",
    legend=True,
    size_unit="percent",
    verbose_labels=False
)

Relembrando Split Federated Learning

splitfed_proto.png

Servidor de treinamento

import numpy as np
import tensorflow as tf
import splifed_pb2 as splifed
import time
# ------------------------------ CRIAR DE MENSAGENS ------------------------------
def create_shape_from_array(array):
    shape = splifed.Shape()
    shape.shape.extend(array.shape)
    return shape

def create_model_msg(weights):
    model = splifed.Model()
    np_weights = [np.array(w) for w in weights]
    flattened_weights = [w.flatten().tolist() for w in np_weights]
    for fw in flattened_weights:
        model.weights.extend(fw)
    for w in np_weights:
        model.s.append(create_shape_from_array(w))
    return model

def create_activations(activations_list):
    activations = splifed.Activations()
    np_activations = np.array(activations_list)
    flattened_activations = np_activations.flatten().tolist()
    activations.activations.extend(flattened_activations)
    activations.s.append(create_shape_from_array(np_activations))
    return activations

def create_forward_message(activations, labels, cid, rnd, batch):
    forward_msg = splifed.ForwardMessage()
    forward_msg.activations.CopyFrom(activations)
    labels = [int(label) for label in labels]
    forward_msg.labels.extend(labels)
    forward_msg.cid = cid
    forward_msg.rnd = rnd
    forward_msg.batch = batch
    forward_msg.time = time.time()
    return forward_msg

def create_gradient(gradients):
    gradient = splifed.Gradient()
    np_gradients = np.array(gradients)
    flattened_gradients = np_gradients.flatten().tolist()
    gradient.gradients.extend(flattened_gradients)
    gradient.s.append(create_shape_from_array(np_gradients))
    return gradient

def create_backward_message(gradient, loss):
    backward_msg = splifed.BackwardMessage()
    backward_msg.gradients.CopyFrom(create_gradient(gradient))
    backward_msg.loss = loss
    backward_msg.time = time.time()
    return backward_msg

def create_aggregate_message(model, size):
    aggregate_msg = splifed.AggregateMessage()
    aggregate_msg.model.CopyFrom(create_model_msg(model))
    aggregate_msg.size = size
    return aggregate_msg

def create_finish_training_message(cid, downlink_times):
    ft = splifed.FinishTraining()
    ft.cid = cid
    ft.downlink_times.extend(downlink_times)
    return ft

# ------------------------------ LEITURA DE MENSAGENS ------------------------------
def read_aggregate_message(aggregate_msg):
    model = read_model(aggregate_msg.model)
    size = aggregate_msg.size
    return {
        "weights": model["weights"],
        "shapes": model["shapes"],
        "size": size
    }

def read_shape(shape_msg):
    return list(shape_msg.shape)

def read_model(model_msg):
    shapes = [read_shape(s) for s in model_msg.s]
    weights = []
    start = 0
    for shape in shapes:
        size = np.prod(shape)
        weights.append(np.array(model_msg.weights[start:start + size]).reshape(shape))
        start += size
    return {
        "weights": weights,
        "shapes": shapes
    }

def read_activations(activations_msg):
    shapes = [read_shape(s) for s in activations_msg.s]
    activations = np.array(activations_msg.activations).reshape(*shapes[0]) if shapes else activations_msg.activations
    return {
        "activations": activations,
        "shapes": shapes,
    }

def read_forward_message(forward_msg):
    activations = read_activations(forward_msg.activations)
    labels = list(forward_msg.labels)
    rnd = forward_msg.rnd
    cid = forward_msg.cid
    time = forward_msg.time
    return {
        "activations": activations['activations'],
        "shapes": activations['shapes'],
        "labels": np.array(labels),
        "cid": cid,
        'rnd': rnd,
        "batch": forward_msg.batch,
        'time': time
    }

def read_gradient(gradient_msg):
    shapes = [read_shape(s) for s in gradient_msg.s]
    gradients = np.array(gradient_msg.gradients).reshape(*shapes[0]) if shapes else gradient_msg.gradients
    gradients = tf.convert_to_tensor(gradients, dtype=tf.float32)
    return {
        "gradients": gradients,
        "shapes": shapes,
    }

def read_backward_message(backward_msg):
    gradients = read_gradient(backward_msg.gradients)
    loss = backward_msg.loss
    time = backward_msg.time
    return {
        "gradients": gradients['gradients'],
        "loss": loss,
        "time": time
    }

def read_finish_training_message(ft_msg):
    cid = ft_msg.cid
    downlink_times = list(ft_msg.downlink_times)
    return {
        "cid": cid,
        "downlink_times": downlink_times
   }


# Definição do modelo de rede neural
class ServerModel(tf.keras.models.Model):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.dense1 = K.layers.Dense(64, activation='relu')
        self.dense2 = K.layers.Dense(128, activation='relu')
        self.dense3 = K.layers.Dense(64, activation='relu')
        self.dense4 = K.layers.Dense(10, activation='softmax')

    def call(self, input):
        x = self.dense1(input)
        x = self.dense2(x)
        x = self.dense3(x)
        x = self.dense4(x)
        return x

# Funções de criação e treinamento do modelo
def create_server_model():
    return ServerModel()

def s_forward(model, X, labels):
    with tf.GradientTape(persistent=True) as tape:
        tape.watch(X)
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(labels, y_pred)
        loss = tf.reduce_mean(loss)
    return tape, loss

def s_backward(model, optimizer, tape, loss):
    server_gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(server_gradients, model.trainable_variables))
    return model
# Serviço Split Federated Learning
class SplitFedService(splifed_grpc.SplifedServicer):
    def __init__(self, n_clients, agg):
        self.server_model = create_server_model()
        self.optimizer = tf.keras.optimizers.Adam()
        self.finished = [False for _ in range(n_clients)]
        self.data = []
        self.agg = agg
        self.downlink_times = pd.DataFrame(columns=['cid', 'downlink'])

    def save_df_losses(self):
        filename = "losses.csv" if not self.agg else "losses_aggregated.csv"
        header = ['rnd', 'batch', 'cid', 'loss', 'ff_package_size', 'bk_package_size', 'uplink']
        data_df = pd.DataFrame(self.data, columns=header)
        new_entries = pd.merge(data_df, self.downlink_times, on='cid')

        if os.path.exists(filename):
            df_existing = pd.read_csv(filename)
            df_combined = pd.concat([df_existing, new_entries], ignore_index=True)
            df_combined.to_csv(filename, index=False)
        else:
            new_entries.to_csv(filename, index=False)

    def Forward(self, request, context):
        ff_package_size = len(request.SerializeToString())

        response = read_forward_message(request)
        labels = response['labels']
        rnd = response['rnd']
        batch = response['batch']
        uplink = abs(time.time() - response['time'])
        act = tf.convert_to_tensor(response["activations"], dtype=tf.float32)

        print("FORWARD")
        tape, loss = s_forward(self.server_model, act, labels)

        print("BACKWARD")
        s_backward(self.server_model, self.optimizer, tape, loss)
        activation_gradients = tape.gradient(loss, act)

        backward_msg = create_backward_message(
            activation_gradients,
            loss.numpy()
        )

        bk_package_size = len(backward_msg.SerializeToString())
        self.data.append([rnd, batch, response['cid'], loss.numpy(), ff_package_size, bk_package_size, uplink])
        return backward_msg

    def create_downlink_df(self, downlink_times, cid):
        return pd.DataFrame(
            {
                'cid': [
                    cid for _ in range(len(downlink_times))
                ],
                'downlink': downlink_times
            }
        )

    def Finish(self, request, context):
        f_msg = read_finish_training_message(request)
        downlink_df = self.create_downlink_df(f_msg['downlink_times'], f_msg['cid'])
        self.finished[f_msg['cid']] = True
        if not downlink_df.empty:
            self.downlink_times = pd.concat([self.downlink_times, downlink_df], ignore_index=True)
        else:
            self.downlink_times = downlink_df

        if all(self.finished):
            self.save_df_losses()
        return splifed.Empty()
# Função para iniciar o servidor em uma thread separada
def serve(n_clients, agg=False):
    MAX_MESSAGE_LENGTH = 20 * 1024 * 1024 * 20
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=n_clients), options=[
        ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
        ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
    ])

    s = SplitFedService(n_clients, agg)
    splifed_grpc.add_SplifedServicer_to_server(s, server)
    server.add_insecure_port('[::]:50051')

    print("Server started on port 50051")
    server.start()

    # Mantém o servidor ativo até que todos os clientes terminem o processo de treinamento
    while not all(s.finished):
        sleep(1)

    print("Stopping server...")
    sleep(2)
    server.stop(0)
server_thread = threading.Thread(target=serve, args=(N_CLIENTS, USE_AGGREGATION))
server_thread.start()

Servidor de agregação

# Definição do serviço agregador
class AggregatorService(splifed_grpc.AggregatorServicer):
    def __init__(self, n_clients):
        self.models_received = 0
        self.n_clients = n_clients
        self.finished = [False for _ in range(n_clients)]
        self.models = []
        self.lock = threading.Lock()
        self.clients_condition = threading.Condition(self.lock)

    def Aggregate(self, request, context):
        print("REQUEST RECEBIDO")
        self.aggregated_model = None

        with self.lock:
            print(f"Received model from client. Current count: {self.models_received + 1}/{self.n_clients}")
            model_msg = read_model(request.model)
            size = request.size
            self.models.append((model_msg["weights"], size))

            self.models_received += 1

            if self.models_received == self.n_clients:
                print("All models received. Aggregating...")
                self.aggregated_model = aggregate(self.models)
                self.models = []  # Reset for the next round
                self.models_received = 0
                print("Aggregation complete. Returning aggregated model.")
                self.clients_condition.notify_all()
            else:
                print(f"Waiting for more models. Current count: {self.models_received}/{self.n_clients}")
                self.clients_condition.wait()

        w_msg = create_model_msg(self.aggregated_model)
        return w_msg

    def Finish(self, request, context):
        self.finished[request.cid] = True
        return splifed.Empty()
# Função para iniciar o servidor em uma thread separada
def start_aggregator_server(port, n_clients):
    MAX_MESSAGE_LENGTH = 20 * 1024 * 1024 * 20
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=n_clients), options=[
        ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
        ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
    ])

    agg_server = AggregatorService(n_clients)
    splifed_grpc.add_AggregatorServicer_to_server(agg_server, server)
    server.add_insecure_port(f'[::]:{port}')

    print(f"Aggregator server started on port {port}")
    server.start()

    # Mantém o servidor ativo até que todos os clientes concluam
    while not all(agg_server.finished):
        sleep(1)

    print("Stopping server...")
    sleep(2)
    server.stop(0)
# Iniciar o servidor agregador em uma thread paralela
aggregator_thread = threading.Thread(target=start_aggregator_server, args=(50052, N_CLIENTS))
aggregator_thread.start()

Cliente

# ------------------- Funções para gRPC -------------------
def connect_to_server():
    MAX_MESSAGE_LENGTH = 20 * 1024 * 1024 * 20
    channel = grpc.insecure_channel('localhost:50051', options=[
            ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
            ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
        ])
    stub = splifed_grpc.SplifedStub(channel)
    return stub

def connect_to_aggregator():
    MAX_MESSAGE_LENGTH = 20 * 1024 * 1024 * 20
    channel = grpc.insecure_channel('localhost:50052', options=[
            ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
            ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
        ])
    stub = splifed_grpc.AggregatorStub(channel)
    return stub
# ------------------- Funções de Treinamento -------------------
def create_model(input_shape):
    input_layer = K.layers.Input(shape=(input_shape[0], input_shape[1], 1))
    x = K.layers.Conv2D(32, (3, 3), activation='relu')(input_layer)
    x = K.layers.MaxPooling2D((2, 2))(x)
    x = K.layers.Flatten()(x)
    x = K.layers.Dense(64, activation='relu')(x)
    model = K.Model(inputs=input_layer, outputs=x)
    optimizer = K.optimizers.Adam()
    return model, optimizer

def load_dataset(dataset_name, n_partitions):
    partitioner = fld.partitioner.IidPartitioner(n_partitions)
    fd = fld.FederatedDataset(dataset=dataset_name, partitioners={'train':partitioner, 'test':partitioner})
    return fd

def load_partitions(fd, cid):
    train_partition = fd.load_partition(cid, 'train').with_format("numpy")
    # test_partition = fd.load_partition(cid, 'test').with_format("numpy")
    x_train, y_train  = train_partition['image']/255.0, train_partition['label']
    # x_test, y_test =  test_partition['image']/255.0, test_partition['label']
    return x_train, y_train

def forward(model, X):
    with tf.GradientTape() as tape:
        y_pred = model(X)
    return model, y_pred, tape

def backward(model, optimizer, tape, activations, server_grad):
    gradient = tape.gradient(activations,model.trainable_variables, output_gradients=server_grad)
    optimizer.apply_gradients(zip(gradient, model.trainable_variables))
    return model
@ray.remote
def train_model(X_train, y_train, cid, aggragation, epochs):
    stub = connect_to_server()
    downlink_times = []
    client_model, optimizer = create_model((28, 28))
    if aggragation:
        agg_stub = connect_to_aggregator()

    for epoch in range(epochs):
        print(f"EPOCH {epoch}")
        for i in range(0, len(X_train), BATCH_SIZE):
            X_batch = X_train[i:i + BATCH_SIZE]
            y_batch = y_train[i:i + BATCH_SIZE]
            # ---- Forward ----
            client_model, activations, tape = forward(client_model, X_batch)
            # ---- Send forward message ----
            msg_act = create_activations(activations)
            forward_msg = create_forward_message(msg_act, y_batch, cid, epoch, batch=i//BATCH_SIZE + epoch * BATCH_SIZE)
            response = stub.Forward(forward_msg)
            # ---- Backward ----
            server_grad = read_backward_message(response)
            downlink = abs(time.time() - server_grad['time'])
            downlink_times.append(downlink)
            client_model = backward(client_model, optimizer, tape, activations, server_grad['gradients'])

        # ---- Aggregation Weights ----
        print("Send model to aggregator")
        weights = client_model.get_weights()

        if aggragation:
            w_msg = create_aggregate_message(weights, X_train.shape[0])
            updated_w_msg = agg_stub.Aggregate(w_msg)
            upd_w = read_model(updated_w_msg)
            client_model.set_weights(upd_w['weights'])


    ft_msg = create_finish_training_message(cid, downlink_times)
    stub.Finish(ft_msg)
    if aggragation:
        agg_stub.Finish(ft_msg)


fd = load_dataset('mnist', N_CLIENTS)


# train_model(stub, agg_stub, client_model, *load_partitions(fd, args.cid), args.cid)

ray_results = ray.get([
    train_model.remote(*load_partitions(fd, cid), cid,USE_AGGREGATION, EPOCHS)
    for cid in range(N_CLIENTS)
])

Análise dos dados

agg_losses = pd.read_csv('losses_aggregated.csv')
losses = pd.read_csv('losses.csv')
agg_losses['approach'] = 'With Aggregation'
losses['approach'] = 'Without Aggregation'
df = pd.concat([agg_losses, losses])
sns.lineplot(data=df, x='batch', y='loss', hue='approach')
plt.xlabel('Round')
plt.ylabel('Loss')
plt.title('Loss over Rounds')
plt.show()
t = df.groupby(['approach']).agg({'uplink': 'mean', 'downlink': 'mean'}).reset_index().set_index('approach')
t.plot(kind='bar')
df['throughput'] = (df['ff_package_size'] + df['bk_package_size']) * 8 / 1000 / (df['uplink'] + df['downlink'])
sns.lineplot(data=df, x='batch', y='throughput', hue='approach')
plt.xlabel('Round')
plt.ylabel('Network Throughput (Kbts/s)')
plt.title('Network Throughput over Rounds')
plt.show()