Skip to content

Commit b8eafe4

Browse files
authored
Merge pull request #280 from algorithmicsuperintelligence/fix-max-tokens
Fix max tokens
2 parents 05f9557 + 2f15dc6 commit b8eafe4

File tree

14 files changed

+128
-66
lines changed

14 files changed

+128
-66
lines changed

optillm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Version information
2-
__version__ = "0.3.8"
2+
__version__ = "0.3.9"
33

44
# Import from server module
55
from .server import (

optillm/bon.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44

55
logger = logging.getLogger(__name__)
66

7-
def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3, request_id: str = None) -> str:
7+
def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3, request_config: dict = None, request_id: str = None) -> str:
88
bon_completion_tokens = 0
99

10+
# Extract max_tokens from request_config with default
11+
max_tokens = 4096
12+
if request_config:
13+
max_tokens = request_config.get('max_tokens', max_tokens)
14+
1015
messages = [{"role": "system", "content": system_prompt},
1116
{"role": "user", "content": initial_query}]
1217

@@ -17,7 +22,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
1722
provider_request = {
1823
"model": model,
1924
"messages": messages,
20-
"max_tokens": 4096,
25+
"max_tokens": max_tokens,
2126
"n": n,
2227
"temperature": 1
2328
}
@@ -50,7 +55,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
5055
provider_request = {
5156
"model": model,
5257
"messages": messages,
53-
"max_tokens": 4096,
58+
"max_tokens": max_tokens,
5459
"temperature": 1
5560
}
5661
response = client.chat.completions.create(**provider_request)

optillm/leap.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
logger = logging.getLogger(__name__)
1111

