Skip to content

FedBN : Federated Learning on Non-IID Features via Local Batch Normalization

Seminário Prático - MO809A (Aprendizado Federado)

Nome : Leonardo dos Santos Marcondes
RA : 291206

Dataset : Link para Download do Dataset
Nootebook : Link para Download do Notebook

!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install flower

1. O problema dos dados Non-IID :

Dados non-iid (ou non-independent and identically distributed) são dados que não seguem o pressuposto comum de serem independentes e identicamente distribuídos (IID).

  • Independentes : Cada ponto de dado é independente dos outros, ou seja, não há relação entre os dados.
  • Identicamente distribuídos : Todos os pontos de dados vêm da mesma distribuição estatística.

Em outras palavras, o problema de dados non-IID ocorre quando clientes diferentes possuem dados com distribuição marginal de features ou labels diferente, quantidade de dados discrepante ou quando a distribuição condicional das features para determinado rótulo varia entre clientes. O impacto dessa heterogenicidade de distribuição de dados entre clientes pode ser melhor compreendida com o exemplo da imagem a seguir.


Problema dos Dados non-IID

Os dados non-IID causam baixa convergência para algoritmos de ML e alto consumo de energia, aumentando também a largura de banda.

2. Batch Normalization e Dados Non-IID

Batch Normalization (BN) é uma técnica fundamental no treinamento de redes neurais profundas, introduzida por Ioffe e Szegedy em 2015, que visa acelerar e estabilizar o processo de treinamento.

Batch Normalization

A arquitetura de aprendizado federado proposta, FedBN, realiza atualizações locais e calcula a média dos modelos locais, excluindo os parâmetros de batch normalization do processo de média. Como veremos, essa modificação simples proporciona melhorias significativas em cenários com dados não-IID.

3. Carregamento e Visualização dos Dados

Serão utilizados 5 datasets de imagens de digitos (0-9) para comparar o efeito da normalização em batchs no desempenho de sistemas de aprendizado federado que sofram com heterogenicidade de dados non-IID. Esses datasets são :

  • MNIST : Digitos escritos a mão.
  • MNIST-M : Digitos do MNIST aleatóriamente colorido.
  • SVHN : Número de casas coletados pelo Street View.
  • USPS : Digitos escritos a mão de envelopes do serviço postal dos Estados Unidos.
  • SynthDigits : Digitos gerados sintéticamente a partir da fonte TM do Windowns, variando o nível de blur, orientação e cor do traçado.

A seguir são apresentados a tabela que contém as características de cada um desses datasets e o perfil dos dados de cada um desses datasets.

| Dataset | MNIST | MNIST-M | SVHN | USPS | SynthDigits | |:--------------------------------:|:---------------:|:-------------------------:|:-----------------:|:---------------:|:------------------:| | **Cor** | Greyscale | RGB | RGB | Greyscale | RGB | | **Tamanho do Pixel** | 28x28 | 28x28 | 32x32 | 16x16 | 32x32 | | **Rótulos** | 0-9 | 0-9 | 1-10 | 0-9 | 1-10 | | **Tamanho da Partição de Treinamento** | 60,000 | 60,000 | 73,257 | 9,298 | 50,000 | | **Tamanho da Partição de Teste** | 10,000 | 10,000 | 26,032 | - | - | | **Dimensões da Imagem** | (28,28) | (28,28,3) | (32,32,3) | (16,16) | (32,32,3) |

import matplotlib.pyplot as plt
from random import sample
import seaborn as sns
import joblib

NUM_PARTITIONS = 10
DATASETS = ["MNIST", "MNIST_M", "SVHN", "SynthDigits", "USPS"]
BATCH_SIZE = 32

