'Faster way to do multiple embeddings in PyTorch?
I'm working on a torch-based library for building autoencoders with tabular datasets.
One big feature is learning embeddings for categorical features.
In practice, however, training many embedding layers simultaneously is creating some slowdowns. I am using for-loops to do this and running the for-loop on each iteration is (I think) what's causing the slowdowns.
When building the model, I associate embedding layers with each categorical feature in the user's dataset:
for ft in self.categorical_fts: feature = self.categorical_fts[ft] n_cats = len(feature['cats']) + 1 embed_dim = compute_embedding_size(n_cats) embed_layer = torch.nn.Embedding(n_cats, embed_dim) feature['embedding'] = embed_layer
Then, with a call to .forward():
embeddings =  for i, ft in enumerate(self.categorical_fts): feature = self.categorical_fts[ft] emb = feature['embedding'](codes[i]) embeddings.append(emb) #num and bin are numeric and binary features x = torch.cat(num + bin + embeddings, dim=1)
x goes into dense layers.
This gets the job done but running this for loop during each forward pass really slows down training, especially when a dataset has tens or hundreds of categorical columns.
Does anybody know of a way of vectorizing something like this? Thanks!
UPDATE: For more clarity, I made this sketch of how I'm feeding categorical features into the network. You can see that each categorical column has its own embedding matrix, while numeric features are concatenated directly to their output before being passed into the feed-forward network.
Can we do this without iterating through each embedding matrix?
just use simple indexing
[, though i'm not sure whether it is fast enough
Here is a simplified version for all feature have same vocab_size and embedding dim, but it should apply to cases of heterogeneous category features
xdim = 240 embed_dim = 8 vocab_size = 64 embedding_table = torch.randn(size=(xdim, vocab_size, embed_dim)) batch_size = 32 x = torch.randint(vocab_size, size=(batch_size, xdim)) out = embedding_table[torch.arange(xdim), x] out.shape # (bz, xdim, embed_dim) # unit test i = np.random.randint(batch_size) j = np.random.randint(xdim) x_index = x[i][j] w = embedding_table[j] torch.allclose(w[x_index], out[i, j])