Skip to content

Regarding the Differences in MatMul Matrix Multiplication Behavior #44

@owenliang

Description

@owenliang

In the implementation of MatMul, the used dot does not conform to the behavior of matrix multiplication in Torch, especially it is not suitable for cases with dimensions higher than two. It should be replaced with np.matmul() or cupy.matmul(). Moreover, during backpropagation, it's important to adapt according to the behavior of matrix multiplication on the last two axes as performed by matmul(). Below is the approach I took in my replicated code version; you can find my code at: this link.

dezero code:

class MatMul(Function):
    def forward(self, x, W):
        y = x.dot(W)
        return y

    def backward(self, gy):
        x, W = self.inputs
        gx = matmul(gy, W.T)
        gW = matmul(x.T, gy)
        return gx, gW

my code:

# Matrix Multiply
class MatMul(Function):
    def _forward(self,a,b): # (N,A,B)@(B,C)=(N,A,C)
        xp=get_array_module(a)   # CUDA compatibility
        return xp.matmul(a,b)

    def _backward(self,grad):
        transpose_idx=list(range(0,len(self.inputs[1].shape)))
        transpose_idx[-1],transpose_idx[-2]=transpose_idx[-2],transpose_idx[-1]
        grad_a=MatMul()(grad,self.inputs[1].transpose(transpose_idx))    # (N,A,C)@(C,B)=(N,A,B)
        if len(self.inputs[0].shape)!=len(grad_a.shape):
            grad_a=Sum(axes=tuple(range(0,len(grad_a.shape)-len(self.inputs[0].shape))),keepdims=False)(grad_a)

        transpose_idx=list(range(0,len(self.inputs[0].shape)))
        transpose_idx[-1],transpose_idx[-2]=transpose_idx[-2],transpose_idx[-1]
        grad_b=MatMul()(self.inputs[0].transpose(transpose_idx),grad)    # (N,B,A)@(N,A,C)=(N,B,C) -> Sum() -> (B,C)
        if len(self.inputs[1].shape)!=len(grad_b.shape):
            grad_b=Sum(axes=tuple(range(0,len(grad_b.shape)-len(self.inputs[1].shape))),keepdims=False)(grad_b)

        return grad_a,grad_b

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions