Source code for fedlib.aggregators.clippedclustering

from typing import List

import numpy as np
import torch
from numpy import inf
from sklearn.cluster import AgglomerativeClustering

from fedlib.utils import torch_utils
from .aggregators import Mean, Median


[docs]class Clippedclustering(object): r"""A robust aggregator from paper `"An Experimental Study of Byzantine-Robust Aggregation Schemes in Federated Learning" <https://ieeexplore.ieee.org/abstract/document/10018261>`_. It separates the client population into two groups based on the cosine similarities. Args: tau (float): threshold of clipping norm. If it is not given, updates are clipped according the median of historical norm. """ def __init__(self, agg="mean", max_tau=1e5, linkage="average") -> None: assert linkage in ["average", "single"] self.tau = max_tau self.linkage = linkage self.l2norm_his = [] if agg == "mean": self.agg = Mean() elif agg == "median": self.agg = Median() else: raise NotImplementedError(f"{agg} is not supported yet.") def __call__(self, inputs: List[torch.Tensor]): # Clip updates based on L2 norm updates = self._clip_updates(inputs) # Compute pairwise cosine similarity dis_max = self._compute_cosine_similarity(updates) # Cluster updates using AgglomerativeClustering selected_idxs = self._cluster_updates(dis_max) # Compute final values using selected updates values = self._compute_values(selected_idxs, updates) return values def _clip_updates(self, inputs: List[torch.Tensor]): l2norms = [torch.norm(update).item() for update in inputs] self.l2norm_his.extend(l2norms) threshold = np.median(self.l2norm_his) threshold = min(threshold, self.tau) clipped_updates = [] for idx, l2 in enumerate(l2norms): if l2 > threshold: clipped_updates.append( torch_utils.clip_tensor_norm_(inputs[idx], threshold) ) else: clipped_updates.append(inputs[idx]) return torch.stack(clipped_updates, dim=0) def _compute_cosine_similarity(self, updates): num = len(updates) dis_max = np.zeros((num, num)) for i in range(num): for j in range(i + 1, num): dis_max[i, j] = 1 - torch.nn.functional.cosine_similarity( updates[i, :], updates[j, :], dim=0 ) dis_max[j, i] = dis_max[i, j] dis_max[dis_max == -inf] = 0 dis_max[dis_max == inf] = 2 dis_max[np.isnan(dis_max)] = 2 return dis_max def _cluster_updates(self, dis_max): clustering = AgglomerativeClustering( metric="precomputed", linkage=self.linkage, n_clusters=2 ) clustering.fit(dis_max) flag = 1 if np.sum(clustering.labels_) > len(dis_max) // 2 else 0 selected_idxs = [ idx for idx, label in enumerate(clustering.labels_) if label == flag ] return selected_idxs def _compute_values(self, selected_idxs, updates): benign_updates = [] for idx in selected_idxs: benign_updates.append(updates[idx]) values = self.agg(benign_updates) return values