Skip to content

Parametric Feature Transfer: One-shot Federated Learning with Foundation Models

Resumo

O artigo apresenta o FedPFT, um método inovador para aprendizado federado de uma rodada que melhora a precisão e a eficiência de comunicação utilizando modelos fundamentais. No aprendizado federado de uma rodada, um modelo global é treinado com apenas uma rodada de comunicação entre os clientes, um método que geralmente sofre de menor precisão. O FedPFT usa modelos de mistura Gaussiana (GMMs) para extrair e compartilhar recursos paramétricos de cada cliente sem enviar dados reais. Essa abordagem não só melhora a precisão, mas também oferece garantias de privacidade diferencial.

Diagrama

O FedPFT usa modelos fundamentais pré-treinados para extrair características que são modeladas usando GMMs. Os clientes transmitem esses modelos paramétricos ao servidor, permitindo que ele gere recursos sintéticos e treine um classificador global.

Explicação do código

main.py

O arquivo main.py é responsável por iniciar o treinamento federado. Ele carrega as configurações do treinamento e inicia o treinamento federado.

# Prepare dataset
trainloaders, testloaders = instantiate(
    cfg.dataset,
    transform=cfg.model.transform,
    image_input_size=cfg.model.image_input_size,
).get_loaders()

# Define clients
client_fn = generate_client_fn(
    client_cfg=cfg.client,
    trainloaders=trainloaders,
    testloaders=testloaders,
    feature_extractor=instantiate(cfg.model.feature_extractor),
    num_classes=cfg.dataset.num_classes,
    device=device,
)

# Setup strategy
strategy = instantiate(cfg.strategy)

# Start simulation
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=cfg.num_clients,
    config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
    strategy=strategy,
    client_resources={"num_cpus": cfg.num_cpus, "num_gpus": cfg.num_gpus},
)

dataset.py

O arquivo dataset.py é responsável por carregar o dataset de treinamento. Ele carrega o dataset e divide em partições para os clientes. É inicializado com os seguintes parâmetros:

  • dataset: Nome ou path do dataset a ser carregado do HuggingFace.
  • num_clients: Número de clientes.
  • batch_size: Tamanho do batch.
  • dirichlet_alpha: Parâmetro alpha da distribuição de Dirichlet.
  • partition_by: Coluna do dataset a ser utilizada para particionar os dados.
  • image_column_name: Nome da coluna que contém as imagens.
  • transform: Transformação a ser aplicada nas imagens.
  • image_input_size: Tamanho da imagem.
  • seed: Semente para reprodução dos resultados.
  • split_size: Tamanho do conjunto de treinamento.

A seguinte função é responsável por carregar o dataset e dividir em partições para os clientes:

def get_loaders(self):
    """Partition the datasets and return a list of dataloaders."""
    partitioner = DirichletPartitioner(
        num_partitions=self.num_clients,
        partition_by=self.partition_by,
        alpha=self.dirichlet_alpha,
        min_partition_size=10,
        self_balancing=True,
    )

    fds = FederatedDataset(
        dataset=self.dataset, partitioners={"train": partitioner}
    )
    # Create train/val for each partition and wrap it into DataLoader
    trainloaders, testloaders = [], []
    for partition_id in range(self.num_clients):
        partition = fds.load_partition(partition_id)
        partition = partition.with_transform(self.apply_batch_transforms())
        partition = partition.train_test_split(
            train_size=self.split_size, seed=self.seed
        )
        trainloaders.append(
            DataLoader(partition["train"], batch_size=self.batch_size)
        )
        testloaders.append(
            DataLoader(partition["test"], batch_size=self.batch_size)
        )

    return trainloaders, testloaders

client.py

O arquivo client.py é responsável por treinar o modelo local. Ele recebe o modelo global e treina o modelo local. É inicializado com os seguintes parâmetros:

  • trainloaders: Dataloaders de treinamento.
  • testloaders: Dataloaders de teste.
  • feature_extractor: Modelo usado para extrair características do cliente.
  • num_classes: Número de classes.
  • device: Dispositivo de treinamento.

A seguinte função é responsável por treinar o modelo local:

