'RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x726 and 1000x1000)

I'm trying to measure the latent space clustering but the error raised.

class AutoEncoder(nn.Module):

    def __init__(self, input_dim1, input_dim2, hidden_dims, agg, sep_decode):
        super(AutoEncoder, self).__init__()

        self.agg = agg
        self.sep_decode = sep_decode

        print("hidden_dims:", hidden_dims)
        self.encoder_layers = []
        self.encoder2_layers = []
        dims = [[input_dim1, input_dim2]] + hidden_dims
        for i in range(len(dims) - 1):
            if i == 0:
                layer = nn.Sequential(nn.Linear(dims[i][0], dims[i+1]), nn.ReLU())
                layer2 = nn.Sequential(nn.Linear(dims[i][1], dims[i+1]), nn.ReLU())
            elif i != 0 and i < len(dims) - 2:
                layer = nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.ReLU())
                layer2 = nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.ReLU())
            else:
                layer = nn.Linear(dims[i], dims[i+1])
                layer2 = nn.Linear(dims[i], dims[i+1])
            self.encoder_layers.append(layer)
            self.encoder2_layers.append(layer2)
        self.encoder = nn.Sequential(*self.encoder_layers)
        self.encoder2 = nn.Sequential(*self.encoder2_layers)

        self.decoder_layers = []
        self.decoder2_layers = []
        hidden_dims.reverse()
        dims = hidden_dims + [[input_dim1, input_dim2]]
        if self.agg == "concat" and not self.sep_decode:
            dims[0] = 2 * dims[0]
        for i in range(len(dims) - 1):
            if i < len(dims) - 2:
                layer = nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.ReLU())
                layer2 = nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.ReLU())
            else:
                layer = nn.Linear(dims[i], dims[i+1][0])
                layer2 = nn.Linear(dims[i], dims[i+1][1])
            self.decoder_layers.append(layer)
            self.decoder2_layers.append(layer2)
        self.decoder = nn.Sequential(*self.decoder_layers)
        self.decoder2 = nn.Sequential(*self.decoder2_layers)

    def forward(self, x1, x2):
        z1 = self.encoder(x1)
        z2 = self.encoder2(x2)

        if self.agg == "max":
            z = torch.max(z1, z2)
        elif self.agg == "multi":
            z = z1 * z2
        elif self.agg == "sum":
            z = z1 + z2
        elif self.agg == "concat":
            z = torch.cat([z1, z2], dim=1)

        if self.sep_decode:
            x_bar1 = self.decoder(z1)
            x_bar1 = F.normalize(x_bar1, dim=-1)
            x_bar2 = self.decoder2(z2)
            x_bar2 = F.normalize(x_bar2, dim=-1)
        else:
            x_bar1 = self.decoder(z)
            x_bar1 = F.normalize(x_bar1, dim=-1)
            x_bar2 = self.decoder2(z)
            x_bar2 = F.normalize(x_bar2, dim=-1)

        return x_bar1, x_bar2, z


