Distributed Training and Experiment Tracking with Ray Train, MLflow, and MinIO

Distributed Training and Experiment Tracking with Ray Train, MLflow, and MinIO

Over the past few months, I have written about a number of different technologies (Ray Data, Ray Train, and MLflow). I thought it would make sense to pull them all together and deliver an easy-to-understand recipe for distributed data preprocessing and distributed training using a production-ready MLOPs tool for tracking and model serving. This post integrates the code I presented in my Ray Train post that distributes training across a cluster of workers with a deployment of MLFlow that uses MinIO under the hood for artifact storage and model checkpoints. While my code trains a model on the MNIST dataset, the code is mostly boilerplate - replace the MNIST model with your model and replace the MNIST data access and preprocessing with your data access and preprocessing, and you are ready to start training your model. A fully functioning sample containing all the code presented in this post can be found here.

The diagram below is a visualization of how distributed training, distributed preprocessing and MLflow fit together. This is the diagram I presented in my Ray Train post with MLFlow added. It represents a really good start to building a foundation for all your AI initiatives: MinIO for high-speed object storage, Ray for distributed training and data processing, and MLFlow for MLOPs.

Let’s start by revisiting the setup code I introduced for Ray Train and add the MLFlow setup to it.

Setting Up MLFlow for Distributed Training

The code below is the setup for distributed training with the MLFlow setup code added. I have highlighted the additional code necessary for MLFlow. At the top of the function, MLFlow is configured and a run is started. I’ll explain the additions to the training configuration parameter in the next section. When a run is complete, you need to let MLFlow know - this is done at the bottom of the function. If you are new to MLFlow Tracking, then check out my post on MLFlow Tracking with MinIO. You may also want to check out Setting Up a Development Machine with MLFlow and MinIO if you want to install MLflow on your development machine.

