MLflow a open source platform for "model development, deployment, and management". On Alvis you can use MLflow for:

  • Experiment tracking;
  • Hyperparameter tuning;
  • Model registry.

The following features of MLflow are untested and likely will not work fully:

  • Deployment server (AI gateway);
  • LLM tracking system.

Tracking models with MLflow🔗

As an example, here we train a few vision transformers (ViTs) on the CIFAR-10 dataset, with different hyperparameters. We start from a typical script with PyTorch and add some MLflow tracking to it (highlighted below):

""" Sample training script with PyTroch and ViT, tracked by MLFlow.

Usage:  python [--option value] ...

Where available option/values are:
- arch: model architecture (str), see:
- lr: learning rate (float)
- tf32: use tf32 matmul (int), 0 or 1

def train_model(args):
    import mlflow
    from mlflow.models import infer_signature

    with mlflow.start_run():
        # We start by logging some parameters, mlflow can also log system metrics
        # within the scope:
        mlflow.set_tag("Training Info", "Vision Transformers w/wo TF32 matmul")
        mlflow.log_params({"arch":args.arch, "lr", "tf32":args.tf32})
        import torch, time
        import torchvision.models as models
        from torch.nn.utils import clip_grad_norm_
        from import DataLoader
        from torchvision.datasets import CIFAR10
        from torchvision.transforms import Compose, ToTensor, Resize

        if args.tf32:

        dataroot = "/mimer/NOBACKUP/Datasets/CIFAR"
        batch_size = 64
        image_size = 224
        transforms = Compose([Resize((image_size, image_size)), ToTensor()])
        train_data = CIFAR10(root=dataroot, train=True, transform=transforms)
        valid_data = CIFAR10(root=dataroot, train=False, transform=transforms)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)

        model = models.__dict__[args.arch]().to(torch.device("cuda"))

        optimizer = torch.optim.AdamW(
            model.parameters(),, weight_decay=0.3)
        loss_fn = torch.nn.CrossEntropyLoss()

        # standard training & validation loop
        EPOCHS = 10;
        for epoch in range(EPOCHS):
            for i, data in enumerate(torch.hub.tqdm(train_loader)):
                images, labels = map(lambda x:"cuda")), data)
                logits = model(images)
                loss = loss_fn(logits, labels)
                clip_grad_norm_(model.parameters(), 0.1)

            with torch.no_grad():
                vloss = 0
                for i, data in enumerate(valid_loader):
                    vinputs, vlabels = map(lambda"cuda")), data)
                    vlogits = model(vinputs)
                    vloss += loss_fn(vlogits, vlabels)
                vloss /= (i+1)

            # Metrics are recommended to be tracked per-epoch, see:
            mlflow.log_metric("vloss", vloss, step=epoch)
            mlflow.log_metric("time", time.time()-t0, step=epoch)

        # At the end of training, you can log the model to be used later, see:
        signature = infer_signature(images.cpu().numpy(), logits.cpu().numpy())
        mlflow.pytorch.log_model(model, f"{args.arch}-{}-{args.tf32}",

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Specify port and address")
    parser.add_argument("--arch", type=str, help="CNN architecture")
    parser.add_argument("--lr", type=float, help="Learning rate")
    parser.add_argument("--tf32", type=int, help="Enable TF32 matmul")
    args = parser.parse_args()

We can perform a grid scan of the hyperparameters to see how different architecture, hyperparameters, and the numerical precision affects our results, and visualize that with MLflow.

#!/usr/bin/env bash
#SBATCH --time=00:30:00
#SBATCH --gres=gpu:A100:1
#SBATCH --array=0-15

arch=(vit_b_16 vit_b_32 vit_l_16 vit_l_32) # vit_h_14
lr=(0.01 0.001)
tf32=(0 1)

ml MLflow/2.10.2-gfbf-2023a PyTorch-bundle/2.1.2-foss-2023a-CUDA-12.1.1

# a simple grid seach that gets parameter from the SLURM_ARRAY_TASK_ID
python ./ --arch=${arch[id%4]} --lr=${lr[id/4%2]} --tf32=${tf32[id/8%2]}

For this example, we just perform a plain grid search of parameters, you can of course use other libraries to perform do it more efficiently, such as Ray and Dask, with a bit more effort.

Inspecting the results🔗

You can use the MLflow web ui to inspect your training by running it on the login node. Note that your server will be visible to Alvis users. We recommend setting up a simple password authentication to avoid eavesdroppers, to do so, start MLflow server with basic-auth. The authentication functionality is in an early stage, at the point of writing MLflow provides a python snippet to update the password.

ml MLflow/2.10.2-gfbf-2023a
mlflow server --port 50000 --app-name basic-auth # default user: admin; password: password.

You can access the web interface via a desktop session on the login node, or with a ssh port forward:

ssh -L50000:localhost:50000

You can then visit and use the web interface to visualize and compare models. Below shows a typical usage where one visualizes 1) the correlation of different parameters and metrics using a parallel coordinate chart and 2) the validation losses, to determine favorable choices of hyperparameters.

mlflow web ui example

Finally, you can load saved models through Mlflow:

import mlflow
import numpy as np

logged_model = 'runs:/4d6e610aa0034ac8a1cde9f0ea166e0e/vit_b_32-0.001-1'
loaded_model = mlflow.pyfunc.load_model(logged_model)

data = np.random.rand(3, 3, 224, 224)

You can get the path for the logged models from the web ui, or use model registry to manage different versions of your model.

Reference and notes🔗

This example is based on the CIFAR 10 dataset and the VisionTransformer models implemented in torchvision. See respective links for more information. The training script is a trimmed-down version borrowed from torchvision, for realistic tasks and good performance you might want to include the tweaks of the training procedure implemented in their reference scripts.