'Batch Prediction Job non-blocking

I am running a Vertex AI batch prediction using the python API. The function I am using is from the google cloud docs:

def create_batch_prediction_job_dedicated_resources_sample(
    key_path,
    project: str,
    location: str,
    model_display_name: str,
    job_display_name: str,
    gcs_source: Union[str, Sequence[str]],
    gcs_destination: str,
    machine_type: str = "n1-standard-2",
    sync: bool = True,
):
    credentials = service_account.Credentials.from_service_account_file(
    key_path)

# Initilaize an aiplatfrom object
 aiplatform.init(project=project, location=location, credentials=credentials)

# Get a list of Models by Model name
 models = aiplatform.Model.list(filter=f'display_name="{model_display_name}"')
 model_resource_name = models[0].resource_name

# Get the model
 my_model = aiplatform.Model(model_resource_name)

 batch_prediction_job = my_model.batch_predict(
    job_display_name=job_display_name,
    gcs_source=gcs_source,
    gcs_destination_prefix=gcs_destination,
    machine_type=machine_type,
    sync=sync,
)

 #batch_prediction_job.wait_for_resource_creation()
 batch_prediction_job.wait()

 print(batch_prediction_job.display_name)
 print(batch_prediction_job.resource_name)
 print(batch_prediction_job.state)
 return batch_prediction_job

datetime_today = datetime.datetime.now()
model_display_name = 'test_model'
key_path = 'vertex_key.json'
project = 'my_project'
location = 'asia-south1'
job_display_name = 'batch_prediction_' + str(datetime_today)
model_name = '1234'
gcs_source = 'gs://my_bucket/Cleaned_Data/user_item_pairs.jsonl'
gcs_destination = 'gs://my_bucket/prediction'

create_batch_prediction_job_dedicated_resources_sample(key_path,project,location,model_display_name,job_display_name,
                                                      gcs_source,gcs_destination)

OUTPUT:

92 current state:
JobState.JOB_STATE_RUNNING
INFO:google.cloud.aiplatform.jobs:BatchPredictionJob projects/my_project/locations/asia-south1/batchPredictionJobs/37737350127597649

The above output is being printed on the terminal over and over after every few seconds.

The issue that I have is that the python program calling this function keeps on running until it is force stopped. I have tried both batch_prediction_job.wait() & batch_prediction_job.wait_for_resource_creation() with the same results.

How do I start a batch_prediction_job without waiting for it to complete and terminating the program just after the job has be created?



Solution 1:[1]

I gave you the wrong instruction on the comments, change the parameter sync=False and the function should return just after be executed.

Whether this function call should be synchronous (wait for pipeline run to finish before terminating) or asynchronous (return immediately)

sync=False

def create_batch_prediction_job_dedicated_resources_sample(
# ...
    sync: bool = False,
):

UPDATE - Adding more details:

Check here my notebook code where I tested it and its working: You have to change the sync=False AND remove/comment the following print lines:

#batch_prediction_job.wait()
#print(batch_prediction_job.display_name)
#print(batch_prediction_job.resource_name)
#print(batch_prediction_job.state)

Your code edited:

def create_batch_prediction_job_dedicated_resources_sample(
    key_path,
    project: str,
    location: str,
    model_display_name: str,
    job_display_name: str,
    gcs_source: Union[str, Sequence[str]],
    gcs_destination: str,
    machine_type: str = "n1-standard-2",
    sync: bool = False,
):
    credentials = service_account.Credentials.from_service_account_file(key_path)

# Initilaize an aiplatfrom object
 aiplatform.init(project=project, location=location, credentials=credentials)

# Get a list of Models by Model name
 models = aiplatform.Model.list(filter=f'display_name="{model_display_name}"')
 model_resource_name = models[0].resource_name

# Get the model
 my_model = aiplatform.Model(model_resource_name)

 batch_prediction_job = my_model.batch_predict(
    job_display_name=job_display_name,
    gcs_source=gcs_source,
    gcs_destination_prefix=gcs_destination,
    machine_type=machine_type,
    sync=sync,
)


 return batch_prediction_job

datetime_today = datetime.datetime.now()
model_display_name = 'test_model'
key_path = 'vertex_key.json'
project = '<my_project_name>'
location = 'asia-south1'
job_display_name = 'batch_prediction_' + str(datetime_today)
model_name = '1234'
gcs_source = 'gs://<my_bucket_name>/Cleaned_Data/user_item_pairs.jsonl'
gcs_destination = 'gs://<my_bucket_name>/prediction'

create_batch_prediction_job_dedicated_resources_sample(key_path,
                         project,location,
                         model_display_name,
                         job_display_name,
                         gcs_source,
                         gcs_destination,
                         sync=False,)

Results sync=False: enter image description here Results sync=True: enter image description here

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1