Skip to content

Commit ce73a2c

Browse files
Ilya-Fradlinrwightman
authored andcommitted
Refactor layer_scale to use init_values for gamma
1 parent 47c18f4 commit ce73a2c

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

timm/layers/layer_scale.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ def __init__(
1414
dtype=None,
1515
) -> None:
1616
super().__init__()
17+
self.init_values = init_values
1718
self.inplace = inplace
18-
self.gamma = nn.Parameter(init_values * torch.empty(dim, device=device, dtype=dtype))
19+
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
1920

2021
self.reset_parameters()
2122

2223
def reset_parameters(self):
23-
torch.nn.init.ones_(self.gamma)
24+
torch.nn.init.constant_(self.gamma, self.init_values)
2425

2526
def forward(self, x: torch.Tensor) -> torch.Tensor:
2627
return x.mul_(self.gamma) if self.inplace else x * self.gamma
@@ -38,13 +39,14 @@ def __init__(
3839
dtype=None,
3940
):
4041
super().__init__()
42+
self.init_values = init_values
4143
self.inplace = inplace
42-
self.gamma = nn.Parameter(init_values * torch.empty(dim, device=device, dtype=dtype))
44+
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
4345

4446
self.reset_parameters()
4547

4648
def reset_parameters(self):
47-
torch.nn.init.ones_(self.gamma)
49+
torch.nn.init.constant_(self.gamma, self.init_values)
4850

4951
def forward(self, x):
5052
gamma = self.gamma.view(1, -1, 1, 1)

0 commit comments

Comments
 (0)