'Tensorflow (Cummulative) Counting, how to call different labels?

I am very new to Python, but need to work with it during my study for a project. We are trying to make a camera scanning system. I have found an amazing code by Tanner Gilbert: https://github.com/TannerGilbert/Tensorflow-2-Object-Counting that does exactly what we want it to do, but to make it fully functionable for our project, we want to make a counter for a few individual objects. So we would love to have an output like:

  • Apples: 1
  • Bananas: 3
  • Carrots: 0

But since I don't fully understand Python yet I can't figure out how to get these counters working, so how to use labels individually. So I thought why not as on here! Hopefully you can help me out. (I have already looked at the official Tensorflow Counting API github, but I still can't figure it out) Thanks in advance for helping me and my group out!

This is the main part of the code:

import cv2
import numpy as np
import argparse
import tensorflow as tf
import dlib

from object_detection.utils import label_map_util
from object_detection.utils import ops as utils_ops

from trackable_object import TrackableObject
from centroidtracker import CentroidTracker

# patch tf1 into `utils.ops`
utils_ops.tf = tf.compat.v1

# Patch the location of gfile
tf.gfile = tf.io.gfile


def load_model(model_path):
    tf.keras.backend.clear_session()
    model = tf.saved_model.load(model_path)
    return model


def run_inference_for_single_image(model, image):
    image = np.asarray(image)
    # The input needs to be a tensor, convert it using `tf.convert_to_tensor`.
    input_tensor = tf.convert_to_tensor(image)
    # The model expects a batch of images, so add an axis with `tf.newaxis`.
    input_tensor = input_tensor[tf.newaxis, ...]

    # Run inference
    output_dict = model(input_tensor)

    # All outputs are batches tensors.
    # Convert to numpy arrays, and take index [0] to remove the batch dimension.
    # We're only interested in the first num_detections.
    num_detections = int(output_dict.pop('num_detections'))
    output_dict = {key: value[0, :num_detections].numpy()
                   for key, value in output_dict.items()}
    output_dict['num_detections'] = num_detections

    # detection_classes should be ints.
    output_dict['detection_classes'] = output_dict['detection_classes'].astype(
        np.int64)

    # Handle models with masks:
    if 'detection_masks' in output_dict:
        # Reframe the the bbox mask to the image size.
        detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
            output_dict['detection_masks'], output_dict['detection_boxes'],
            image.shape[0], image.shape[1])
        detection_masks_reframed = tf.cast(
            detection_masks_reframed > 0.5, tf.uint8)
        output_dict['detection_masks_reframed'] = detection_masks_reframed.numpy()

    return output_dict


