syft.frameworks.torch.fl.utils¶
Module Contents¶
-
syft.frameworks.torch.fl.utils.logger¶
-
syft.frameworks.torch.fl.utils.extract_batches_per_worker(federated_train_loader: sy.FederatedDataLoader)¶ Extracts the batches from the federated_train_loader and stores them in a dictionary (keys = data.location).
Args: federated_train_loader: the connection object we use to send responses.
back to the client.
-
syft.frameworks.torch.fl.utils.add_model(dst_model, src_model)¶ Add the parameters of two models.
- Parameters
dst_model (torch.nn.Module) – the model to which the src_model will be added.
src_model (torch.nn.Module) – the model to be added to dst_model.
- Returns
the resulting model of the addition.
- Return type
torch.nn.Module
-
syft.frameworks.torch.fl.utils.scale_model(model, scale)¶ Scale the parameters of a model.
- Parameters
model (torch.nn.Module) – the models whose parameters will be scaled.
scale (float) – the scaling factor.
- Returns
the module with scaled parameters.
- Return type
torch.nn.Module
-
syft.frameworks.torch.fl.utils.federated_avg(models: List[torch.nn.Module]) → torch.nn.Module¶ Calculate the federated average of a list of models.
- Parameters
models (List[torch.nn.Module]) – the models of which the federated average is calculated.
- Returns
the module with averaged parameters.
- Return type
torch.nn.Module
-
syft.frameworks.torch.fl.utils.accuracy(pred_softmax, target)¶ Calculate the accuray of a given prediction.
This functions assumes pred_softmax to be converted into the final prediction by taking the argmax.
- Parameters
pred_softmax – array type(float), providing nr_classes values per element in target.
target – array type(int), correct classes, taking values in range [0, nr_classes).
- Returns
float, fraction of correct predictions.
- Return type
accuracy
-
syft.frameworks.torch.fl.utils.create_gaussian_mixture_toy_data(nr_samples: int)¶ Create a simple toy data for binary classification
The data is drawn from two normal distributions target = 1: mu = 2, sigma = 1 target = 0: mu = 0, sigma = 1 The dataset is balanced with an equal number of positive and negative samples
- Parameters
nr_samples – number of samples to generate
- Returns
data, targets
-
syft.frameworks.torch.fl.utils.iris_data_partial()¶ Returns: 30 samples from the iris data set: https://archive.ics.uci.edu/ml/datasets/iris