class TopicCluster(nn.Module):

    def __init__(self, args):
        super(TopicCluster, self).__init__()
        self.alpha = 1.0
        self.dataset_path = args.dataset_path
        self.args = args
        self.device = args.device
        self.temperature = args.temperature
        self.distribution = args.distribution
        self.agg_method = args.agg_method
        self.sep_decode = (args.sep_decode == 1)

        input_dim1 = args.input_dim1
        input_dim2 = args.input_dim2
        hidden_dims = eval(args.hidden_dims)
        self.model = AutoEncoder(input_dim1, input_dim2, hidden_dims, self.agg_method, self.sep_decode)
        if self.agg_method == "concat":
            self.topic_emb = Parameter(torch.Tensor(args.n_clusters, 2*hidden_dims[-1]))
        else:
            self.topic_emb = Parameter(torch.Tensor(args.n_clusters, hidden_dims[-1]))
        torch.nn.init.xavier_normal_(self.topic_emb.data)

    def pretrain(self, input_data, pretrain_epoch=200):
        pretrained_path = os.path.join(self.dataset_path, f"pretrained_{args.suffix}.pt")
        if os.path.exists(pretrained_path) and self.args.load_pretrain:
            # load pretrain weights
            print(f"loading pretrained model from {pretrained_path}")
            self.model.load_state_dict(torch.load(pretrained_path))
        else:
            train_loader = DataLoader(input_data, batch_size=self.args.batch_size, shuffle=True)
            optimizer = Adam(self.model.parameters(), lr=self.args.lr)
            for epoch in range(pretrain_epoch):
                total_loss = 0
                for batch_idx, (x1, x2, _, weight) in enumerate(train_loader):
                    x1 = x1.to(self.device)
                    x2 = x2.to(self.device)
                    weight = weight.to(self.device)
                    optimizer.zero_grad()
                    x_bar1, x_bar2, z = self.model(x1, x2)
                    loss = cosine_dist(x_bar1, x1) + cosine_dist(x_bar2, x2) #, weight)
                    total_loss += loss.item()
                    loss.backward()
                    optimizer.step()
                print(f"epoch {epoch}: loss = {total_loss / (batch_idx+1):.4f}")
            torch.save(self.model.state_dict(), pretrained_path)
            print(f"model saved to {pretrained_path}")

    def cluster_assign(self, z):
        if self.distribution == 'student':
            p = 1.0 / (1.0 + torch.sum(
                torch.pow(z.unsqueeze(1) - self.topic_emb, 2), 2) / self.alpha)
            p = p.pow((self.alpha + 1.0) / 2.0)
            p = (p.t() / torch.sum(p, 1)).t()
        else:
            self.topic_emb.data = F.normalize(self.topic_emb.data, dim=-1)
            z = F.normalize(z, dim=-1)
            sim = torch.matmul(z, self.topic_emb.t()) / self.temperature
            p = F.softmax(sim, dim=-1)
        return p
    
    def forward(self, x1, x2):
        x_bar1, x_bar2, z = self.model(x1, x2)
        p = self.cluster_assign(z)
        return x_bar1, x_bar2, z, p

    def target_distribution(self, x1, x2, freq, method='all', top_num=0):
        _, _, z = self.model(x1, x2)
        p = self.cluster_assign(z).detach()
        if method == 'all':
            q = p**2 / (p * freq.unsqueeze(-1)).sum(dim=0)
            q = (q.t() / q.sum(dim=1)).t()
        elif method == 'top':
            assert top_num > 0
            q = p.clone()
            sim = torch.matmul(self.topic_emb, z.t())
            _, selected_idx = sim.topk(k=top_num, dim=-1)
            for i, topic_idx in enumerate(selected_idx):
                q[topic_idx] = 0
                q[topic_idx, i] = 1
        return p, q


def cosine_dist(x_bar, x, weight=None):
    if weight is None:
        weight = torch.ones(x.size(0), device=x.device)
    cos_sim = (x_bar * x).sum(-1)
    cos_dist = 1 - cos_sim
    cos_dist = (cos_dist * weight).sum() / weight.sum()
    return cos_dist


