Source code for fedlib.aggregators.aggregators

from typing import List, Optional

import numpy as np
import torch


def _mean(inputs: List[torch.Tensor]):
    inputs_tensor = torch.stack(inputs, dim=0)
    return inputs_tensor.mean(dim=0)


def _median(inputs: List[torch.Tensor]):
    inputs_tensor = torch.stack(inputs, dim=0)
    values_upper, _ = inputs_tensor.median(dim=0)
    values_lower, _ = (-inputs_tensor).median(dim=0)
    return (values_upper - values_lower) / 2


[docs]class Mean(object): r"""Computes the ``sample mean`` over the updates from all give clients.""" def __call__(self, inputs: List[torch.Tensor]): return _mean(inputs)
[docs]class Median(object): r"""A robust aggregator from paper `"Byzantine-robust distributed learning: Towards optimal statistical rates" <https://proceedings.mlr.press/v80/yin18a>`_. It computes the coordinate-wise median of the given set of clients """ def __call__(self, inputs: List[torch.Tensor]): return _median(inputs)
[docs]class Trimmedmean(object): r"""A robust aggregator from paper `"Byzantine-robust distributed learning: Towards optimal statistical rates" <https://proceedings.mlr.press/v80/yin18a>`_ It computes the coordinate-wise trimmed average of the global_model updates, which can be expressed by: .. math:: trmean := TrimmedMean( \{ \Delta_k : k \in [K] \} ), where the :math:`i`-th coordinate :math:`trmean_i = \frac{1}{(1-2\beta)m} \sum_{x \in U_k}x`, and :math:`U_k` is a subset obtained by removing the largest and smallest :math:`\beta` fraction of its elements. """ def __init__(self, num_byzantine: int, *, filter_frac=1.0): if filter_frac > 1.0 or filter_frac < 0.0: raise ValueError(f"filter_frac should be in [0.0, 1.0], got {filter_frac}.") if not isinstance(num_byzantine, int): raise ValueError( f"num_byzantine should be an integer, got {num_byzantine}." ) def round_up_to_power_of_two(num): num = int(num) return num self.num_excluded = round_up_to_power_of_two(filter_frac * num_byzantine) def __call__(self, inputs: List[torch.Tensor]): if len(inputs) <= self.num_excluded * 2: raise ValueError( f"Not enough inputs to compute trimmed mean," f"got {len(inputs)} inputs but need at least " f"{self.num_excluded * 2 + 1} inputs." ) inputs = torch.stack(inputs, dim=0) largest, _ = torch.topk(inputs, self.num_excluded, 0) neg_smallest, _ = torch.topk(-inputs, self.num_excluded, 0) new_stacked = torch.cat([inputs, -largest, neg_smallest]).sum(0) new_stacked /= len(inputs) - 2 * self.num_excluded return new_stacked
[docs]class GeoMed: r"""A robust aggregator from paper `"Distributed Statistical Machine Learning in Adversarial Settings: Byzantine Gradient Descent" <https://arxiv.org/abs/1705.05491>`_. ``GeoMed`` aims to find a vector that minimizes the sum of its Euclidean distances to all the update vectors: .. math:: GeoMed := \arg\min_{\boldsymbol{z}} \sum_{k \in [K]} \lVert \boldsymbol{z} - {\Delta}_i \rVert. There is no closed-form solution to the ``GeoMed`` problem. It is approximately solved using Weiszfeld's algorithm in this implementation to. :param maxiter: Maximum number of Weiszfeld iterations. Default 100 :param eps: Smallest allowed value of denominator, to avoid divide by zero. Equivalently, this is a smoothing parameter. Default 1e-6. :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm, default 1e-10. """ def __init__( self, maxiter: Optional[int] = 100, eps: Optional[float] = 1e-6, ftol: Optional[float] = 1e-10, ): self.maxiter = maxiter self.eps = eps self.ftol = ftol def __call__(self, inputs: List[torch.Tensor], weights=None): if weights is None: weights = (torch.ones(len(inputs)) / len(inputs)).to(inputs[0].device) input_tensor = torch.stack(inputs, dim=0) return self._geometric_median( input_tensor, weights=weights, maxiter=self.maxiter, eps=self.eps, ftol=self.ftol, ) @staticmethod def _geometric_median(inputs, weights, eps=1e-6, maxiter=100, ftol=1e-20): weighted_average = ( lambda inputs, weights: torch.sum(weights.view(-1, 1) * inputs, dim=0) / weights.sum() ) def obj_func(median, inputs, weights): # This function is not used so far, # as the numpy version appears to be more accurate (and I don't know why). # norms = torch.norm(inputs - median, dim=1) # return (torch.sum(norms * weights) / torch.sum(weights)).item() return np.average( [torch.norm(p - median).item() for p in inputs], weights=weights.cpu(), ) with torch.no_grad(): median = weighted_average(inputs, weights) new_weights = weights objective_value = obj_func(median, inputs, weights) # Weiszfeld iterations for _ in range(maxiter): prev_obj_value = objective_value denom = torch.stack([torch.norm(p - median) for p in inputs]) new_weights = weights / torch.clamp(denom, min=eps) median = weighted_average(inputs, new_weights) objective_value = obj_func(median, inputs, weights) if abs(prev_obj_value - objective_value) <= ftol * objective_value: break median = weighted_average(inputs, new_weights) return median
[docs]class DnC(object): r"""A robust aggregator from paper `"Manipulating the Byzantine: Optimizing Model Poisoning Attacks and Defenses for Federated Learning" <https://par.nsf.gov/servlets/purl/10286354>`_. """ def __init__( self, num_byzantine, *, sub_dim=10000, num_iters=5, filter_frac=1.0 ) -> None: self.num_byzantine = num_byzantine self.sub_dim = sub_dim self.num_iters = num_iters self.fliter_frac = filter_frac def __call__(self, inputs: List[torch.Tensor]): updates = torch.stack(inputs, dim=0) d = len(updates[0]) benign_ids = [] for i in range(self.num_iters): indices = torch.randperm(d)[: self.sub_dim] sub_updates = updates[:, indices] mu = sub_updates.mean(dim=0) centered_update = sub_updates - mu v = torch.linalg.svd(centered_update, full_matrices=False)[2][0, :] s = np.array( [(torch.dot(update - mu, v) ** 2).item() for update in sub_updates] ) good = s.argsort()[ : len(updates) - int(self.fliter_frac * self.num_byzantine) ] benign_ids.append(good) # Convert the first list to a set to start the intersection intersection_set = set(benign_ids[0]) # Iterate over the rest of the lists and get the intersection for lst in benign_ids[1:]: intersection_set.intersection_update(lst) if intersection_set: # Convert the set back to a list benign_ids = list(intersection_set) benign_updates = updates[benign_ids, :].mean(dim=0) else: benign_updates = torch.zeros_like(inputs[0]) return benign_updates