Hyperparameter tuning with Optuna and FEDn Python API

2024-09-13 by Benjamin Åstrand

In machine learning, hyperparameter tuning plays a crucial role in optimizing model performance. When it comes to federated learning, the need to tune hyperparameters not only emerges on the client-side when training on local data, but can also emerge on the server-side depending on the server-side aggregator that is used. FEDn supports several server-side aggregators through the FEDn Python API. If you want to learn more about how to use different server-side optimizers with FEDn, please see this notebook on our Github. And if you want to gain a deeper understanding of how these optimizers work, please see this paper that proposes the FedOpt algorithm, of which FedAdam is a special case. Here, we will specifically showcase how to tune the learning rate of FedAdam using the package Optuna. The code is available in this notebook on Github.

This blog post assumes that you are already familiar with the fundamentals of how to use FEDn and the Python API. If you haven’t already, please see our tutorials on Getting started with FEDn and Using the Python API.

What is Optuna

Optuna is a hyperparameter optimization framework designed to automate the tedious process of finding the best hyperparameters. It uses Bayesian Optimization to find new hyperparameter values to try, given the achieved performance with previous hyperparameter values. It is efficient, easy to use, and supports a variety of machine learning frameworks which makes it suitable for hyperparameter tuning with FEDn.

Using Optuna to tune the server-side learning rate of FedAdam

The Optuna framework expects the user to define an objective function, which is used to evaluate the model given a certain set of hyperparameter values. This blog post is based on an existing example on the FEDn Github, where we use a simple PyTorch model on the MNIST handwritten digit dataset. To evaluate the performance given different hyperparameter values, we will view the accuracy on the test set as the validation accuracy and we want to find the learning rate that maximizes this metric.

Defining the objective function

For each choice of hyperparameter values, we start a new session, with a given number of rounds, using FEDn and train the global model with the current hyperparameter values. When the session has finished, we evaluate the performance attained in the session. This is where the objective function comes into play! The objective function should follow these steps:

  1. Set a range for each hyperparameter to tune using the trial object in Optuna.
  2. Train the model, using the hyperparameters suggested by Optuna.
  3. Calculate and return an evaluation metric.

The code below shows how we can complete these three steps with FEDn. The range in which Optuna will look for hyperparameter values is defined in step 1. Note that we are only tuning the learning rate of FedAdam in this example to keep things simple. Step 2 entails starting a session and waiting for it to finish before evaluating the resulting model.

In step 3, we can choose to calculate the evaluation metrics however we want. Below are two suggested methods for evaluating the performance attained in a session:

  • Highest score - select the highest achieved test accuracy out of all rounds in the session.
  • Average final few rounds - compute the average test accuracy over the final few (ex. 5) rounds to account for the stochastic nature of the test accuracy score.

…and how to implement them using FEDn, where the parameter eval_method determines which of the two methods to use:

Creating, running and analyzing an Optuna study

It’s time to create and run our study to let Optuna find optimal server-side learning rate for FedAdam. At this stage, all that is left to do is to tell Optuna in which direction to optimize the objective function and how many hyperparameter values we want to try. We create an Optuna study object and run the optimize method, passing the objective function we defined earlier as a parameter.

The parameter n_trial determines the number of hyperparameter values that Optuna will try. Since we are starting a new session for each trial, the number of sessions will be equal to n_trials. Now we can access the results through the study object, for example the best learning rate:

…and visualize the optimization process:

Conclusion

In this post, we showed how to integrate Optuna with FEDn for hyperparameter tuning, using the example of tuning the learning rate of FedAdam. By defining an objective function and leveraging Optuna's efficient optimization, we automated the search for the best server-side learning rate to maximize test accuracy. With FEDn’s flexible API, we were able to evaluate performance in a flexible manner, whether by selecting the highest accuracy or averaging the final rounds.