'How to deploy image classifier with resnet50 model on AWS endpoint to predict without worker dying?

Created a imageclassifier model built on renet50 to identify dog breeds. I created it in sagemaker studio. Tuning and training are done, I deployed it, but when I try to predict on it, it fails. I believe this is related to the pid of the worker because its first warning I see. Getting following Cloudwatch log output says worker pid not available yet then soon after the worker dies.

timestamp,message,logStreamName
1648240674535,"2022-03-25 20:37:54,107 [INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager - Initializing plugins manager...",AllTraffic/i-055c5d00e53e84b93
1648240674535,"2022-03-25 20:37:54,188 [INFO ] main org.pytorch.serve.ModelServer - ",AllTraffic/i-055c5d00e53e84b93
1648240674535,Torchserve version: 0.4.0,AllTraffic/i-055c5d00e53e84b93
1648240674535,TS Home: /opt/conda/lib/python3.6/site-packages,AllTraffic/i-055c5d00e53e84b93
1648240674535,Current directory: /,AllTraffic/i-055c5d00e53e84b93
1648240674535,Temp directory: /home/model-server/tmp,AllTraffic/i-055c5d00e53e84b93
1648240674535,Number of GPUs: 0,AllTraffic/i-055c5d00e53e84b93
1648240674535,Number of CPUs: 1,AllTraffic/i-055c5d00e53e84b93
1648240674535,Max heap size: 6838 M,AllTraffic/i-055c5d00e53e84b93
1648240674535,Python executable: /opt/conda/bin/python3.6,AllTraffic/i-055c5d00e53e84b93
1648240674535,Config file: /etc/sagemaker-ts.properties,AllTraffic/i-055c5d00e53e84b93
1648240674535,Inference address: http://0.0.0.0:8080,AllTraffic/i-055c5d00e53e84b93
1648240674535,Management address: http://0.0.0.0:8080,AllTraffic/i-055c5d00e53e84b93
1648240674535,Metrics address: http://127.0.0.1:8082,AllTraffic/i-055c5d00e53e84b93
1648240674535,Model Store: /.sagemaker/ts/models,AllTraffic/i-055c5d00e53e84b93
1648240674535,Initial Models: model.mar,AllTraffic/i-055c5d00e53e84b93
1648240674535,Log dir: /logs,AllTraffic/i-055c5d00e53e84b93
1648240674535,Metrics dir: /logs,AllTraffic/i-055c5d00e53e84b93
1648240674535,Netty threads: 0,AllTraffic/i-055c5d00e53e84b93
1648240674535,Netty client threads: 0,AllTraffic/i-055c5d00e53e84b93
1648240674535,Default workers per model: 1,AllTraffic/i-055c5d00e53e84b93
1648240674535,Blacklist Regex: N/A,AllTraffic/i-055c5d00e53e84b93
1648240674535,Maximum Response Size: 6553500,AllTraffic/i-055c5d00e53e84b93
1648240674536,Maximum Request Size: 6553500,AllTraffic/i-055c5d00e53e84b93
1648240674536,Prefer direct buffer: false,AllTraffic/i-055c5d00e53e84b93
1648240674536,Allowed Urls: [file://.*|http(s)?://.*],AllTraffic/i-055c5d00e53e84b93
1648240674536,Custom python dependency for model allowed: false,AllTraffic/i-055c5d00e53e84b93
1648240674536,Metrics report format: prometheus,AllTraffic/i-055c5d00e53e84b93
1648240674536,Enable metrics API: true,AllTraffic/i-055c5d00e53e84b93
1648240674536,Workflow Store: /.sagemaker/ts/models,AllTraffic/i-055c5d00e53e84b93
1648240674536,"2022-03-25 20:37:54,195 [INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager -  Loading snapshot serializer plugin...",AllTraffic/i-055c5d00e53e84b93
1648240675536,"2022-03-25 20:37:54,217 [INFO ] main org.pytorch.serve.ModelServer - Loading initial models: model.mar",AllTraffic/i-055c5d00e53e84b93
1648240675536,"2022-03-25 20:37:55,505 [INFO ] main org.pytorch.serve.wlm.ModelManager - Model model loaded.",AllTraffic/i-055c5d00e53e84b93
1648240675786,"2022-03-25 20:37:55,515 [INFO ] main org.pytorch.serve.ModelServer - Initialize Inference server with: EpollServerSocketChannel.",AllTraffic/i-055c5d00e53e84b93
1648240675786,"2022-03-25 20:37:55,569 [INFO ] main org.pytorch.serve.ModelServer - Inference API bind to: http://0.0.0.0:8080",AllTraffic/i-055c5d00e53e84b93
1648240675786,"2022-03-25 20:37:55,569 [INFO ] main org.pytorch.serve.ModelServer - Initialize Metrics server with: EpollServerSocketChannel.",AllTraffic/i-055c5d00e53e84b93
1648240675786,"2022-03-25 20:37:55,569 [INFO ] main org.pytorch.serve.ModelServer - Metrics API bind to: http://127.0.0.1:8082",AllTraffic/i-055c5d00e53e84b93
1648240675786,Model server started.,AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,727 [WARN ] pool-2-thread-1 org.pytorch.serve.metrics.MetricCollector - worker pid is not available yet.",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,812 [INFO ] pool-2-thread-1 TS_METRICS - CPUUtilization.Percent:100.0|#Level:Host|#hostname:container-0.local,timestamp:1648240675",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,813 [INFO ] pool-2-thread-1 TS_METRICS - DiskAvailable.Gigabytes:38.02598190307617|#Level:Host|#hostname:container-0.local,timestamp:1648240675",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,813 [INFO ] pool-2-thread-1 TS_METRICS - DiskUsage.Gigabytes:12.715518951416016|#Level:Host|#hostname:container-0.local,timestamp:1648240675",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,814 [INFO ] pool-2-thread-1 TS_METRICS - DiskUtilization.Percent:25.1|#Level:Host|#hostname:container-0.local,timestamp:1648240675",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,815 [INFO ] pool-2-thread-1 TS_METRICS - MemoryAvailable.Megabytes:29583.98046875|#Level:Host|#hostname:container-0.local,timestamp:1648240675",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,815 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUsed.Megabytes:1355.765625|#Level:Host|#hostname:container-0.local,timestamp:1648240675",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,816 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUtilization.Percent:5.7|#Level:Host|#hostname:container-0.local,timestamp:1648240675",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,994 [INFO ] W-9000-model_1-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,994 [INFO ] W-9000-model_1-stdout MODEL_LOG - [PID]48",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,994 [INFO ] W-9000-model_1-stdout MODEL_LOG - Torch worker started.",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,994 [INFO ] W-9000-model_1-stdout MODEL_LOG - Python runtime: 3.6.13",AllTraffic/i-055c5d00e53e84b93
1648240676036,"2022-03-25 20:37:55,999 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Connecting to: /home/model-server/tmp/.ts.sock.9000",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,006 [INFO ] W-9000-model_1-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000.",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,111 [INFO ] W-9000-model_1-stdout MODEL_LOG - Backend worker process died.",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,111 [INFO ] W-9000-model_1-stdout MODEL_LOG - Traceback (most recent call last):",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,111 [INFO ] W-9000-model_1-stdout MODEL_LOG -   File ""/opt/conda/lib/python3.6/site-packages/ts/model_service_worker.py"", line 182, in <module>",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,111 [INFO ] W-9000-model_1-stdout MODEL_LOG -     worker.run_server()",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,111 [INFO ] W-9000-model_1-stdout MODEL_LOG -   File ""/opt/conda/lib/python3.6/site-packages/ts/model_service_worker.py"", line 154, in run_server",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,111 [INFO ] W-9000-model_1-stdout MODEL_LOG -     self.handle_connection(cl_socket)",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,112 [INFO ] W-9000-model_1-stdout MODEL_LOG -   File ""/opt/conda/lib/python3.6/site-packages/ts/model_service_worker.py"", line 116, in handle_connection",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,112 [INFO ] W-9000-model_1-stdout MODEL_LOG -     service, result, code = self.load_model(msg)",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,112 [INFO ] epollEventLoopGroup-5-1 org.pytorch.serve.wlm.WorkerThread - 9000 Worker disconnected. WORKER_STARTED",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,112 [INFO ] W-9000-model_1-stdout MODEL_LOG -   File ""/opt/conda/lib/python3.6/site-packages/ts/model_service_worker.py"", line 89, in load_model",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,112 [INFO ] W-9000-model_1-stdout MODEL_LOG -     service = model_loader.load(model_name, model_dir, handler, gpu, batch_size, envelope)",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,112 [INFO ] W-9000-model_1-stdout MODEL_LOG -   File ""/opt/conda/lib/python3.6/site-packages/ts/model_loader.py"", line 110, in load",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,112 [INFO ] W-9000-model_1-stdout MODEL_LOG -     initialize_fn(service.context)",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,113 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.BatchAggregator - Load model failed: model, error: Worker died.",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,113 [INFO ] W-9000-model_1-stdout MODEL_LOG -   File ""/home/model-server/tmp/models/23b30361031647d08792d32672910688/handler_service.py"", line 51, in initialize",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,113 [INFO ] W-9000-model_1-stdout MODEL_LOG -     super().initialize(context)",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,113 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stderr",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,113 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stdout",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,113 [INFO ] W-9000-model_1-stdout MODEL_LOG -   File ""/opt/conda/lib/python3.6/site-packages/sagemaker_inference/default_handler_service.py"", line 66, in initialize",AllTraffic/i-055c5d00e53e84b93
1648240676286,"2022-03-25 20:37:56,113 [INFO ] W-9000-model_1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stdout",AllTraffic/i-055c5d00e53e84b93
1648240676536,"2022-03-25 20:37:56,114 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Retry worker: 9000 in 1 seconds.",AllTraffic/i-055c5d00e53e84b93
1648240676536,"2022-03-25 20:37:56,416 [INFO ] W-9000-model_1-stderr org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stderr",AllTraffic/i-055c5d00e53e84b93
1648240676536,"2022-03-25 20:37:56,461 [INFO ] W-9000-model_1 ACCESS_LOG - /169.254.178.2:39848 ""GET /ping HTTP/1.1"" 200 9",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:56,461 [INFO ] W-9000-model_1 TS_METRICS - Requests2XX.Count:1|#Level:Host|#hostname:container-0.local,timestamp:null",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,567 [INFO ] W-9000-model_1-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,568 [INFO ] W-9000-model_1-stdout MODEL_LOG - [PID]86",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,568 [INFO ] W-9000-model_1-stdout MODEL_LOG - Torch worker started.",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,568 [INFO ] W-9000-model_1-stdout MODEL_LOG - Python runtime: 3.6.13",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,568 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Connecting to: /home/model-server/tmp/.ts.sock.9000",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,569 [INFO ] W-9000-model_1-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000.",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,642 [INFO ] W-9000-model_1-stdout MODEL_LOG - Backend worker process died.",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,642 [INFO ] W-9000-model_1-stdout MODEL_LOG - Traceback (most recent call last):",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,642 [INFO ] epollEventLoopGroup-5-2 org.pytorch.serve.wlm.WorkerThread - 9000 Worker disconnected. WORKER_STARTED",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,642 [INFO ] W-9000-model_1-stdout MODEL_LOG -   File ""/opt/conda/lib/python3.6/site-packages/ts/model_service_worker.py"", line 182, in <module>",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,643 [INFO ] W-9000-model_1-stdout MODEL_LOG -     worker.run_server()",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,643 [INFO ] W-9000-model_1-stdout MODEL_LOG -   File ""/opt/conda/lib/python3.6/site-packages/ts/model_service_worker.py"", line 154, in run_server",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,643 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.BatchAggregator - Load model failed: model, error: Worker died.",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,643 [INFO ] W-9000-model_1-stdout MODEL_LOG -     self.handle_connection(cl_socket)",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,643 [INFO ] W-9000-model_1-stdout MODEL_LOG -   File ""/opt/conda/lib/python3.6/site-packages/ts/model_service_worker.py"", line 116, in handle_connection",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,643 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stderr",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,643 [INFO ] W-9000-model_1-stdout MODEL_LOG -     service, result, code = self.load_model(msg)",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,643 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stdout",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,643 [INFO ] W-9000-model_1-stdout MODEL_LOG -   File ""/opt/conda/lib/python3.6/site-packages/ts/model_service_worker.py"", line 89, in load_model",AllTraffic/i-055c5d00e53e84b93
1648240677787,"2022-03-25 20:37:57,643 [INFO ] W-9000-model_1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stdout",AllTraffic/i-055c5d00e53e84b93
1648240678037,"2022-03-25 20:37:57,643 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Retry worker: 9000 in 1 seconds.",AllTraffic/i-055c5d00e53e84b93
1648240679288,"2022-03-25 20:37:57,991 [INFO ] W-9000-model_1-stderr org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stderr",AllTraffic/i-055c5d00e53e84b93
1648240679288,"2022-03-25 20:37:59,096 [INFO ] W-9000-model_1-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000",AllTraffic/i-055c5d00e53e84b93
1648240679288,"2022-03-25 20:37:59,097 [INFO ] W-9000-model_1-stdout MODEL_LOG - [PID]114",AllTraffic/i-055c5d00e53e84b93

Model tuning and training came out alright so I'm not sure why it won't predict if that is fine. Someone mentioned to me that it might be due to entry point script, but I don't know what would cause it fail in predicting after deployed if it can predict fine during training.

Entry point script:

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import json

import copy
import argparse
import os
import logging
import sys
from tqdm import tqdm
from PIL import ImageFile
import smdebug.pytorch as smd

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger=logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))

