Apresentação Final Model-Contrastive Federated Learning
Moon
A aprendizagem federada permite que várias partes treinar colaborativamente um modelo de aprendizado de máquina sem comunicar seus dados locais. Um desafio fundamental na aprendizagem federada é lidar com a heterogeneidade da distribuição de dados locais distribuição entre as partes. Embora muitos estudos tenham sido propostas para enfrentar este desafio, descobrimos que elas falham para alcançar alto desempenho em conjuntos de dados de imagens com modelos de aprendizagem profunda. Neste artigo, propomos MOON: modelo-aprendizagem federada contrastiva. MOON é um simples e estrutura de aprendizagem federada eficaz. A ideia chave de MOON é utilizar a semelhança entre as representações do modelo para corrigir o treinamento local de partes individuais, isto é, conduzindo aprendizagem contrastiva em nível de modelo, superando assim outros algoritmos em várias tarefas de classificação de imagens.
Neste Artigo é utilizado o framework Flower
Flower Website
Github Moon
Instalação das bibliotecas
!pip install -q "flwr[simulation]" flwr-datasets
!pip install omegaconf
!pip install -U ray
!pip install --upgrade --index-url https://pypi.ngc.nvidia.com nvidia-tensorrt
Imports
Fazendo os imports da biliotecas necessárias e os logs das execuções
import logging
import os
import copy
import numpy as np
import torch
import random
import shutil
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from collections import OrderedDict
from typing import Callable, Dict, List, Tuple, Optional
import flwr as fl
from flwr.common.typing import NDArrays, Scalar
from omegaconf import DictConfig, OmegaConf
from pathlib import Path
import matplotlib.pyplot as plt
from flwr.server.history import History
log()
Método utilizado para geração de logs.
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
Output()
A seguir temos as funções de output para fazer o plot dos gráficos
def compute_accuracy(model, dataloader, device="cpu", multiloader=False):
"""Compute accuracy."""
was_training = False
if model.training:
model.eval()
was_training = True
true_labels_list, pred_labels_list = np.array([]), np.array([])
correct, total = 0, 0
if device == "cpu":
criterion = nn.CrossEntropyLoss()
elif "cuda" in device.type:
criterion = nn.CrossEntropyLoss().cuda()
loss_collector = []
if multiloader:
for loader in dataloader:
with torch.no_grad():
for _, (x, target) in enumerate(loader):
if device != "cpu":
x, target = x.cuda(), target.to(dtype=torch.int64).cuda()
_, _, out = model(x)
if len(target) == 1:
loss = criterion(out, target)
else:
loss = criterion(out, target)
_, pred_label = torch.max(out.data, 1)
loss_collector.append(loss.item())
total += x.data.size()[0]
correct += (pred_label == target.data).sum().item()
if device == "cpu":
pred_labels_list = np.append(
pred_labels_list, pred_label.numpy()
)
true_labels_list = np.append(
true_labels_list, target.data.numpy()
)
else:
pred_labels_list = np.append(
pred_labels_list, pred_label.cpu().numpy()
)
true_labels_list = np.append(
true_labels_list, target.data.cpu().numpy()
)
avg_loss = sum(loss_collector) / len(loss_collector)
else:
with torch.no_grad():
for _, (x, target) in enumerate(dataloader):
# print("x:",x)
if device != "cpu":
x, target = x.cuda(), target.to(dtype=torch.int64).cuda()
_, _, out = model(x)
loss = criterion(out, target)
_, pred_label = torch.max(out.data, 1)
loss_collector.append(loss.item())
total += x.data.size()[0]
correct += (pred_label == target.data).sum().item()
if device == "cpu":
pred_labels_list = np.append(pred_labels_list, pred_label.numpy())
true_labels_list = np.append(true_labels_list, target.data.numpy())
else:
pred_labels_list = np.append(
pred_labels_list, pred_label.cpu().numpy()
)
true_labels_list = np.append(
true_labels_list, target.data.cpu().numpy()
)
avg_loss = sum(loss_collector) / len(loss_collector)
if was_training:
model.train()
return correct / float(total), avg_loss
def plot_metric_from_history(
hist: History,
save_plot_path: Path,
suffix: Optional[str] = "",
) -> None:
"""Plot data from Flower server History.
Parameters
----------
hist : History
Object containing evaluation for all rounds.
save_plot_path : Path
Folder to save the plot to.
suffix: Optional[str]
Optional string to add at the end of the filename for the plot.
"""
metric_type = "centralized"
metric_dict = (
hist.metrics_centralized
if metric_type == "centralized"
else hist.metrics_distributed
)
rounds, values = zip(*metric_dict["accuracy"])
# Plot the curve
plt.figure(figsize=(10, 6))
plt.plot(rounds, values)
plt.xlabel("#round")
plt.ylabel("Test accuracy")
plt.legend()
plt.show()
plt.savefig(Path(save_plot_path) / Path(f"{metric_type}_metrics{suffix}.png"))
plt.close()
create_model()
Esse método cria o modelo que será utilizado no treinamento federado, é um modelo simples para o CIFAR10 que recebe como parâmetro o formato dos dados de entrada para criação da primeira camada do modelo.
class SimpleCNNHeader(nn.Module):
"""Simple CNN model."""
def __init__(self, input_dim, hidden_dims):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(input_dim, hidden_dims[0])
self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
def forward(self, x):
"""Forward."""
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
# x = self.fc3(x)
return x
class ModelMOON(nn.Module):
"""Model for MOON."""
def __init__(self, base_model, out_dim, n_classes):
super().__init__()
base_model == "simple-cnn"
self.features = SimpleCNNHeader(
input_dim=(16 * 5 * 5), hidden_dims=[120, 84]
)
num_ftrs = 84
# projection MLP
self.l1 = nn.Linear(num_ftrs, num_ftrs)
self.l2 = nn.Linear(num_ftrs, out_dim)
# last layer
self.l3 = nn.Linear(out_dim, n_classes)
def _get_basemodel(self, model_name):
try:
model = self.model_dict[model_name]
return model
except KeyError as err:
raise ValueError("Invalid model name.") from err
def forward(self, x):
"""Forward."""
h = self.features(x)
h = h.squeeze()
x = self.l1(h)
x = F.relu(x)
x = self.l2(x)
y = self.l3(x)
return h, x, y
def init_net(dataset, model, output_dim, device="cpu"):
"""Initialize model."""
if dataset == "cifar10":
n_classes = 10
net = ModelMOON(model, output_dim, n_classes)
if device == "cpu":
net.to(device)
else:
net = net.cuda()
return net
def train_moon(
net,
global_net,
previous_net,
train_dataloader,
epochs,
lr,
mu,
temperature,
device="cpu",
):
"""Training function for MOON."""
net.to(device)
global_net.to(device)
previous_net.to(device)
train_acc, _ = compute_accuracy(net, train_dataloader, device=device)
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, net.parameters()),
lr=lr,
momentum=0.9,
weight_decay=1e-5,
)
criterion = nn.CrossEntropyLoss().cuda()
previous_net.eval()
for param in previous_net.parameters():
param.requires_grad = False
previous_net.cuda()
cnt = 0
cos = torch.nn.CosineSimilarity(dim=-1)
for epoch in range(epochs):
epoch_loss_collector = []
epoch_loss1_collector = []
epoch_loss2_collector = []
for _, (x, target) in enumerate(train_dataloader):
x, target = x.to(device), target.to(device)
optimizer.zero_grad()
x.requires_grad = False
target.requires_grad = False
target = target.long()
# pro1 is the representation by the current model (Line 14 of Algorithm 1)
_, pro1, out = net(x)
# pro2 is the representation by the global model (Line 15 of Algorithm 1)
_, pro2, _ = global_net(x)
# posi is the positive pair
posi = cos(pro1, pro2)
logits = posi.reshape(-1, 1)
previous_net.to(device)
# pro 3 is the representation by the previous model (Line 16 of Algorithm 1)
_, pro3, _ = previous_net(x)
# nega is the negative pair
nega = cos(pro1, pro3)
logits = torch.cat((logits, nega.reshape(-1, 1)), dim=1)
previous_net.to("cpu")
logits /= temperature
labels = torch.zeros(x.size(0)).cuda().long()
# compute the model-contrastive loss (Line 17 of Algorithm 1)
loss2 = mu * criterion(logits, labels)
# compute the cross-entropy loss (Line 13 of Algorithm 1)
loss1 = criterion(out, target)
# compute the loss (Line 18 of Algorithm 1)
loss = loss1 + loss2
loss.backward()
optimizer.step()
cnt += 1
epoch_loss_collector.append(loss.item())
epoch_loss1_collector.append(loss1.item())
epoch_loss2_collector.append(loss2.item())
epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
epoch_loss1 = sum(epoch_loss1_collector) / len(epoch_loss1_collector)
epoch_loss2 = sum(epoch_loss2_collector) / len(epoch_loss2_collector)
print(
"Epoch: %d Loss: %f Loss1: %f Loss2: %f"
% (epoch, epoch_loss, epoch_loss1, epoch_loss2)
)
previous_net.to("cpu")
train_acc, _ = compute_accuracy(net, train_dataloader, device=device)
print(">> Training accuracy: %f" % train_acc)
net.to("cpu")
global_net.to("cpu")
print(" ** Training complete **")
return net
def test(net, test_dataloader, device="cpu"):
"""Test function."""
net.to(device)
test_acc, loss = compute_accuracy(net, test_dataloader, device=device)
print(">> Test accuracy: %f" % test_acc)
net.to("cpu")
return test_acc, loss
Implementação Cliente
Após entender o funcionamento do framework vamos implementar um treinamento federado com o MNIST para 50 clientes com particionamento IID e não-IID. Vamos iniciar a implementação com o código do cliente, o qual é uma classe que extende fl.client.NumPyClient. Nessa classe, os seguintes métodos serâo implementados:
- init: o construtor da classe para inicialização dos atributos necessários
- get_parameters: para coletar os pesos do modelo durante a inicialização
- load_data: para carregar os dados particionados de cada cliente
- create_model: para instânciar o modelo que será treinado:
- fit: para realizar o treinamento do modelo
- evaluate: para avaliar o modelo treinado
- log_client: para gerar os logs dos resultados
init()
class FlowerClient(fl.client.NumPyClient):
"""Standard Flower client for CNN training."""
def __init__(
self,
# net: torch.nn.Module,
net_id: int,
dataset: str,
model: str,
output_dim: int,
trainloader: DataLoader,
valloader: DataLoader,
device: torch.device,
num_epochs: int,
learning_rate: float,
mu: float,
temperature: float,
model_dir: str,
alg: str,
): # pylint: disable=too-many-arguments
self.net = init_net(dataset, model, output_dim)
self.net_id = net_id
self.dataset = dataset
self.model = model
self.output_dim = output_dim
self.trainloader = trainloader
self.valloader = valloader
self.device = device
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.mu = mu # pylint: disable=invalid-name
self.temperature = temperature
self.model_dir = model_dir
self.alg = alg
def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
"""Return the parameters of the current net."""
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
def set_parameters(self, parameters: NDArrays) -> None:
"""Change the parameters of the model using the given ones."""
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})
self.net.load_state_dict(state_dict, strict=True)
def fit(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
"""Implement distributed fit function for a given client."""
self.set_parameters(parameters)
prev_net = init_net(self.dataset, self.model, self.output_dim)
if not os.path.exists(os.path.join(self.model_dir, str(self.net_id))):
prev_net = copy.deepcopy(self.net)
else:
# load previous model from model_dir
prev_net.load_state_dict(
torch.load(
os.path.join(self.model_dir, str(self.net_id), "prev_net.pt")
)
)
global_net = init_net(self.dataset, self.model, self.output_dim)
global_net.load_state_dict(self.net.state_dict())
if self.alg == "moon":
train_moon(
self.net,
global_net,
prev_net,
self.trainloader,
self.num_epochs,
self.learning_rate,
self.mu,
self.temperature,
self.device,
)
if not os.path.exists(os.path.join(self.model_dir, str(self.net_id))):
os.makedirs(os.path.join(self.model_dir, str(self.net_id)))
torch.save(
self.net.state_dict(),
os.path.join(self.model_dir, str(self.net_id), "prev_net.pt"),
)
return self.get_parameters({}), len(self.trainloader), {"is_straggler": False}
def evaluate(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[float, int, Dict]:
"""Implement distributed evaluation for a given client."""
self.set_parameters(parameters)
# skip evaluation in the client-side
loss = 0.0
accuracy = 0.0
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
def gen_client_fn(
trainloaders: List[DataLoader],
testloaders: List[DataLoader],
cfg: DictConfig,
) -> Callable[[str], FlowerClient]:
"""Generate the client function that creates the Flower Clients."""
def client_fn(cid: str) -> FlowerClient:
"""Create a Flower client representing a single organization."""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trainloader = trainloaders[int(cid)]
testloader = testloaders[int(cid)]
return FlowerClient(
int(cid),
cfg.dataset,
cfg.modelname,
cfg.outputdim,
trainloader,
testloader,
device,
cfg.num_epochs,
cfg.learning_rate,
cfg.mu,
cfg.temperature,
cfg.modeldir,
cfg.alg,
)
return client_fn
get_parameters()
Esse método retorna os pesos do modelo, ele é importante pois é o primeiro método solicitado pela Strategy para ter os parâmetros iniciais para todos os clientes.
def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
"""Return the parameters of the current net."""
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
data_preparation()
O método a seguir faz a preparação dos dados do dataset CIFAR10.
def load_cifar10_data(datadir):
"""Load CIFAR10 dataset."""
transform = transforms.Compose([transforms.ToTensor()])
cifar10_train_ds = CIFAR10Sub(
datadir, train=True, download=True, transform=transform
)
cifar10_test_ds = CIFAR10Sub(
datadir, train=False, download=True, transform=transform
)
X_train, y_train = cifar10_train_ds.data, cifar10_train_ds.target
X_test, y_test = cifar10_test_ds.data, cifar10_test_ds.target
return (X_train, y_train, X_test, y_test)
data_partition()
Separa as partições de treino e teste para cada cliente baseado em seu identificador cid. Além disso, baseado no atributo self.niid o método decide se vai ser feito um particionamento IID ou não-IID nos dados. Ao final, o conjutno de treino e teste é retornado para cada cliente
def partition_data(dataset, datadir, partition, num_clients, beta):
"""Partition data into train and test sets for IID and non-IID experiments."""
if dataset == "cifar10":
X_train, y_train, X_test, y_test = load_cifar10_data(datadir)
n_train = y_train.shape[0]
if partition in ("homo", "iid"):
idxs = np.random.permutation(n_train)
batch_idxs = np.array_split(idxs, num_clients)
net_dataidx_map = {i: batch_idxs[i] for i in range(num_clients)}
elif partition in ("noniid-labeldir", "noniid"):
min_size = 0
min_require_size = 10
K = 10
N = y_train.shape[0]
net_dataidx_map = {}
while min_size < min_require_size:
idx_batch = [[] for _ in range(num_clients)]
for k in range(K):
idx_k = np.where(y_train == k)[0]
np.random.shuffle(idx_k)
proportions = np.random.dirichlet(np.repeat(beta, num_clients))
proportions = np.array(
[
p * (len(idx_j) < N / num_clients)
for p, idx_j in zip(proportions, idx_batch)
]
)
proportions = proportions / proportions.sum()
proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
idx_batch = [
idx_j + idx.tolist()
for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))
]
min_size = min([len(idx_j) for idx_j in idx_batch])
for j in range(num_clients):
np.random.shuffle(idx_batch[j])
net_dataidx_map[j] = idx_batch[j]
return (X_train, y_train, X_test, y_test, net_dataidx_map)
data_load()
Esse método carrega o CIFAR10 dataset.
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
class CIFAR10Sub(data.Dataset):
"""CIFAR-10 dataset with idxs."""
def __init__(
self,
root,
dataidxs=None,
train=True,
transform=None,
target_transform=None,
download=False,
):
self.root = root
self.dataidxs = dataidxs
self.train = train
self.transform = transform
self.target_transform = target_transform
self.download = download
self.data, self.target = self.__build_sub_dataset__()
def __build_sub_dataset__(self):
"""Build sub dataset given idxs."""
cifar_dataobj = CIFAR10(
self.root, self.train, self.transform, self.target_transform, self.download
)
if torchvision.__version__ == "0.2.1":
if self.train:
# pylint: disable=redefined-outer-name
data, target = cifar_dataobj.train_data, np.array(
cifar_dataobj.train_labels
)
else:
# pylint: disable=redefined-outer-name
data, target = cifar_dataobj.test_data, np.array(
cifar_dataobj.test_labels
)
else:
data = cifar_dataobj.data
target = np.array(cifar_dataobj.targets)
if self.dataidxs is not None:
data = data[self.dataidxs]
target = target[self.dataidxs]
return data, target
def __getitem__(self, index):
"""Get item by index.
Args:
index (int): Index.
Returns
-------
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.target[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
"""Length.
Returns
-------
int: length of data
"""
return len(self.data)
def __getitem__(self, index):
"""Get item by index.
Args:
index (int): Index.
Returns
-------
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.target[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
"""Length.
Returns
-------
int: length of data
"""
return len(self.data)
def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None, noise_level=0):
"""Get dataloader for a given dataset."""
if dataset == "cifar10":
dl_obj = CIFAR10Sub
normalize = transforms.Normalize(
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
)
transform_train = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(
lambda x: F.pad(
Variable(x.unsqueeze(0), requires_grad=False),
(4, 4, 4, 4),
mode="reflect",
).data.squeeze()
),
transforms.ToPILImage(),
transforms.ColorJitter(brightness=noise_level),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
# data prep for test set
transform_test = transforms.Compose([transforms.ToTensor(), normalize])
# data prep for test set
transform_test = transforms.Compose([transforms.ToTensor(), normalize])
if dataset == "cifar10" and os.path.isdir(
os.path.join(datadir, "cifar-10-batches-py")
):
download = False
train_ds = dl_obj(
datadir,
dataidxs=dataidxs,
train=True,
transform=transform_train,
download=download,
)
test_ds = dl_obj(datadir, train=False, transform=transform_test, download=download)
train_dl = data.DataLoader(
dataset=train_ds, batch_size=train_bs, drop_last=True, shuffle=True
)
test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False)
return train_dl, test_dl, train_ds, test_ds
fit()
O fit será feito somente para o MOON
def fit(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
"""Implement distributed fit function for a given client."""
self.set_parameters(parameters)
prev_net = init_net(self.dataset, self.model, self.output_dim)
if not os.path.exists(os.path.join(self.model_dir, str(self.net_id))):
prev_net = copy.deepcopy(self.net)
else:
# load previous model from model_dir
prev_net.load_state_dict(
torch.load(
os.path.join(self.model_dir, str(self.net_id), "prev_net.pt")
)
)
global_net = init_net(self.dataset, self.model, self.output_dim)
global_net.load_state_dict(self.net.state_dict())
if self.alg == "moon":
train_moon(
self.net,
global_net,
prev_net,
self.trainloader,
self.num_epochs,
self.learning_rate,
self.mu,
self.temperature,
self.device,
)
if not os.path.exists(os.path.join(self.model_dir, str(self.net_id))):
os.makedirs(os.path.join(self.model_dir, str(self.net_id)))
torch.save(
self.net.state_dict(),
os.path.join(self.model_dir, str(self.net_id), "prev_net.pt"),
)
return self.get_parameters({}), len(self.trainloader), {"is_straggler": False}
evaluate()
def evaluate(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[float, int, Dict]:
"""Implement distributed evaluation for a given client."""
self.set_parameters(parameters)
# skip evaluation in the client-side
loss = 0.0
accuracy = 0.0
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
Implementação Servidor
O método inicializa apenas atributos de controle.
def gen_evaluate_fn(
testloader: DataLoader,
device: torch.device,
cfg: DictConfig,
) -> Callable[
[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]
]:
"""Generate the function for centralized evaluation."""
def evaluate(
server_round: int, parameters_ndarrays: NDArrays, config: Dict[str, Scalar]
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
# pylint: disable=unused-argument
net = init_net(cfg.dataset.name, cfg.model.name, cfg.model.output_dim)
params_dict = zip(net.state_dict().keys(), parameters_ndarrays)
state_dict = OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
net.to(device)
accuracy, loss = test(net, testloader, device=device)
return loss, {"accuracy": accuracy}
return evaluate
Parâmetros de Configuração
class getArgsClass:
def __init__(self, num_clients=2, num_epochs=10, fraction_fit=1.0, batch_size=64,
learning_rate=0.01, mu=5, temperature=0.5, alg='moon', seed=0,
server_device='cuda', num_rounds=100, num_cpus=4, num_gpus=0.5,
dataset='cifar10', datadir='./data/moon/', dspartition='noniid',
dsbeta=0.5, modelname='simple-cnn', outputdim=256, modeldir='./client_states/moon/cifar10/'):
self.num_clients = num_clients
self.num_epochs = num_epochs
self.fraction_fit = fraction_fit
self.batch_size = batch_size
self.learning_rate = learning_rate
self.mu = mu
self.temperature = temperature
self.alg = alg
self.seed = seed
self.server_device = server_device
self.num_rounds = num_rounds
self.num_cpus = num_cpus
self.num_gpus = num_gpus
self.dataset = dataset
self.datadir = datadir
self.dspartition = dspartition
self.dsbeta = dsbeta
self.modelname = modelname
self.outputdim = outputdim
self.modeldir = modeldir
Executando Treinamento Federado
# @hydra.main(config_path="conf", config_name="cifar10", version_base=None)
def get_args():
cfg = getArgsClass()
return cfg
def main(cfg) -> None:
"""Run the baseline.
Parameters
----------
cfg : DictConfig
An omegaconf object that stores the hydra config.
"""
# Clean the model directory to save models for MOON
if cfg.alg == "moon":
if os.path.exists(cfg.modeldir):
shutil.rmtree(cfg.modeldir)
# 1. Print parsed config
#print(OmegaConf.to_yaml(args))
# 2. Prepare your dataset
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(cfg.seed)
random.seed(cfg.seed)
(
_,
_,
_,
_,
net_dataidx_map,
) = partition_data(
dataset=cfg.dataset,
datadir=cfg.datadir,
partition=cfg.dspartition,
num_clients=cfg.num_clients,
beta=cfg.dsbeta,
)
_, test_global_dl, _, _ = get_dataloader(
dataset=cfg.dataset,
datadir=cfg.datadir,
train_bs=cfg.batch_size,
test_bs=32,
)
trainloaders = []
testloaders = []
for idx in range(cfg.num_clients):
train_dl, test_dl, _, _ = get_dataloader(
cfg.dataset, cfg.datadir, cfg.batch_size, 32, net_dataidx_map[idx]
)
trainloaders.append(train_dl)
testloaders.append(test_dl)
# 3. Define your clients
# Define a function that returns another function that will be used during
# simulation to instantiate each individual client
client_fn = gen_client_fn(
trainloaders=trainloaders,
testloaders=testloaders,
cfg=cfg,
)
# get function that will executed by the strategy's evaluate() method
# Set server's device
device = (
torch.device("cuda:0")
if torch.cuda.is_available() and cfg.server_device == "cuda"
else "cpu"
)
evaluate_fn = gen_evaluate_fn(test_global_dl, device=device, cfg=cfg)
# 4. Define your strategy
strategy = fl.server.strategy.FedAvg(
# Clients in MOON do not perform federated evaluation
# (see the client's evaluate())
fraction_fit=cfg.fraction_fit,
fraction_evaluate=0.0,
evaluate_fn=evaluate_fn,
)
# 5. Start Simulation
# history = fl.simulation.start_simulation(<arguments for simulation>)
history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=cfg.num_clients,
config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
client_resources={
"num_cpus": cfg.num_cpus,
"num_gpus": cfg.num_gpus,
},
strategy=strategy,
)
# remove saved models
if cfg.alg == "moon":
shutil.rmtree(cfg.modeldir)
# 6. Save your results
# Experiment completed. Now we save the results and
# generate plots using the `history`
print("................")
print(history)
# Hydra automatically creates an output directory
# Let's retrieve it and save some results there
# save_path = HydraConfig.get().runtime.output_dir
save_path = './output'
# plot results and include them in the readme
strategy_name = strategy.__class__.__name__
file_suffix: str = (
f"_{strategy_name}"
f"{'_dataset' if cfg.dataset.name else ''}"
f"_C={cfg.num_clients}"
f"_B={cfg.batch_size}"
f"_E={cfg.num_epochs}"
f"_R={cfg.num_rounds}"
f"_mu={cfg.mu}"
)
plot_metric_from_history(
history,
Path(save_path),
(file_suffix),
)
if __name__ == "__main__":
cfg = get_args()
main(cfg)