'Unable to load pytorch neural network in fastapi app

I have trained and saved a pytorch neural network in Jupyternotebook with this code

class MultiClass(nn.Module):
    def __init__(self, num_features):
        super(MultiClass, self).__init__()
        self.layer_1 = nn.Linear(num_features, 128)
        self.layer_out = nn.Linear(128, 5)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = F.dropout(F.relu(self.layer_1(x)), training=self.training)      
        x = self.layer_out(x)      
        return self.softmax(x)  

// def train(...)
// def test(...)

device = get_device()
model = MultiClass(7)
model.to(device) 
N_EPOCHS = 200
BATCH_SIZE = 32
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

// train the model

torch.save(model, 'model.pt')

I could load the saved model to make prediction in the Jupyternotebook like this

model = torch.load('model.pt')

I then built a docker image with this Dockerfile

FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7
COPY requirements.txt .
RUN pip3 install -r requirements.txt
RUN pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
COPY ./app /app
COPY ./models /models
CMD ["gunicorn", "-k", "uvicorn.workers.UvicornWorker", "-c", "/gunicorn_conf.py", "main:app"]

But when I tried to do docker run with that image, I got this error

[2022-03-20 14:23:04 +0000] [1] [INFO] Starting gunicorn 20.1.0
[2022-03-20 14:23:04 +0000] [1] [INFO] Listening at: http://0.0.0.0:80 (1)
[2022-03-20 14:23:04 +0000] [1] [INFO] Using worker: uvicorn.workers.UvicornWorker
[2022-03-20 14:23:04 +0000] [9] [INFO] Booting worker with pid: 9
[2022-03-20 14:23:04 +0000] [10] [INFO] Booting worker with pid: 10
[2022-03-20 14:23:08 +0000] [9] [ERROR] Exception in worker process
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/site-packages/gunicorn/arbiter.py", line 589, in spawn_worker
    worker.init_process()
  File "/usr/local/lib/python3.7/site-packages/uvicorn/workers.py", line 63, in init_process
    super(UvicornWorker, self).init_process()
  File "/usr/local/lib/python3.7/site-packages/gunicorn/workers/base.py", line 134, in init_process
    self.load_wsgi()
  File "/usr/local/lib/python3.7/site-packages/gunicorn/workers/base.py", line 146, in load_wsgi
    self.wsgi = self.app.wsgi()
  File "/usr/local/lib/python3.7/site-packages/gunicorn/app/base.py", line 67, in wsgi
    self.callable = self.load()
  File "/usr/local/lib/python3.7/site-packages/gunicorn/app/wsgiapp.py", line 58, in load
    return self.load_wsgiapp()
  File "/usr/local/lib/python3.7/site-packages/gunicorn/app/wsgiapp.py", line 48, in load_wsgiapp
    return util.import_app(self.app_uri)
  File "/usr/local/lib/python3.7/site-packages/gunicorn/util.py", line 359, in import_app
    mod = importlib.import_module(module)
  File "/usr/local/lib/python3.7/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1006, in _gcd_import
  File "<frozen importlib._bootstrap>", line 983, in _find_and_load
  File "<frozen importlib._bootstrap>", line 967, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 677, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 728, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/app/main.py", line 206, in <module>
    model = torch.load('../models/model.pt')
  File "/usr/local/lib/python3.7/site-packages/torch/serialization.py", line 607, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/usr/local/lib/python3.7/site-packages/torch/serialization.py", line 882, in _load
    result = unpickler.load()
  File "/usr/local/lib/python3.7/site-packages/torch/serialization.py", line 875, in find_class
    return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'MultiClass' on <module '__main__' from '/usr/local/bin/gunicorn'>

I don't know what that message AttributeError: Can't get attribute 'MultiClass' on <module '__main__' from '/usr/local/bin/gunicorn' meant. I thought it couldn't find the MultiClass (even I had imported it with from src.torch import MultiClass) so I put the MultiClass code inside the main.py where the app runs from, but still got the same attribute error.

How could I resolve this?

Thanks



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source