def distributed_training(training_parameters, num_workers: int, use_gpu: bool):
  logger = du.create_logger()

  # Setup mlflow to point to our server.
  experiment_name = 'MLFlow - Ray test'
  run_name = 'Testing Epoch metrics'
  mlflow_base_url = 'http://localhost:5001/'
  active_experiment = mlflow.set_experiment(experiment_name)
  active_run = mlflow.start_run(run_name=run_name)
  training_parameters['mlflow_base_url'] = mlflow_base_url
  training_parameters['run_id'] = active_run.info.run_id
  # Log parameters
  logger.info('Initializing Ray.')

  train_data, test_data, load_time_sec = du.get_ray_dataset(training_parameters)

  # Scaling configuration
  scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

  # Initialize a Ray TorchTrainer
  start_time = time()
  trainer = TorchTrainer(
      datasets={'train': train_data},
  result = trainer.fit()
  training_time_sec = (time()-start_time)

  logger.info(f'Load Time (in seconds) = {load_time_sec}')
  logger.info(f'Training Time (in seconds) = {training_time_sec}')
  model = tu.MNISTModel(training_parameters['input_size'], 


  with result.checkpoint.as_directory() as checkpoint_dir:
      model.load_state_dict(torch.load(os.path.join(checkpoint_dir, "model.pt")))
  tu.test_model(model, test_data)

  # Shut down Ray  
  # End the run

The Problem with Tracking Distributed Experiments

The problem with using the MLFlow Python library with distributed training is that all of its functions use a run id that is maintained internally - the run id itself is not a parameter to functions like log_metric(), or log_metrics(). So, when the Ray Train workers start, they will not have the run ID that was created when the controlling processes started a run since they are in different processes. This problem is easy to fix. We can simply pass the run ID into the worker processes as part of the training configuration. However, that does not solve the problem with the MLFlow library. Fortunately, MLFlow has a REST API that accepts run ID as a parameter for all calls. It also requires the base URL for MLflow. Below is a function that wraps the MLFlow REST API for logging a metric. Check out the MLFlow REST API samples for functions that wrap other MLFlow features.

def log_metric(base_url: str, run_id: str, metric: Dict[str, float]) -> int:
  '''Log a metric dict for the given run.'''
  base_url = f'{base_url}/api/2.0/mlflow'
  url = base_url + '/runs/log-metric'
  payload = {
      "run_id": run_id,
      "key": metric["key"],
      "value": metric["value"],
      "timestamp": mlflow.utils.time.get_current_time_millis(),
      "step": metric["step"],
  r = requests.post(url, json=payload)
  return r.status_code

MLflow’s base URL and the run ID can be added to the training configuration variable using the snippet below. (The training configuration variable is a Python dictionary; it is the only parameter that can be passed to the worker functions.)

training_parameters['mlflow_url'] = mlflow_url
training_parameters['run_id'] = active_run.info.run_id

We now have a way to send MLFlow information to the distributed workers and we have a function that can make RESTful calls to MLFlow’s Tracking Server. The next step is to use the function above from within the distributed workers' training loop.

Adding Experiment Tracking to Ray Train Workers

Adding tracking to the function that will run within the processes of the remote worker requires minimal code. The complete function is shown below with the added lines of code highlighted.

def train_func_per_worker(training_parameters):
  # Train the model and log training metrics.
  model = tu.MNISTModel(training_parameters['input_size'], 

  model = ray.train.torch.prepare_model(model)

  # Get the dataset shard for the training worker.
  train_data_shard = train.get_dataset_shard('train')

  loss_func = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=training_parameters['lr'], 


  metrics = {}
  batch_size_per_worker = training_parameters['batch_size_per_worker']
  for epoch in range(training_parameters['epochs']):
      total_loss = 0
      batch_count = 0
      for batch in train_data_shard.iter_torch_batches(batch_size=batch_size_per_worker):
          # Get the images and labels from the batch.
          images, labels = batch['X'], batch['y']
          labels = labels.type(torch.LongTensor)   # casting to long
          images, labels = images.to(device), labels.to(device)

          # Flatten MNIST images into a 784 long vector.
          images = images.view(images.shape[0], -1)
          # Training pass
          output = model(images)

          loss = loss_func(output, labels)
          # This is where the model learns by backpropagating
          # And optimizes its weights here
          total_loss += loss.item()
          batch_count += 1

      metrics = {'training_loss': total_loss/batch_count}
      checkpoint = None
      if train.get_context().get_world_rank() == 0:
          temp_dir = os.path.join(os.getcwd(), 'checkpoint')
          torch.save(model.module.state_dict(), os.path.join(temp_dir, 'mnist_model.pt'))
          checkpoint = Checkpoint.from_directory(temp_dir)
          mlflow_metric = {}
          mlflow_metric['key'] = 'training_loss'
          mlflow_metric['value'] = loss.item()
          mlflow_metric['step'] = epoch+1
          log_metric(training_parameters['mlflow_base_url'], training_parameters['run_id'], 


      train.report(metrics, checkpoint=checkpoint)

There are a few things to know about this code. First, I am only logging metrics from one of the workers. While all workers have their own copy of the model being trained, it is synchronized across the workers. Therefore, it is unnecessary to log metrics from every worker. If you do, you will get redundant information in MLFlow. Second, This code still uses Ray Train reporting for metrics and checkpointing. If you wish, it is possible to transition all reporting and checkpointing to MLFlow.


In this post I showed how to add MLflow Tracking to a Machine Learning pipeline that uses distributed training and distributed preprocessing. If you want to learn more about what you can do with MinIO, Ray Data, Ray Train, and MLflow, then check out the following related posts. 

Distributed Training with Ray Train and MinIO

Distributed Data Processing with Ray Data and MinIO

Setting Up a Development Machine with MLFlow and MinIO

MLflow Tracking and MinIO

MLflow Model Registry and MinIO

Incorporating these technologies into your ML pipeline is the first step toward building a complete AI Infrastructure.  You will have: 

  • MinIO - A high-performance Data Lake
  • Ray Data - Distributed preprocessing
  • Ray Train - Distributed Training
  • MLflow - MLOPs

As a next step, consider adding a Modern Datalake to your infrastructure.

Feel free to reach out to us on our general Slack channel or at hello@min.io if you want to discuss further.

Previous Post Next Post