Pytorch
Training
"""
Example taken from https://github.com/mlflow/mlflow/blob/master/mlflow/pytorch/__init__.py
"""
import torch
import mlflow
import mlflow.pytorch
# X data
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
# Y data with its expected value: labels
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
# Partial Model example modified from Sung Kim
# https://github.com/hunkim/PyTorchZeroToAll
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(1, 1) # One in and one out
def forward(self, x):
y_pred = self.linear(x)
return y_pred
# our model
model = Model()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Training loop
for epoch in range(500):
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(x_data)
# Compute and print loss
loss = criterion(y_pred, y_data)
# Zero gradients, perform a backward pass, and update the weights.
optimizer.zero_grad()
loss.backward()
optimizer.step()
# After training
for hv in [4.0, 5.0, 6.0]:
hour_var = torch.Tensor([[hv]])
y_pred = model(hour_var)
mlflow.set_tracking_uri("http://localhost:5000")
experiment_name = "test_pytorch"
if mlflow.get_experiment_by_name(experiment_name) is None:
mlflow.create_experiment(experiment_name)
# log the model
with mlflow.start_run(experiment_id=2) as run:
mlflow.log_param("epochs", 500)
mlflow.pytorch.log_model(
model, "model", registered_model_name="pytorch_model"
)
To run it :
python3 -m examples.training.pytorch
Serving
from serveml.api import ApiBuilder
from serveml.inputs import BasicInput
from serveml.loader import load_mlflow_model
from serveml.predictions import GenericPrediction
# load model
model = load_mlflow_model(
# MlFlow model path
"models:/pytorch_model/1",
# MlFlow Tracking URI
"http://localhost:5000",
)
# Implement deserializer for input data
class LinearRegression(BasicInput):
input_prediction: float
# implement application
app = ApiBuilder(GenericPrediction(model), LinearRegression).build_api()
To run it :
uvicorn examples.serving.pytorch:app --host 0.0.0.0