def test(model, test_loader, criterion, hook):
    model.eval()
    running_loss=0
    running_corrects=0
    hook.set_mode(smd.modes.EVAL)
    
    
    for inputs, labels in test_loader:
        outputs=model(inputs)
        loss=criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    ##total_loss = running_loss // len(test_loader)
    ##total_acc = running_corrects.double() // len(test_loader)
    
    ##logger.info(f"Testing Loss: {total_loss}")
    ##logger.info(f"Testing Accuracy: {total_acc}")
    logger.info("New test acc")
    logger.info(f'Test set: Accuracy: {running_corrects}/{len(test_loader.dataset)} = {100*(running_corrects/len(test_loader.dataset))}%)')

def train(model, train_loader, validation_loader, criterion, optimizer, hook):
    epochs=50
    best_loss=1e6
    image_dataset={'train':train_loader, 'valid':validation_loader}
    loss_counter=0
    hook.set_mode(smd.modes.TRAIN)
    
    for epoch in range(epochs):
        logger.info(f"Epoch: {epoch}")
        for phase in ['train', 'valid']:
            if phase=='train':
                model.train()
                logger.info("Model Trained")
            else:
                model.eval()
            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in image_dataset[phase]:
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                if phase=='train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    logger.info("Model Optimized")

                _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss // len(image_dataset[phase])
            epoch_acc = running_corrects // len(image_dataset[phase])
            
            
            if phase=='valid':
                logger.info("Model Validating")
                if epoch_loss<best_loss:
                    best_loss=epoch_loss
                else:
                    loss_counter+=1

            logger.info(loss_counter)
            '''logger.info('{} loss: {:.4f}, acc: {:.4f}, best loss: {:.4f}'.format(phase,
                                                                                 epoch_loss,
                                                                                 epoch_acc,
                                                                                 best_loss))'''
            
            if phase=="train":
                logger.info("New epoch acc for Train:")
                logger.info(f"Epoch {epoch}: Loss {loss_counter/len(train_loader.dataset)}, Accuracy {100*(running_corrects/len(train_loader.dataset))}%")
            if phase=="valid":
                logger.info("New epoch acc for Valid:")
                logger.info(f"Epoch {epoch}: Loss {loss_counter/len(train_loader.dataset)}, Accuracy {100*(running_corrects/len(train_loader.dataset))}%")
            
        ##if loss_counter==1:
        ##    break
        ##if epoch==0:
        ##    break
    return model
    
