File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -14,13 +14,14 @@ def __init__(
1414 dtype = None ,
1515 ) -> None :
1616 super ().__init__ ()
17+ self .init_values = init_values
1718 self .inplace = inplace
18- self .gamma = nn .Parameter (init_values * torch .empty (dim , device = device , dtype = dtype ))
19+ self .gamma = nn .Parameter (torch .empty (dim , device = device , dtype = dtype ))
1920
2021 self .reset_parameters ()
2122
2223 def reset_parameters (self ):
23- torch .nn .init .ones_ (self .gamma )
24+ torch .nn .init .constant_ (self .gamma , self . init_values )
2425
2526 def forward (self , x : torch .Tensor ) -> torch .Tensor :
2627 return x .mul_ (self .gamma ) if self .inplace else x * self .gamma
@@ -38,13 +39,14 @@ def __init__(
3839 dtype = None ,
3940 ):
4041 super ().__init__ ()
42+ self .init_values = init_values
4143 self .inplace = inplace
42- self .gamma = nn .Parameter (init_values * torch .empty (dim , device = device , dtype = dtype ))
44+ self .gamma = nn .Parameter (torch .empty (dim , device = device , dtype = dtype ))
4345
4446 self .reset_parameters ()
4547
4648 def reset_parameters (self ):
47- torch .nn .init .ones_ (self .gamma )
49+ torch .nn .init .constant_ (self .gamma , self . init_values )
4850
4951 def forward (self , x ):
5052 gamma = self .gamma .view (1 , - 1 , 1 , 1 )
You can’t perform that action at this time.
0 commit comments