FedBN : Federated Learning on Non-IID Features via Local Batch Normalization
Seminário Prático - MO809A (Aprendizado Federado)
Nome : Leonardo dos Santos Marcondes
RA : 291206
Dataset : Link para Download do Dataset
Nootebook : Link para Download do Notebook
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install flower
1. O problema dos dados Non-IID :
Dados non-iid (ou non-independent and identically distributed) são dados que não seguem o pressuposto comum de serem independentes e identicamente distribuídos (IID).
- Independentes : Cada ponto de dado é independente dos outros, ou seja, não há relação entre os dados.
- Identicamente distribuídos : Todos os pontos de dados vêm da mesma distribuição estatística.
Os dados non-IID causam baixa convergência para algoritmos de ML e alto consumo de energia, aumentando também a largura de banda.
2. Batch Normalization e Dados Non-IID
Batch Normalization (BN) é uma técnica fundamental no treinamento de redes neurais profundas, introduzida por Ioffe e Szegedy em 2015, que visa acelerar e estabilizar o processo de treinamento.
A arquitetura de aprendizado federado proposta, FedBN, realiza atualizações locais e calcula a média dos modelos locais, excluindo os parâmetros de batch normalization do processo de média. Como veremos, essa modificação simples proporciona melhorias significativas em cenários com dados não-IID.
3. Carregamento e Visualização dos Dados
Serão utilizados 5 datasets de imagens de digitos (0-9) para comparar o efeito da normalização em batchs no desempenho de sistemas de aprendizado federado que sofram com heterogenicidade de dados non-IID. Esses datasets são :
- MNIST : Digitos escritos a mão.
- MNIST-M : Digitos do MNIST aleatóriamente colorido.
- SVHN : Número de casas coletados pelo Street View.
- USPS : Digitos escritos a mão de envelopes do serviço postal dos Estados Unidos.
- SynthDigits : Digitos gerados sintéticamente a partir da fonte TM do Windowns, variando o nível de blur, orientação e cor do traçado.
A seguir são apresentados a tabela que contém as características de cada um desses datasets e o perfil dos dados de cada um desses datasets.
import matplotlib.pyplot as plt
from random import sample
import seaborn as sns
import joblib
NUM_PARTITIONS = 10
DATASETS = ["MNIST", "MNIST_M", "SVHN", "SynthDigits", "USPS"]
BATCH_SIZE = 32
def display_samples(x_train, y_train, dataset):
sns.set(style="whitegrid")
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i in range(10):
# Exibir a imagem em uma posição correspondente à sua classe
ax = axes[i // 5, i % 5]
ax.imshow(x_train[i], cmap='gray')
ax.axis('off')
ax.set_title(f'Classe {y_train[i]}')
plt.suptitle(f"Exemplo de Classes do {dataset}")
plt.tight_layout()
plt.show()
for dataset in DATASETS :
for partition in range(NUM_PARTITIONS) :
x_train, y_train = joblib.load(f'./data/{dataset}/partitions/train_part{partition}.pkl')
if partition == 0 :
indices = sample(range(0, len(x_train)), 10)
display_samples(x_train=x_train[indices], y_train=y_train[indices], dataset=dataset)
Classe do Dataset :
from torch.utils.data import DataLoader, Dataset
from typing import List, Optional, Tuple
from PIL import Image # Importar a biblioteca PIL para manipular as imagens.
import numpy as np
import os
class DigitsDataset(Dataset):
"""Split datasets."""
total_partitions: int = 10
def __init__( # pylint: disable=too-many-arguments
self,
data_path: str,
channels: int,
train: bool,
partitions: Optional[List[int]] = None,
transform=None,
):
if train and partitions is not None:
# Construct dataset by loading one or more partitions
self.images, self.labels = joblib.load( # Using joblib to load .pkl files
os.path.join(data_path, f"partitions/train_part{partitions[0]}.pkl")
)
for part in partitions[1:]:
images, labels = joblib.load(
os.path.join(data_path, f"partitions/train_part{part}.pkl")
)
self.images = np.concatenate([self.images, images], axis=0)
self.labels = np.concatenate([self.labels, labels], axis=0)
else:
self.images, self.labels = joblib.load(
os.path.join(data_path, "test.pkl")
)
self.transform = transform
self.channels = channels
self.labels = self.labels.squeeze()
def __len__(self) -> int:
"""Return number of images."""
return self.images.shape[0]
def __getitem__(self, idx):
"""Return a transformed example of the dataset."""
image = self.images[idx]
label = self.labels[idx]
if self.channels == 1:
image = Image.fromarray(image, mode="L")
elif self.channels == 3:
image = Image.fromarray(image, mode="RGB")
else:
raise ValueError(f"{self.channels} channel is not allowed.")
if self.transform is not None:
image = self.transform(image)
return image, label
Função para carregar as partições e pre-processar as imagens :
from torchvision import transforms
def load_partition(
dataset: str, path_to_data: str, partition_indx: List[int], batch_size: int
) -> Tuple[DataLoader, DataLoader]:
data_path = os.path.join(path_to_data, dataset)
if dataset == "MNIST":
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Média, Desvio Padrão de cada canal
])
trainset = DigitsDataset(data_path, channels=1, partitions=partition_indx, train=True, transform=transform)
testset = DigitsDataset(data_path, channels=1, train=False, transform=transform)
elif dataset == "SVHN":
transform = transforms.Compose([
transforms.Resize([28, 28]), # Redimensiona 32x32 para 28x28
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = DigitsDataset(data_path, channels=3, partitions=partition_indx, train=True, transform=transform)
testset = DigitsDataset(data_path, channels=3, train=False, transform=transform)
elif dataset == "USPS":
transform = transforms.Compose([
transforms.Resize([28, 28]), # Redimensiona 16x16 para 28x28
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = DigitsDataset(data_path, channels=1, partitions=partition_indx, train=True, transform=transform)
testset = DigitsDataset(data_path, channels=1, train=False, transform=transform)
elif dataset == "SynthDigits":
transform = transforms.Compose([
transforms.Resize([28, 28]), # Redimensiona 32x32 para 28x28
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = DigitsDataset(data_path, channels=3, partitions=partition_indx, train=True, transform=transform)
testset = DigitsDataset(data_path, channels=3, train=False, transform=transform)
elif dataset == "MNIST_M":
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = DigitsDataset(data_path, channels=3, partitions=partition_indx, train=True, transform=transform)
testset = DigitsDataset(data_path, channels=3, train=False, transform=transform)
else:
raise NotImplementedError(f"dataset: {dataset} is not available")
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)
return trainloader, testloader
Função para obter os dataloaders :
from random import shuffle
def get_data() :
client_data = []
for dataset_name in DATASETS :
parts = list(range(NUM_PARTITIONS))
shuffle(parts) # Seleciona uma partição aleatória
parts_for_client = parts[0:1]
trainloader, testloader = load_partition(
dataset_name,
path_to_data="./data",
partition_indx=parts_for_client,
batch_size=BATCH_SIZE,
)
client_data.append((trainloader, testloader, dataset_name))
return client_data
4. Recursos Computacionais e Configurações
Não é obrigatório uma unidade de processamento gráfico (GPU) para rodar versões reduzidas do modelo (limitando os clientes a usarem 10% de cada dataset). Qualquer máquina com no mínimo 8 cores de CPU é capaz de executar 100 rodadas de comunicação do FedAvg e FedBN. Os hyperparametros utilizados para o treinamento são :
NUM_CLIENTES = 5 # Número de Clientes
NUM_ROUNDS = 10 # Número de rodadas de comunicação
LEARNING_RATE = 0.01 # Learning Rate do Modelo
FRACTION_FIT = 1.0 # Fração de Clientes que participam a cada rodada
5. Definição do Modelo e Funções de Treinamento e Teste
import torch
from torch import nn # Neural Network module
from typing import Tuple
class CNNModel(nn.Module):
"""CNN model proposed in the FedBN paper."""
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 5, 1, 2) # 3 canais de entrada, 64 canais de saída, 5x5 kernel, 1 stride, 2 padding
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 64, 5, 1, 2)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, 5, 1, 2)
self.bn3 = nn.BatchNorm2d(128)
self.fc1 = nn.Linear(6272, 2048)
self.bn4 = nn.BatchNorm1d(2048)
self.fc2 = nn.Linear(2048, 512)
self.bn5 = nn.BatchNorm1d(512)
self.fc3 = nn.Linear(512, num_classes)
def forward(self, x):
"""Forward pass."""
x = nn.functional.relu(self.bn1(self.conv1(x))) # 64 x 28 x 28
x = nn.functional.max_pool2d(x, 2) # Pooling de tamanho 2x2 reduz a dimensionalidade espacial pela metade
x = nn.functional.relu(self.bn2(self.conv2(x))) # 64 x 14 x 14
x = nn.functional.max_pool2d(x, 2) # 64 x 7 x 7
x = nn.functional.relu(self.bn3(self.conv3(x))) # 128 x 7 x 7
x = x.view(x.shape[0], -1) # 6272
x = self.fc1(x) # 2048
x = self.bn4(x)
x = nn.functional.relu(x)
x = self.fc2(x) # 512
x = self.bn5(x)
x = nn.functional.relu(x)
x = self.fc3(x) # 10
return x
Função de Treinamento da Rede Neural :
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Verifica se existe uma GPU disponível
def train(model, traindata, epochs, l_r) -> Tuple[float, float]:
"""Train the network."""
# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=l_r)
# Train the network
model.to(DEVICE) # Move o modelo para GPU se disponível
model.train()
total_loss = 0.0
for _ in range(epochs): # loop over the dataset multiple times
total = 0.0
correct = 0
for _i, data in enumerate(traindata, 0):
images, labels = data[0].to(DEVICE), data[1].to(DEVICE) # Move os tensores de entrada e rótulos para GPU
# zero the parameter gradients
optimizer.zero_grad()
outputs = model(images) # Foward pass
loss = criterion(outputs, labels) # Calculate loss
loss.backward() # Calcula o gradiente dos pesos em relação a loss
optimizer.step() # Atualiza os pesos do modelo
# print statistics
total_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return total_loss / len(traindata), correct / total
Função de Teste da Rede Neural :
def test(model, testdata) -> Tuple[float, float]:
"""Validate the network on the entire test set."""
# Define loss and metrics
criterion = nn.CrossEntropyLoss()
correct = 0
total = 0
loss = 0.0
# Evaluate the network
model.to(DEVICE)
model.eval() # Coloca o modelo no modo de avaliação
with torch.no_grad():
for data in testdata:
images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
outputs = model(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
loss = loss / len(testdata)
return loss, accuracy
6. Configuração dos Clientes
Classe de Clientes FedAvg :
from torch.utils.data import DataLoader
import torch
from collections import OrderedDict
from typing import Dict
import flwr as fl
from flwr.common.typing import NDArrays, Scalar
from pathlib import Path
import pickle
class FlowerClient(fl.client.NumPyClient):
def __init__(
self,
model: CNNModel,
trainloader: DataLoader,
testloader: DataLoader,
l_r: float,
dataset_name: str
):
self.model = model
self.trainloader = trainloader
self.testloader = testloader
self.dataset_name = dataset_name
self.l_r = l_r
def get_parameters(self, config) -> NDArrays:
"""Return model parameters as a list of NumPy ndarrays w or w/o.
using BNlayers.
"""
# Return all model parameters as a list of NumPy ndarrays
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def set_parameters(self, parameters: NDArrays) -> None:
"""Set model parameters from a list of NumPy ndarrays.
using BNlayers.
"""
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=True)
def fit(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
"""Set model parameters, train model, return updated model parameters."""
self.set_parameters(parameters)
# Evaluate the state of the global model on the train set;
pre_train_loss, pre_train_acc = test(
self.model, self.trainloader
)
# Train model on local dataset
loss, acc = train(
self.model,
self.trainloader,
epochs=1,
l_r=self.l_r
)
# Construct metrics to return to server
fl_round = config["round"]
metrics = {
"dataset_name": self.dataset_name,
"round": fl_round,
"accuracy": acc,
"loss": loss,
"pre_train_loss": pre_train_loss,
"pre_train_acc": pre_train_acc,
}
return (
self.get_parameters({}),
len(self.trainloader.dataset),
metrics,
)
def evaluate(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[float, int, Dict]:
"""Set model parameters, evaluate model on local test dataset, return result."""
self.set_parameters(parameters)
loss, accuracy = test(self.model, self.testloader)
return (
float(loss),
len(self.testloader.dataset),
{"loss": loss, "accuracy": accuracy, "dataset_name": self.dataset_name},
)
Classe de Clientes FedBN :
class FedBNFlowerClient(FlowerClient):
"""Similar to FlowerClient but this is used by FedBN clients."""
def __init__(self, save_path: Path, client_id: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# For FedBN clients we need to persist the state of the BN
# layers across rounds. In Simulation clients are statess
# so everything not communicated to the server (as it is the
# case as with params in BN layers of FedBN clients) is lost
# once a client completes its training. An upcoming version of
# Flower suports stateful clients
bn_state_dir = save_path / "bn_states"
bn_state_dir.mkdir(exist_ok=True)
self.bn_state_pkl = bn_state_dir / f"client_{client_id}.pkl"
def _save_bn_statedict(self) -> None:
"""Save contents of state_dict related to BN layers."""
bn_state = {
name: val.cpu().numpy()
for name, val in self.model.state_dict().items()
if "bn" in name
}
with open(self.bn_state_pkl, "wb") as handle:
pickle.dump(bn_state, handle, protocol=pickle.HIGHEST_PROTOCOL)
def _load_bn_statedict(self) -> Dict[str, torch.tensor]:
"""Load pickle with BN state_dict and return as dict."""
with open(self.bn_state_pkl, "rb") as handle:
data = pickle.load(handle)
bn_stae_dict = {k: torch.tensor(v) for k, v in data.items()}
return bn_stae_dict
def get_parameters(self, config) -> NDArrays:
"""Return model parameters as a list of NumPy ndarrays w or w/o using BN.
layers.
"""
# First update bn_state_dir
self._save_bn_statedict()
# Excluding parameters of BN layers when using FedBN
parameters = [
val.cpu().numpy()
for name, val in self.model.state_dict().items()
if "bn" not in name
]
return parameters
def set_parameters(self, parameters: NDArrays) -> None:
"""Set model parameters from a list of NumPy ndarrays Exclude the bn layer if.
available.
"""
keys = [k for k in self.model.state_dict().keys() if "bn" not in k]
params_dict = zip(keys, parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=False)
# Now also load from bn_state_dir
if self.bn_state_pkl.exists(): # It won't exist in the first round
bn_state_dict = self._load_bn_statedict()
self.model.load_state_dict(bn_state_dict, strict=False)
Configuração dos Clientes :
def gen_client_fn(client_data, learning_rate, client_type) :
"""Return a function that will be called to instantiate the cid-th client."""
def client_fn(cid :str) -> FlowerClient :
"""Create a Flower client representing a single organization."""
net = CNNModel()
# Note: each client gets a different trainloader/valloader, so each client
# will train and evaluate on their own unique data
trainloader, valloader, dataset_name = client_data[int(cid)]
flower_client = None
if client_type == "FedBN":
flower_client = FedBNFlowerClient(
save_path=Path("./"),
client_id=int(cid),
model=net,
trainloader=trainloader,
testloader=valloader,
l_r=learning_rate,
dataset_name=dataset_name
)
elif client_type == "FedAvg":
flower_client = FlowerClient(
model=net,
trainloader=trainloader,
testloader=valloader,
l_r=learning_rate,
dataset_name=dataset_name
)
return flower_client
return client_fn
7. Configuração do Servidor
from typing import Dict, List, Tuple
from flwr.common.typing import Metrics
def get_on_fit_config(server_round: int) -> Dict[str, int]:
"""Return a config (a dict) to be sent to clients during fit()."""
fit_config = {"round": server_round} # Add round info
return fit_config
def get_metrics_aggregation_fn():
"""Return a function that computes the weighted average of metrics."""
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
"""Compute the weighted average of metrics."""
# Initialize dictionaries to store totals and weighted sums
totals: Dict[str, int] = {}
accuracies: Dict[str, float] = {}
losses: Dict[str, float] = {}
pre_train_accuracies: Dict[str, float] = {}
pre_train_losses: Dict[str, float] = {}
# Aggregate metrics from all clients
for num_examples, metric in metrics:
dataset_name = metric.get("dataset_name", "default")
# Update totals and weighted sums
totals[dataset_name] = totals.get(dataset_name, 0) + num_examples
accuracies[dataset_name] = accuracies.get(dataset_name, 0) + num_examples * metric["accuracy"]
losses[dataset_name] = losses.get(dataset_name, 0) + num_examples * metric["loss"]
# Handle optional metrics
if "pre_train_acc" in metric:
pre_train_accuracies[dataset_name] = pre_train_accuracies.get(dataset_name, 0) + num_examples * metric["pre_train_acc"]
if "pre_train_loss" in metric:
pre_train_losses[dataset_name] = pre_train_losses.get(dataset_name, 0) + num_examples * metric["pre_train_loss"]
# Normalize by the total number of examples
accuracies = {k: v / totals[k] for k, v in accuracies.items()}
losses = {k: v / totals[k] for k, v in losses.items()}
aggregated_metrics = {"accuracy": accuracies, "loss": losses}
# Include optional metrics
if pre_train_accuracies:
aggregated_metrics["pre_train_accuracy"] = {k: v / totals[k] for k, v in pre_train_accuracies.items()}
if pre_train_losses:
aggregated_metrics["pre_train_loss"] = {k: v / totals[k] for k, v in pre_train_losses.items()}
return aggregated_metrics
return weighted_average
class Servidor(fl.server.strategy.FedAvg):
def __init__(self, num_clients):
self.num_clients = num_clients
print("Initializing Servidor...")
super().__init__(
fraction_fit=FRACTION_FIT,
fraction_evaluate=0.0,
min_available_clients=num_clients,
on_fit_config_fn=get_on_fit_config,
evaluate_metrics_aggregation_fn=get_metrics_aggregation_fn(),
fit_metrics_aggregation_fn=get_metrics_aggregation_fn()
)
8. Simulação Flower
CLIENT_TYPES = ["FedAvg", "FedBN"]
save_path = Path("./results")
save_path.mkdir(exist_ok=True)
client_data_loaders = get_data()
strategy = Servidor(NUM_CLIENTES)
for client_type in CLIENT_TYPES:
client_fn = gen_client_fn(
client_data=client_data_loaders,
learning_rate=LEARNING_RATE,
client_type=client_type
)
history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTES,
client_resources={
"num_cpus": 2.0,
"num_gpus": 0.0
},
config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
strategy=strategy,
)
data = {"history": history}
print(data)
history_path = f"{str(save_path)}/history_{client_type}.pkl"
with open(history_path, "wb") as handle:
pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
9. Visualização dos Resultados
def save_fig(name, fig):
"""Save matplotlib plot."""
fig.savefig(
name,
dpi=None,
facecolor=fig.get_facecolor(),
edgecolor="none",
orientation="portrait",
format="png",
transparent=False,
bbox_inches="tight",
pad_inches=0.2,
metadata=None,
)
# Read and process pickle files
def _read_pickle(path_to_pickle):
with open(path_to_pickle, "rb") as handle:
data = pickle.load(handle)
return data
# Aggregate losses and accuracy by dataset
def _fuse_by_dataset(losses):
fussed_losses = {}
losses[0][1].keys()
for _, loss_dict in losses:
for k, val in loss_dict.items():
if k in fussed_losses:
fussed_losses[k].append(val)
else:
fussed_losses[k] = [val]
return fussed_losses
# Plot final results
def quick_plot(pickle_paths, metric):
"""Plot training loss for each dataset, comparing FedAvg and FedBN."""
if metric == "pre_train_loss":
y_variable = "Train Loss"
if metric == "pre_train_accuracy":
y_variable = "Train Acurracy"
# Read and process pickle files
data_dict = {}
for path in pickle_paths:
data = _read_pickle(path)
losses = data["history"].metrics_distributed_fit[metric]
model_name = Path(path).stem.split("_")[1] # Extract FedAvg or FedBN from filename
data_dict[model_name] = _fuse_by_dataset(losses)
# Validate that both pickle files contain the same datasets
datasets = sorted(list(data_dict["FedAvg"].keys()))
assert datasets == sorted(list(data_dict["FedBN"].keys())), "Datasets must match in both pickle files"
# Number of datasets
num_datasets = len(datasets)
# Create a single row of subplots with one column per dataset
fig, axes = plt.subplots(1, num_datasets, figsize=(18, 5), sharey=True)
# Flatten axes array for easier indexing
if num_datasets == 1:
axes = [axes] # Ensure axes is iterable when there's only one dataset
for i, (ax, dataset) in enumerate(zip(axes, datasets)):
# Plot FedAvg
ax.plot(range(len(data_dict["FedAvg"][dataset])), data_dict["FedAvg"][dataset], label="FedAvg", color="blue")
# Plot FedBN
ax.plot(range(len(data_dict["FedBN"][dataset])), data_dict["FedBN"][dataset], label="FedBN", color="orange")
# Configure the plot
ax.set_title(f"{dataset}")
ax.grid(True)
ax.set_xlabel("Round")
if i == 0: # Add Y-axis label only for the first subplot
ax.set_ylabel(y_variable)
ax.legend()
# Adjust layout and save the figure
plt.tight_layout()
save_fig(f"./results/comparison_{metric}_by_dataset.png", fig)
plt.show()
metrics = ["pre_train_loss", "pre_train_accuracy"]
for metric in metrics :
quick_plot(["./results/history_FedAvg.pkl", "./results/history_FedBN.pkl"], metric=metric)



