'PyG Temporal: Creating datasets from small synthetic data

General

In contrast with the examples cited in pytorch geometric temporal docs where large time series are used, I am trying to train a model on a large number of static networks with a small lag (generally <10). So far, I've succeeded in training a model but the way I've done seems hack-ish and I run into memory issues. I've also been unable to train on a gpu.

My data consists of a set of (super simple - majority rule for now) dynamics, ran on a randomly generated Erdős–Rényi network for a set number of "timesteps" stored in an array of shape (num nodes, lag) . I then transform each set of data in a PyG compatible signal and pass it through the model.

Implementation

My main entrypoint for the program is the Experiment class.

Data generation & transforms

Here Dataset is just a dataclass which holds both node features data and graph level info (the actual graph) - used later when generating signals.

Data transformations is handled by the SignalTransform class which returns a list of pytorch geometric temporal's StaticGraphTemporalSignal objects (class here).

class Experiment:
   [...]
    def generate_data(self) -> None:
        """Populates the `self.dataset` attr with `Dataset` objects
        containing the dynamics ran on each network. Datasets have shape
        (samples, nodes, lagsteps) and also contain network info.
        """
        self.setup()
        for nw in range(self.cfg.exp.num_networks):
            dataset = Dataset(self.networks[nw], self.cfg)
            self.set_dynamics(self.networks[nw])
            for sample in range(self.cfg.exp.num_samples):
                x = self.dynamics.initial_state()
                # import pdb; pdb.set_trace()
                dataset.data[sample, :, 0] = x
                for lagstep in range(1, self.cfg.exp.lag):
                    if lagstep % self.cfg.exp.lagstep == 0:
                        dataset.data[sample, :, lagstep] = x = self.dynamics.step(x)
            self.dataset.append(dataset)

    def generate_signals(self) -> None:
        self.signals = dict()
        for nw_idx in range(self.cfg.exp.num_networks):
            self.signals[nw_idx] = SignalTransform(self.dataset, nw_idx).get_signals()

   [...]

Model

I've pretty much used the model and methodology described in the docs and this notebook for training. File containing all my training methodology can be found here.

In short, I've created a Temporal GNN with valid (hopefully) parameters for my case, and run the actual training and evaluation of the model in the TestModel class:

class TestModel:
    [...]
    def train(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01)
        self.model.train()
        for epoch in range(10):
            loss = 0
            step = 0
            for signal in self.train_set:
                for snapshot in signal:
                    snapshot = snapshot.to(self.device)
                    y_hat = self.model(snapshot.x, snapshot.edge_index)
                    loss = loss + torch.mean((y_hat - snapshot.y) ** 2)
                    step += 1
            loss = loss / (step + 1)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print("Epoch {} train MSE: {:.4f}".format(epoch, loss.item()))
    [...]

Problems and End-goals

I would like to have some sort of batching done, so that I could train on larger networks / sets of data. Current implementation on a set of 1000 networks with 500 nodes and 5 lagsteps each, takes more than 10G of RAM. I'm quite unsure on how to proceed (or even if that's even possible, I am quite new to ML tbh). The only way I can think of is actually "unpacking" the snapshots as pytorch_geometric.Data objects and creating datasets and batches like that. I would also like to be able to train using cuda.

My end-goal would be able to do graph level classification (i.e. input network reaches a "stable - consensus" state) but I think that won't be hard if the above issues are resolved.

Any help would be greatly appreciated - thanks :)

P.S. Sorry for the number of external links, but I tried to keep this question as short as possible. Not sure if that's ok.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source