def net():
    model = models.resnet50(pretrained=True)

    for param in model.parameters():
        param.requires_grad = False   

    model.fc = nn.Sequential(
                   nn.Linear(2048, 128),
                   nn.ReLU(inplace=True),
                   nn.Linear(128, 133))
    return model

def create_data_loaders(data, batch_size):
    train_data_path = os.path.join(data, 'train')
    test_data_path = os.path.join(data, 'test')
    validation_data_path=os.path.join(data, 'valid')

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        ])

    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        ])

    train_data = torchvision.datasets.ImageFolder(root=train_data_path, transform=train_transform)
    train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

    test_data = torchvision.datasets.ImageFolder(root=test_data_path, transform=test_transform)
    test_data_loader  = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

    validation_data = torchvision.datasets.ImageFolder(root=validation_data_path, transform=test_transform)
    validation_data_loader  = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, shuffle=True) 
    
    return train_data_loader, test_data_loader, validation_data_loader

def main(args):
    logger.info(f'Hyperparameters are LR: {args.lr}, Batch Size: {args.batch_size}')
    logger.info(f'Data Paths: {args.data}')

    
    train_loader, test_loader, validation_loader=create_data_loaders(args.data, args.batch_size)
    model=net()
    
    hook = smd.Hook.create_from_json_file()
    hook.register_hook(model)
    
    criterion = nn.CrossEntropyLoss(ignore_index=133)
    optimizer = optim.Adam(model.fc.parameters(), lr=args.lr)
    
    logger.info("Starting Model Training")
    model=train(model, train_loader, validation_loader, criterion, optimizer, hook)
    
    logger.info("Testing Model")
    test(model, test_loader, criterion, hook)
    
    logger.info("Saving Model")
    torch.save(model.cpu().state_dict(), os.path.join(args.model_dir, "model.pth"))
    
