'Unable to load pytorch neural network in fastapi app
I have trained and saved a pytorch neural network in Jupyternotebook with this code
class MultiClass(nn.Module):
def __init__(self, num_features):
super(MultiClass, self).__init__()
self.layer_1 = nn.Linear(num_features, 128)
self.layer_out = nn.Linear(128, 5)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = F.dropout(F.relu(self.layer_1(x)), training=self.training)
x = self.layer_out(x)
return self.softmax(x)
// def train(...)
// def test(...)
device = get_device()
model = MultiClass(7)
model.to(device)
N_EPOCHS = 200
BATCH_SIZE = 32
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
// train the model
torch.save(model, 'model.pt')
I could load the saved model to make prediction in the Jupyternotebook like this
model = torch.load('model.pt')
I then built a docker image with this Dockerfile
FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7
COPY requirements.txt .
RUN pip3 install -r requirements.txt
RUN pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
COPY ./app /app
COPY ./models /models
CMD ["gunicorn", "-k", "uvicorn.workers.UvicornWorker", "-c", "/gunicorn_conf.py", "main:app"]
But when I tried to do docker run with that image, I got this error
[2022-03-20 14:23:04 +0000] [1] [INFO] Starting gunicorn 20.1.0
[2022-03-20 14:23:04 +0000] [1] [INFO] Listening at: http://0.0.0.0:80 (1)
[2022-03-20 14:23:04 +0000] [1] [INFO] Using worker: uvicorn.workers.UvicornWorker
[2022-03-20 14:23:04 +0000] [9] [INFO] Booting worker with pid: 9
[2022-03-20 14:23:04 +0000] [10] [INFO] Booting worker with pid: 10
[2022-03-20 14:23:08 +0000] [9] [ERROR] Exception in worker process
Traceback (most recent call last):
File "/usr/local/lib/python3.7/site-packages/gunicorn/arbiter.py", line 589, in spawn_worker
worker.init_process()
File "/usr/local/lib/python3.7/site-packages/uvicorn/workers.py", line 63, in init_process
super(UvicornWorker, self).init_process()
File "/usr/local/lib/python3.7/site-packages/gunicorn/workers/base.py", line 134, in init_process
self.load_wsgi()
File "/usr/local/lib/python3.7/site-packages/gunicorn/workers/base.py", line 146, in load_wsgi
self.wsgi = self.app.wsgi()
File "/usr/local/lib/python3.7/site-packages/gunicorn/app/base.py", line 67, in wsgi
self.callable = self.load()
File "/usr/local/lib/python3.7/site-packages/gunicorn/app/wsgiapp.py", line 58, in load
return self.load_wsgiapp()
File "/usr/local/lib/python3.7/site-packages/gunicorn/app/wsgiapp.py", line 48, in load_wsgiapp
return util.import_app(self.app_uri)
File "/usr/local/lib/python3.7/site-packages/gunicorn/util.py", line 359, in import_app
mod = importlib.import_module(module)
File "/usr/local/lib/python3.7/importlib/__init__.py", line 127, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "<frozen importlib._bootstrap>", line 1006, in _gcd_import
File "<frozen importlib._bootstrap>", line 983, in _find_and_load
File "<frozen importlib._bootstrap>", line 967, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 677, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 728, in exec_module
File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
File "/app/main.py", line 206, in <module>
model = torch.load('../models/model.pt')
File "/usr/local/lib/python3.7/site-packages/torch/serialization.py", line 607, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "/usr/local/lib/python3.7/site-packages/torch/serialization.py", line 882, in _load
result = unpickler.load()
File "/usr/local/lib/python3.7/site-packages/torch/serialization.py", line 875, in find_class
return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'MultiClass' on <module '__main__' from '/usr/local/bin/gunicorn'>
I don't know what that message AttributeError: Can't get attribute 'MultiClass' on <module '__main__' from '/usr/local/bin/gunicorn' meant. I thought it couldn't find the MultiClass (even I had imported it with from src.torch import MultiClass) so I put the MultiClass code inside the main.py where the app runs from, but still got the same attribute error.
How could I resolve this?
Thanks
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