def display_samples(x_train, y_train, dataset):
    sns.set(style="whitegrid")

    fig, axes = plt.subplots(2, 5, figsize=(12, 5))  

    for i in range(10):
        # Exibir a imagem em uma posição correspondente à sua classe
        ax = axes[i // 5, i % 5]
        ax.imshow(x_train[i], cmap='gray')
        ax.axis('off')
        ax.set_title(f'Classe {y_train[i]}')

    plt.suptitle(f"Exemplo de Classes do {dataset}")
    plt.tight_layout()
    plt.show()

for dataset in DATASETS : 
    for partition in range(NUM_PARTITIONS) : 
        x_train, y_train = joblib.load(f'./data/{dataset}/partitions/train_part{partition}.pkl')
        if partition == 0 :
            indices = sample(range(0, len(x_train)), 10)
            display_samples(x_train=x_train[indices], y_train=y_train[indices], dataset=dataset)

Classe do Dataset :

from torch.utils.data import DataLoader, Dataset
from typing import List, Optional, Tuple
from PIL import Image # Importar a biblioteca PIL para manipular as imagens.
import numpy as np
import os

class DigitsDataset(Dataset):
    """Split datasets."""

    total_partitions: int = 10

    def __init__(  # pylint: disable=too-many-arguments
        self,
        data_path: str,
        channels: int,
        train: bool,
        partitions: Optional[List[int]] = None,
        transform=None,
    ):
        if train and partitions is not None:
            # Construct dataset by loading one or more partitions
            self.images, self.labels = joblib.load(  # Using joblib to load .pkl files
                os.path.join(data_path, f"partitions/train_part{partitions[0]}.pkl")
            )
            for part in partitions[1:]:
                images, labels = joblib.load(
                    os.path.join(data_path, f"partitions/train_part{part}.pkl")
                )
                self.images = np.concatenate([self.images, images], axis=0)
                self.labels = np.concatenate([self.labels, labels], axis=0)

        else:
            self.images, self.labels = joblib.load(
                os.path.join(data_path, "test.pkl")
            )

        self.transform = transform
        self.channels = channels
        self.labels = self.labels.squeeze()

    def __len__(self) -> int:
        """Return number of images."""
        return self.images.shape[0]

    def __getitem__(self, idx):
        """Return a transformed example of the dataset."""
        image = self.images[idx]
        label = self.labels[idx]
        if self.channels == 1:
            image = Image.fromarray(image, mode="L")
        elif self.channels == 3:
            image = Image.fromarray(image, mode="RGB")
        else:
            raise ValueError(f"{self.channels} channel is not allowed.")

        if self.transform is not None:
            image = self.transform(image)
        return image, label

Função para carregar as partições e pre-processar as imagens :

from torchvision import transforms

def load_partition(
    dataset: str, path_to_data: str, partition_indx: List[int], batch_size: int
) -> Tuple[DataLoader, DataLoader]:
    data_path = os.path.join(path_to_data, dataset)

    if dataset == "MNIST":
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),  
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Média, Desvio Padrão de cada canal
        ])

        trainset = DigitsDataset(data_path, channels=1, partitions=partition_indx, train=True, transform=transform)
        testset = DigitsDataset(data_path, channels=1, train=False, transform=transform)

    elif dataset == "SVHN":
        transform = transforms.Compose([
            transforms.Resize([28, 28]), # Redimensiona 32x32 para 28x28
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        trainset = DigitsDataset(data_path, channels=3, partitions=partition_indx, train=True, transform=transform)
        testset = DigitsDataset(data_path, channels=3, train=False, transform=transform)

    elif dataset == "USPS":
        transform = transforms.Compose([
            transforms.Resize([28, 28]), # Redimensiona 16x16 para 28x28
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        trainset = DigitsDataset(data_path, channels=1, partitions=partition_indx, train=True, transform=transform)
        testset = DigitsDataset(data_path, channels=1, train=False, transform=transform)

    elif dataset == "SynthDigits":
        transform = transforms.Compose([
            transforms.Resize([28, 28]), # Redimensiona 32x32 para 28x28
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        trainset = DigitsDataset(data_path, channels=3, partitions=partition_indx, train=True, transform=transform)
        testset = DigitsDataset(data_path, channels=3, train=False, transform=transform)

    elif dataset == "MNIST_M":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        trainset = DigitsDataset(data_path, channels=3, partitions=partition_indx, train=True, transform=transform)
        testset = DigitsDataset(data_path, channels=3, train=False, transform=transform)

    else:
        raise NotImplementedError(f"dataset: {dataset} is not available")

    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

    return trainloader, testloader

Função para obter os dataloaders :

from random import shuffle

def get_data() : 
    client_data = []

    for dataset_name in DATASETS : 
        parts = list(range(NUM_PARTITIONS))
        shuffle(parts) # Seleciona uma partição aleatória
        parts_for_client = parts[0:1]
        trainloader, testloader = load_partition(
            dataset_name,
            path_to_data="./data",
            partition_indx=parts_for_client,
            batch_size=BATCH_SIZE,
        )
        client_data.append((trainloader, testloader, dataset_name))

    return client_data

4. Recursos Computacionais e Configurações

Não é obrigatório uma unidade de processamento gráfico (GPU) para rodar versões reduzidas do modelo (limitando os clientes a usarem 10% de cada dataset). Qualquer máquina com no mínimo 8 cores de CPU é capaz de executar 100 rodadas de comunicação do FedAvg e FedBN. Os hyperparametros utilizados para o treinamento são :

| Descrição | Valor | |-------------------------------|--------------------| | Rounds | 10 | | Number of Clients | 5 | | Strategy Fraction Fit | 1.0 | | Strategy Fraction Evaluate | 0.0 | | Training Samples per Client | 743 | | Client Learning Rate | 10E-2 | | Local Epochs | 1 | | Loss | Cross Entropy Loss | | Optimizer | SGD | | Client Resources (CPU) | 2 | | Client Resources (GPUs) | 0.0 |

NUM_CLIENTES = 5 # Número de Clientes
NUM_ROUNDS = 10 # Número de rodadas de comunicação
LEARNING_RATE = 0.01 # Learning Rate do Modelo
FRACTION_FIT = 1.0 # Fração de Clientes que participam a cada rodada

5. Definição do Modelo e Funções de Treinamento e Teste

import torch
from torch import nn # Neural Network module
from typing import Tuple

class CNNModel(nn.Module):
    """CNN model proposed in the FedBN paper."""

    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 5, 1, 2) # 3 canais de entrada, 64 canais de saída, 5x5 kernel, 1 stride, 2 padding
        self.bn1 = nn.BatchNorm2d(64) 
        self.conv2 = nn.Conv2d(64, 64, 5, 1, 2) 
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 5, 1, 2)
        self.bn3 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(6272, 2048)
        self.bn4 = nn.BatchNorm1d(2048)
        self.fc2 = nn.Linear(2048, 512)
        self.bn5 = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, num_classes)

    def forward(self, x):
        """Forward pass."""
        x = nn.functional.relu(self.bn1(self.conv1(x))) # 64 x 28 x 28
        x = nn.functional.max_pool2d(x, 2) # Pooling de tamanho 2x2 reduz a dimensionalidade espacial pela metade

        x = nn.functional.relu(self.bn2(self.conv2(x))) # 64 x 14 x 14
        x = nn.functional.max_pool2d(x, 2) # 64 x 7 x 7

        x = nn.functional.relu(self.bn3(self.conv3(x))) # 128 x 7 x 7

        x = x.view(x.shape[0], -1) # 6272

        x = self.fc1(x) # 2048
        x = self.bn4(x)
        x = nn.functional.relu(x)

        x = self.fc2(x) # 512
        x = self.bn5(x)
        x = nn.functional.relu(x)

        x = self.fc3(x) # 10
        return x

Função de Treinamento da Rede Neural :

DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Verifica se existe uma GPU disponível

def train(model, traindata, epochs, l_r) -> Tuple[float, float]:
    """Train the network."""
    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=l_r)

    # Train the network
    model.to(DEVICE) # Move o modelo para GPU se disponível
    model.train()
    total_loss = 0.0
    for _ in range(epochs):  # loop over the dataset multiple times
        total = 0.0
        correct = 0
        for _i, data in enumerate(traindata, 0):
            images, labels = data[0].to(DEVICE), data[1].to(DEVICE) # Move os tensores de entrada e rótulos para GPU

            # zero the parameter gradients
            optimizer.zero_grad()

            outputs = model(images) # Foward pass
            loss = criterion(outputs, labels) # Calculate loss
            loss.backward() # Calcula o gradiente dos pesos em relação a loss
            optimizer.step() # Atualiza os pesos do modelo

            # print statistics
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return total_loss / len(traindata), correct / total

Função de Teste da Rede Neural :

def test(model, testdata) -> Tuple[float, float]:
    """Validate the network on the entire test set."""
    # Define loss and metrics
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    loss = 0.0

    # Evaluate the network
    model.to(DEVICE)
    model.eval() # Coloca o modelo no modo de avaliação
    with torch.no_grad():
        for data in testdata:
            images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            outputs = model(images) 
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    loss = loss / len(testdata)
    return loss, accuracy

6. Configuração dos Clientes

Classe de Clientes FedAvg :

from torch.utils.data import DataLoader
import torch
from collections import OrderedDict
from typing import Dict
import flwr as fl
from flwr.common.typing import NDArrays, Scalar
from pathlib import Path
import pickle

class FlowerClient(fl.client.NumPyClient):
    def __init__(
            self,
            model: CNNModel, 
            trainloader: DataLoader,
            testloader: DataLoader,
            l_r: float, 
            dataset_name: str
        ):

        self.model = model
        self.trainloader = trainloader
        self.testloader = testloader
        self.dataset_name = dataset_name
        self.l_r = l_r

    def get_parameters(self, config) -> NDArrays:
        """Return model parameters as a list of NumPy ndarrays w or w/o.

        using BNlayers.
        """

        # Return all model parameters as a list of NumPy ndarrays
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters: NDArrays) -> None:
        """Set model parameters from a list of NumPy ndarrays.

        using BNlayers.
        """

        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(
        self, parameters: NDArrays, config: Dict[str, Scalar]
    ) -> Tuple[NDArrays, int, Dict]:
        """Set model parameters, train model, return updated model parameters."""

        self.set_parameters(parameters)

        # Evaluate the state of the global model on the train set; 
        pre_train_loss, pre_train_acc = test(
            self.model, self.trainloader
        )

        # Train model on local dataset
        loss, acc = train(
            self.model,
            self.trainloader,
            epochs=1,
            l_r=self.l_r
        )

        # Construct metrics to return to server
        fl_round = config["round"]
        metrics = {
            "dataset_name": self.dataset_name,
            "round": fl_round,
            "accuracy": acc,
            "loss": loss,
            "pre_train_loss": pre_train_loss,
            "pre_train_acc": pre_train_acc,
        }

        return (
            self.get_parameters({}),
            len(self.trainloader.dataset),
            metrics,
        )

    def evaluate(
        self, parameters: NDArrays, config: Dict[str, Scalar]
    ) -> Tuple[float, int, Dict]:
        """Set model parameters, evaluate model on local test dataset, return result."""
        self.set_parameters(parameters)

        loss, accuracy = test(self.model, self.testloader)
        return (
            float(loss),
            len(self.testloader.dataset),
            {"loss": loss, "accuracy": accuracy, "dataset_name": self.dataset_name},
        )

Classe de Clientes FedBN :

class FedBNFlowerClient(FlowerClient):
    """Similar to FlowerClient but this is used by FedBN clients."""

    def __init__(self, save_path: Path, client_id: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # For FedBN clients we need to persist the state of the BN
        # layers across rounds. In Simulation clients are statess
        # so everything not communicated to the server (as it is the
        # case as with params in BN layers of FedBN clients) is lost
        # once a client completes its training. An upcoming version of
        # Flower suports stateful clients
        bn_state_dir = save_path / "bn_states"
        bn_state_dir.mkdir(exist_ok=True)
        self.bn_state_pkl = bn_state_dir / f"client_{client_id}.pkl"

    def _save_bn_statedict(self) -> None:
        """Save contents of state_dict related to BN layers."""

        bn_state = {
            name: val.cpu().numpy()
            for name, val in self.model.state_dict().items()
            if "bn" in name
        }

        with open(self.bn_state_pkl, "wb") as handle:
            pickle.dump(bn_state, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def _load_bn_statedict(self) -> Dict[str, torch.tensor]:
        """Load pickle with BN state_dict and return as dict."""

        with open(self.bn_state_pkl, "rb") as handle:
            data = pickle.load(handle)
        bn_stae_dict = {k: torch.tensor(v) for k, v in data.items()}
        return bn_stae_dict

    def get_parameters(self, config) -> NDArrays:
        """Return model parameters as a list of NumPy ndarrays w or w/o using BN.

        layers.
        """

        # First update bn_state_dir
        self._save_bn_statedict()
        # Excluding parameters of BN layers when using FedBN
        parameters = [
            val.cpu().numpy()
            for name, val in self.model.state_dict().items()
            if "bn" not in name
        ] 

        return parameters

    def set_parameters(self, parameters: NDArrays) -> None:
        """Set model parameters from a list of NumPy ndarrays Exclude the bn layer if.

        available.
        """

        keys = [k for k in self.model.state_dict().keys() if "bn" not in k]
        params_dict = zip(keys, parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=False)

        # Now also load from bn_state_dir
        if self.bn_state_pkl.exists():  # It won't exist in the first round
            bn_state_dict = self._load_bn_statedict()
            self.model.load_state_dict(bn_state_dict, strict=False)

Configuração dos Clientes :

def gen_client_fn(client_data, learning_rate, client_type) :
    """Return a function that will be called to instantiate the cid-th client."""

    def client_fn(cid :str) -> FlowerClient :
        """Create a Flower client representing a single organization."""

        net = CNNModel()

        # Note: each client gets a different trainloader/valloader, so each client
        # will train and evaluate on their own unique data
        trainloader, valloader, dataset_name = client_data[int(cid)]

        flower_client = None

        if client_type == "FedBN":

            flower_client = FedBNFlowerClient(
                save_path=Path("./"),
                client_id=int(cid),
                model=net,
                trainloader=trainloader,
                testloader=valloader,
                l_r=learning_rate,
                dataset_name=dataset_name
            )

        elif client_type == "FedAvg":

            flower_client = FlowerClient(
                model=net,
                trainloader=trainloader,
                testloader=valloader,
                l_r=learning_rate,
                dataset_name=dataset_name
            )

        return flower_client

    return client_fn

7. Configuração do Servidor

from typing import Dict, List, Tuple

from flwr.common.typing import Metrics

def get_on_fit_config(server_round: int) -> Dict[str, int]:
    """Return a config (a dict) to be sent to clients during fit()."""
    fit_config = {"round": server_round}  # Add round info
    return fit_config

def get_metrics_aggregation_fn():
    """Return a function that computes the weighted average of metrics."""

    def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
        """Compute the weighted average of metrics."""
        # Initialize dictionaries to store totals and weighted sums
        totals: Dict[str, int] = {}
        accuracies: Dict[str, float] = {}
        losses: Dict[str, float] = {}
        pre_train_accuracies: Dict[str, float] = {}
        pre_train_losses: Dict[str, float] = {}

        # Aggregate metrics from all clients
        for num_examples, metric in metrics:
            dataset_name = metric.get("dataset_name", "default")

            # Update totals and weighted sums
            totals[dataset_name] = totals.get(dataset_name, 0) + num_examples
            accuracies[dataset_name] = accuracies.get(dataset_name, 0) + num_examples * metric["accuracy"]
            losses[dataset_name] = losses.get(dataset_name, 0) + num_examples * metric["loss"]

            # Handle optional metrics
            if "pre_train_acc" in metric:
                pre_train_accuracies[dataset_name] = pre_train_accuracies.get(dataset_name, 0) + num_examples * metric["pre_train_acc"]
            if "pre_train_loss" in metric:
                pre_train_losses[dataset_name] = pre_train_losses.get(dataset_name, 0) + num_examples * metric["pre_train_loss"]

        # Normalize by the total number of examples
        accuracies = {k: v / totals[k] for k, v in accuracies.items()}
        losses = {k: v / totals[k] for k, v in losses.items()}

        aggregated_metrics = {"accuracy": accuracies, "loss": losses}

        # Include optional metrics
        if pre_train_accuracies:
            aggregated_metrics["pre_train_accuracy"] = {k: v / totals[k] for k, v in pre_train_accuracies.items()}
        if pre_train_losses:
            aggregated_metrics["pre_train_loss"] = {k: v / totals[k] for k, v in pre_train_losses.items()}

        return aggregated_metrics

    return weighted_average

class Servidor(fl.server.strategy.FedAvg):

    def __init__(self, num_clients):
        self.num_clients     = num_clients
        print("Initializing Servidor...")

        super().__init__(
            fraction_fit=FRACTION_FIT,
            fraction_evaluate=0.0,
            min_available_clients=num_clients,
            on_fit_config_fn=get_on_fit_config,
            evaluate_metrics_aggregation_fn=get_metrics_aggregation_fn(),
            fit_metrics_aggregation_fn=get_metrics_aggregation_fn()
        )

8. Simulação Flower

CLIENT_TYPES = ["FedAvg", "FedBN"]

save_path = Path("./results")

save_path.mkdir(exist_ok=True)

client_data_loaders = get_data()

strategy = Servidor(NUM_CLIENTES)

for client_type in CLIENT_TYPES:

    client_fn = gen_client_fn(
        client_data=client_data_loaders,
        learning_rate=LEARNING_RATE,
        client_type=client_type
    )

    history = fl.simulation.start_simulation(
            client_fn=client_fn,
            num_clients=NUM_CLIENTES,
            client_resources={
                "num_cpus": 2.0,
                "num_gpus": 0.0
            },
            config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
            strategy=strategy,
        )

    data = {"history": history}

    print(data)

    history_path = f"{str(save_path)}/history_{client_type}.pkl"
    with open(history_path, "wb") as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

9. Visualização dos Resultados

def save_fig(name, fig):
    """Save matplotlib plot."""
    fig.savefig(
        name,
        dpi=None,
        facecolor=fig.get_facecolor(),
        edgecolor="none",
        orientation="portrait",
        format="png",
        transparent=False,
        bbox_inches="tight",
        pad_inches=0.2,
        metadata=None,
    )

# Read and process pickle files
def _read_pickle(path_to_pickle):
    with open(path_to_pickle, "rb") as handle:
        data = pickle.load(handle)
    return data

# Aggregate losses and accuracy by dataset
def _fuse_by_dataset(losses):
    fussed_losses = {}
    losses[0][1].keys()

    for _, loss_dict in losses:
        for k, val in loss_dict.items():
            if k in fussed_losses:
                fussed_losses[k].append(val)
            else:
                fussed_losses[k] = [val]
    return fussed_losses

# Plot final results 
def quick_plot(pickle_paths, metric):
    """Plot training loss for each dataset, comparing FedAvg and FedBN."""

    if metric == "pre_train_loss":
        y_variable = "Train Loss"
    if metric == "pre_train_accuracy":
        y_variable = "Train Acurracy"

    # Read and process pickle files
    data_dict = {}
    for path in pickle_paths:
        data = _read_pickle(path)
        losses = data["history"].metrics_distributed_fit[metric]
        model_name = Path(path).stem.split("_")[1]  # Extract FedAvg or FedBN from filename
        data_dict[model_name] = _fuse_by_dataset(losses)

    # Validate that both pickle files contain the same datasets
    datasets = sorted(list(data_dict["FedAvg"].keys()))
    assert datasets == sorted(list(data_dict["FedBN"].keys())), "Datasets must match in both pickle files"

    # Number of datasets
    num_datasets = len(datasets)

    # Create a single row of subplots with one column per dataset
    fig, axes = plt.subplots(1, num_datasets, figsize=(18, 5), sharey=True)

    # Flatten axes array for easier indexing
    if num_datasets == 1:
        axes = [axes]  # Ensure axes is iterable when there's only one dataset

    for i, (ax, dataset) in enumerate(zip(axes, datasets)):
        # Plot FedAvg
        ax.plot(range(len(data_dict["FedAvg"][dataset])), data_dict["FedAvg"][dataset], label="FedAvg", color="blue")

        # Plot FedBN
        ax.plot(range(len(data_dict["FedBN"][dataset])), data_dict["FedBN"][dataset], label="FedBN", color="orange")

        # Configure the plot
        ax.set_title(f"{dataset}")
        ax.grid(True)
        ax.set_xlabel("Round")
        if i == 0:  # Add Y-axis label only for the first subplot
            ax.set_ylabel(y_variable)
        ax.legend()

    # Adjust layout and save the figure
    plt.tight_layout()
    save_fig(f"./results/comparison_{metric}_by_dataset.png", fig)
    plt.show()
metrics = ["pre_train_loss", "pre_train_accuracy"]

for metric in metrics :
    quick_plot(["./results/history_FedAvg.pkl", "./results/history_FedBN.pkl"], metric=metric)

Train Loss :

Batch Normalization

Train Accuracy :

Batch Normalization