if __name__=='__main__':
    parser=argparse.ArgumentParser()
    '''
    TODO: Specify any training args that you might need
    '''
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=1000,
        metavar="N",
        help="input batch size for testing (default: 1000)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=5,
        metavar="N",
        help="number of epochs to train (default: 10)",
    )
    parser.add_argument(
        "--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)"
    )
    parser.add_argument(
        "--momentum", type=float, default=0.5, metavar="M", help="SGD momentum (default: 0.5)"
    )

    # Container environment
    parser.add_argument("--hosts", type=list, default=json.loads(os.environ["SM_HOSTS"]))
    parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"])
    parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
    parser.add_argument("--data", type=str, default=os.environ["SM_CHANNEL_TRAINING"])
    parser.add_argument("--num-gpus", type=int, default=os.environ["SM_NUM_GPUS"])
    args=parser.parse_args()
    
    main(args)

To test the model on the endpoint I sent over an image using the following code:

from sagemaker.serializers import IdentitySerializer
import base64

predictor.serializer = IdentitySerializer("image/png")
with open("Akita_00282.jpg", "rb") as f:
    payload = f.read()

    
response = predictor.predict(payload)```


Solution 1:[1]

The model serving workers are either dying because they cannot load your model or deserialize the payload you are sending to them.

Note that you have to provide a model_fn implementation. Please read these docs here or this blog here to know more about how to adapt the inference scripts for SageMaker deployment. If you do not want to override the input_fn, predict_fn, and/or output_fn handlers, you can find their default implementations, for example, here.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1