'Using tf.image.pyramids during training to create downsampled feature maps
I'm attempting to use tf.image.pyramids.downsample from tensorflow_graphics in an auto-encoder model in every Down (encoding) block, to be then sent as a skip connection to Up (decoder) blocks.
class DownConv(Model):
n = 0
def __init__(self, kernel_size, filters, initializer, n_lower_levels):
super(DownConv, self).__init__(name=f"DownConv_{DownConv.n}")
DownConv.n += 1
self.pad = tf.constant([[0, 0], [kernel_size // 2, kernel_size // 2], [kernel_size // 2, kernel_size // 2], [0, 0]])
self.conv = L.Conv2D(filters, kernel_size, strides=2, kernel_initializer=initializer)
self.pyramid = None
self.filters = filters
self.n_lower_levels = n_lower_levels
def call(self, input_t):
logger.debug(f"Received {input_t.shape} in {self.name}")
x = tf.pad(input_t, self.pad, "SYMMETRIC")
x = self.conv(x)
p = tf.Variable(x)
self.pyramid = downsample(p, self.n_lower_levels)
pyramods = ", ".join([str(p.shape) for p in self.pyramid])
logger.debug(f"Received {input_t.shape} in {self.name}")
logger.debug(f"Generated pyramids: {pyramods}")
return tf.nn.selu(x)
However, thanks to logging I found out that this doesn't work. It seems only the very first pyramid (the first step of the downsample) contains the channels, the rest of them have None for channels.
self.pyramid[0].shape yields the correct (None, 256, 256, 64), but self.pyramid[1] yields (None, 256, 256, None) during a training step. Note that batches are correctly None here for axis=0, it is normal Tensorflow behavior for error logs.
Due to this issue, the training step produces an error in my Up blocks, when it tries to concatenate the two feature maps:
ValueError: The channel dimension of the inputs should be defined. The input_shape received is (None, 32, 32, None), where axis -1 (0-based) is the channel dimension, which found to be `None`.
Call arguments received:
• input_t=tf.Tensor(shape=(None, 32, 32, 256), dtype=float32)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