def run_inference(model, category_index, cap, labels, roi_position=0.6, threshold=0.5, x_axis=True, skip_frames=20, save_path='', show=True):
    counter = [0, 0, 0, 0]  # left, right, up, down
    total_frames = 0

    ct = CentroidTracker(maxDisappeared=40, maxDistance=50)
    trackers = []
    trackableObjects = {}

    # Check if results should be saved
    if save_path:
        width = int(cap.get(3))
        height = int(cap.get(4))
        fps = cap.get(cv2.CAP_PROP_FPS)
        out = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(
            'M', 'J', 'P', 'G'), fps, (width, height))

    while cap.isOpened():
        ret, image_np = cap.read()
        if not ret:
            break

        height, width, _ = image_np.shape
        rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)

        status = "Waiting"
        rects = []

        if total_frames % skip_frames == 0:
            status = "Detecting"
            trackers = []

            # Actual detection.
            output_dict = run_inference_for_single_image(model, image_np)

            for i, (y_min, x_min, y_max, x_max) in enumerate(output_dict['detection_boxes']):
                if output_dict['detection_scores'][i] > threshold and (labels == None or category_index[output_dict['detection_classes'][i]]['name'] in labels):
                    tracker = dlib.correlation_tracker()
                    rect = dlib.rectangle(
                        int(x_min * width), int(y_min * height), int(x_max * width), int(y_max * height))
                    tracker.start_track(rgb, rect)
                    trackers.append(tracker)
        else:
            status = "Tracking"
            for tracker in trackers:
                # update the tracker and grab the updated position
                tracker.update(rgb)
                pos = tracker.get_position()

                # unpack the position object
                x_min, y_min, x_max, y_max = int(pos.left()), int(
                    pos.top()), int(pos.right()), int(pos.bottom())

                # add the bounding box coordinates to the rectangles list
                rects.append((x_min, y_min, x_max, y_max))

        objects = ct.update(rects)

        for (objectID, centroid) in objects.items():
            to = trackableObjects.get(objectID, None)

            if to is None:
                to = TrackableObject(objectID, centroid)
            else:
                if x_axis and not to.counted:
                    x = [c[0] for c in to.centroids]
                    direction = centroid[0] - np.mean(x)

                    if centroid[0] > roi_position*width and direction > 0 and np.mean(x) < args.roi_position*width:
                        counter[1] += 1
                        to.counted = True
                    elif centroid[0] < roi_position*width and direction < 0 and np.mean(x) > args.roi_position*width:
                        counter[0] += 1
                        to.counted = True

                elif not x_axis and not to.counted:
                    y = [c[1] for c in to.centroids]
                    direction = centroid[1] - np.mean(y)

                    if centroid[1] > roi_position*height and direction > 0 and np.mean(y) < args.roi_position*height:
                        counter[3] += 1
                        to.counted = True
                    elif centroid[1] < roi_position*height and direction < 0 and np.mean(y) > args.roi_position*height:
                        counter[2] += 1
                        to.counted = True

                to.centroids.append(centroid)

            trackableObjects[objectID] = to

            text = "ID {}".format(objectID)
            cv2.putText(image_np, text, (centroid[0] - 10, centroid[1] - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
            cv2.circle(
                image_np, (centroid[0], centroid[1]), 4, (255, 255, 255), -1)

        # Draw ROI line
        if x_axis:
            cv2.line(image_np, (int(roi_position*width), 0),
                     (int(roi_position*width), height), (0xFF, 0, 0), 5)
        else:
            cv2.line(image_np, (0, int(roi_position*height)),
                     (width, int(roi_position*height)), (0xFF, 0, 0), 5)

        # display count and status
        font = cv2.FONT_HERSHEY_SIMPLEX
        if x_axis:
            cv2.putText(image_np, f'Left: {counter[0]}; Right: {counter[1]}', (
                10, 35), font, 0.8, (0, 0xFF, 0xFF), 2, cv2.FONT_HERSHEY_SIMPLEX)
        else:
            cv2.putText(image_np, f'Up: {counter[2]}; Down: {counter[3]}', (
                10, 35), font, 0.8, (0, 0xFF, 0xFF), 2, cv2.FONT_HERSHEY_SIMPLEX)
        cv2.putText(image_np, 'Status: ' + status, (10, 70), font,
                    0.8, (0, 0xFF, 0xFF), 2, cv2.FONT_HERSHEY_SIMPLEX)

        if show:
            cv2.imshow('cumulative_object_counting', image_np)
            if cv2.waitKey(25) & 0xFF == ord('q'):
                break

        if save_path:
            out.write(image_np)

        total_frames += 1

    cap.release()
    if save_path:
        out.release()
    cv2.destroyAllWindows()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Detect objects inside webcam videostream')
    parser.add_argument('-m', '--model', type=str,
                        required=True, help='Model Path')
    parser.add_argument('-l', '--labelmap', type=str,
                        required=True, help='Path to Labelmap')
    parser.add_argument('-v', '--video_path', type=str, default='',
                        help='Path to video. If None camera will be used')
    parser.add_argument('-t', '--threshold', type=float,
                        default=0.5, help='Detection threshold')
    parser.add_argument('-roi', '--roi_position', type=float,
                        default=0.6, help='ROI Position (0-1)')
    parser.add_argument('-la', '--labels', nargs='+', type=str,
                        help='Label names to detect (default="all-labels")')
    parser.add_argument('-a', '--axis', default=True, action="store_false",
                        help='Axis for cumulative counting (default=x axis)')
    parser.add_argument('-s', '--skip_frames', type=int, default=20,
                        help='Number of frames to skip between using object detection model')
    parser.add_argument('-sh', '--show', default=True,
                        action="store_false", help='Show output')
    parser.add_argument('-sp', '--save_path', type=str, default='',
                        help='Path to save the output. If None output won\'t be saved')
    args = parser.parse_args()

    detection_model = load_model(args.model)
    category_index = label_map_util.create_category_index_from_labelmap(
        args.labelmap, use_display_name=True)

    if args.video_path != '':
        cap = cv2.VideoCapture(args.video_path)
    else:
        cap = cv2.VideoCapture(0)

    if not cap.isOpened():
        print("Error opening video stream or file")

    run_inference(detection_model, category_index, cap, labels=args.labels, threshold=args.threshold,
                  roi_position=args.roi_position, x_axis=args.axis, skip_frames=args.skip_frames, save_path=args.save_path, show=args.show)


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source