1212
class LEAP:
13-
def __init__(self, system_prompt: str, client, model: str, request_id: str = None):
13+
def __init__(self, system_prompt: str, client, model: str, request_config: dict = None, request_id: str = None):
1414
self.system_prompt = system_prompt
1515
self.client = client
1616
self.model = model
@@ -19,6 +19,11 @@ def __init__(self, system_prompt: str, client, model: str, request_id: str = Non
1919
self.high_level_principles = []
2020
self.leap_completion_tokens = 0
2121

22+
# Extract max_tokens from request_config with default
23+
self.max_tokens = 4096
24+
if request_config:
25+
self.max_tokens = request_config.get('max_tokens', self.max_tokens)
26+
2227
def extract_output(self, text: str) -> str:
2328
match = re.search(r'<output>(.*?)(?:</output>|$)', text, re.DOTALL)
2429
return match.group(1).strip() if match else ""
@@ -29,7 +34,7 @@ def extract_examples_from_query(self, initial_query: str) -> List[Tuple[str, str
2934
# Prepare request for logging
3035
provider_request = {
3136
"model": self.model,
32-
"max_tokens": 4096,
37+
"max_tokens": self.max_tokens,
3338
"messages": [
3439
{"role": "system", "content": self.system_prompt},
3540
{"role": "user", "content": f"""
@@ -83,7 +88,7 @@ def generate_mistakes(self, examples: List[Tuple[str, str]]) -> List[Tuple[str,
8388
# Prepare request for logging
8489
provider_request = {
8590
"model": self.model,
86-
"max_tokens": 4096,
91+
"max_tokens": self.max_tokens,
8792
"messages": [
8893
{"role": "system", "content": self.system_prompt},
8994
{"role": "user", "content": f"""
@@ -116,7 +121,7 @@ def generate_low_level_principles(self, mistakes: List[Tuple[str, str, str, str]
116121
# Prepare request for logging
117122
provider_request = {
118123
"model": self.model,
119-
"max_tokens": 4096,
124+
"max_tokens": self.max_tokens,
120125
"messages": [
121126
{"role": "system", "content": self.system_prompt},
122127
{"role": "user", "content": f"""
@@ -152,7 +157,7 @@ def generate_high_level_principles(self) -> List[str]:
152157
# Prepare request for logging
153158
provider_request = {
154159
"model": self.model,
155-
"max_tokens": 4096,
160+
"max_tokens": self.max_tokens,
156161
"messages": [
157162
{"role": "system", "content": self.system_prompt},
158163
{"role": "user", "content": f"""
@@ -185,7 +190,7 @@ def apply_principles(self, query: str) -> str:
185190
# Prepare request for logging
186191
provider_request = {
187192
"model": self.model,
188-
"max_tokens": 4096,
193+
"max_tokens": self.max_tokens,
189194
"messages": [
190195
{"role": "system", "content": self.system_prompt},
191196
{"role": "user", "content": f"""
@@ -220,6 +225,6 @@ def solve(self, initial_query: str) -> str:
220225

221226
return self.apply_principles(initial_query)
222227

223-
def leap(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str:
224-
leap_solver = LEAP(system_prompt, client, model, request_id)
228+
def leap(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None, request_id: str = None) -> str:
229+
leap_solver = LEAP(system_prompt, client, model, request_config=request_config, request_id=request_id)
225230
return leap_solver.solve(initial_query), leap_solver.leap_completion_tokens

optillm/mcts.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, state: DialogueState, parent=None):
2626
self.value = 0
2727

2828
class MCTS:
29-
def __init__(self, simulation_depth, exploration_weight, client, model, request_id=None):
29+
def __init__(self, simulation_depth, exploration_weight, client, model, request_config=None, request_id=None):
3030
self.simulation_depth = simulation_depth
3131
self.exploration_weight = exploration_weight
3232
self.root = None
@@ -37,6 +37,11 @@ def __init__(self, simulation_depth, exploration_weight, client, model, request_
3737
self.completion_tokens = 0
3838
self.request_id = request_id
3939

40+
# Extract max_tokens from request_config with default
41+
self.max_tokens = 4096
42+
if request_config:
43+
self.max_tokens = request_config.get('max_tokens', self.max_tokens)
44+
4045
def select(self, node: MCTSNode) -> MCTSNode:
4146
logger.debug(f"Selecting node. Current node visits: {node.visits}, value: {node.value}")
4247
if not node.children:
@@ -117,7 +122,7 @@ def generate_actions(self, state: DialogueState) -> List[str]:
117122
provider_request = {
118123
"model": self.model,
119124
"messages": messages,
120-
"max_tokens": 4096,
125+
"max_tokens": self.max_tokens,
121126
"n": n,
122127
"temperature": 1
123128
}
@@ -151,7 +156,7 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState:
151156
provider_request = {
152157
"model": self.model,
153158
"messages": messages,
154-
"max_tokens": 1024,
159+
"max_tokens": min(self.max_tokens, 1024),
155160
"n": 1,
156161
"temperature": 1
157162
}
@@ -220,11 +225,11 @@ def evaluate_state(self, state: DialogueState) -> float:
220225
logger.warning("Failed to parse evaluation score. Using default value 0.5")
221226
return 0.5 # Default to a neutral score if parsing fails
222227

223-
def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, num_simulations: int = 2, exploration_weight: float = 0.2,
224-
simulation_depth: int = 1, request_id: str = None) -> str:
228+
def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, num_simulations: int = 2, exploration_weight: float = 0.2,
229+
simulation_depth: int = 1, request_config: dict = None, request_id: str = None) -> str:
225230
logger.info("Starting chat with MCTS")
226231
logger.info(f"Parameters: num_simulations={num_simulations}, exploration_weight={exploration_weight}, simulation_depth={simulation_depth}")
227-
mcts = MCTS(simulation_depth=simulation_depth, exploration_weight=exploration_weight, client=client, model=model, request_id=request_id)
232+
mcts = MCTS(simulation_depth=simulation_depth, exploration_weight=exploration_weight, client=client, model=model, request_config=request_config, request_id=request_id)
228233
initial_state = DialogueState(system_prompt, [], initial_query)
229234
logger.info(f"Initial query: {initial_query}")
230235
final_state = mcts.search(initial_state, num_simulations)

optillm/moa.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44

55
logger = logging.getLogger(__name__)
66

7-
def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str:
7+
def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None, request_id: str = None) -> str:
88
logger.info(f"Starting mixture_of_agents function with model: {model}")
99
moa_completion_tokens = 0
10+
11+
# Extract max_tokens from request_config with default
12+
max_tokens = 4096
13+
if request_config:
14+
max_tokens = request_config.get('max_tokens', max_tokens)
15+
1016
completions = []
1117

1218
logger.debug(f"Generating initial completions for query: {initial_query}")
@@ -19,7 +25,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
1925
{"role": "system", "content": system_prompt},
2026
{"role": "user", "content": initial_query}
2127
],
22-
"max_tokens": 4096,
28+
"max_tokens": max_tokens,
2329
"n": 3,
2430
"temperature": 1
2531
}
@@ -59,7 +65,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
5965
{"role": "system", "content": system_prompt},
6066
{"role": "user", "content": initial_query}
6167
],
62-
"max_tokens": 4096,
68+
"max_tokens": max_tokens,
6369
"temperature": 1
6470
}
6571

@@ -182,14 +188,14 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
182188
"""
183189

184190
logger.debug("Generating final response")
185-
191+
186192
provider_request = {
187193
"model": model,
188194
"messages": [
189195
{"role": "system", "content": system_prompt},
190196
{"role": "user", "content": final_prompt}
191197
],
192-
"max_tokens": 8192,
198+
"max_tokens": max_tokens,
193199
"n": 1,
194200
"temperature": 0.1
195201
}

optillm/plansearch.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66
logger = logging.getLogger(__name__)
77

88
class PlanSearch:
9-
def __init__(self, system_prompt: str, client, model: str, request_id: str = None):
9+
def __init__(self, system_prompt: str, client, model: str, request_config: dict = None, request_id: str = None):
1010
self.system_prompt = system_prompt
1111
self.client = client
1212
self.model = model
1313
self.request_id = request_id
1414
self.plansearch_completion_tokens = 0
1515

16+
# Extract max_tokens from request_config with default
17+
self.max_tokens = 4096
18+
if request_config:
19+
self.max_tokens = request_config.get('max_tokens', self.max_tokens)
20+
1621
def generate_observations(self, problem: str, num_observations: int = 3) -> List[str]:
1722
prompt = f"""You are an expert Python programmer. You will be given a competitive programming question
1823
(problem specification). You will return several useful, non-obvious, and correct observations
@@ -27,7 +32,7 @@ def generate_observations(self, problem: str, num_observations: int = 3) -> List
2732
# Prepare request for logging
2833
provider_request = {
2934
"model": self.model,
30-
"max_tokens": 4096,
35+
"max_tokens": self.max_tokens,
3136
"messages": [
3237
{"role": "system", "content": self.system_prompt},
3338
{"role": "user", "content": prompt}
@@ -71,7 +76,7 @@ def generate_derived_observations(self, problem: str, observations: List[str], n
7176
# Prepare request for logging
7277
provider_request = {
7378
"model": self.model,
74-
"max_tokens": 4096,
79+
"max_tokens": self.max_tokens,
7580
"messages": [
7681
{"role": "system", "content": self.system_prompt},
7782
{"role": "user", "content": prompt}
@@ -113,7 +118,7 @@ def generate_solution(self, problem: str, observations: List[str]) -> str:
113118
# Prepare request for logging
114119
provider_request = {
115120
"model": self.model,
116-
"max_tokens": 4096,
121+
"max_tokens": self.max_tokens,
117122
"messages": [
118123
{"role": "system", "content": self.system_prompt},
119124
{"role": "user", "content": prompt}
@@ -155,7 +160,7 @@ def implement_solution(self, problem: str, solution: str) -> str:
155160
# Prepare request for logging
156161
provider_request = {
157162
"model": self.model,
158-
"max_tokens": 4096,
163+
"max_tokens": self.max_tokens,
159164
"messages": [
160165
{"role": "system", "content": self.system_prompt},
161166
{"role": "user", "content": prompt}
@@ -204,6 +209,6 @@ def solve_multiple(self, problem: str, n: int, num_initial_observations: int = 3
204209
solutions.append(python_implementation)
205210
return solutions
206211

207-
def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1, request_id: str = None) -> List[str]:
208-
planner = PlanSearch(system_prompt, client, model, request_id)
212+
def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1, request_config: dict = None, request_id: str = None) -> List[str]:
213+
planner = PlanSearch(system_prompt, client, model, request_config=request_config, request_id=request_id)
209214
return planner.solve_multiple(initial_query, n), planner.plansearch_completion_tokens

optillm/pvg.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
pvg_completion_tokens = 0
1010

11-
def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7, request_id: str = None) -> List[str]:
11+
def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7, max_tokens: int = 4096, request_id: str = None) -> List[str]:
1212
global pvg_completion_tokens
1313
role = "sneaky" if is_sneaky else "helpful"
1414
logger.info(f"Generating {num_solutions} {role} solutions")
@@ -36,7 +36,7 @@ def generate_solutions(client, system_prompt: str, query: str, model: str, num_s
3636
"model": model,
3737
"messages": messages,
3838
"n": num_solutions,
39-
"max_tokens": 4096,
39+
"max_tokens": max_tokens,
4040
"temperature": temperature,
4141
}
4242
response = client.chat.completions.create(**provider_request)
@@ -151,10 +151,15 @@ def extract_answer(final_state: str) -> Tuple[str, float]:
151151
logger.warning("No answer found in the state.")
152152
return "", 0.0
153153

154-
def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3, request_id: str = None) -> str:
154+
def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3, request_config: dict = None, request_id: str = None) -> str:
155155
global pvg_completion_tokens
156156
logger.info(f"Starting inference-time PV game with {num_rounds} rounds and {num_solutions} solutions per round")
157-
157+
158+
# Extract max_tokens from request_config with default
159+
max_tokens = 4096
160+
if request_config:
161+
max_tokens = request_config.get('max_tokens', max_tokens)
162+
158163
best_solution = ""
159164
best_score = -1
160165

@@ -163,8 +168,8 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model
163168

164169
temperature = max(0.2, 0.7 - (round * 0.1))
165170

166-
helpful_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, temperature=temperature, request_id=request_id)
167-
sneaky_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, is_sneaky=True, temperature=temperature, request_id=request_id)
171+
helpful_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, temperature=temperature, max_tokens=max_tokens, request_id=request_id)
172+
sneaky_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, is_sneaky=True, temperature=temperature, max_tokens=max_tokens, request_id=request_id)
168173
all_solutions = helpful_solutions + sneaky_solutions
169174

170175
scores = verify_solutions(client, system_prompt, initial_query, all_solutions, model, request_id=request_id)
@@ -198,7 +203,7 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model
198203
provider_request = {
199204
"model": model,
200205
"messages": messages,
201-
"max_tokens": 1024,
206+
"max_tokens": min(max_tokens, 1024),
202207
"temperature": 0.5,
203208
}
204209
response = client.chat.completions.create(**provider_request)

optillm/reread.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,28 @@
44

55
logger = logging.getLogger(__name__)
66

7-
def re2_approach(system_prompt, initial_query, client, model, n=1, request_id: str = None):
7+
def re2_approach(system_prompt, initial_query, client, model, n=1, request_config: dict = None, request_id: str = None):
88
"""
99
Implement the RE2 (Re-Reading) approach for improved reasoning in LLMs.
10-
10+
1111
Args:
1212
system_prompt (str): The system prompt to be used.
1313
initial_query (str): The initial user query.
1414
client: The OpenAI client object.
1515
model (str): The name of the model to use.
1616
n (int): Number of completions to generate.
17-
17+
request_config (dict): Optional configuration including max_tokens.
18+
1819
Returns:
1920
str or list: The generated response(s) from the model.
2021
"""
2122
logger.info("Using RE2 approach for query processing")
2223
re2_completion_tokens = 0
24+
25+
# Extract max_tokens from request_config if provided
26+
max_tokens = None
27+
if request_config:
28+
max_tokens = request_config.get('max_tokens')
2329

2430
# Construct the RE2 prompt
2531
re2_prompt = f"{initial_query}\nRead the question again: {initial_query}"
@@ -35,6 +41,8 @@ def re2_approach(system_prompt, initial_query, client, model, n=1, request_id: s
3541
"messages": messages,
3642
"n": n
3743
}
44+
if max_tokens is not None:
45+
provider_request["max_tokens"] = max_tokens
3846
response = client.chat.completions.create(**provider_request)
3947

4048
# Log provider call

0 commit comments

Comments
 (0)