Skip to main content

MLflow

mlflow logo

MLflow is an open-source platform designed to manage the machine learning lifecycle, including experimentation, reproducibility, and deployment.

MLflow Tracking allows you to log and visualize your machine learning experiments, allowing you to keep track of various aspects of your runs, such as metrics, parameters, artifacts, and models.

As an alternative to logging local data as artifacts or using an MLflow Tracking Server to track datasets that live in remote object storage, you can simply log Arraylake repo data to an MLflow run to track your Arraylake datasets alongside your machine learning experiments.

Install MLflow and Arraylake (if you haven't already) according to the docs:

To log an Arraylake dataset as an MLflow dataset, you will first need to define a custom ArraylakeDatasetSource class so that MLflow can build an MLflow Dataset object from an Arraylake source.

To do this, subclass the mlflow.data.dataset_source.DataSource abstract class, add the required arraylake_repo and arraylake_ref properties, and flesh out the _get_source_type, load, to_dict, and from_dict class methods:

from typing import Any, Dict

import xarray as xr
from mlflow.data.dataset_source import DatasetSource
from mlflow.exceptions import MlflowException

class ArraylakeDatasetSource(DatasetSource):
"""
Represents the source of an Arraylake dataset.

Args:
arraylake_repo:
The Arraylake repo associated with the dataset
arraylake_ref:
The Arraylake commit hash version of the dataset.
"""
def __init__(self, arraylake_repo: str, arraylake_ref: str):
self.arraylake_repo = arraylake_repo
self.arraylake_ref = arraylake_ref

@staticmethod
def _get_source_type() -> str:
return "arraylake"

def load(self, group: str) -> xr.Dataset:
"""
Loads the given Zarr group from the arraylake repo into
an Xarray dataset.
"""
al_client = al.Client()
repo = al_client.get_repo(self.arraylake_repo, checkout=False)
repo.checkout(ref=self.arraylake_ref)
return repo.to_xarray(group)

def to_dict(self) -> Dict[Any, Any]:
return {
"arraylake_repo": self.arraylake_repo,
"arraylake_ref": self.arraylake_ref,
}

@classmethod
def from_dict(cls, source_dict: Dict[Any, Any]) -> "ArraylakeDatasetSource":
arraylake_repo = source_dict.get("arraylake_repo")
if arraylake_repo is None:
raise MlflowException('Failed to parse ArraylakeDatasetSource. Missing expected key: "arraylake_repo"')

arraylake_ref = source_dict.get("arraylake_ref")
if arraylake_ref is None:
raise MlflowException('Failed to parse ArraylakeDatasetSource. Missing expected key: "arraylake_ref"')

return cls(arraylake_repo=arraylake_repo, arraylake_ref=arraylake_ref)

@staticmethod
def _can_resolve(raw_source: Any):
return False

@classmethod
def _resolve(cls, raw_source: str) -> "ArraylakeDatasetSource":
raise NotImplementedError

Register the Arraylake data source with MLflow:

from mlflow.data import get_registered_sources
from mlflow.data.dataset_source_registry import register_dataset_source

# Register the custom data source class
register_dataset_source(ArraylakeDatasetSource)

# Confirm that the data source successfully registered
get_registered_sources()
[mlflow.data.artifact_dataset_sources.LocalArtifactDatasetSource,
mlflow.data.artifact_dataset_sources.LocalArtifactDatasetSource,
mlflow.data.artifact_dataset_sources.S3ArtifactDatasetSource,
...
__main__.ArraylakeDatasetSource]

Open the Arraylake repo containing the dataset that you would like to track:

import arraylake as al

# Connect to Arraylake by specifying 'organization/repo'
al_client = al.Client()
repo = al_client.get_repo("earthmover/mlflow-demo")

Create a MLflow Dataset object with the Arraylake dataset source:

from mlflow.data.meta_dataset import MetaDataset

# Instantiate ArraylakeDatasetSource with Arraylake repo data
source = ArraylakeDatasetSource.from_dict({
"arraylake_repo": repo.repo_name,
"arraylake_ref": str(repo.session.base_commit)
})

# Create MetaDataset object from this source
arraylake_dataset = MetaDataset(
source=source,
name="arraylake_dataset",
digest=source.arraylake_ref,
)

We recommend using a MetaDataset to represent your MLflow dataset, since we only want to log dataset metadata (as opposed to logging the actual data).

Configure a new MLflow experiment:

import mlflow
mlflow.set_experiment("Arraylake Integration")

Log the dataset to an MLflow run within this experiment:

# Start an MLflow run
with mlflow.start_run() as run:
mlflow.log_input(arraylake_dataset, context="demo")

You can view the logged Arraylake dataset in the MLflow UI, which you can start with the command mlflow ui in your terminal. By default, it runs on http://localhost:5000.

Dataset in MLflow

You can now get the dataset logged to the above run and open the data with Xarray:

# Get the previous run
logged_run = mlflow.get_run(run.info.run_id)

# Get the dataset logged to the run
logged_dataset = logged_run.inputs.dataset_inputs[0].dataset

# Get the dataset source object
dataset_source = mlflow.data.get_source(logged_dataset)

# Open the repo data with Xarray
group = "air_temperature"
dataset_source.load(group=group)
<xarray.Dataset> Size: 31MB
Dimensions: (time: 2920, lat: 25, lon: 53)
Coordinates:
* lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
* lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
* time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
air (time, lat, lon) float64 31MB ...
Attributes:
Conventions: COARDS
description: Data is from NMC initialized reanalysis\n(4x/day). These a...
platform: Model
references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
title: 4x daily NMC reanalysis (1948)

That's it! Now you can log Arraylake data to MLflow and can open your data directly from an MLflow run 🚀