import itertools
import random
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Tuple, Union
import numpy as np
import torch
from torch.utils.data import Dataset, Subset
from fedlib.datasets.clientdataset import ClientDataset
if TYPE_CHECKING:
from datasets import Dataset as HuggingFaceDataset
[docs]class DatasetPartitioner(ABC):
"""An abstract base class for dataset splitting strategies that considers
random states from both NumPy and PyTorch."""
def __init__(
self,
num_clients: int,
random_seed: int = 123,
client_id_generator: Callable[[], Iterator] = None,
):
"""Initializes the dataset partitioner with the number of clients and
an optional random seed.
Args:
num_clients: The number of clients to split the data for.
random_seed: An optional random seed for reproducibility.
client_id_generator: An optional generator for creating client IDs.
"""
self.num_clients = num_clients
self.random_seed = random_seed
if random_seed is not None:
np.random.seed(random_seed) # Set NumPy random seed
torch.manual_seed(random_seed) # Set PyTorch random seed
self.client_id_generator = (
client_id_generator or self._default_client_id_generator()
)
[docs] def generate_subsets(self, dataset: Dataset) -> Dict[str, Subset]:
"""Generates subsets from a single dataset.
Args:
dataset: The dataset to be split.
Returns:
A dictionary with client IDs as keys and corresponding subsets as values.
"""
subsets = self.split_dataset(
dataset
) # Pass None for test_dataset if only one dataset is provided
client_ids = self.generate_client_ids()
return dict(zip(client_ids, subsets))
[docs] def generate_paired_subsets(
self,
train_dataset: Union[Dataset, "HuggingFaceDataset"],
test_dataset: Union[Dataset, "HuggingFaceDataset"],
) -> Dict[str, Tuple[Subset, Subset]]:
"""Generates paired subsets from two keyconcepts that may interact with
each other.
Args:
train_dataset: The training dataset to be split.
test_dataset: The testing dataset to be split.
Returns:
A dictionary with client IDs as keys and tuples of corresponding training
and testing subsets as values.
"""
train_subsets, test_subsets = self.split_datasets(train_dataset, test_dataset)
client_ids = self.generate_client_ids()
return dict(zip(client_ids, zip(train_subsets, test_subsets)))
[docs] def generate_client_datasets(
self,
train_dataset: Union[Dataset, "HuggingFaceDataset"],
test_dataset: Union[Dataset, "HuggingFaceDataset"],
**kwargs,
) -> List[ClientDataset]:
"""Generates client keyconcepts from two keyconcepts that may interact
with each other.
Args:
train_dataset: The training dataset to be split.
test_dataset: The testing dataset to be split.
Returns:
A list of ClientDataset instances.
"""
client_datasets = []
paired_subsets = self.generate_paired_subsets(train_dataset, test_dataset)
for client_id, (train_subset, test_subset) in paired_subsets.items():
train_indices = train_subset.indices
test_indices = test_subset.indices
random.shuffle(train_indices)
random.shuffle(test_indices)
if not isinstance(train_dataset, Dataset):
shuffled_train_subset = train_dataset.select(train_indices)
shuffled_test_subset = test_dataset.select(test_indices)
else:
shuffled_train_subset = Subset(train_dataset, train_indices)
shuffled_test_subset = Subset(test_dataset, test_indices)
client_datasets.append(
ClientDataset(
uid=client_id,
train_set=shuffled_train_subset,
test_set=shuffled_test_subset,
**kwargs,
)
)
return client_datasets
[docs] @abstractmethod
def split_dataset(self, dataset: Dataset) -> List[Subset]:
"""Split a single dataset into multiple subsets, each keyed by a unique
client_id.
Args:
dataset (Dataset): The dataset to be split.
Returns:
Dict[str, Subset]: A dictionary where the key is a string client_id and the
value is a Subset.
"""
[docs] @abstractmethod
def split_datasets(
self, train_dataset: Dataset, test_dataset: Dataset
) -> List[Tuple[Subset, Subset]]:
"""Split two keyconcepts (e.g., training and testing keyconcepts) into
multiple pairs of subsets, each keyed by a unique client_id.
Args:
train_dataset (Dataset): The training dataset to be split.
test_dataset (Dataset): The testing dataset to be split.
Returns:
Dict[str, Tuple[Subset, Subset]]: A dictionary where the key is a string
client_id and the value is a tuple of two Subsets (training and testing).
"""
@staticmethod
def _default_client_id_generator():
"""A default generator for client IDs that yields sequential
numbers."""
return (f"client_{i}" for i in itertools.count(1))
[docs] def generate_client_ids(self) -> List[Any]:
"""Generate a list of client IDs using the specified client ID
generator."""
return [next(self.client_id_generator) for _ in range(self.num_clients)]