'How do I get a dataframe or database write from TFX BulkInferrer?

I'm very new to TFX, but have an apparently-working ML Pipeline which is to be used via BulkInferrer. That seems to produce output exclusively in Protobuf format, but since I'm running bulk inference I want to pipe the results to a database instead. (DB output seems like it should be the default for bulk inference, since both Bulk Inference & DB access take advantage of parallelization... but Protobuf is a per-record, serialized format.)

I assume I could use something like Parquet-Avro-Protobuf to do the conversion (though that's in Java and the rest of the pipeline's in Python), or I could write something myself to consume all the protobuf messages one-by-one, convert them into JSON, deserialize the JSON into a list of dicts, and load the dict into a Pandas DataFrame, or store it as a bunch of key-value pairs which I treat like a single-use DB... but that sounds like a lot of work and pain involving parallelization and optimization for a very common use case. The top-level Protobuf message definition is Tensorflow's PredictionLog.

This must be a common use case, because TensorFlowModelAnalytics functions like this one consume Pandas DataFrames. I'd rather be able to write directly to a DB (preferably Google BigQuery), or a Parquet file (since Parquet / Spark seems to parallelize better than Pandas), and again, those seem like they should be common use cases, but I haven't found any examples. Maybe I'm using the wrong search terms?

I also looked at the PredictExtractor, since "extracting predictions" sounds close to what I want... but the official documentation appears silent on how that class is supposed to be used. I thought TFTransformOutput sounded like a promising verb, but instead it's a noun.

I'm clearly missing something fundamental here. Is there a reason no one wants to store BulkInferrer results in a database? Is there a configuration option that allows me to write the results to a DB? Maybe I want to add a ParquetIO or BigQueryIO instance to the TFX pipeline? (TFX docs say it uses Beam "under the hood" but that doesn't say much about how I should use them together.) But the syntax in those documents looks sufficiently different from my TFX code that I'm not sure if they're compatible?

Help?



Solution 1:[1]

Answering my own question here to document what we did, even though I think @Hamza Tahir's answer below is objectively better. This may provide an option for other situations where it's necessary to change the operation of an out-of-the-box TFX component. It's hacky though:

We copied and edited the file tfx/components/bulk_inferrer/executor.py, replacing this transform in the _run_model_inference() method's internal pipeline:

| 'WritePredictionLogs' >> beam.io.WriteToTFRecord(
             os.path.join(inference_result.uri, _PREDICTION_LOGS_FILE_NAME),
             file_name_suffix='.gz',
             coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog)))

with this one:

| 'WritePredictionLogsBigquery' >> beam.io.WriteToBigQuery(
           'our_project:namespace.TableName',
           schema='SCHEMA_AUTODETECT',
           write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
           create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
           custom_gcs_temp_location='gs://our-storage-bucket/tmp',
           temp_file_format='NEWLINE_DELIMITED_JSON',
           ignore_insert_ids=True,
       )

(This works because when you import the BulkInferrer component, the per-node work gets farmed out to these executors running on the worker nodes, and TFX copies its own library onto those nodes. It doesn't copy everything from user-space libaries, though, which is why we couldn't just subclass BulkInferrer and import our custom version.)

We had to make sure the table at 'our_project:namespace.TableName' had a schema compatible with the model's output, but didn't have to translate that schema into JSON / AVRO.

In theory, my group would like to make a pull request with TFX built around this, but for now we're hard-coding a couple key parameters, and don't have the time to get this to a real public / production state.

Solution 2:[2]

I'm a little late to this party but this is some code I use for this task:

import tensorflow as tf
from tensorflow_serving.apis import prediction_log_pb2
import pandas as pd


def parse_prediction_logs(inference_filenames: List[Text]): -> pd.DataFrame
    """
    Args:
        inference files:  tf.io.gfile.glob(Inferrer artifact uri)
    Returns:
        a dataframe of userids, predictions, and features
    """

    def parse_log(pbuf):
        # parse the protobuf
        message = prediction_log_pb2.PredictionLog()
        message.ParseFromString(pbuf)
        # my model produces scores and classes and I extract the topK classes
        predictions = [x.decode() for x in (message
                                            .predict_log
                                            .response
                                            .outputs['output_2']
                                            .string_val
                                            )[:10]]
        # here I parse the input tf.train.Example proto
        inputs = tf.train.Example()
        inputs.ParseFromString(message
                               .predict_log
                               .request
                               .inputs['input_1'].string_val[0]
                               )

        # you can pull out individual features like this         
        uid = inputs.features.feature["userId"].bytes_list.value[0].decode()

        feature1 = [
            x.decode() for x in inputs.features.feature["feature1"].bytes_list.value
        ]

        feature2 = [
            x.decode() for x in inputs.features.feature["feature2"].bytes_list.value
        ]

        return (uid, predictions, feature1, feature2)

    return pd.DataFrame(
        [parse_log(x) for x in
         tf.data.TFRecordDataset(inference_filenames, compression_type="GZIP").as_numpy_iterator()
        ], columns = ["userId", "predictions", "feature1", "feature2"]
    )

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Sarah Messer
Solution 2 Tim