Parametric Feature Transfer: One-shot Federated Learning with Foundation Models
Resumo
O artigo apresenta o FedPFT, um método inovador para aprendizado federado de uma rodada que melhora a precisão e a eficiência de comunicação utilizando modelos fundamentais. No aprendizado federado de uma rodada, um modelo global é treinado com apenas uma rodada de comunicação entre os clientes, um método que geralmente sofre de menor precisão. O FedPFT usa modelos de mistura Gaussiana (GMMs) para extrair e compartilhar recursos paramétricos de cada cliente sem enviar dados reais. Essa abordagem não só melhora a precisão, mas também oferece garantias de privacidade diferencial.
O FedPFT usa modelos fundamentais pré-treinados para extrair características que são modeladas usando GMMs. Os clientes transmitem esses modelos paramétricos ao servidor, permitindo que ele gere recursos sintéticos e treine um classificador global.
Explicação do código
main.py
O arquivo main.py é responsável por iniciar o treinamento federado. Ele carrega as configurações do treinamento e inicia o treinamento federado.
# Prepare dataset
trainloaders, testloaders = instantiate(
cfg.dataset,
transform=cfg.model.transform,
image_input_size=cfg.model.image_input_size,
).get_loaders()
# Define clients
client_fn = generate_client_fn(
client_cfg=cfg.client,
trainloaders=trainloaders,
testloaders=testloaders,
feature_extractor=instantiate(cfg.model.feature_extractor),
num_classes=cfg.dataset.num_classes,
device=device,
)
# Setup strategy
strategy = instantiate(cfg.strategy)
# Start simulation
history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=cfg.num_clients,
config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
strategy=strategy,
client_resources={"num_cpus": cfg.num_cpus, "num_gpus": cfg.num_gpus},
)
dataset.py
O arquivo dataset.py é responsável por carregar o dataset de treinamento. Ele carrega o dataset e divide em partições para os clientes. É inicializado com os seguintes parâmetros:
dataset: Nome ou path do dataset a ser carregado do HuggingFace.num_clients: Número de clientes.batch_size: Tamanho dobatch.dirichlet_alpha: Parâmetro alpha da distribuição de Dirichlet.partition_by: Coluna do dataset a ser utilizada para particionar os dados.image_column_name: Nome da coluna que contém as imagens.transform: Transformação a ser aplicada nas imagens.image_input_size: Tamanho da imagem.seed: Semente para reprodução dos resultados.split_size: Tamanho do conjunto de treinamento.
A seguinte função é responsável por carregar o dataset e dividir em partições para os clientes:
def get_loaders(self):
"""Partition the datasets and return a list of dataloaders."""
partitioner = DirichletPartitioner(
num_partitions=self.num_clients,
partition_by=self.partition_by,
alpha=self.dirichlet_alpha,
min_partition_size=10,
self_balancing=True,
)
fds = FederatedDataset(
dataset=self.dataset, partitioners={"train": partitioner}
)
# Create train/val for each partition and wrap it into DataLoader
trainloaders, testloaders = [], []
for partition_id in range(self.num_clients):
partition = fds.load_partition(partition_id)
partition = partition.with_transform(self.apply_batch_transforms())
partition = partition.train_test_split(
train_size=self.split_size, seed=self.seed
)
trainloaders.append(
DataLoader(partition["train"], batch_size=self.batch_size)
)
testloaders.append(
DataLoader(partition["test"], batch_size=self.batch_size)
)
return trainloaders, testloaders
client.py
O arquivo client.py é responsável por treinar o modelo local. Ele recebe o modelo global e treina o modelo local. É inicializado com os seguintes parâmetros:
trainloaders: Dataloaders de treinamento.testloaders: Dataloaders de teste.feature_extractor: Modelo usado para extrair características do cliente.num_classes: Número de classes.device: Dispositivo de treinamento.
A seguinte função é responsável por treinar o modelo local:
def fit(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
"""Fit a GMM on features and return GMM parameters."""
# Extracting features
features, labels = extract_features(
dataloader=self.trainloader,
feature_extractor=self.feature_extractor,
device=self.device,
)
# Learning GMM
gmm_list = learn_gmm(
features=features,
labels=labels,
n_mixtures=int(config["n_mixtures"]),
cov_type=str(config["cov_type"]),
seed=int(config["seed"]),
tol=float(config["tol"]),
max_iter=int(config["max_iter"]),
)
# Reshaping GMM parameters into an NDArray
return [array for gmm in gmm_list for array in gmmparam_to_ndarrays(gmm)], 0, {}
A seguinte função é responsável por avaliar o modelo local:
def evaluate(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[float, int, Dict]:
"""Evaluate `classifier_head` on the test data."""
self.set_parameters(parameters)
loss, acc = test(
classifier_head=self.classifier_head,
dataloader=self.testloader,
feature_extractor=self.feature_extractor,
device=self.device,
)
return loss, len(self.testloader.dataset), {"accuracy": acc}
Os pontos de interesse são:
extract_features: Extrai características do modelo.
def extract_features(
dataloader: DataLoader, feature_extractor: torch.nn.Module, device: torch.device
) -> Tuple[NDArray, NDArray]:
"""Extract features and labels from images using feature extractor.
Parameters
----------
dataloader : DataLoader
Dataloader containing {'img': img, 'label': label}
dicts to be extracted.
feature_extractor : torch.nn.Module
Model for extracting features.
device : torch.device
Device for loading `feature_extractor`.
Returns
-------
features : NDArray
2D array containing features extracted from `feature_extractor`.
labels : NDArray
2D array containing labels of `features`.
"""
feature_extractor.to(device)
features, labels = [], []
for sample in dataloader:
batch_samples = sample["img"].to(device)
batch_label = sample["label"].to(device)
with torch.no_grad():
feature = feature_extractor(batch_samples)
features.append(feature.cpu().detach().numpy())
labels.append(batch_label.cpu().detach().numpy())
# reshape feauturs and labels into a single numpy array
features_np = np.concatenate(features, axis=0).astype("float64")
labels_np = np.concatenate(labels)
return features_np, labels_np
learn_gmm: Aprende um modelo de mistura Gaussiana.
def learn_gmm(
features: NDArray,
labels: NDArray,
n_mixtures: int,
cov_type: str,
seed: int,
tol: float = 1e-12,
max_iter: int = 1000,
) -> List[GMMParameters]:
"""Learn a list of 16-bits GMMs for each label.
Parameters
----------
features : NDArray
A 2-d array with size (n_samples, feature_dimension) containing
extracted features for all the samples.
labels : NDArray
An array with size (n_samples) containing labels associated for
each sample in `features`.
n_mixtures : int
Number of mixtures in each Gaussian Mixture.
cov_type : str
Covariance type of Gaussian Mixtures, e.g. spherical.
seed: int
Seed for learning and sampling from Gaussian Mixtures.
tol: float
Tolerance of Gaussian Mixtures.
max_iter: int
Number of maximum iterations to learn the Gaussian Mixtures.
Returns
-------
List[GMMParameters]
Returns a list containing the GMMParameters for each class.
"""
gmm_list = []
for label in np.unique(labels):
cond_features = features[label == labels]
if (
len(cond_features) > n_mixtures
): # number of samples should be larger than `n_mixtures`.
gmm = GaussianMixture(
n_components=n_mixtures,
covariance_type=cov_type,
random_state=seed,
tol=tol,
max_iter=max_iter,
)
gmm.fit(cond_features)
gmm_list.append(
GMMParameters(
label=np.array(label),
means=gmm.means_.astype("float16"),
weights=gmm.weights_.astype("float16"),
covariances=gmm.covariances_.astype("float16"),
num_samples=np.array(len(cond_features)),
)
)
return gmm_list
gmmparam_to_ndarrays: Converte os parâmetros do modelo de mistura Gaussiana em um array.
def gmmparam_to_ndarrays(gmm: GMMParameters) -> List[NDArray]:
"""Convert gmm object to NumPy ndarrays."""
return [gmm.label, gmm.means, gmm.weights, gmm.covariances, gmm.num_samples]
strategy.py
O arquivo strategy.py é responsável por definir a estratégia de treinamento. Ele é inicializado com os seguintes parâmetros:
num_classes: Número de classes.feature_dimension: Dimensão das características.server_opt: Otimizador do servidor.server_batch_size: Tamanho do lote do servidor.num_epochs: Número de épocas.
A seguinte função é responsável por agregar os resultados de treinamento de vários clientes em um servidor central. Ela pode ser dividida em duas partes:
- Verifica se houve falhas durante o treinamento dos clientes.
- Constroi um dataset sintético a partir dos modelos de mistura Gaussiana (GMMs) dos clientes.
- Treina um classificador com o dataset sintético.
- Envia o classificador treinado para os clientes.
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Learn a classifier head by generating samples from the GMMs."""
# Do not aggregate if there are failures.
if not self.accept_failures and failures:
raise Exception("there are failures and failures are not accepted")
assert self.on_fit_config_fn is not None
config = self.on_fit_config_fn(server_round)
# Sample from the GMMs to create synthetic feature dataset
synthetic_features_dataset: List[Union[Dict, Tuple]] = []
for _, fit_res in results:
# Convert byte parameters into ndarrays and GMMParameters
ndarray = parameters_to_ndarrays(fit_res.parameters)
all_gmm_parameters = [
ndarrays_to_gmmparam(array) for array in chunks(ndarray, 5)
]
# Sample from GMM_label pairs to create synthetic features
for gmm_parameter in all_gmm_parameters:
gmm = GMM(
n_components=int(config["n_mixtures"]),
covariance_type=config["cov_type"],
random_state=int(config["seed"]),
tol=float(config["tol"]),
max_iter=int(config["max_iter"]),
)
# Set values of the GMMs
gmm.means_ = gmm_parameter.means.astype("float32")
gmm.weights_ = gmm_parameter.weights.astype("float32")
gmm.covariances_ = gmm_parameter.covariances.astype("float32")
# Sample features
syn_features, _ = gmm.sample(gmm_parameter.num_samples)
syn_features = torch.tensor(syn_features, dtype=torch.float32)
gmm_labels = torch.tensor(
[int(gmm_parameter.label)] * int(gmm_parameter.num_samples)
)
# Add to train data
synthetic_features_dataset += list(zip(syn_features, gmm_labels))
# Train a classifier head
synthetic_features_dataset = [
{"img": img, "label": label} for img, label in synthetic_features_dataset
]
synthetic_loader = DataLoader(
synthetic_features_dataset,
batch_size=self.server_batch_size,
shuffle=True,
)
classifier_head = torch.nn.Linear(self.feature_dimension, self.num_classes)
opt = torch.optim.AdamW(
params=classifier_head.parameters(), lr=self.server_opt.lr
)
train(
classifier_head=classifier_head,
dataloader=synthetic_loader,
device=self.device,
num_epochs=self.num_epochs,
opt=opt,
verbose=True,
)
# Send the classifier head to clients
classifier_ndarray = [
val.cpu().numpy() for _, val in classifier_head.state_dict().items()
]
return ndarrays_to_parameters(classifier_ndarray), {}
server.py
O arquivo server.py contém as funções auxiliares para o servidor, são elas:
fedpft_get_on_fit_config_fn: Retorna a configuração de treinamento.weighted_average: Calcula a média ponderada dos resultados de treinamento dos clientes.
Colab
Install dependencies
!pip install flwr
!pip install numpy
!pip install omegaconf
!pip install --upgrade pip setuptools wheel
!pip install hydra-core --upgrade
!pip install torch
!pip install -U "flwr[simulation]"
!pip install torchvision
!pip install transformers
!pip install flwr_datasets
Default configuration
O funcionamento do treinamento é descrito pelas seguintes variáveis:
strategy: Estratégia de treinamento.client: Cliente de treinamento.model: Modelo de rede neural.dataset: Dataset de treinamento.num_clients: Número de clientes.dirichlet_alpha: Parâmetro alpha da distribuição de Dirichlet.num_rounds: Número de rodadas de treinamento.num_cpus: Número de CPUs.num_gpus: Número de GPUs.batch_size: Tamanho do lote.device: Dispositivo de treinamento.
Paridade entre modelo e dataset:
model=resnet50edataset=CIFAR100model=clipedataset=Caltech101
strategy:
_target_: fedpft.strategy.FedPFT
fraction_fit: 1
fraction_evaluate: 1
accept_failures: false
num_classes: ${dataset.num_classes}
feature_dimension: ${model.hidden_dimension}
device: ${device}
server_batch_size: 32
num_epochs: 50
server_opt:
lr: 0.0001
on_fit_config_fn:
_target_: fedpft.server.fedpft_get_on_fit_config_fn
n_mixtures: 1
cov_type: spherical
seed: 0
tol: 1.0e-12
max_iter: 10000
evaluate_metrics_aggregation_fn:
_target_: fedpft.server.weighted_average
_partial_: true
client:
_target_: fedpft.client.FedPFTClient
model:
feature_extractor:
_target_: fedpft.models.resnet50
transform:
_target_: fedpft.models.transform
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
image_input_size: 224
hidden_dimension: 2048
dataset:
_target_: fedpft.dataset.Dataset
name: cifar100
dataset: CIFAR100
num_classes: 100
image_column_name: img
partition_by: fine_label
num_clients: ${num_clients}
dirichlet_alpha: ${dirichlet_alpha}
batch_size: ${batch_size}
num_clients: 50
dirichlet_alpha: 0.1
num_rounds: 1
num_cpus: 2
num_gpus: 0
batch_size: 64
device: cpu
Run simulation
!python -m fedpft.main
Tempo de execução: 520 segundos.
FedPFT com CIFAR100
!python -m fedpft.main dataset=CIFAR100 model=resnet50
Tempo de execução: 540 segundos. A acurácia do modelo é de 0.521.
FedPFT com Caltech101
!python -m fedpft.main dataset=Caltech101 model=clip
Tempo de execução: 97 segundos. A acurácia do modelo é de 0.905.
FedAvg com CIFAR100
São realizados 10 rounds de treinamento com 1 época por rodada.
!python -m fedpft.main strategy=fedavg client=fedavg dataset=CIFAR100 model=resnet50 num_rounds=10 strategy.on_fit_config_fn.num_epochs=1
Tempo de execução: 3906 segundos. A acurácia do modelo é de 0.441.
FedAvg com Caltech101
São realizados 10 rounds de treinamento com 1 época por rodada.
!python -m fedpft.main strategy=fedavg client=fedavg dataset=Caltech101 model=clip num_rounds=10
Tempo de execução: 1328 segundos. A acurácia do modelo é de 0.815.
Resultados
O FedPFT obteve uma acurácia de 0.521 com CIFAR100 e 0.905 com Caltech101. O FedAvg obteve uma acurácia de 0.441 com CIFAR100 e 0.815 com Caltech101. O FedPFT obteve uma acurácia superior ao FedAvg em ambos os datasets.

