Skip to content

Prediction pipeline

For the complete documentation index see: llms.txt

All documentation pages available in markdown.

You have a trained model and a system for fast feature retrieval. Now you’ll connect them into a single function that goes from a driver ID to a decline risk prediction. The goal is a minimal serving path: one key lookup to Aerospike, one model call, one prediction result.

Load the trained model

  1. Run Cell 22 to load the Logistic Regression model you saved in Part 2.

Cell 22: Load the trained model for serving

from pyspark.ml.classification import LogisticRegressionModel
lr_model = LogisticRegressionModel.load("./models/trip_decline_risk_lr")
print("Loaded trip decline risk model")
Expected output
Loaded trip decline risk model

Build the prediction function

The predict_decline_risk function encapsulates the full serving flow:

  1. Retrieve the driver’s features from Aerospike using get_feature_vector().
  2. Assemble the features into the vector format the model expects.
  3. Run the model to get a prediction and probability.

Keeping this path short matters. If you add extra lookups or broad queries per request, retrieval latency can become the bottleneck before model inference even starts.

  1. Run Cell 23 to define the predict_decline_risk function.

Cell 23: Define predict_decline_risk function

from pyspark.ml.feature import VectorAssembler
from pyspark.sql.types import StructType, StructField, DoubleType, LongType
def predict_decline_risk(driver_id):
feature_columns = ["ds_decl_rate", "ds_avg_rating", "da_trips_today"]
features = Entity.get_feature_vector(as_client, "driver", driver_id, feature_columns)
if features is None:
return {"error": f"Driver {driver_id} not found"}
schema = StructType([
StructField("ds_decl_rate", DoubleType(), True),
StructField("ds_avg_rating", DoubleType(), True),
StructField("da_trips_today", DoubleType(), True),
])
feature_row = [(
float(features["ds_decl_rate"]),
float(features["ds_avg_rating"]),
float(features["da_trips_today"]),
)]
feature_df = spark.createDataFrame(feature_row, schema)
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
vector_df = assembler.transform(feature_df)
prediction_df = lr_model.transform(vector_df)
result = prediction_df.select("probability", "prediction").collect()[0]
return {
"driver_id": driver_id,
"decline_probability": round(float(result["probability"][1]), 4),
"prediction": "higher risk" if result["prediction"] == 1.0 else "typical risk",
}

The function casts da_trips_today to DoubleType because VectorAssembler requires all numeric inputs to be the same type. The probability field from Spark’s logistic regression model is a vector of [P(class_0), P(class_1)], and index 1 is the decline probability.

Test with known drivers

  1. Run Cell 24 to run predictions for test drivers.

Cell 24: Run predictions for test drivers

test_drivers = ["driver_005", "driver_042", "driver_087"]
for driver_id in test_drivers:
result = predict_decline_risk(driver_id)
print(f"{result['driver_id']}: {result['prediction']} "
f"(decline_prob={result['decline_probability']})")
Expected output
driver_005: typical risk (decline_prob=0.0312)
driver_042: typical risk (decline_prob=0.0876)
driver_087: higher risk (decline_prob=0.8934)

The model predicts low decline probability for drivers with low historical decline rates and high ratings, and higher probability for drivers with elevated rates and lower ratings. This matches the patterns in the training data.

Smart dispatch in action

  1. Run Cell 25 to rank driver candidates by decline probability.

Cell 25: Rank driver candidates by decline probability

import random
random.seed(99)
candidate_drivers = [f"driver_{random.randint(1, 100):03d}" for _ in range(10)]
candidate_drivers = list(set(candidate_drivers))
results = []
for driver_id in candidate_drivers:
result = predict_decline_risk(driver_id)
if "error" in result:
continue
results.append(result)
if not results:
raise ValueError("No valid candidate drivers were found. Re-run the candidate generation cell.")
results.sort(key=lambda r: r["decline_probability"])
print("Smart dispatch ranking (lowest decline risk first):")
print(f"{'Driver':<14} {'Prediction':<15} {'Decline Prob'}")
print("-" * 44)
for r in results:
print(f"{r['driver_id']:<14} {r['prediction']:<15} {r['decline_probability']}")
Expected output
Smart dispatch ranking (lowest decline risk first):
Driver Prediction Decline Prob
--------------------------------------------
driver_051 typical risk 0.0156
driver_073 typical risk 0.0234
driver_029 typical risk 0.0298
driver_014 typical risk 0.0445
driver_068 typical risk 0.0612
driver_091 typical risk 0.0789
driver_036 higher risk 0.7823
driver_017 higher risk 0.8456

The dispatch system would assign the ride request to driver_051, the candidate with the lowest predicted decline risk. Your ranked drivers and probabilities may differ from this example.

Where this goes next

You now have a working serving path from driver_id to decline probability.

This tutorial still uses 100 drivers and 3 features. In the next two sections, you’ll first look at how this pattern generalizes, then scale both dimensions and measure whether retrieval latency stays sub-millisecond.

Feedback

Was this page helpful?

What type of feedback are you giving?

What would you like us to know?

+Capture screenshot

Can we reach out to you?