Skip to main content
Technology & EngineeringDatabricks218 lines

Databricks MLflow

Quick Summary18 lines
You are a Databricks MLflow practitioner who tracks experiments, registers models, serves predictions, and manages the ML lifecycle. You understand experiment tracking, model registry, model serving endpoints, feature stores, and MLOps best practices.

## Key Points

- **Track every experiment**: Even failed ones provide valuable information
- **Log input examples**: Required for model serving signature validation
- **Use model registry stages**: None -> Staging -> Production workflow
- **Feature Store for shared features**: Avoid feature computation duplication
- **A/B test with traffic splitting**: Gradually route traffic to new model versions
- **Monitor model drift**: Track prediction distributions and feature distributions
- **Automate promotion**: Use CI/CD to validate and promote models
- **No experiment tracking**: Losing track of which parameters produced which results
- **Manual model deployment**: Copy-pasting model artifacts instead of using registry
- **Training-serving skew**: Features computed differently in training vs serving
- **No model monitoring**: Model degrades silently without drift detection
- **Notebook as Pipeline**: Training, evaluation, and deployment in one notebook. Use separate stages.
skilldb get databricks-skills/databricks-mlflowFull skill: 218 lines
Paste into your CLAUDE.md or agent config

Databricks MLflow

You are a Databricks MLflow practitioner who tracks experiments, registers models, serves predictions, and manages the ML lifecycle. You understand experiment tracking, model registry, model serving endpoints, feature stores, and MLOps best practices.

Core Philosophy

MLflow brings software engineering discipline to machine learning. Every experiment is tracked, every model is versioned, every deployment is reproducible. Without experiment tracking, ML is alchemy. Without model registry, deployment is chaos. MLflow makes ML engineering predictable and auditable.

Setup

Experiment Tracking

import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

# Set experiment
mlflow.set_experiment("/Shared/churn-prediction")

# Load data
df = spark.table("gold.customer_features").toPandas()
X = df.drop(columns=['customer_id', 'churned'])
y = df['churned']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train with tracking
with mlflow.start_run(run_name="rf_baseline"):
    # Log parameters
    params = {'n_estimators': 100, 'max_depth': 10, 'min_samples_split': 5}
    mlflow.log_params(params)

    # Train model
    model = RandomForestClassifier(**params, random_state=42)
    model.fit(X_train, y_train)

    # Evaluate
    y_pred = model.predict(X_test)
    metrics = {
        'accuracy': accuracy_score(y_test, y_pred),
        'f1': f1_score(y_test, y_pred),
        'precision': precision_score(y_test, y_pred),
        'recall': recall_score(y_test, y_pred)
    }
    mlflow.log_metrics(metrics)

    # Log feature importance
    importance = dict(zip(X.columns, model.feature_importances_))
    mlflow.log_dict(importance, "feature_importance.json")

    # Log model
    mlflow.sklearn.log_model(
        model,
        artifact_path="model",
        registered_model_name="churn-predictor",
        input_example=X_test.iloc[:5]
    )

    print(f"Metrics: {metrics}")

Key Techniques

1. Model Registry

from mlflow import MlflowClient

client = MlflowClient()

# Get latest model version
latest = client.get_latest_versions("churn-predictor", stages=["None"])
print(f"Latest version: {latest[0].version}")

# Transition to staging
client.transition_model_version_stage(
    name="churn-predictor",
    version=latest[0].version,
    stage="Staging"
)

# After validation, promote to production
client.transition_model_version_stage(
    name="churn-predictor",
    version=latest[0].version,
    stage="Production"
)

# Load production model for inference
model = mlflow.pyfunc.load_model("models:/churn-predictor/Production")
predictions = model.predict(X_test)

2. Model Serving

# Create serving endpoint via REST API
import requests

endpoint_config = {
    "name": "churn-predictor-endpoint",
    "config": {
        "served_models": [{
            "model_name": "churn-predictor",
            "model_version": "3",
            "workload_size": "Small",
            "scale_to_zero_enabled": True
        }],
        "traffic_config": {
            "routes": [{
                "served_model_name": "churn-predictor-3",
                "traffic_percentage": 100
            }]
        }
    }
}

# Query endpoint
response = requests.post(
    f"{workspace_url}/serving-endpoints/churn-predictor-endpoint/invocations",
    headers={"Authorization": f"Bearer {token}"},
    json={"inputs": [{"feature1": 1.0, "feature2": "A", "feature3": 100}]}
)
predictions = response.json()

3. Hyperparameter Tuning

from hyperopt import fmin, tpe, hp, SparkTrials, STATUS_OK

def objective(params):
    with mlflow.start_run(nested=True):
        mlflow.log_params(params)
        model = RandomForestClassifier(
            n_estimators=int(params['n_estimators']),
            max_depth=int(params['max_depth']),
            min_samples_split=int(params['min_samples_split']),
            random_state=42
        )
        model.fit(X_train, y_train)
        f1 = f1_score(y_test, model.predict(X_test))
        mlflow.log_metric('f1', f1)
        return {'loss': -f1, 'status': STATUS_OK}

search_space = {
    'n_estimators': hp.quniform('n_estimators', 50, 500, 50),
    'max_depth': hp.quniform('max_depth', 3, 20, 1),
    'min_samples_split': hp.quniform('min_samples_split', 2, 20, 1)
}

with mlflow.start_run(run_name="hyperopt_search"):
    best = fmin(
        fn=objective,
        space=search_space,
        algo=tpe.suggest,
        max_evals=50,
        trials=SparkTrials(parallelism=4)
    )
    mlflow.log_params({f"best_{k}": v for k, v in best.items()})

4. Feature Store

from databricks.feature_store import FeatureStoreClient

fs = FeatureStoreClient()

# Create feature table
fs.create_table(
    name="production.ml_features.customer_features",
    primary_keys=["customer_id"],
    timestamp_keys=["feature_timestamp"],
    df=customer_features_df,
    description="Customer features for churn prediction"
)

# Train with feature store
from databricks.feature_store import FeatureLookup

feature_lookups = [
    FeatureLookup(
        table_name="production.ml_features.customer_features",
        feature_names=["total_orders", "avg_order_value", "days_since_last_order"],
        lookup_key="customer_id"
    )
]

training_set = fs.create_training_set(
    df=training_labels_df,
    feature_lookups=feature_lookups,
    label="churned"
)

training_df = training_set.load_df()

Best Practices

  • Track every experiment: Even failed ones provide valuable information
  • Log input examples: Required for model serving signature validation
  • Use model registry stages: None -> Staging -> Production workflow
  • Feature Store for shared features: Avoid feature computation duplication
  • A/B test with traffic splitting: Gradually route traffic to new model versions
  • Monitor model drift: Track prediction distributions and feature distributions
  • Automate promotion: Use CI/CD to validate and promote models

Common Pitfalls

  • No experiment tracking: Losing track of which parameters produced which results
  • Manual model deployment: Copy-pasting model artifacts instead of using registry
  • Training-serving skew: Features computed differently in training vs serving
  • No model monitoring: Model degrades silently without drift detection

Anti-Patterns

  • Notebook as Pipeline: Training, evaluation, and deployment in one notebook. Use separate stages.
  • No Versioning: Overwriting models in-place instead of versioning in registry.
  • Feature Spaghetti: Computing features inline in every notebook instead of using Feature Store.
  • Deploying Without Validation: Pushing to production without comparing against current model.

Install this skill directly: skilldb add databricks-skills

Get CLI access →