'Amazon Sagemaker: User Input data validation in Inference Endpoint

I have successfully built a Sagemaker endpoint using a Tensorflow model. The pre and post processing is done inside "inference.py" which calls a handler function based on this tutorial: https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/using_tf.html#how-to-implement-the-pre-and-or-post-processing-handler-s

My questions are:

  • Which method is good for validating user input data within inference.py?
  • If such validation tests fail (e.g. wrong data types or data not in allowed range, etc.), how is it possible to return appropriate error messages with status codes to the user?
  • How is this compatible with the API gateway placed above the endpoint?

Here is the structure of the inference.py with the desired validation check as a comment:

import json
import requests


def handler(data, context):
    """Handle request.
    Args:
        data (obj): the request data
        context (Context): an object containing request and configuration details
    Returns:
        (bytes, string): data to return to client, (optional) response content type
    """
    processed_input = _process_input(data, context)
    response = requests.post(context.rest_uri, data=processed_input)
    return _process_output(response, context)


def _process_input(data, context):
    if context.request_content_type == 'application/json':
        # pass through json (assumes it's correctly formed)
        d = data.read().decode('utf-8')
        data_dict = json.loads(data)


        # ----->   if data_dict['input_1'] > 25000:
        # ----->       return some error specific message with status code 123


        return some_preprocessing_function(data_dict)

    raise ValueError('{{"error": "unsupported content type {}"}}'.format(
        context.request_content_type or "unknown"))


def _process_output(data, context):
    if data.status_code != 200:
        raise ValueError(data.content.decode('utf-8'))

    response_content_type = context.accept_header
    prediction = data.content
    return prediction, response_content_type


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source