-
-
Notifications
You must be signed in to change notification settings - Fork 327
Open
Description
In the implementation of MatMul, the used dot does not conform to the behavior of matrix multiplication in Torch, especially it is not suitable for cases with dimensions higher than two. It should be replaced with np.matmul() or cupy.matmul(). Moreover, during backpropagation, it's important to adapt according to the behavior of matrix multiplication on the last two axes as performed by matmul(). Below is the approach I took in my replicated code version; you can find my code at: this link.
dezero code:
class MatMul(Function):
def forward(self, x, W):
y = x.dot(W)
return y
def backward(self, gy):
x, W = self.inputs
gx = matmul(gy, W.T)
gW = matmul(x.T, gy)
return gx, gW
my code:
# Matrix Multiply
class MatMul(Function):
def _forward(self,a,b): # (N,A,B)@(B,C)=(N,A,C)
xp=get_array_module(a) # CUDA compatibility
return xp.matmul(a,b)
def _backward(self,grad):
transpose_idx=list(range(0,len(self.inputs[1].shape)))
transpose_idx[-1],transpose_idx[-2]=transpose_idx[-2],transpose_idx[-1]
grad_a=MatMul()(grad,self.inputs[1].transpose(transpose_idx)) # (N,A,C)@(C,B)=(N,A,B)
if len(self.inputs[0].shape)!=len(grad_a.shape):
grad_a=Sum(axes=tuple(range(0,len(grad_a.shape)-len(self.inputs[0].shape))),keepdims=False)(grad_a)
transpose_idx=list(range(0,len(self.inputs[0].shape)))
transpose_idx[-1],transpose_idx[-2]=transpose_idx[-2],transpose_idx[-1]
grad_b=MatMul()(self.inputs[0].transpose(transpose_idx),grad) # (N,B,A)@(N,A,C)=(N,B,C) -> Sum() -> (B,C)
if len(self.inputs[1].shape)!=len(grad_b.shape):
grad_b=Sum(axes=tuple(range(0,len(grad_b.shape)-len(self.inputs[1].shape))),keepdims=False)(grad_b)
return grad_a,grad_b
Metadata
Metadata
Assignees
Labels
No labels