Skip to content

Hyperparameter tuning

Optuna

Optuna is a software used for hyperparameter tuning. It is agnostic to what you are trying to tune, and highly suitable for distributed computing.

Briefly, the way it works is that Optuna sets up or loads an object called a Study, which can be tracked using various database structures - in this case, we will be using a journal file, as this is the most robust approach with the fewest points of failure. The Study has a Sampler attached which contains an algorithm for picking hyperparameters - for example, a grid-based or random sampler. The study can then be used to generate a number of Trials, where each trial can query the Sampler for values for the set of hyperparameters to be tuned, based on the set of trials completed thus far in the study. The Trial is the argument to an objective function (a loss or acquisition function), in which the querying for hyperparameters is carried out; the objective function needs to return a numerical value by which the success of the trial is evaluated. The Study automatically keeps track of all hyperparameters that have been tried thus far, as well as the results of all trials that have come in; a variety of advanced usage, such as pruning specific trials, is possible.

Optuna is a flexible framework, but we will show some examples suitable for usage on a Slurm cluster. An Optuna study can look like this, using sklearn for example:

import optuna
import random
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import datetime

class Objective:
    def __init__(self, hyperparameter_config: dict):
        self.hyperparameter_config = hyperparameter_config

    def __call__(self, trial):
        # We can modify the ranges of hyperparameters during the study
        hyperparameters = dict()
        for key, value in self.hyperparameter_config.items():
            trial_fun = getattr(trial, value['fun'])
            hyperparameters[key] = \
                trial_fun(key, value['min'],
                          value['max'], log=value['log'])

        clf = RandomForestClassifier(
            random_state=random.randint(0, 1e9),
            **hyperparameters)

        # Load the breast cancer dataset
        X, y = load_breast_cancer(return_X_y=True)
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=random.randint(0, 1e9))
        clf.fit(X_train, y_train)
        y_pred = clf.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)

        # Log additional information
        trial.set_user_attr(
            'timestamp',
            datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

        return accuracy

if  __name__ == "__main__":
    # Define the hyperparameters to optimize
    hyperparameter_config = dict(
        n_estimators=dict(
            fun='suggest_int', min=10, max=100, log=False),
        max_depth=dict(
            fun='suggest_int', min=2, max=32, log=True),
        min_samples_split=dict(
            fun='suggest_int', min=2, max=10, log=False),
        min_samples_leaf=dict(
            fun='suggest_int', min=1, max=10, log=False))
    # Create journal storage
    file_path = "./db/journal.log"
    study_name = "my_study"
    sampler = optuna.samplers.RandomSampler(seed=random.randint(0, 1e9))
    storage = optuna.storages.JournalStorage(
        optuna.storages.JournalFileStorage(file_path))
    study = optuna.create_study(
        study_name=study_name,
        storage=storage,
        direction='maximize',
        sampler=sampler,
        load_if_exists=True)
    if len(study.trials) < 100:
        # Callback to constrain the total number of trials
        max_trial_callback = optuna.study.MaxTrialsCallback(
            100, states=(optuna.trial.TrialState.COMPLETE,))
        # Run the optimization
        study.optimize(Objective(hyperparameter_config),
                       n_trials=5,
                       callbacks=[max_trial_callback])
1
2
3
4
5
6
7
8
#!/bin/bash
#SBATCH -A YOUR_ACCOUNT -p vera
#SBATCH -n 2
#SBATCH -t 00:05:00
#SBATCH -o test_optuna_%j.txt

ml Optuna/3.5.0 scikit-learn/1.3.1
python3 optuna_test.py
$ sbatch --array=0-19 optuna-batch.sh
Submitted batch job 6445805

We can see some of the output from the journal log (here lightly edited for readability):

$ tail -5 db/journal.log
{"op_code": 5, "worker_id": "ec5b2781-8724-4351-8e7e-7cad2cc457f0-2259702416",
 "trial_id": 100, "param_name": "min_samples_leaf", "param_value_internal": 1.0,
 "distribution": {"name": "IntDistribution",
                  "attributes": {
                    "log": false,
                    "step": 1,
                    "low": 1,
                    "high": 10}}}
{"op_code": 8, "worker_id": "c86b6df2-2b1c-4a53-825f-94d22a4a7604-2328590375",
 "trial_id": 98, "user_attr": {"timestamp": "2024-11-08 11:30:52"}}
{"op_code": 6, "worker_id": "c86b6df2-2b1c-4a53-825f-94d22a4a7604-2328590375",
 "trial_id": 98, "state": 1, "values": [0.9298245614035088],
 "datetime_complete": "2024-11-08T11:30:52.108614"}
{"op_code": 8, "worker_id": "ec5b2781-8724-4351-8e7e-7cad2cc457f0-2259702416",
 "trial_id": 100, "user_attr": {"timestamp": "2024-11-08 11:30:52"}}
{"op_code": 6, "worker_id": "ec5b2781-8724-4351-8e7e-7cad2cc457f0-2259702416",
 "trial_id": 100, "state": 1, "values": [0.9824561403508771],
 "datetime_complete": "2024-11-08T11:30:52.242400"}

We can get the best result from the study via the Optuna CLI, here in yaml format:

$ ml Optuna
$ optuna --storage db/journal.log --storage-class JournalFileStorage \
--study-name my_study best-trial -f yaml
datetime_complete: "2024-11-08 11:30:38"
datetime_start: "2024-11-08 11:30:38"
duration: "0:00:00.342559"
number: 21
params:
    max_depth: 3
    min_samples_leaf: 8
    min_samples_split: 8
    n_estimators: 76
state: COMPLETE
user_attrs:
    timestamp: "2024-11-08 11:30:38"
value: 0.9912280701754386

In Optuna, we can use the best trial (or the best few trials) as a starting point for fine-tuning our hyperparameters. For example, you can use the number property to save the state of each run, and then load the state of your best run in your next optimization to continue tuning. Alternatively, you can use your trial scores to generate a new set of constraints for further optimization - there are many different options. You can even continue the same study and use the entirety of the random sampling to inform the other samplers available in Optuna, or you can write your own sampler thanks to Optuna's object-oriented structure.

Advanced usage

MPI

Using MPI with Optuna is straightforward.

...
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

class Objective:
    def __init__(self, hyperparameter_config: dict):
        ...

    def __call__(self, trial):
        # Only root gets suggestions
        if rank == 0:
            # standard non-MPI code for hyperparameters
            hyperparameters = ...
        else:
            hyperparameters = None
        hyperparameters = comm.bcast(hyperparameters, root=0)

if __name__ == "__main__":
    n_trials = 5
    hyperparameter_config = ...
    if rank == 0:
        # standard non-MPI code for study
        ...
    else:
        # Worker nodes just participate in trials
        for _ in range(n_trials):
            Objective(hyperparameter_config)(optuna.trial.FixedTrial({}))

Services

Given a set of optimization scripts that are suitably written, you can run series of batch optimizations programmatically. You can run a small service on the login node that submits jobs accordingly by using the Optuna CLI, for example:

#!/bin/bash

JOURNALPATH="$HOME/optuna/db/journal.log"

# If the journal file does not exist, run the first study.
if [-z "$JOURNALPATH"]; then
    sbatch --array=0-19 optuna_first_study.sh
    return 0
fi

# Parse the number of the last trial and check if it's at least 100
TRIALS=$(optuna --storage ${JOURNALPATH} \
        --storage-class JournalFileStorage \
        --study-name my_study trials -f yaml |
        | grep number | tail -n 1 | cut -d' ' -f4)

if [ "$TRIALS" -ge 100 ]; then
    sbatch --array=0-19 optuna_second_study.sh
    return 0
fi

...
[Unit]
Description=File Monitor Service
After=network.target

[Service]
Type=simple
ExecStart=%h/optuna/optuna_slurm.sh
Restart=always
RestartSec=900

[Install]
WantedBy=default.target

Enable the service by running the following:

$ systemctl --user daemon-reload && \
systemctl --user enable optuna-slurm && \
systemctl --user start optuna-slurm

You can check the service by running systemctl --user status optuna-slurm. The advantage of such a service is that it can survive reboots of the login node, which is an advantage for very long-running batches. The job of distributing the workload is then left to Slurm.