Prediction pipeline
For the complete documentation index see: llms.txt
All documentation pages available in markdown.
You have a trained model and a system for fast feature retrieval. Now you’ll connect them into a single function that goes from a driver ID to a decline risk prediction. The goal is a minimal serving path: one key lookup to Aerospike, one model call, one prediction result.
Load the trained model
- Run
Cell 22to load the Logistic Regression model you saved in Part 2.
Cell 22: Load the trained model for serving
from pyspark.ml.classification import LogisticRegressionModel
lr_model = LogisticRegressionModel.load("./models/trip_decline_risk_lr")print("Loaded trip decline risk model")Loaded trip decline risk modelBuild the prediction function
The predict_decline_risk function encapsulates the full serving flow:
- Retrieve the driver’s features from Aerospike using
get_feature_vector(). - Assemble the features into the vector format the model expects.
- Run the model to get a prediction and probability.
Keeping this path short matters. If you add extra lookups or broad queries per request, retrieval latency can become the bottleneck before model inference even starts.
- Run
Cell 23to define thepredict_decline_riskfunction.
Cell 23: Define predict_decline_risk function
from pyspark.ml.feature import VectorAssemblerfrom pyspark.sql.types import StructType, StructField, DoubleType, LongType
def predict_decline_risk(driver_id): feature_columns = ["ds_decl_rate", "ds_avg_rating", "da_trips_today"]
features = Entity.get_feature_vector(as_client, "driver", driver_id, feature_columns) if features is None: return {"error": f"Driver {driver_id} not found"}
schema = StructType([ StructField("ds_decl_rate", DoubleType(), True), StructField("ds_avg_rating", DoubleType(), True), StructField("da_trips_today", DoubleType(), True), ])
feature_row = [( float(features["ds_decl_rate"]), float(features["ds_avg_rating"]), float(features["da_trips_today"]), )] feature_df = spark.createDataFrame(feature_row, schema)
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features") vector_df = assembler.transform(feature_df)
prediction_df = lr_model.transform(vector_df) result = prediction_df.select("probability", "prediction").collect()[0]
return { "driver_id": driver_id, "decline_probability": round(float(result["probability"][1]), 4), "prediction": "higher risk" if result["prediction"] == 1.0 else "typical risk", }The function casts da_trips_today to DoubleType because VectorAssembler requires all numeric inputs to be the same type. The probability field from Spark’s logistic regression model is a vector of [P(class_0), P(class_1)], and index 1 is the decline probability.
Test with known drivers
- Run
Cell 24to run predictions for test drivers.
Cell 24: Run predictions for test drivers
test_drivers = ["driver_005", "driver_042", "driver_087"]
for driver_id in test_drivers: result = predict_decline_risk(driver_id) print(f"{result['driver_id']}: {result['prediction']} " f"(decline_prob={result['decline_probability']})")driver_005: typical risk (decline_prob=0.0312)driver_042: typical risk (decline_prob=0.0876)driver_087: higher risk (decline_prob=0.8934)The model predicts low decline probability for drivers with low historical decline rates and high ratings, and higher probability for drivers with elevated rates and lower ratings. This matches the patterns in the training data.
Smart dispatch in action
- Run
Cell 25to rank driver candidates by decline probability.
Cell 25: Rank driver candidates by decline probability
import randomrandom.seed(99)
candidate_drivers = [f"driver_{random.randint(1, 100):03d}" for _ in range(10)]candidate_drivers = list(set(candidate_drivers))
results = []for driver_id in candidate_drivers: result = predict_decline_risk(driver_id) if "error" in result: continue results.append(result)
if not results: raise ValueError("No valid candidate drivers were found. Re-run the candidate generation cell.")
results.sort(key=lambda r: r["decline_probability"])
print("Smart dispatch ranking (lowest decline risk first):")print(f"{'Driver':<14} {'Prediction':<15} {'Decline Prob'}")print("-" * 44)for r in results: print(f"{r['driver_id']:<14} {r['prediction']:<15} {r['decline_probability']}")Smart dispatch ranking (lowest decline risk first):Driver Prediction Decline Prob--------------------------------------------driver_051 typical risk 0.0156driver_073 typical risk 0.0234driver_029 typical risk 0.0298driver_014 typical risk 0.0445driver_068 typical risk 0.0612driver_091 typical risk 0.0789driver_036 higher risk 0.7823driver_017 higher risk 0.8456The dispatch system would assign the ride request to driver_051, the candidate with the lowest predicted decline risk.
Your ranked drivers and probabilities may differ from this example.
Where this goes next
You now have a working serving path from driver_id to decline probability.
This tutorial still uses 100 drivers and 3 features. In the next two sections, you’ll first look at how this pattern generalizes, then scale both dimensions and measure whether retrieval latency stays sub-millisecond.