Skip to content

Custom PyFunc Models

Skill: databricks-model-serving

You can package arbitrary Python logic — preprocessing pipelines, ensemble models, external API wrappers, anything — into a single deployable unit that Databricks Model Serving knows how to load and run. The PythonModel base class gives you two hooks: load_context for initialization and predict for inference. Everything else is your code.

“Create a custom PyFunc model that loads a pickled preprocessor and sklearn model, runs preprocessing on input, and returns predictions. Use Python.”

import mlflow
import pandas as pd
class MyCustomModel(mlflow.pyfunc.PythonModel):
def load_context(self, context):
"""Load artifacts when model is loaded on the serving endpoint."""
import pickle
with open(context.artifacts["preprocessor"], "rb") as f:
self.preprocessor = pickle.load(f)
with open(context.artifacts["model"], "rb") as f:
self.model = pickle.load(f)
def predict(self, context, model_input: pd.DataFrame) -> pd.DataFrame:
"""Run prediction with preprocessing."""
processed = self.preprocessor.transform(model_input)
predictions = self.model.predict(processed)
return pd.DataFrame({"prediction": predictions})
with mlflow.start_run():
mlflow.pyfunc.log_model(
artifact_path="model",
python_model=MyCustomModel(),
artifacts={
"preprocessor": "artifacts/preprocessor.pkl",
"model": "artifacts/model.pkl",
},
pip_requirements=["scikit-learn==1.3.0", "pandas==2.0.0"],
registered_model_name="main.models.custom_model",
)

Key decisions:

  • load_context runs once when the endpoint starts — put expensive initialization here, not in predict
  • artifacts dict maps logical names to local file paths. MLflow uploads them alongside the model and makes them available via context.artifacts at serving time
  • Pin exact package versions in pip_requirements to avoid dependency resolution failures on the serving endpoint
  • Return a DataFrame from predict — the serving infrastructure serializes it to JSON for the response

“Define input and output schemas for my custom model so the serving endpoint validates requests. Use Python.”

from mlflow.models import infer_signature, ModelSignature
from mlflow.types.schema import Schema, ColSpec
# Option 1: Infer from sample data
signature = infer_signature(
model_input=X_sample,
model_output=predictions_sample,
)
# Option 2: Define explicitly
input_schema = Schema([
ColSpec("double", "age"),
ColSpec("double", "income"),
ColSpec("string", "category"),
])
output_schema = Schema([
ColSpec("double", "probability"),
ColSpec("string", "class"),
])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
mlflow.pyfunc.log_model(
artifact_path="model",
python_model=MyModel(),
signature=signature,
input_example={"age": 25, "income": 50000, "category": "A"},
registered_model_name="main.models.my_model",
)

Explicit signatures catch malformed requests before they reach your predict method. The input_example also documents expected payload format in the MLflow UI.

“Log my custom model using the file-based pattern from MLflow 3 so it’s compatible with the agent deployment workflow. Use Python.”

my_model.py
import mlflow
from mlflow.pyfunc import PythonModel
class MyModel(PythonModel):
def predict(self, context, model_input):
return model_input * 2
mlflow.models.set_model(MyModel())
log_model.py
import mlflow
mlflow.set_registry_uri("databricks-uc")
with mlflow.start_run():
model_info = mlflow.pyfunc.log_model(
name="my-model",
python_model="my_model.py", # File path, not instance
pip_requirements=["mlflow>=3.0"],
registered_model_name="main.models.my_model",
)

The file-based pattern (python_model="my_model.py") is MLflow 3’s preferred approach. It lets MLflow capture the entire module, including imports and helpers, rather than pickling a single class instance.

“Log a model that depends on utility modules I wrote. Use Python.”

mlflow.pyfunc.log_model(
artifact_path="model",
python_model=MyModel(),
code_paths=["src/utils.py", "src/preprocessing.py"],
pip_requirements=["scikit-learn==1.3.0"],
registered_model_name="main.models.my_model",
)

Files in code_paths get added to the Python path when the model loads on the serving endpoint. Use this for helper modules that your model imports but that are not installable packages.

“Test my logged model locally before deploying to a serving endpoint. Use Python.”

# Load from the run and test
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
test_input = pd.DataFrame({"age": [25], "income": [50000]})
result = loaded_model.predict(test_input)
print(result)
# Full environment validation (installs deps in isolation)
mlflow.models.predict(
model_uri=model_info.model_uri,
input_data={"age": 25, "income": 50000},
env_manager="uv",
)

mlflow.models.predict with env_manager="uv" creates an isolated environment matching your pip_requirements, then runs prediction. This catches missing dependencies before you wait 15 minutes for a serving endpoint to fail.

  • Loose dependency versionspip_requirements=["scikit-learn", "pandas"] invites resolution failures on the serving endpoint. Pin exact versions: "scikit-learn==1.3.0", "pandas==2.0.0".
  • Heavy work in predict — loading files, importing large modules, or initializing clients belongs in load_context. The predict method runs on every request; load_context runs once at startup.
  • Missing input_example — without it, the MLflow UI cannot show expected payload format and the serving endpoint cannot validate request shape. Always pass one.
  • Forgetting code_paths — if your model imports from local modules, those files will not be on the serving endpoint unless you list them in code_paths. You will get ModuleNotFoundError at query time, not at logging time.