Prophet

Training

"""
Example taken from https://github.com/mlflow/mlflow/blob/master/examples/prophet/train.py
"""
import warnings
import sys

import pandas as pd
import numpy as np

import mlflow
import mlflow.pyfunc

import cloudpickle

import fbprophet
from fbprophet import Prophet
from fbprophet.diagnostics import cross_validation
from fbprophet.diagnostics import performance_metrics

import logging

logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)


class FbProphetWrapper(mlflow.pyfunc.PythonModel):
    def __init__(self, model):
        self.model = model
        super(FbProphetWrapper, self).__init__()

    def load_context(self, context):
        from fbprophet import Prophet

        return

    def predict(self, context, model_input):
        future = self.model.make_future_dataframe(
            periods=model_input["periods"][0]
        )
        return self.model.predict(future)


conda_env = {
    "channels": ["defaults", "conda-forge"],
    "dependencies": [
        "fbprophet={}".format(fbprophet.__version__),
        "cloudpickle={}".format(cloudpickle.__version__),
    ],
    "name": "fbp_env",
}


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    np.random.seed(40)

    csv_url = (
        sys.argv[1]
        if len(sys.argv) > 1
        else "https://raw.githubusercontent.com/facebook/prophet/e21a05f4f9290649255a2a306855e8b4620816d7/examples/example_wp_log_peyton_manning.csv"
    )
    rolling_window = float(sys.argv[2]) if len(sys.argv) > 2 else 0.1

    # Read the csv file from the URL
    try:
        df = pd.read_csv(csv_url)
    except Exception as e:
        logger.exception(
            "Unable to download training & test CSV, check your internet connection. Error: %s",
            e,
        )

    mlflow.set_tracking_uri("http://localhost:5000")

    experiment_name = "test_prophet"
    if mlflow.get_experiment_by_name(experiment_name) is None:
        mlflow.create_experiment(experiment_name)

    # Useful for multiple runs (only doing one run in this sample notebook)
    with mlflow.start_run(experiment_id=6):
        m = Prophet()
        m.fit(df)

        # Evaluate Metrics
        df_cv = cross_validation(
            m, initial="730 days", period="180 days", horizon="365 days"
        )
        df_p = performance_metrics(df_cv, rolling_window=rolling_window)

        # Print out metrics
        print("Prophet model (rolling_window=%f):" % (rolling_window))
        print("  CV: \n%s" % df_cv.head())
        print("  Perf: \n%s" % df_p.head())

        # Log parameter, metrics, and model to MLflow
        mlflow.log_param("rolling_window", rolling_window)
        mlflow.log_metric("rmse", df_p.loc[0, "rmse"])

        mlflow.pyfunc.log_model(
            "model",
            conda_env=conda_env,
            python_model=FbProphetWrapper(m),
            registered_model_name="prophet_model",
        )
        print(
            "Logged model with URI: runs:/{run_id}/model".format(
                run_id=mlflow.active_run().info.run_id
            )
        )

To run it :

python3 -m examples.training.prophet

Serving

from serveml.api import ApiBuilder
from serveml.inputs import BasicInput
from serveml.loader import load_mlflow_model
from serveml.predictions import GenericPrediction

# load model
model = load_mlflow_model(
    # MlFlow model path
    "models:/prophet_model/1",
    # MlFlow Tracking URI
    "http://localhost:5000",
)


# Implement deserializer for input data
class PeriodPrediction(BasicInput):
    periods: int


# implement application
app = ApiBuilder(GenericPrediction(model), PeriodPrediction).build_api()

To run it :

uvicorn examples.serving.prophet:app --host 0.0.0.0