Loading
Python and Apache Spark integration
Overview
This page describes how to develop a simple ingest and search application using the AVS Python client. Aerospike Vector Search (AVS) Python client integrates AVS with Apache Spark to facilitate large-scale data ingestion or search tasks.
Prerequisites
- A Spark cluster with PySpark Python version 3.9 or later.
- A running AVS cluster that is reachable from the Spark master and worker nodes.
Develop an ingest and search application
- Install the AVS Python client Add the following snippet to the Spark cluster's initialization script. This script installs the AVS Python client and ensures that the client is available on all nodes.
#!/bin/bash
# Initialization action to install a Python package on all nodes of a Google Cloud Dataproc cluster
python3 -m pip install aerospike_vector_search==1.0.1
- The script runs when the Spark cluster is created.
Sample application
You can copy the following example to build the ingest and search application. The example outlines a framework for integrating the AVS Python client into large-scale data processing systems.
import argparse
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType
#Usage
#spark-submit test.py --host "34.41.1.43" --port 5000 --read_path "gs://avs-parquet-sample"
# schema is encoded in parquet not needed in scipt, only for understanding how data looks like
schema = StructType([
StructField("productId", StringType(), True),
StructField("partnerId", StringType(), True),
StructField("image_embs", ArrayType(FloatType(), containsNull=False), True),
StructField("text_embs", ArrayType(FloatType(), containsNull=False), True)
])
img_set = "img_set_1"
img_idx = "img_idx_1"
txt_idx = "text_idx_1"
txt_set = "text_set_1"
from aerospike_vector_search import AdminClient, types, Client
def process_row(row, client):
try:
img_doc = {"productId": row.productId, "partnerId": row.partnerId, "image_embs": row.image_embs}
txt_doc = {"productId": row.productId, "text_embs": row.text_embs}
client.upsert(namespace="test", set_name=img_set, key=img_doc["productId"], record_data=img_doc)
client.upsert(namespace="test", set_name=txt_set, key=txt_doc["productId"], record_data=txt_doc)
except Exception as e:
raise
def init_client(host, port):
client = Client(seeds=types.HostPort(host=host, port=port), is_loadbalancer=True)
return client
def process_partition(partition, host, port):
client = None
try:
client = init_client(host, port)
for row in partition:
process_row(row, client)
finally:
if client:
client.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PySpark AVS Integration')
parser.add_argument('--host', required=True, help='AVS hostname')
parser.add_argument('--port', type=int, required=True, help='AVS port ')
parser.add_argument('--read_path', required=True, help='GCS Path to read Parquet data')
args = parser.parse_args()
# Initialize Spark session
spark = SparkSession.builder.appName("AVS python client integration with Spark").getOrCreate()
# Read Parquet data from GCS using the read_path argument
df_parquet = spark.read.parquet(args.read_path)
# Initialize Aerospike Admin Client with the command line arguments
try:
avs_admin_client = AdminClient(seeds=types.HostPort(host=args.host, port=args.port), is_loadbalancer=True)
avs_admin_client.index_create(namespace="test", name=img_idx, sets=img_set, vector_field="image_embs", dimensions=256)
avs_admin_client.index_create(namespace="test", name=txt_idx, sets=txt_set, vector_field="text_embs", dimensions=256)
finally:
avs_admin_client.close()
# Process partitions
process_partition_function = lambda partition: process_partition(partition, args.host, args.port)
df_parquet.rdd.foreachPartition(process_partition_function)
# Fetch first row and execute vector search as an example
first_row = df_parquet.first()
image_embs = first_row['image_embs']
try:
client_local = Client(seeds=types.HostPort(host=args.host, port=args.port), is_loadbalancer=True)
#wait for data to be indexed
client_local.wait_for_index_completion(namespace= "test", name= img_idx)
results = client_local.vector_search(namespace="test", index_name=img_idx, query=image_embs, limit=10, field_names=["productId", "partnerId"])
print (f"Len of vector results: {len(results)}")
for result in results:
print(f"{result.key} -> {result.fields}")
except Exception as e:
raise
finally:
client_local.close()
spark.stop()
Read the Docs
For details about using the Python client, visit our Read the Docs page.