syft.frameworks.torch.fl.utils

Module Contents

syft.frameworks.torch.fl.utils.logger
syft.frameworks.torch.fl.utils.extract_batches_per_worker(federated_train_loader: sy.FederatedDataLoader)

Extracts the batches from the federated_train_loader and stores them in a dictionary (keys = data.location).

Args: federated_train_loader: the connection object we use to send responses.

back to the client.

syft.frameworks.torch.fl.utils.add_model(dst_model, src_model)

Add the parameters of two models.

Parameters
  • dst_model (torch.nn.Module) – the model to which the src_model will be added.

  • src_model (torch.nn.Module) – the model to be added to dst_model.

Returns

the resulting model of the addition.

Return type

torch.nn.Module

syft.frameworks.torch.fl.utils.scale_model(model, scale)

Scale the parameters of a model.

Parameters
  • model (torch.nn.Module) – the models whose parameters will be scaled.

  • scale (float) – the scaling factor.

Returns

the module with scaled parameters.

Return type

torch.nn.Module

syft.frameworks.torch.fl.utils.federated_avg(models: List[torch.nn.Module]) → torch.nn.Module

Calculate the federated average of a list of models.

Parameters

models (List[torch.nn.Module]) – the models of which the federated average is calculated.

Returns

the module with averaged parameters.

Return type

torch.nn.Module

syft.frameworks.torch.fl.utils.accuracy(pred_softmax, target)

Calculate the accuray of a given prediction.

This functions assumes pred_softmax to be converted into the final prediction by taking the argmax.

Parameters
  • pred_softmax – array type(float), providing nr_classes values per element in target.

  • target – array type(int), correct classes, taking values in range [0, nr_classes).

Returns

float, fraction of correct predictions.

Return type

accuracy

syft.frameworks.torch.fl.utils.create_gaussian_mixture_toy_data(nr_samples: int)

Create a simple toy data for binary classification

The data is drawn from two normal distributions target = 1: mu = 2, sigma = 1 target = 0: mu = 0, sigma = 1 The dataset is balanced with an equal number of positive and negative samples

Parameters

nr_samples – number of samples to generate

Returns

data, targets

syft.frameworks.torch.fl.utils.iris_data_partial()

Returns: 30 samples from the iris data set: https://archive.ics.uci.edu/ml/datasets/iris