def train(args, emb_dict):
    # ipdb.set_trace()
    inv_vocab = {k: " ".join(v) for k, v in emb_dict["inv_vocab"].items()}
    vocab = {" ".join(k):v for k, v in emb_dict["vocab"].items()}
    print(f"Vocab size: {len(vocab)}")
    embs = F.normalize(torch.tensor(emb_dict["vs_emb"]), dim=-1)
    embs2 = F.normalize(torch.tensor(emb_dict["oh_emb"]), dim=-1)
    freq = np.array(emb_dict["tuple_freq"])
    if not args.use_freq:
        freq = np.ones_like(freq)

    input_data = TensorDataset(embs, embs2, torch.arange(embs.size(0)), torch.tensor(freq))
    topic_cluster = TopicCluster(args).to(args.device)
    topic_cluster.pretrain(input_data, args.pretrain_epoch)
    train_loader = DataLoader(input_data, batch_size=args.batch_size, shuffle=False)
    optimizer = Adam(topic_cluster.parameters(), lr=args.lr)

    # topic embedding initialization
    embs = embs.to(args.device)
    embs2 = embs2.to(args.device)
    x_bar1, x_bar2, z = topic_cluster.model(embs, embs2)
    z = F.normalize(z, dim=-1)

    print(f"Running K-Means for initialization")
    kmeans = KMeans(n_clusters=args.n_clusters, n_init=5)
    if args.use_freq:
        y_pred = kmeans.fit_predict(z.data.cpu().numpy(), sample_weight=freq)
    else:
        y_pred = kmeans.fit_predict(z.data.cpu().numpy())
    print(f"Finish K-Means")

    freq = torch.tensor(freq).to(args.device)
    
    y_pred_last = y_pred
    topic_cluster.topic_emb.data = torch.tensor(kmeans.cluster_centers_).to(args.device)

    topic_cluster.train()
    i = 0
    for epoch in range(50):
        if epoch % 5 == 0:
            _, _, z, p = topic_cluster(embs, embs2)
            z = F.normalize(z, dim=-1)
            topic_cluster.topic_emb.data = F.normalize(topic_cluster.topic_emb.data, dim=-1)
            if not os.path.exists(os.path.join(args.dataset_path, f"clusters_{args.suffix}")):
                os.makedirs(os.path.join(args.dataset_path, f"clusters_{args.suffix}"))
            embed_save_path = os.path.join(args.dataset_path, f"clusters_{args.suffix}/embed_{epoch}.pt")
            torch.save({
                "inv_vocab": emb_dict['inv_vocab'],
                "embed": z.detach().cpu().numpy(),
                "topic_embed": topic_cluster.topic_emb.detach().cpu().numpy(),
            }, embed_save_path)
            f = open(os.path.join(args.dataset_path, f"clusters_{args.suffix}/{epoch}.txt"), 'w')
            pred_cluster = p.argmax(-1)
            result_strings = []
            for j in range(args.n_clusters):
                if args.sort_method == 'discriminative':
                    word_idx = torch.arange(embs.size(0))[pred_cluster == j]
                    sorted_idx = torch.argsort(p[pred_cluster == j][:, j], descending=True)
                    word_idx = word_idx[sorted_idx]
                else:
                    sim = torch.matmul(topic_cluster.topic_emb[j], z.t())
                    _, word_idx = sim.topk(k=30, dim=-1)
                word_cluster = []
                freq_sum = 0
                for idx in word_idx:
                    freq_sum += freq[idx].item()
                    if inv_vocab[idx.item()] not in word_cluster:
                        word_cluster.append(inv_vocab[idx.item()])
                        if len(word_cluster) >= 10:
                            break
                result_strings.append((freq_sum, f"Topic {j} ({freq_sum}): " + ', '.join(word_cluster)+'\n'))
            result_strings = sorted(result_strings, key=lambda x: x[0], reverse=True)
            for result_string in result_strings:
                f.write(result_string[1])
                 
        for x1, x2, idx, weight in train_loader:
            
            if i % args.update_interval == 0:
                p, q = topic_cluster.target_distribution(embs, embs2, freq.clone().fill_(1), method='all', top_num=epoch+1)

                y_pred = p.cpu().numpy().argmax(1)
                delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / y_pred.shape[0]
                y_pred_last = y_pred

                if i > 0 and delta_label < args.tol:
                    print(f'delta_label {delta_label:.4f} < tol ({args.tol})')
                    print('Reached tolerance threshold. Stopping training.')
                    return None

            i += 1
            x1 = x1.to(args.device)
            x2 = x2.to(args.device)
            idx = idx.to(args.device)
            weight = weight.to(args.device)

            x_bar1, x_bar2, _, p = topic_cluster(x1, x2)
            reconstr_loss = cosine_dist(x_bar1, x1) + cosine_dist(x_bar2, x2) #, weight)
            kl_loss = F.kl_div(p.log(), q[idx], reduction='none').sum(-1)
            kl_loss = (kl_loss * weight).sum() / weight.sum()
            loss = args.gamma * kl_loss + reconstr_loss
            if i % args.update_interval == 0:
                print(f"KL loss: {kl_loss}; Reconstruction loss: {reconstr_loss}")

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return None


if __name__ == "__main__":
    # CUDA_VISIBLE_DEVICES=0 python3 latent_space_clustering.py --dataset_path ./pandemic --input_emb_name po_tuple_features_all_svos.pk
    parser = argparse.ArgumentParser(
        description='train',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--dataset_path', type=str)
    parser.add_argument('--input_emb_name', type=str)
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--n_clusters', default=30, type=int)
    parser.add_argument('--input_dim1', default=1000, type=int)
    parser.add_argument('--input_dim2', default=1000, type=int)
    parser.add_argument('--agg_method', default="multi", choices=["sum", "multi", "concat", "attend"], type=str)
    parser.add_argument('--sep_decode', default=0, choices=[0, 1], type=int)
    parser.add_argument('--pretrain_epoch', default=100, type=int)
    parser.add_argument('--load_pretrain', default=False, action='store_true')
    parser.add_argument('--temperature', default=0.1, type=float)
    parser.add_argument('--sort_method', default='generative', choices=['generative', 'discriminative'])
    parser.add_argument('--distribution', default='softmax', choices=['softmax', 'student'])
    parser.add_argument('--batch_size', default=256, type=int)
    parser.add_argument('--use_freq', default=False, action='store_true')
    parser.add_argument('--hidden_dims', default='[1000, 2000, 1000, 100]', type=str)
    parser.add_argument('--suffix', type=str, default='')
    parser.add_argument('--gamma', default=5, type=float, help='weight of clustering loss')
    parser.add_argument('--update_interval', default=100, type=int)
    parser.add_argument('--tol', default=0.001, type=float)
    args = parser.parse_args()
    args.cuda = torch.cuda.is_available()
    print("use cuda: {}".format(args.cuda))
    args.device = torch.device("cuda" if args.cuda else "cpu")
    print(args)
    with open(os.path.join(args.dataset_path, args.input_emb_name), "rb") as fin:
        emb_dict = pk.load(fin)

    candidate_idx = train(args, emb_dict)
    print(candidate_idx)

The error I'm getting is: RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x726 and 1000x1000). I cannot figure out which part is the problem. Please help me.. Thank you so much

for the images runtime error like enter image description here



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source