Skip to content

Commit 36651e8

Browse files
MengqingCaoxuebwang-amd
authored andcommitted
[Misc][Model][Refactor] Pass the prefix into Linear layers (vllm-project#28259)
Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 4424e75 commit 36651e8

26 files changed

+190
-25
lines changed

vllm/model_executor/models/arctic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,19 @@ def __init__(
7575
)
7676

7777
self.w13 = MergedColumnParallelLinear(
78-
self.hidden_size, [self.ffn_dim] * 2, bias=False, quant_config=quant_config
78+
self.hidden_size,
79+
[self.ffn_dim] * 2,
80+
bias=False,
81+
quant_config=quant_config,
82+
prefix=f"{prefix}.w13",
7983
)
8084
self.w2 = RowParallelLinear(
8185
self.ffn_dim,
8286
self.hidden_size,
8387
bias=False,
8488
reduce_results=reduce_results,
8589
quant_config=quant_config,
90+
prefix=f"{prefix}.w2",
8691
)
8792
if config.hidden_act != "silu":
8893
raise ValueError(
@@ -297,13 +302,15 @@ def __init__(
297302
self.total_num_kv_heads,
298303
bias=False,
299304
quant_config=quant_config,
305+
prefix=f"{prefix}.qkv_proj",
300306
)
301307
self.o_proj = RowParallelLinear(
302308
self.total_num_heads * self.head_dim,
303309
self.hidden_size,
304310
bias=False,
305311
reduce_results=True,
306312
quant_config=quant_config,
313+
prefix=f"{prefix}.o_proj",
307314
)
308315

309316
self.rotary_emb = get_rope(

vllm/model_executor/models/baichuan.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,22 @@ def __init__(
9898
intermediate_size: int,
9999
hidden_act: str,
100100
quant_config: QuantizationConfig | None = None,
101+
prefix: str = "",
101102
):
102103
super().__init__()
103104
self.gate_up_proj = MergedColumnParallelLinear(
104-
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
105+
hidden_size,
106+
[intermediate_size] * 2,
107+
bias=False,
108+
quant_config=quant_config,
109+
prefix=f"{prefix}.gate_up_proj",
105110
)
106111
self.down_proj = RowParallelLinear(
107-
intermediate_size, hidden_size, bias=False, quant_config=quant_config
112+
intermediate_size,
113+
hidden_size,
114+
bias=False,
115+
quant_config=quant_config,
116+
prefix=f"{prefix}.down_proj",
108117
)
109118
if hidden_act != "silu":
110119
raise ValueError(
@@ -152,12 +161,14 @@ def __init__(
152161
self.total_num_heads,
153162
bias=False,
154163
quant_config=quant_config,
164+
prefix=f"{prefix}.W_pack",
155165
)
156166
self.o_proj = RowParallelLinear(
157167
self.total_num_heads * self.head_dim,
158168
hidden_size,
159169
bias=False,
160170
quant_config=quant_config,
171+
prefix=f"{prefix}.o_proj",
161172
)
162173
# Create the alibi slopes and slice them.
163174
if self.position_embedding == "ALIBI":
@@ -235,6 +246,7 @@ def __init__(
235246
intermediate_size=config.intermediate_size,
236247
hidden_act=config.hidden_act,
237248
quant_config=quant_config,
249+
prefix=f"{prefix}.mlp",
238250
)
239251
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
240252
self.post_attention_layernorm = RMSNorm(

vllm/model_executor/models/bamba.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,22 @@ def __init__(
6060
config: BambaConfig,
6161
quant_config: QuantizationConfig | None = None,
6262
bias: bool = False,
63+
prefix: str = "",
6364
) -> None:
6465
super().__init__()
6566
self.gate_up_proj = MergedColumnParallelLinear(
6667
input_size=config.hidden_size,
6768
output_sizes=[config.intermediate_size] * 2,
6869
bias=bias,
6970
quant_config=quant_config,
71+
prefix=f"{prefix}.gate_up_proj",
7072
)
7173
self.down_proj = RowParallelLinear(
7274
input_size=config.intermediate_size,
7375
output_size=config.hidden_size,
7476
bias=bias,
7577
quant_config=quant_config,
78+
prefix=f"{prefix}.down_proj",
7679
)
7780
if config.hidden_act != "silu":
7881
raise ValueError(
@@ -118,7 +121,9 @@ def __init__(
118121
prefix=f"{prefix}.mixer",
119122
)
120123

121-
self.feed_forward = BambaMLP(config, quant_config=quant_config)
124+
self.feed_forward = BambaMLP(
125+
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
126+
)
122127
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
123128
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
124129

@@ -202,12 +207,14 @@ def __init__(
202207
self.total_num_kv_heads,
203208
bias=False,
204209
quant_config=quant_config,
210+
prefix=f"{prefix}.qkv_proj",
205211
)
206212
self.o_proj = RowParallelLinear(
207213
self.total_num_heads * self.head_dim,
208214
config.hidden_size,
209215
bias=False,
210216
quant_config=quant_config,
217+
prefix=f"{prefix}.o_proj",
211218
)
212219

213220
self.attn = Attention(
@@ -219,7 +226,9 @@ def __init__(
219226
prefix=f"{prefix}.attn",
220227
)
221228

222-
self.feed_forward = BambaMLP(config, quant_config=quant_config)
229+
self.feed_forward = BambaMLP(
230+
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
231+
)
223232
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
224233
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
225234

vllm/model_executor/models/bloom.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,14 @@ def __init__(
108108
self.total_num_heads,
109109
bias=True,
110110
quant_config=quant_config,
111+
prefix=f"{prefix}.query_key_value",
111112
)
112113
self.dense = RowParallelLinear(
113114
self.hidden_size,
114115
self.hidden_size,
115116
bias=True,
116117
quant_config=quant_config,
118+
prefix=f"{prefix}.dense",
117119
)
118120

119121
# Create the alibi slopes and slice them.
@@ -152,19 +154,22 @@ def __init__(
152154
self,
153155
config: BloomConfig,
154156
quant_config: QuantizationConfig | None = None,
157+
prefix: str = "",
155158
):
156159
super().__init__()
157160
hidden_size = config.hidden_size
158161
self.dense_h_to_4h = ColumnParallelLinear(
159162
hidden_size,
160163
4 * hidden_size,
161164
quant_config=quant_config,
165+
prefix=f"{prefix}.dense_h_to_4h",
162166
)
163167
self.gelu_impl = get_act_fn("gelu")
164168
self.dense_4h_to_h = RowParallelLinear(
165169
4 * hidden_size,
166170
hidden_size,
167171
quant_config=quant_config,
172+
prefix=f"{prefix}.dense_4h_to_h",
168173
)
169174

170175
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -192,7 +197,7 @@ def __init__(
192197
self.post_attention_layernorm = nn.LayerNorm(
193198
hidden_size, eps=config.layer_norm_epsilon
194199
)
195-
self.mlp = BloomMLP(config, quant_config)
200+
self.mlp = BloomMLP(config, quant_config, prefix=f"{prefix}.mlp")
196201
self.apply_residual_connection_post_layernorm = (
197202
config.apply_residual_connection_post_layernorm
198203
)

vllm/model_executor/models/chameleon.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,19 +227,22 @@ def __init__(
227227
hidden_act: str,
228228
quant_config: QuantizationConfig | None = None,
229229
bias: bool = False,
230+
prefix: str = "",
230231
) -> None:
231232
super().__init__()
232233
self.gate_up_proj = MergedColumnParallelLinear(
233234
input_size=hidden_size,
234235
output_sizes=[intermediate_size] * 2,
235236
bias=bias,
236237
quant_config=quant_config,
238+
prefix=f"{prefix}.gate_up_proj",
237239
)
238240
self.down_proj = RowParallelLinear(
239241
input_size=intermediate_size,
240242
output_size=hidden_size,
241243
bias=bias,
242244
quant_config=quant_config,
245+
prefix=f"{prefix}.down_proj",
243246
)
244247
if hidden_act != "silu":
245248
raise ValueError(
@@ -299,12 +302,14 @@ def __init__(
299302
total_num_kv_heads=self.total_num_kv_heads,
300303
bias=bias,
301304
quant_config=quant_config,
305+
prefix=f"{prefix}.qkv_proj",
302306
)
303307
self.o_proj = RowParallelLinear(
304308
input_size=self.total_num_heads * self.head_dim,
305309
output_size=hidden_size,
306310
bias=bias,
307311
quant_config=quant_config,
312+
prefix=f"{prefix}.o_proj",
308313
)
309314
self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
310315
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
@@ -393,6 +398,7 @@ def __init__(
393398
hidden_act=config.hidden_act,
394399
quant_config=quant_config,
395400
bias=getattr(config, "mlp_bias", False),
401+
prefix=f"{prefix}.mlp",
396402
)
397403
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
398404
self.post_attention_layernorm = RMSNorm(
@@ -462,6 +468,7 @@ def __init__(
462468
hidden_act=config.hidden_act,
463469
quant_config=quant_config,
464470
bias=getattr(config, "mlp_bias", False),
471+
prefix=f"{prefix}.mlp",
465472
)
466473
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
467474
self.post_attention_layernorm = RMSNorm(

vllm/model_executor/models/dbrx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,14 @@ def __init__(
209209
self.total_num_kv_heads,
210210
bias=False,
211211
quant_config=quant_config,
212+
prefix=f"{prefix}.Wqkv",
212213
)
213214
self.out_proj = RowParallelLinear(
214215
self.d_model,
215216
self.d_model,
216217
bias=False,
217218
quant_config=quant_config,
219+
prefix=f"{prefix}.out_proj",
218220
)
219221
self.rotary_emb = get_rope(
220222
self.head_dim,

vllm/model_executor/models/deepseek.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,19 @@ def __init__(
8585
) -> None:
8686
super().__init__()
8787
self.gate_up_proj = MergedColumnParallelLinear(
88-
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
88+
hidden_size,
89+
[intermediate_size] * 2,
90+
bias=False,
91+
quant_config=quant_config,
92+
prefix=f"{prefix}.gate_up_proj",
8993
)
9094
self.down_proj = RowParallelLinear(
9195
intermediate_size,
9296
hidden_size,
9397
bias=False,
9498
quant_config=quant_config,
9599
reduce_results=reduce_results,
100+
prefix=f"{prefix}.down_proj",
96101
)
97102
if hidden_act != "silu":
98103
raise ValueError(
@@ -242,13 +247,15 @@ def __init__(
242247
self.total_num_kv_heads,
243248
bias=False,
244249
quant_config=quant_config,
250+
prefix=f"{prefix}.qkv_proj",
245251
)
246252

247253
self.o_proj = RowParallelLinear(
248254
self.total_num_heads * self.head_dim,
249255
hidden_size,
250256
bias=False,
251257
quant_config=quant_config,
258+
prefix=f"{prefix}.o_proj",
252259
)
253260

254261
self.rotary_emb = get_rope(

vllm/model_executor/models/dots1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,15 @@ def __init__(
240240
self.total_num_kv_heads,
241241
bias=attention_bias,
242242
quant_config=quant_config,
243+
prefix=f"{prefix}.qkv_proj",
243244
)
244245

245246
self.o_proj = RowParallelLinear(
246247
self.total_num_heads * self.head_dim,
247248
hidden_size,
248249
bias=False,
249250
quant_config=quant_config,
251+
prefix=f"{prefix}.o_proj",
250252
)
251253

252254
self.rotary_emb = get_rope(

vllm/model_executor/models/falcon.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def __init__(
137137
bias=config.bias,
138138
skip_bias_add=True,
139139
quant_config=quant_config,
140+
prefix=f"{prefix}.query_key_value",
140141
)
141142
self.q_size = self.num_heads * self.head_dim
142143
self.kv_size = self.num_kv_heads * self.head_dim
@@ -153,6 +154,7 @@ def __init__(
153154
skip_bias_add=True,
154155
quant_config=quant_config,
155156
reduce_results=self.reduce_row_parallel_results,
157+
prefix=f"{prefix}.dense",
156158
)
157159

158160
self.use_rotary = config.rotary
@@ -227,6 +229,7 @@ def __init__(
227229
self,
228230
config: FalconConfig,
229231
quant_config: QuantizationConfig | None = None,
232+
prefix: str = "",
230233
):
231234
super().__init__()
232235
hidden_size = config.hidden_size
@@ -237,6 +240,7 @@ def __init__(
237240
bias=config.bias,
238241
skip_bias_add=True,
239242
quant_config=quant_config,
243+
prefix=f"{prefix}.dense_h_to_4h",
240244
)
241245
self.act = get_act_fn("gelu")
242246
self.reduce_row_parallel_results = not (
@@ -249,6 +253,7 @@ def __init__(
249253
skip_bias_add=True,
250254
reduce_results=self.reduce_row_parallel_results,
251255
quant_config=quant_config,
256+
prefix=f"{prefix}.dense_4h_to_h",
252257
)
253258

254259
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -275,7 +280,7 @@ def __init__(
275280
self.self_attention = FalconAttention(
276281
config, cache_config, quant_config, prefix=f"{prefix}.self_attention"
277282
)
278-
self.mlp = FalconMLP(config, quant_config)
283+
self.mlp = FalconMLP(config, quant_config, prefix=f"{prefix}.mlp")
279284
self.config = config
280285

281286
if not hasattr(config, "num_ln_in_parallel_attn"):

vllm/model_executor/models/falcon_h1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,22 @@ def __init__(
5959
config: FalconH1Config,
6060
quant_config: QuantizationConfig | None = None,
6161
bias: bool = False,
62+
prefix: str = "",
6263
) -> None:
6364
super().__init__()
6465
self.gate_up_proj = MergedColumnParallelLinear(
6566
input_size=config.hidden_size,
6667
output_sizes=[config.intermediate_size] * 2,
6768
bias=bias,
6869
quant_config=quant_config,
70+
prefix=f"{prefix}.gate_up_proj",
6971
)
7072
self.down_proj = RowParallelLinear(
7173
input_size=config.intermediate_size,
7274
output_size=config.hidden_size,
7375
bias=bias,
7476
quant_config=quant_config,
77+
prefix=f"{prefix}.down_proj",
7578
)
7679
self.tp_size = get_tensor_model_parallel_world_size()
7780
self.intermediate_size = config.intermediate_size
@@ -365,7 +368,7 @@ def __init__(
365368
self.attention_in_multiplier = config.attention_in_multiplier
366369
self.attn_out_multiplier = config.attention_out_multiplier
367370

368-
self.feed_forward = FalconH1MLP(config)
371+
self.feed_forward = FalconH1MLP(config, prefix=f"{prefix}.feed_forward")
369372

370373
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
371374
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

0 commit comments

Comments
 (0)