---
title: "Prediction pipeline"
description: "Load the trained model and build an end-to-end prediction function."
---

# Prediction pipeline

> For the complete documentation index see: [llms.txt](https://aerospike.com/docs/llms.txt)
> 
> All documentation pages available in markdown.

You have a trained model and a system for fast feature retrieval. Now you’ll connect them into a single function that goes from a driver ID to a decline risk prediction. The goal is a minimal serving path: one key lookup to Aerospike, one model call, one prediction result.

## Load the trained model

1.  Run `Cell 22` to load the Logistic Regression model you saved in Part 2.

Cell 22: Load the trained model for serving

```python
from pyspark.ml.classification import LogisticRegressionModel

lr_model = LogisticRegressionModel.load("./models/trip_decline_risk_lr")

print("Loaded trip decline risk model")
```

Expected output

```plaintext
Loaded trip decline risk model
```

## Build the prediction function

The `predict_decline_risk` function encapsulates the full serving flow:

1.  Retrieve the driver’s features from Aerospike using `get_feature_vector()`.
2.  Assemble the features into the vector format the model expects.
3.  Run the model to get a prediction and probability.

Keeping this path short matters. If you add extra lookups or broad queries per request, retrieval latency can become the bottleneck before model inference even starts.

1.  Run `Cell 23` to define the `predict_decline_risk` function.

Cell 23: Define predict\_decline\_risk function

```python
from pyspark.ml.feature import VectorAssembler

from pyspark.sql.types import StructType, StructField, DoubleType, LongType

def predict_decline_risk(driver_id):

    feature_columns = ["ds_decl_rate", "ds_avg_rating", "da_trips_today"]

    features = Entity.get_feature_vector(as_client, "driver", driver_id, feature_columns)

    if features is None:

        return {"error": f"Driver {driver_id} not found"}

    schema = StructType([

        StructField("ds_decl_rate", DoubleType(), True),

        StructField("ds_avg_rating", DoubleType(), True),

        StructField("da_trips_today", DoubleType(), True),

    ])

    feature_row = [(

        float(features["ds_decl_rate"]),

        float(features["ds_avg_rating"]),

        float(features["da_trips_today"]),

    )]

    feature_df = spark.createDataFrame(feature_row, schema)

    assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")

    vector_df = assembler.transform(feature_df)

    prediction_df = lr_model.transform(vector_df)

    result = prediction_df.select("probability", "prediction").collect()[0]

    return {

        "driver_id": driver_id,

        "decline_probability": round(float(result["probability"][1]), 4),

        "prediction": "higher risk" if result["prediction"] == 1.0 else "typical risk",

    }
```

The function casts `da_trips_today` to `DoubleType` because `VectorAssembler` requires all numeric inputs to be the same type. The `probability` field from Spark’s logistic regression model is a vector of `[P(class_0), P(class_1)]`, and index 1 is the decline probability.

## Test with known drivers

1.  Run `Cell 24` to run predictions for test drivers.

Cell 24: Run predictions for test drivers

```python
test_drivers = ["driver_005", "driver_042", "driver_087"]

for driver_id in test_drivers:

    result = predict_decline_risk(driver_id)

    print(f"{result['driver_id']}: {result['prediction']} "

          f"(decline_prob={result['decline_probability']})")
```

Expected output

```plaintext
driver_005: typical risk (decline_prob=0.0312)

driver_042: typical risk (decline_prob=0.0876)

driver_087: higher risk (decline_prob=0.8934)
```

The model predicts low decline probability for drivers with low historical decline rates and high ratings, and higher probability for drivers with elevated rates and lower ratings. This matches the patterns in the training data.

## Smart dispatch in action

1.  Run `Cell 25` to rank driver candidates by decline probability.

Cell 25: Rank driver candidates by decline probability

```python
import random

random.seed(99)

candidate_drivers = [f"driver_{random.randint(1, 100):03d}" for _ in range(10)]

candidate_drivers = list(set(candidate_drivers))

results = []

for driver_id in candidate_drivers:

    result = predict_decline_risk(driver_id)

    if "error" in result:

        continue

    results.append(result)

if not results:

    raise ValueError("No valid candidate drivers were found. Re-run the candidate generation cell.")

results.sort(key=lambda r: r["decline_probability"])

print("Smart dispatch ranking (lowest decline risk first):")

print(f"{'Driver':<14} {'Prediction':<15} {'Decline Prob'}")

print("-" * 44)

for r in results:

    print(f"{r['driver_id']:<14} {r['prediction']:<15} {r['decline_probability']}")
```

Expected output

```plaintext
Smart dispatch ranking (lowest decline risk first):

Driver         Prediction      Decline Prob

--------------------------------------------

driver_051     typical risk    0.0156

driver_073     typical risk    0.0234

driver_029     typical risk    0.0298

driver_014     typical risk    0.0445

driver_068     typical risk    0.0612

driver_091     typical risk    0.0789

driver_036     higher risk     0.7823

driver_017     higher risk     0.8456
```

The dispatch system would assign the ride request to `driver_051`, the candidate with the lowest predicted decline risk. Your ranked drivers and probabilities may differ from this example.

## Where this goes next

You now have a working serving path from `driver_id` to decline probability.

This tutorial still uses 100 drivers and 3 features. In the next two sections, you’ll first look at how this pattern generalizes, then scale both dimensions and measure whether retrieval latency stays sub-millisecond.

::: undefined
-   I can load the trained model and run predictions.
-   I have a working predict\_decline\_risk() function.
:::

[Previous  
Feature vectors for serving](https://aerospike.com/docs/develop/model-serving/step/1/part/1/get-feature-vector) [Next  
Beyond the tutorial](https://aerospike.com/docs/develop/model-serving/step/3/part/0/beyond-the-tutorial)