Dice/Jaccard Coefficient Optimization in Tensorflow

The name of the pictureThe name of the pictureThe name of the pictureClash Royale CLAN TAG#URR8PPP



Dice/Jaccard Coefficient Optimization in Tensorflow



I am trying to optimize my network with either Dice's or Jaccard's coefficient. My issue is an image segmentation problem so my output is a tensor of shape (1, 256, 256, 11). In order to calculate the intersection of my output and the truth image I take


tf.argmax(output, axis = 3)



which returns a datatype of "int" which tensorflow optimizers (specifically AdamOptimizer) don't seem to take so I then convert this to a float with


AdamOptimizer


tf.cast(tf.argmax(output, axis = 3), tf.float32)



However, it doesn't seem as though there is a gradient defined for tf.cast (or tf.argmax for that matter). Has anybody been able to sucessfully implement




2 Answers
2



The operation tf.argmax() is not differentiable, that's why the gradient is not implemented. You can not optimize Jaccard directly, because it is not differentiable.


tf.argmax()



The same happens with accuracy, when you train a classifier you optimize a differentiable loss even if you only care about the accuracy, because the accuracy is not differentiable.



So there isn't a solution to your problem, you'll have to use a loss function that you can differentiate and optimize that loss instead of the jaccard.



I've had the same problem today and I realized that it is not possible to calculate dice loss without tf.argmax() function, which has no gradients.


tf.argmax()



One possible solution for similar situations would be to use generalized dice coefficient, i.e. use softmax predictions in combination with 'one_hot' ground truth image as an inputs.






By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.

Popular posts from this blog

Firebase Auth - with Email and Password - Check user already registered

Dynamically update html content plain JS

How to determine optimal route across keyboard