def fit(
    self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
    """Fit a GMM on features and return GMM parameters."""
    # Extracting features
    features, labels = extract_features(
        dataloader=self.trainloader,
        feature_extractor=self.feature_extractor,
        device=self.device,
    )

    # Learning GMM
    gmm_list = learn_gmm(
        features=features,
        labels=labels,
        n_mixtures=int(config["n_mixtures"]),
        cov_type=str(config["cov_type"]),
        seed=int(config["seed"]),
        tol=float(config["tol"]),
        max_iter=int(config["max_iter"]),
    )

    # Reshaping GMM parameters into an NDArray
    return [array for gmm in gmm_list for array in gmmparam_to_ndarrays(gmm)], 0, {}

A seguinte função é responsável por avaliar o modelo local:

def evaluate(
    self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[float, int, Dict]:
    """Evaluate `classifier_head` on the test data."""
    self.set_parameters(parameters)
    loss, acc = test(
        classifier_head=self.classifier_head,
        dataloader=self.testloader,
        feature_extractor=self.feature_extractor,
        device=self.device,
    )
    return loss, len(self.testloader.dataset), {"accuracy": acc}

Os pontos de interesse são:

  • extract_features: Extrai características do modelo.
def extract_features(
    dataloader: DataLoader, feature_extractor: torch.nn.Module, device: torch.device
) -> Tuple[NDArray, NDArray]:
    """Extract features and labels from images using feature extractor.

    Parameters
    ----------
    dataloader : DataLoader
        Dataloader containing {'img': img, 'label': label}
        dicts to be extracted.
    feature_extractor : torch.nn.Module
        Model for extracting features.
    device : torch.device
        Device for loading `feature_extractor`.

    Returns
    -------
    features : NDArray
        2D array containing features extracted from `feature_extractor`.
    labels : NDArray
        2D array containing labels of `features`.
    """
    feature_extractor.to(device)

    features, labels = [], []
    for sample in dataloader:
        batch_samples = sample["img"].to(device)
        batch_label = sample["label"].to(device)
        with torch.no_grad():
            feature = feature_extractor(batch_samples)
        features.append(feature.cpu().detach().numpy())
        labels.append(batch_label.cpu().detach().numpy())

    # reshape feauturs and labels into a single numpy array
    features_np = np.concatenate(features, axis=0).astype("float64")
    labels_np = np.concatenate(labels)

    return features_np, labels_np
  • learn_gmm: Aprende um modelo de mistura Gaussiana.
def learn_gmm(
    features: NDArray,
    labels: NDArray,
    n_mixtures: int,
    cov_type: str,
    seed: int,
    tol: float = 1e-12,
    max_iter: int = 1000,
) -> List[GMMParameters]:
    """Learn a list of 16-bits GMMs for each label.

    Parameters
    ----------
    features : NDArray
        A 2-d array with size (n_samples, feature_dimension) containing
        extracted features for all the samples.
    labels : NDArray
        An array with size (n_samples) containing labels associated for
        each sample in `features`.
    n_mixtures : int
        Number of mixtures in each Gaussian Mixture.
    cov_type : str
        Covariance type of Gaussian Mixtures, e.g. spherical.
    seed: int
        Seed for learning and sampling from Gaussian Mixtures.
    tol: float
        Tolerance of Gaussian Mixtures.
    max_iter: int
        Number of maximum iterations to learn the Gaussian Mixtures.

    Returns
    -------
    List[GMMParameters]
        Returns a list containing the GMMParameters for each class.
    """
    gmm_list = []
    for label in np.unique(labels):
        cond_features = features[label == labels]
        if (
            len(cond_features) > n_mixtures
        ):  # number of samples should be larger than `n_mixtures`.
            gmm = GaussianMixture(
                n_components=n_mixtures,
                covariance_type=cov_type,
                random_state=seed,
                tol=tol,
                max_iter=max_iter,
            )
            gmm.fit(cond_features)
            gmm_list.append(
                GMMParameters(
                    label=np.array(label),
                    means=gmm.means_.astype("float16"),
                    weights=gmm.weights_.astype("float16"),
                    covariances=gmm.covariances_.astype("float16"),
                    num_samples=np.array(len(cond_features)),
                )
            )
    return gmm_list
  • gmmparam_to_ndarrays: Converte os parâmetros do modelo de mistura Gaussiana em um array.
def gmmparam_to_ndarrays(gmm: GMMParameters) -> List[NDArray]:
    """Convert gmm object to NumPy ndarrays."""
    return [gmm.label, gmm.means, gmm.weights, gmm.covariances, gmm.num_samples]

strategy.py

O arquivo strategy.py é responsável por definir a estratégia de treinamento. Ele é inicializado com os seguintes parâmetros:

  • num_classes: Número de classes.
  • feature_dimension: Dimensão das características.
  • server_opt: Otimizador do servidor.
  • server_batch_size: Tamanho do lote do servidor.
  • num_epochs: Número de épocas.

A seguinte função é responsável por agregar os resultados de treinamento de vários clientes em um servidor central. Ela pode ser dividida em duas partes:

  1. Verifica se houve falhas durante o treinamento dos clientes.
  2. Constroi um dataset sintético a partir dos modelos de mistura Gaussiana (GMMs) dos clientes.
  3. Treina um classificador com o dataset sintético.
  4. Envia o classificador treinado para os clientes.
def aggregate_fit(
    self,
    server_round: int,
    results: List[Tuple[ClientProxy, FitRes]],
    failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
    """Learn a classifier head by generating samples from the GMMs."""
    # Do not aggregate if there are failures.
    if not self.accept_failures and failures:
        raise Exception("there are failures and failures are not accepted")

    assert self.on_fit_config_fn is not None
    config = self.on_fit_config_fn(server_round)

    # Sample from the GMMs to create synthetic feature dataset
    synthetic_features_dataset: List[Union[Dict, Tuple]] = []
    for _, fit_res in results:
        # Convert byte parameters into ndarrays and GMMParameters
        ndarray = parameters_to_ndarrays(fit_res.parameters)
        all_gmm_parameters = [
            ndarrays_to_gmmparam(array) for array in chunks(ndarray, 5)
        ]

        # Sample from GMM_label pairs to create synthetic features
        for gmm_parameter in all_gmm_parameters:
            gmm = GMM(
                n_components=int(config["n_mixtures"]),
                covariance_type=config["cov_type"],
                random_state=int(config["seed"]),
                tol=float(config["tol"]),
                max_iter=int(config["max_iter"]),
            )
            # Set values of the GMMs
            gmm.means_ = gmm_parameter.means.astype("float32")
            gmm.weights_ = gmm_parameter.weights.astype("float32")
            gmm.covariances_ = gmm_parameter.covariances.astype("float32")

            # Sample features
            syn_features, _ = gmm.sample(gmm_parameter.num_samples)
            syn_features = torch.tensor(syn_features, dtype=torch.float32)
            gmm_labels = torch.tensor(
                [int(gmm_parameter.label)] * int(gmm_parameter.num_samples)
            )

            # Add to train data
            synthetic_features_dataset += list(zip(syn_features, gmm_labels))

    # Train a classifier head
    synthetic_features_dataset = [
        {"img": img, "label": label} for img, label in synthetic_features_dataset
    ]
    synthetic_loader = DataLoader(
        synthetic_features_dataset,
        batch_size=self.server_batch_size,
        shuffle=True,
    )
    classifier_head = torch.nn.Linear(self.feature_dimension, self.num_classes)
    opt = torch.optim.AdamW(
        params=classifier_head.parameters(), lr=self.server_opt.lr
    )

    train(
        classifier_head=classifier_head,
        dataloader=synthetic_loader,
        device=self.device,
        num_epochs=self.num_epochs,
        opt=opt,
        verbose=True,
    )

    # Send the classifier head to clients
    classifier_ndarray = [
        val.cpu().numpy() for _, val in classifier_head.state_dict().items()
    ]

    return ndarrays_to_parameters(classifier_ndarray), {}

server.py

O arquivo server.py contém as funções auxiliares para o servidor, são elas:

  • fedpft_get_on_fit_config_fn: Retorna a configuração de treinamento.
  • weighted_average: Calcula a média ponderada dos resultados de treinamento dos clientes.

Colab

Install dependencies

!pip install flwr
!pip install numpy
!pip install omegaconf
!pip install --upgrade pip setuptools wheel
!pip install hydra-core --upgrade
!pip install torch
!pip install -U "flwr[simulation]"
!pip install torchvision
!pip install transformers
!pip install flwr_datasets

Default configuration

O funcionamento do treinamento é descrito pelas seguintes variáveis:

  • strategy: Estratégia de treinamento.
  • client: Cliente de treinamento.
  • model: Modelo de rede neural.
  • dataset: Dataset de treinamento.
  • num_clients: Número de clientes.
  • dirichlet_alpha: Parâmetro alpha da distribuição de Dirichlet.
  • num_rounds: Número de rodadas de treinamento.
  • num_cpus: Número de CPUs.
  • num_gpus: Número de GPUs.
  • batch_size: Tamanho do lote.
  • device: Dispositivo de treinamento.

Paridade entre modelo e dataset:

  • model=resnet50 e dataset=CIFAR100
  • model=clip e dataset=Caltech101
strategy:
  _target_: fedpft.strategy.FedPFT
  fraction_fit: 1
  fraction_evaluate: 1
  accept_failures: false
  num_classes: ${dataset.num_classes}
  feature_dimension: ${model.hidden_dimension}
  device: ${device}
  server_batch_size: 32
  num_epochs: 50
  server_opt:
    lr: 0.0001
  on_fit_config_fn:
    _target_: fedpft.server.fedpft_get_on_fit_config_fn
    n_mixtures: 1
    cov_type: spherical
    seed: 0
    tol: 1.0e-12
    max_iter: 10000
  evaluate_metrics_aggregation_fn:
    _target_: fedpft.server.weighted_average
    _partial_: true
client:
  _target_: fedpft.client.FedPFTClient
model:
  feature_extractor:
    _target_: fedpft.models.resnet50
  transform:
    _target_: fedpft.models.transform
    mean:
    - 0.485
    - 0.456
    - 0.406
    std:
    - 0.229
    - 0.224
    - 0.225
  image_input_size: 224
  hidden_dimension: 2048
dataset:
  _target_: fedpft.dataset.Dataset
  name: cifar100
  dataset: CIFAR100
  num_classes: 100
  image_column_name: img
  partition_by: fine_label
  num_clients: ${num_clients}
  dirichlet_alpha: ${dirichlet_alpha}
  batch_size: ${batch_size}
num_clients: 50
dirichlet_alpha: 0.1
num_rounds: 1
num_cpus: 2
num_gpus: 0
batch_size: 64
device: cpu

Run simulation

!python -m fedpft.main

Tempo de execução: 520 segundos.

FedPFT com CIFAR100
!python -m fedpft.main dataset=CIFAR100 model=resnet50

Tempo de execução: 540 segundos. A acurácia do modelo é de 0.521.

FedPFT com Caltech101
!python -m fedpft.main dataset=Caltech101 model=clip

Tempo de execução: 97 segundos. A acurácia do modelo é de 0.905.

FedAvg com CIFAR100

São realizados 10 rounds de treinamento com 1 época por rodada.

!python -m fedpft.main strategy=fedavg client=fedavg dataset=CIFAR100 model=resnet50 num_rounds=10 strategy.on_fit_config_fn.num_epochs=1

Tempo de execução: 3906 segundos. A acurácia do modelo é de 0.441.

FedAvg com Caltech101

São realizados 10 rounds de treinamento com 1 época por rodada.

!python -m fedpft.main strategy=fedavg client=fedavg dataset=Caltech101 model=clip num_rounds=10

Tempo de execução: 1328 segundos. A acurácia do modelo é de 0.815.

Resultados

O FedPFT obteve uma acurácia de 0.521 com CIFAR100 e 0.905 com Caltech101. O FedAvg obteve uma acurácia de 0.441 com CIFAR100 e 0.815 com Caltech101. O FedPFT obteve uma acurácia superior ao FedAvg em ambos os datasets.

Resultados