How to gather the logits according to a 2D index?
Clash Royale CLAN TAG#URR8PPP
How to gather the logits according to a 2D index?
Assume the sequence [[2, 1, 4], [3, 4, 2]]
is generated by a pre-trained LSTM. It's dimension is (2*3)
meaning batch-size = 2
and 3 time steps
in each sample.
[[2, 1, 4], [3, 4, 2]]
(2*3)
batch-size = 2
3 time steps
Then for example, there are 5 features
in total so the logits may be:
5 features
[[[0.2, 0.2, 0.1, 0.1, 0.2],
[0.3, 0.2, 0.1, 0.1, 0.1],
[0.1, 0.2, 0.1, 0.1, 0.3]],
[[0.2, 0.2, 0.1, 0.1, 0.2],
[0.2, 0.2, 0.1, 0.1, 0.2],
[0.2, 0.2, 0.1, 0.1, 0.2]]]
I want to use the sequence as the index to get the corresponding probabilities from the logits for each sample and each time step. Regarding the example above, the final result I want to get is
[[0.1, 0.2, 0.3],[0.1, 0.2, 0.1]]
I knew that I probably need tf.stack() but I'm confused about how to handle the dimension. Appreciate for any help!
1 Answer
1
I think I found a way.
tf.losses.sparse_softmax_cross_entropy(labels = None, logits = None)
labels
is the index and logits
is the output of a model.
labels
logits
By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.