Skip to content

Commit 34960c9

Browse files
Jialinxuebwang-amd
authored andcommitted
[Perf] Introduce FlattenLogprobs to store logprobs results to reduce GC overhead (vllm-project#28171)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent f9694a0 commit 34960c9

File tree

6 files changed

+534
-125
lines changed

6 files changed

+534
-125
lines changed

tests/samplers/test_logprobs.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
6+
from vllm import SamplingParams
7+
from vllm.logprobs import FlattenLogprobs
8+
9+
MODELS = ["distilbert/distilgpt2"]
10+
MAX_TOKENS = 5
11+
NUM_TOP_LOGPROBS = 5
12+
NUM_PROMPT_LOGPROBS = 7
13+
MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
14+
15+
16+
@pytest.mark.parametrize("model", MODELS)
17+
@pytest.mark.parametrize("dtype", ["half"])
18+
@pytest.mark.parametrize("greedy", [True, False])
19+
@pytest.mark.parametrize("flatten_logprobs", [True, False])
20+
def test_ranks(
21+
vllm_runner,
22+
model,
23+
dtype,
24+
greedy,
25+
flatten_logprobs,
26+
example_prompts,
27+
monkeypatch: pytest.MonkeyPatch,
28+
):
29+
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0")
30+
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
31+
tokenizer = vllm_model.llm.get_tokenizer()
32+
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
33+
sampling_params = SamplingParams(
34+
temperature=0.0 if greedy else 1.0,
35+
top_p=1.0,
36+
max_tokens=MAX_TOKENS,
37+
logprobs=NUM_TOP_LOGPROBS,
38+
prompt_logprobs=NUM_PROMPT_LOGPROBS,
39+
)
40+
results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
41+
42+
assert len(results) == len(example_prompt_tokens)
43+
for i, (result, prompt_tokens) in enumerate(zip(results, example_prompt_tokens)):
44+
decode_tokens, _, decode_logprobs, prompt_logprobs = result
45+
46+
# Ensure the return type of logprobs is accurate
47+
assert isinstance(
48+
prompt_logprobs, FlattenLogprobs if flatten_logprobs else list
49+
)
50+
assert isinstance(
51+
decode_logprobs, FlattenLogprobs if flatten_logprobs else list
52+
)
53+
54+
########################
55+
# Check prompt logprobs
56+
########################
57+
assert len(prompt_tokens) == len(prompt_logprobs)
58+
# No logprob for first prompt token
59+
assert not prompt_logprobs[0]
60+
for position, (token, logprobs) in enumerate(
61+
zip(prompt_tokens[1:], prompt_logprobs[1:]), start=1
62+
):
63+
# Ensure logprobs of prompt token is always returned
64+
logprob = logprobs.get(token)
65+
assert logprob is not None
66+
assert logprob.rank >= 1
67+
# Ensure # of returned logprobs should be
68+
# either NUM_PROMPT_LOGPROBS or NUM_PROMPT_LOGPROBS+1
69+
assert NUM_PROMPT_LOGPROBS <= len(logprobs) <= NUM_PROMPT_LOGPROBS + 1
70+
# Ensure top NUM_PROMPT_LOGPROBS is always extracted
71+
assert set(range(1, NUM_PROMPT_LOGPROBS + 1)).issubset(
72+
{logprob.rank for logprob in logprobs.values()}
73+
)
74+
75+
########################
76+
# Check sample logprobs
77+
########################
78+
assert len(decode_tokens) == len(decode_logprobs)
79+
for position, (token, logprobs) in enumerate(
80+
zip(decode_tokens, decode_logprobs)
81+
):
82+
# Ensure logprobs of chosen token is always returned
83+
logprob = logprobs.get(token)
84+
assert logprob is not None
85+
if greedy:
86+
# For greedy sampling, all chosen logprob should be top ranked
87+
assert logprob.rank == 1
88+
else:
89+
assert logprob.rank >= 1
90+
# Ensure # of returned logprobs should be
91+
# either NUM_TOP_LOGPROBS or NUM_TOP_LOGPROBS+1
92+
assert NUM_TOP_LOGPROBS <= len(logprobs) <= NUM_TOP_LOGPROBS + 1
93+
# Ensure top NUM_TOP_LOGPROBS logprobs is always extracted
94+
assert set(range(1, NUM_TOP_LOGPROBS + 1)).issubset(
95+
{logprob.rank for logprob in logprobs.values()}
96+
)

tests/samplers/test_ranks.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

tests/test_logprobs.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
5+
import pytest
6+
7+
from vllm.logprobs import (
8+
FlattenLogprobs,
9+
Logprob,
10+
LogprobsOnePosition,
11+
append_logprobs_for_next_position,
12+
create_prompt_logprobs,
13+
create_sample_logprobs,
14+
)
15+
16+
17+
def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
18+
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
19+
20+
prompt_logprobs = create_prompt_logprobs()
21+
assert isinstance(prompt_logprobs, list)
22+
# Ensure first prompt position logprobs is None
23+
assert len(prompt_logprobs) == 1
24+
assert prompt_logprobs[0] is None
25+
26+
sample_logprobs = create_sample_logprobs()
27+
assert isinstance(sample_logprobs, list)
28+
assert len(sample_logprobs) == 0
29+
30+
31+
def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
32+
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
33+
34+
prompt_logprobs = create_prompt_logprobs()
35+
assert isinstance(prompt_logprobs, FlattenLogprobs)
36+
assert prompt_logprobs.start_indices == [0]
37+
assert prompt_logprobs.end_indices == [0]
38+
assert len(prompt_logprobs.token_ids) == 0
39+
assert len(prompt_logprobs.logprobs) == 0
40+
assert len(prompt_logprobs.ranks) == 0
41+
assert len(prompt_logprobs.decoded_tokens) == 0
42+
# Ensure first prompt position logprobs is empty
43+
assert len(prompt_logprobs) == 1
44+
assert prompt_logprobs[0] == dict()
45+
46+
sample_logprobs = create_sample_logprobs()
47+
assert isinstance(sample_logprobs, FlattenLogprobs)
48+
assert len(sample_logprobs.start_indices) == 0
49+
assert len(sample_logprobs.end_indices) == 0
50+
assert len(sample_logprobs.token_ids) == 0
51+
assert len(sample_logprobs.logprobs) == 0
52+
assert len(sample_logprobs.ranks) == 0
53+
assert len(sample_logprobs.decoded_tokens) == 0
54+
assert len(sample_logprobs) == 0
55+
56+
57+
def test_append_logprobs_for_next_position_none_flatten(
58+
monkeypatch: pytest.MonkeyPatch,
59+
) -> None:
60+
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
61+
logprobs = create_sample_logprobs()
62+
append_logprobs_for_next_position(
63+
logprobs,
64+
token_ids=[1],
65+
logprobs=[0.1],
66+
decoded_tokens=["1"],
67+
rank=10,
68+
num_logprobs=-1,
69+
)
70+
append_logprobs_for_next_position(
71+
logprobs,
72+
token_ids=[2, 3],
73+
logprobs=[0.2, 0.3],
74+
decoded_tokens=["2", "3"],
75+
rank=11,
76+
num_logprobs=-1,
77+
)
78+
assert isinstance(logprobs, list)
79+
assert logprobs == [
80+
{1: Logprob(logprob=0.1, rank=10, decoded_token="1")},
81+
{
82+
2: Logprob(logprob=0.2, rank=11, decoded_token="2"),
83+
3: Logprob(logprob=0.3, rank=1, decoded_token="3"),
84+
},
85+
]
86+
87+
88+
def test_append_logprobs_for_next_position_flatten(
89+
monkeypatch: pytest.MonkeyPatch,
90+
) -> None:
91+
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
92+
logprobs = create_sample_logprobs()
93+
append_logprobs_for_next_position(
94+
logprobs,
95+
token_ids=[1],
96+
logprobs=[0.1],
97+
decoded_tokens=["1"],
98+
rank=10,
99+
num_logprobs=-1,
100+
)
101+
append_logprobs_for_next_position(
102+
logprobs,
103+
token_ids=[2, 3],
104+
logprobs=[0.2, 0.3],
105+
decoded_tokens=["2", "3"],
106+
rank=11,
107+
num_logprobs=-1,
108+
)
109+
assert isinstance(logprobs, FlattenLogprobs)
110+
assert logprobs.start_indices == [0, 1]
111+
assert logprobs.end_indices == [1, 3]
112+
assert logprobs.token_ids == [1, 2, 3]
113+
assert logprobs.logprobs == [0.1, 0.2, 0.3]
114+
assert logprobs.ranks == [10, 11, 1]
115+
assert logprobs.decoded_tokens == ["1", "2", "3"]
116+
117+
118+
LOGPROBS_ONE_POSITION_0: LogprobsOnePosition = {
119+
1: Logprob(logprob=0.1, rank=10, decoded_token="10")
120+
}
121+
LOGPROBS_ONE_POSITION_1: LogprobsOnePosition = {
122+
2: Logprob(logprob=0.2, rank=20, decoded_token="20"),
123+
3: Logprob(logprob=0.3, rank=30, decoded_token="30"),
124+
}
125+
LOGPROBS_ONE_POSITION_2: LogprobsOnePosition = {
126+
4: Logprob(logprob=0.4, rank=40, decoded_token="40"),
127+
5: Logprob(logprob=0.5, rank=50, decoded_token="50"),
128+
6: Logprob(logprob=0.6, rank=60, decoded_token="60"),
129+
}
130+
131+
132+
def test_flatten_logprobs_append() -> None:
133+
logprobs = FlattenLogprobs()
134+
logprobs.append(LOGPROBS_ONE_POSITION_0)
135+
logprobs.append(LOGPROBS_ONE_POSITION_1)
136+
assert logprobs.start_indices == [0, 1]
137+
assert logprobs.end_indices == [1, 3]
138+
assert logprobs.token_ids == [1, 2, 3]
139+
assert logprobs.logprobs == [0.1, 0.2, 0.3]
140+
assert logprobs.ranks == [10, 20, 30]
141+
assert logprobs.decoded_tokens == ["10", "20", "30"]
142+
143+
logprobs.append(LOGPROBS_ONE_POSITION_2)
144+
assert logprobs.start_indices == [0, 1, 3]
145+
assert logprobs.end_indices == [1, 3, 6]
146+
assert logprobs.token_ids == [1, 2, 3, 4, 5, 6]
147+
assert logprobs.logprobs == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
148+
assert logprobs.ranks == [10, 20, 30, 40, 50, 60]
149+
assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"]
150+
151+
152+
def test_flatten_logprobs_extend() -> None:
153+
logprobs = FlattenLogprobs()
154+
# Extend with list[LogprobsOnePosition]
155+
logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0])
156+
assert logprobs.start_indices == [0, 3]
157+
assert logprobs.end_indices == [3, 4]
158+
assert logprobs.token_ids == [4, 5, 6, 1]
159+
assert logprobs.logprobs == [0.4, 0.5, 0.6, 0.1]
160+
assert logprobs.ranks == [40, 50, 60, 10]
161+
assert logprobs.decoded_tokens == ["40", "50", "60", "10"]
162+
163+
other_logprobs = FlattenLogprobs()
164+
other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0])
165+
# Extend with another FlattenLogprobs
166+
logprobs.extend(other_logprobs)
167+
assert logprobs.start_indices == [0, 3, 4, 6]
168+
assert logprobs.end_indices == [3, 4, 6, 7]
169+
assert logprobs.token_ids == [4, 5, 6, 1, 2, 3, 1]
170+
assert logprobs.logprobs == [0.4, 0.5, 0.6, 0.1, 0.2, 0.3, 0.1]
171+
assert logprobs.ranks == [40, 50, 60, 10, 20, 30, 10]
172+
assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"]
173+
174+
175+
def test_flatten_logprobs_access() -> None:
176+
logprobs = FlattenLogprobs()
177+
logprobs.extend(
178+
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]
179+
)
180+
assert logprobs.start_indices == [0, 2, 5]
181+
assert logprobs.end_indices == [2, 5, 6]
182+
assert logprobs.token_ids == [2, 3, 4, 5, 6, 1]
183+
assert logprobs.logprobs == [0.2, 0.3, 0.4, 0.5, 0.6, 0.1]
184+
assert logprobs.ranks == [20, 30, 40, 50, 60, 10]
185+
assert logprobs.decoded_tokens == ["20", "30", "40", "50", "60", "10"]
186+
187+
# Test __len__
188+
assert len(logprobs) == 3
189+
190+
# Test __iter__
191+
for actual_logprobs, expected_logprobs in zip(
192+
logprobs,
193+
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0],
194+
):
195+
assert actual_logprobs == expected_logprobs
196+
197+
# Test __getitem__ : single item
198+
assert logprobs[0] == LOGPROBS_ONE_POSITION_1
199+
assert logprobs[1] == LOGPROBS_ONE_POSITION_2
200+
assert logprobs[2] == LOGPROBS_ONE_POSITION_0
201+
202+
# Test __getitem__ : slice
203+
logprobs02 = logprobs[:2]
204+
assert len(logprobs02) == 2
205+
assert logprobs02[0] == LOGPROBS_ONE_POSITION_1
206+
assert logprobs02[1] == LOGPROBS_ONE_POSITION_2
207+
assert logprobs02.start_indices == [0, 2]
208+
assert logprobs02.end_indices == [2, 5]
209+
assert logprobs02.token_ids == [2, 3, 4, 5, 6]
210+
assert logprobs02.logprobs == [0.2, 0.3, 0.4, 0.5, 0.6]
211+
assert logprobs02.ranks == [20, 30, 40, 50, 60]
212+
assert logprobs02.decoded_tokens == ["20", "30", "40", "50", "60"]
213+
logprobs_last2 = logprobs[-2:]
214+
assert len(logprobs_last2) == 2
215+
assert logprobs_last2[0] == LOGPROBS_ONE_POSITION_2
216+
assert logprobs_last2[1] == LOGPROBS_ONE_POSITION_0
217+
assert logprobs_last2.start_indices == [0, 3]
218+
assert logprobs_last2.end_indices == [3, 4]
219+
assert logprobs_last2.token_ids == [4, 5, 6, 1]
220+
assert logprobs_last2.logprobs == [0.4, 0.5, 0.6, 0.1]
221+
assert logprobs_last2.ranks == [40, 50, 60, 10]
222+
assert logprobs_last2.decoded_tokens == ["40", "50", "60", "10"]

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@
220220
VLLM_GC_DEBUG: str = ""
221221
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
222222
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
223+
VLLM_FLATTEN_LOGPROBS: bool = False
223224

224225

225226
def get_default_cache_root():
@@ -1463,6 +1464,11 @@ def get_vllm_port() -> int | None:
14631464
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
14641465
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
14651466
),
1467+
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than
1468+
# the original list[dict[int, Logprob]] approach.
1469+
# After enabled, PromptLogprobs and SampleLogprobs would populated as
1470+
# FlattenLogprobs.
1471+
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))),
14661472
}
14671473

14681474
# --8<-- [end:env-vars-definition]

0 commit comments

Comments
 (0)