@@ -227,19 +227,22 @@ def __init__(
227227 hidden_act : str ,
228228 quant_config : QuantizationConfig | None = None ,
229229 bias : bool = False ,
230+ prefix : str = "" ,
230231 ) -> None :
231232 super ().__init__ ()
232233 self .gate_up_proj = MergedColumnParallelLinear (
233234 input_size = hidden_size ,
234235 output_sizes = [intermediate_size ] * 2 ,
235236 bias = bias ,
236237 quant_config = quant_config ,
238+ prefix = f"{ prefix } .gate_up_proj" ,
237239 )
238240 self .down_proj = RowParallelLinear (
239241 input_size = intermediate_size ,
240242 output_size = hidden_size ,
241243 bias = bias ,
242244 quant_config = quant_config ,
245+ prefix = f"{ prefix } .down_proj" ,
243246 )
244247 if hidden_act != "silu" :
245248 raise ValueError (
@@ -299,12 +302,14 @@ def __init__(
299302 total_num_kv_heads = self .total_num_kv_heads ,
300303 bias = bias ,
301304 quant_config = quant_config ,
305+ prefix = f"{ prefix } .qkv_proj" ,
302306 )
303307 self .o_proj = RowParallelLinear (
304308 input_size = self .total_num_heads * self .head_dim ,
305309 output_size = hidden_size ,
306310 bias = bias ,
307311 quant_config = quant_config ,
312+ prefix = f"{ prefix } .o_proj" ,
308313 )
309314 self .q_norm = ChameleonLayerNorm ((self .num_heads , self .head_dim ))
310315 self .k_norm = ChameleonLayerNorm ((self .num_kv_heads , self .head_dim ))
@@ -393,6 +398,7 @@ def __init__(
393398 hidden_act = config .hidden_act ,
394399 quant_config = quant_config ,
395400 bias = getattr (config , "mlp_bias" , False ),
401+ prefix = f"{ prefix } .mlp" ,
396402 )
397403 self .input_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
398404 self .post_attention_layernorm = RMSNorm (
@@ -462,6 +468,7 @@ def __init__(
462468 hidden_act = config .hidden_act ,
463469 quant_config = quant_config ,
464470 bias = getattr (config , "mlp_bias" , False ),
471+ prefix = f"{ prefix } .mlp" ,
465472 )
466473 self .input_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
467474 self .post_attention_layernorm = RMSNorm (
0 commit comments