'Correct axes to use dot product to evaluate the final output of a listwise learning to rank model

I'm not being able to find the correct configuration to pass to a tf.keras.layers.Dot to make a pairwise dot product when the entries each have lists of values, like from a listwise learning to rank model. For instance, suppose:

repeated_query_vector = [
  [[1, 2], [1, 2]],
  [[3, 4], [3, 4]]
]

document_vectors = [
  [[5, 6], [7, 8]],
  [[9, 10], [11, 12]],
]

Calling tf.keras.layers.Dot(axes=??)([repeated_query_vector, document_vectors]) I want the output to be like:

[
  [1*5 + 2*6, 1*7 + 2*8]
  [3*9 + 4*10, 3*11 + 4*12]
]

All examples I found in the documentation have one dimension less than my use case. What would be the correct value of axes for this call?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source