'How to get output class from Transformers model?
I am new to Transformers and learning it from the Huggingface site. I was testing "nlpaueb/legal-bert-base-uncased" model and here is my code that I tried:
from transformers import AutoModel, AutoTokenizer
model_name = "nlpaueb/legal-bert-base-uncased"
model_obj = AutoModel.from_pretrained(model_name)
tokenizer_obj = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer_obj("The applicant submitted that her husband was subjected to treatment amounting to whilst in the custody of police.", return_tensors="pt")
outputs = model_obj(**inputs)
When I executed the piece of code, I am getting tensors value like below.
BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.3980, -0.4034, -0.0714, ..., -0.3217, 0.0475, 0.2944],
[-0.2333, -0.4582, 0.0377, ..., 0.0586, -0.1767, -0.0554],
[-0.8154, -0.1276, -0.1303, ..., -0.2736, -0.1420, -0.0670],
...,
[-0.2970, 0.4259, 0.2161, ..., 0.2423, -0.1243, 0.0498],
[-0.1869, -0.2065, -0.4536, ..., -0.2596, -0.6090, -0.4568],
[-0.4008, -0.3992, -0.0599, ..., -0.3235, 0.0489, 0.3189]]],
grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-4.9280e-01, 3.1985e-01, 9.8638e-01, 3.7893e-01, 4.7280e-01,
2.5696e-01, -4.3757e-01, -2.8542e-01, 9.5426e-01, 5.8507e-01,
8.9029e-01, -1.9357e-01, 5.1010e-01, -9.7153e-01, -6.6227e-01,
3.9676e-01, 1.6204e-01, 2.0944e-01, -4.4709e-01, 2.2710e-01,
2.5347e-01, -4.8045e-01, 9.6663e-01, 8.1490e-02, 1.4256e-02,
-8.6576e-01, -1.2257e-01, 7.2242e-01, -5.7108e-02, -5.1389e-01,
-3.1027e-01, 7.5320e-02, 8.4610e-01, 3.7473e-01, 9.8752e-01,
-5.1406e-01, 3.5442e-01, 8.7922e-02, 1.3536e-01, 1.2982e-01,
-8.3116e-02, -2.2633e-01, -6.9962e-01, -9.7585e-01, -2.3570e-01,
7.9274e-02, 9.7615e-01, -2.0190e-01, 4.1222e-01, -9.5692e-01,
-9.6368e-01, -9.8291e-01, -7.4987e-01, -3.0562e-01, -3.4266e-01,
5.7505e-01, 9.2277e-04, -1.6647e-02, 9.6635e-02, 3.7435e-01,
9.5338e-02, 5.4641e-02, 7.4288e-01, -6.1934e-01, -9.9428e-01,
-9.8903e-01, 3.0664e-01, 2.3324e-01, -3.4921e-02, -3.3807e-01,
-8.5948e-02, -5.4122e-01, -1.6288e-01, 2.8543e-01, -9.8887e-01,
-2.6837e-01, -7.8015e-02, -9.7983e-01, -1.6459e-01, 8.4981e-01,
-9.2979e-01, -9.7963e-01, 1.2715e-01, 9.8744e-01, -9.9516e-01,
9.6128e-01, 3.9141e-01, 3.8937e-01, 9.4803e-01, -2.4949e-01,
5.9605e-02, 3.9464e-01, -9.9303e-01, 1.5636e-01, 9.6287e-01,
2.6005e-01, 4.8204e-01, 6.0301e-01, 3.0567e-01, 8.1555e-02,
-1.0181e-01, 4.9904e-01, 6.4291e-02, -2.4735e-01, 6.3350e-02,
-4.0078e-01, 3.1575e-01, 7.0243e-03, -9.6668e-01, -5.8473e-01,
2.6421e-01, -2.7827e-01, -3.4950e-01, 4.1961e-01, -3.0301e-01,
3.4331e-01, 6.6894e-01, -3.6386e-01, 7.5971e-01, -5.4732e-02,
5.5799e-01, 8.1391e-02, -4.0137e-01, 2.8263e-01, -4.4548e-01,
5.8776e-03, 9.8233e-01, 9.7951e-01, 3.4370e-01, -3.5848e-02,
9.6226e-01, -1.2425e-01, 8.5697e-01, 7.0131e-01, 1.5207e-01,
5.1920e-01, 1.4833e-01, 9.4931e-02, 1.9626e-01, -4.6269e-01,
9.9043e-01, 9.0094e-01, -9.6909e-01, -5.7942e-01, -5.7730e-01,
-1.5304e-01, 3.8288e-01, 1.8558e-01, -4.3948e-01, -9.7623e-01,
-3.6262e-01, 3.8665e-01, 9.8397e-01, 2.5124e-01, -3.1844e-02,
4.6921e-01, 9.8961e-01, 9.7738e-01, 9.6090e-01, 1.4028e-01,
3.8710e-01, 5.2731e-01, -1.1561e-01, -1.5250e-01, 9.8154e-01,
4.7204e-02, -3.3399e-01, -1.3114e-01, -4.2559e-01, 9.7867e-01,
-4.1099e-01, -2.8276e-01, 9.9270e-01, -5.3985e-01, 2.1685e-01,
1.4441e-01, -6.0785e-01, 5.5605e-01, 2.1365e-01, -9.3046e-01,
-3.2222e-01, -9.7055e-01, 2.0959e-01, -8.1909e-02, 5.7460e-01,
2.7706e-01, 3.3799e-01, -5.9168e-01, -3.6999e-01, -9.6900e-01,
2.7919e-02, 9.8560e-01, -9.2154e-01, 1.0898e-02, -8.7751e-02,
3.6851e-01, 8.0715e-01, -3.4052e-01, 5.6057e-01, -3.3722e-01,
1.7886e-02, -8.3574e-01, -8.3405e-01, -2.3854e-01, -9.9399e-01,
6.8231e-01, -7.4014e-01, -3.6860e-01, 2.7322e-01, -2.3088e-01,
2.1408e-02, -9.4805e-01, -1.4743e-01, 3.5126e-01, 4.7888e-02,
-3.6538e-01, 5.3316e-03, -8.7612e-04, 1.0853e-01, 2.9496e-01,
5.4135e-01, 5.0579e-01, -9.5359e-01, 3.2520e-01, 6.1428e-01,
5.2929e-01, -4.2120e-01, -5.4508e-01, -1.5757e-01, 3.2055e-01,
-4.3151e-01, -8.0361e-01, -5.8424e-02, -5.8575e-01, -8.6340e-01,
6.7798e-01, -1.3051e-01, -5.4440e-01, 4.1934e-02, -7.7453e-02,
-8.5519e-02, 4.6932e-01, 2.1512e-02, 4.2632e-01, 4.9824e-01,
9.7215e-01, -4.6594e-02, 9.9216e-01, -2.4452e-01, -9.4086e-01,
-9.8268e-01, 4.3244e-01, 1.6997e-01, 9.4030e-01, -2.3847e-01,
-7.1089e-01, 9.6536e-02, 3.4546e-03, -4.9693e-01, 6.6625e-01,
-9.9886e-03, -8.4540e-01, 1.0267e-01, -8.0777e-01, 4.4288e-01,
2.2209e-01, 5.7160e-01, -7.4323e-01, -1.5376e-01, -1.6651e-01,
3.4940e-01, 5.9011e-01, 1.0613e-01, 6.0594e-01, -1.7720e-01,
-1.9050e-01, 9.9476e-01, -3.3017e-01, 1.7289e-02, -4.5786e-01,
3.4727e-02, 2.4451e-01, 9.0412e-01, -2.5246e-01, -9.9471e-01,
-5.9091e-01, -4.4217e-01, -1.4264e-02, 4.7891e-01, -3.6157e-01,
-3.0530e-01, 1.4889e-01, 9.8921e-01, 5.3107e-01, 2.6525e-01,
-1.7435e-01, 9.7778e-01, 2.0062e-01, 5.0285e-02, -3.0977e-01,
-9.6920e-01, -8.4032e-01, -5.3779e-01, 9.8092e-01, 4.4185e-01,
2.6374e-01, -3.3015e-01, -9.6103e-01, 3.3802e-01, 2.5925e-01,
2.5870e-01, -5.0579e-01, -1.2357e-01, -4.4316e-01, 9.8594e-01,
1.2350e-01, 9.6287e-01, 8.0967e-01, 1.9388e-01, -1.2404e-01,
3.2124e-01, 5.8616e-01, -1.5314e-02, -2.4570e-01, -6.4310e-01,
8.4819e-01, 5.2525e-02, -9.8543e-01, -3.0626e-01, -8.2944e-02,
9.1083e-02, -1.4792e-01, -5.0044e-03, -9.5052e-01, -5.0392e-02,
1.6245e-01, -3.7055e-01, -2.1755e-01, -3.0543e-01, -9.8258e-01,
9.5272e-01, 4.8071e-01, 2.3917e-01, 2.3499e-01, -9.5765e-01,
-9.8480e-01, -4.4222e-01, -3.1991e-01, -5.0537e-01, 6.7079e-02,
-1.3940e-01, -9.8560e-01, 1.0202e-01, 1.8742e-01, -2.2035e-01,
8.2206e-02, 2.6812e-01, 1.0776e-01, -4.8901e-01, 2.8551e-01,
9.8146e-01, 9.6139e-01, 1.7044e-01, -9.8006e-02, 3.9451e-02,
8.0132e-03, 3.6556e-01, 1.0637e-01, 3.8157e-01, 1.7256e-01,
2.8889e-01, -1.5674e-01, 7.2742e-01, -3.6547e-01, -1.0959e-01,
-1.9276e-01, 2.9043e-01, 3.4720e-01, -4.5415e-01, 6.9679e-01,
-3.5634e-01, 1.3876e-02, 1.1194e-01, 9.6529e-01, -2.1689e-01,
3.7095e-01, -5.7832e-01, -3.8176e-01, 2.6055e-01, -3.5829e-01,
-4.0731e-01, -4.2493e-01, -9.8333e-01, -7.0665e-02, -2.2282e-01,
9.4371e-01, -2.0941e-01, -3.7415e-01, 4.1599e-02, -1.2407e-01,
1.1406e-01, -7.0721e-01, -8.7016e-01, 9.2681e-01, -9.3646e-02,
4.5374e-01, -1.1397e-01, 3.3344e-01, -9.6135e-01, -3.2819e-01,
-5.7872e-01, -5.0921e-01, -9.2577e-01, -3.8562e-01, -2.4085e-01,
-9.8934e-01, -7.3063e-01, 4.9201e-01, 3.5961e-01, 3.6015e-01,
9.9120e-01, 1.7533e-01, 3.7820e-01, -2.5824e-01, -7.2579e-01,
3.3893e-01, 3.4279e-01, 1.6663e-01, -9.8262e-01, 5.2075e-04,
9.7380e-02, -5.1041e-01, -9.9358e-01, -6.8939e-01, -3.2765e-01,
8.0164e-01, -4.6448e-01, 8.4143e-01, 1.3234e-01, -1.7607e-01,
-2.2248e-01, 9.9417e-01, 3.5623e-01, 4.4098e-01, 5.7775e-01,
-9.8686e-01, -5.0958e-02, 2.8742e-01, -4.8026e-01, 9.6209e-01,
-4.2679e-01, 1.5697e-01, 2.1174e-01, 1.6248e-01, -1.0552e-01,
-9.4019e-01, 3.5496e-01, 1.8618e-01, 9.7360e-01, 6.9400e-01,
1.6919e-01, -1.5455e-02, 6.4874e-02, -8.6971e-01, 9.6309e-01,
5.1486e-01, 1.7117e-01, 1.3473e-01, 4.4606e-01, -2.0172e-01,
-2.5007e-01, -2.6196e-01, -3.2527e-01, -9.8623e-01, 7.8052e-02,
-8.5959e-01, -1.6832e-01, 1.6538e-01, -1.2245e-01, 2.2528e-01,
-9.8923e-01, -2.8471e-02, 9.9020e-01, 9.6913e-01, 5.3590e-01,
-4.8540e-01, 5.9592e-01, -9.8468e-01, -9.9148e-01, -2.1763e-01,
1.2153e-01, -2.2293e-01, 4.3219e-01, -5.0695e-01, -5.4621e-01,
3.7582e-01, 4.9078e-01, 8.3016e-02, 4.5842e-01, 3.1359e-01,
3.4408e-01, -9.8313e-01, 1.4849e-01, -1.1147e-03, 7.2041e-01,
3.7282e-01, -3.7220e-01, -9.9566e-01, -1.2764e-02, -3.8653e-01,
2.2186e-01, 1.1991e-01, 5.9771e-01, -8.0189e-02, 4.2060e-01,
1.2218e-01, 1.0801e-01, -3.2591e-02, 1.2882e-01, 2.8117e-01,
4.8731e-01, 5.1949e-01, -4.3933e-01, -3.9023e-01, -8.9976e-01,
4.4230e-01, -1.1358e-01, -3.2030e-01, 4.4811e-01, -9.2813e-01,
1.3013e-01, -5.8395e-01, -7.8454e-03, -1.1099e-01, -3.6834e-01,
-1.1078e-01, -2.7228e-01, -2.1859e-02, 2.0997e-01, -9.7199e-01,
9.6223e-01, -8.1099e-01, -9.4619e-01, -6.3706e-01, 3.5820e-01,
-9.8850e-01, 5.3919e-01, -9.7631e-01, -4.3867e-02, 7.0374e-01,
2.4159e-01, 5.3072e-01, -4.9297e-02, 3.2266e-01, 2.8303e-01,
1.8869e-01, -1.7059e-01, 1.5960e-01, 3.4674e-01, 1.1016e-01,
-3.2146e-01, -3.0904e-01, -5.4476e-01, 2.3456e-01, 5.6395e-02,
-6.4573e-01, -3.7922e-01, 4.0422e-01, -1.5804e-01, 3.6244e-02,
6.2539e-01, -1.9240e-01, -4.1256e-02, 8.5903e-01, -4.9849e-01,
9.9019e-01, -1.5544e-01, -9.7440e-01, -7.9971e-02, -6.0300e-01,
-4.4587e-01, -6.0596e-02, -1.5085e-01, 9.8638e-01, 9.8658e-01,
-3.0549e-01, 1.2378e-01, 9.6930e-01, 2.4033e-01, -3.0263e-01,
5.4763e-01, -1.4536e-01, 9.8083e-01, 3.3589e-01, -2.4066e-01,
-9.8417e-01, 1.5586e-01, 3.5779e-02, 4.9460e-01, 1.5837e-01,
4.1872e-01, -9.9192e-01, 6.1394e-01, 1.6527e-01, 2.7309e-01,
-2.9761e-01, 7.2322e-01, 1.6694e-01, -1.5712e-01, -2.5445e-01,
7.4984e-01, -1.9372e-01, -6.1484e-01, -2.5232e-01, -1.0058e-01,
2.3876e-01, 3.3220e-01, -1.9817e-01, -9.7192e-01, -3.7591e-01,
-4.5643e-01, -2.0848e-01, 7.6321e-01, 2.2172e-01, -3.0398e-01,
4.6265e-02, 3.6990e-02, 2.5799e-02, -9.7257e-01, -1.6337e-01,
7.3218e-02, -8.6914e-02, 3.0575e-01, -8.3371e-02, -1.0615e-01,
5.3650e-01, 3.8619e-01, 2.5359e-01, -1.2639e-03, 1.3251e-01,
9.7600e-01, 3.7575e-01, 4.4471e-01, 9.6232e-01, 4.8999e-02,
2.1120e-01, 1.0183e-02, -6.2439e-01, -6.6208e-01, -5.8293e-01,
-1.0187e-01, -2.6874e-01, -2.2109e-01, 7.7067e-02, 2.2421e-01,
9.7894e-01, -3.8511e-01, -1.2718e-01, -9.9247e-01, -9.2586e-01,
-6.3361e-01, 4.4913e-01, -1.0829e-01, 7.7637e-01, 7.5922e-01,
4.9699e-01, 2.8965e-01, -5.0951e-01, 4.0115e-01, -3.1423e-01,
9.8587e-02, -2.3061e-01, 6.5185e-01, 9.9028e-01, 4.2138e-01,
9.3237e-01, -1.9223e-02, 9.4381e-01, 3.6755e-01, -1.8385e-01,
-8.3264e-01, 1.2549e-01, 9.5700e-01, 3.6529e-01, -1.9850e-01,
-2.6823e-01, -3.8882e-01, -3.4665e-02, -6.7890e-02, 1.5502e-01,
2.2114e-01, -5.6561e-01, 9.4280e-01, -1.3087e-01, 9.3805e-01,
-3.7620e-01, 4.6216e-01, -4.3089e-01, -6.1679e-01, -6.1725e-01,
9.5912e-01, 1.7575e-01, -2.0635e-01, 3.6181e-01, 5.5383e-01,
-1.1507e-01, -4.9977e-01, 6.2315e-01, 7.0466e-02, -5.1235e-01,
-5.2878e-01, 4.7806e-01, 3.5455e-01, 4.4542e-01, -1.5396e-01,
6.6818e-01, -4.0604e-01, 8.9400e-04, -1.8198e-01, -9.9252e-02,
9.3845e-01, -1.3923e-01, -3.0741e-01, 5.3658e-01, 9.9429e-01,
-4.4807e-01, -2.8110e-01, 5.4434e-02, -9.3874e-01, -2.6605e-01,
2.7385e-02, -3.3638e-01, 3.1558e-01, -1.3955e-01, 2.2981e-01,
1.7492e-01, 2.4483e-01, -1.4271e-01, -9.0175e-01, 9.8612e-01,
9.8981e-01, 4.5702e-01, -2.8206e-01, 3.1277e-01, 8.2384e-01,
9.9097e-01, -6.9213e-02, 8.6031e-01, -9.8318e-01, 9.9326e-01,
-4.9006e-01, 1.3579e-01, -8.4466e-02, 2.6975e-01, 6.5068e-02,
-9.4347e-01, -2.6899e-01, -2.1086e-01, 6.6393e-01, 9.2319e-01,
-7.0795e-01, -1.1812e-01, -4.7284e-01, 9.8149e-01, 2.9307e-01,
4.5661e-01, -1.3705e-01, -5.0187e-01]], grad_fn=<TanhBackward0>), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)
How can I get the output class from these tensors? Can anyone guide me?
Solution 1:[1]
Use TFAutoModel instead of default (Pytorch) AutoModel
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | karteek menda |
