且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

tensorflow 相当于torch.gather

更新时间:2022-04-01 05:11:18

对于 2D 情况,有一个方法可以做到:

For 2D case,there is a method to do it:

# a.shape (16L, 10L)
# idx.shape (16L,1)
idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
b = tf.gather_nd(a,idx)

然而,对于 ND 情况,这种方法可能非常复杂

However,For ND case,this method maybe very complex