Source code for fedlib.datasets.partitioners.iid_partitioner

from typing import List

import torch
from torch.utils.data import Dataset, Subset

from .dataset_partitioner import DatasetPartitioner


[docs]class IIDPartitioner(DatasetPartitioner): """Partitioner that splits a dataset into IID subsets."""
[docs] def split_dataset(self, dataset: Dataset) -> List[Subset]: # Shuffle the dataset indices using PyTorch indices = torch.randperm(len(dataset)).tolist() # Calculate the size of each split, allowing for uneven splits split_size = len(dataset) // self.num_clients # Calculate the number of keyconcepts that will have an extra sample to account # for remainders remainder = len(dataset) % self.num_clients # Generate the subsets subsets = [] start_idx = 0 for i in range(self.num_clients): end_idx = start_idx + split_size + (1 if i < remainder else 0) subsets.append(Subset(dataset, indices[start_idx:end_idx])) start_idx = end_idx return subsets
[docs] def split_datasets( self, train_dataset: Dataset, test_dataset: Dataset ) -> tuple[list[Subset], list[Subset]]: # Use the split_dataset method to split both the # training and testing keyconcepts train_subsets = self.split_dataset(train_dataset) test_subsets = self.split_dataset(test_dataset) return train_subsets, test_subsets