1010logger = logging .getLogger (__name__ )
1111
1212class LEAP :
13- def __init__ (self , system_prompt : str , client , model : str , request_id : str = None ):
13+ def __init__ (self , system_prompt : str , client , model : str , request_config : dict = None , request_id : str = None ):
1414 self .system_prompt = system_prompt
1515 self .client = client
1616 self .model = model
@@ -19,6 +19,11 @@ def __init__(self, system_prompt: str, client, model: str, request_id: str = Non
1919 self .high_level_principles = []
2020 self .leap_completion_tokens = 0
2121
22+ # Extract max_tokens from request_config with default
23+ self .max_tokens = 4096
24+ if request_config :
25+ self .max_tokens = request_config .get ('max_tokens' , self .max_tokens )
26+
2227 def extract_output (self , text : str ) -> str :
2328 match = re .search (r'<output>(.*?)(?:</output>|$)' , text , re .DOTALL )
2429 return match .group (1 ).strip () if match else ""
@@ -29,7 +34,7 @@ def extract_examples_from_query(self, initial_query: str) -> List[Tuple[str, str
2934 # Prepare request for logging
3035 provider_request = {
3136 "model" : self .model ,
32- "max_tokens" : 4096 ,
37+ "max_tokens" : self . max_tokens ,
3338 "messages" : [
3439 {"role" : "system" , "content" : self .system_prompt },
3540 {"role" : "user" , "content" : f"""
@@ -83,7 +88,7 @@ def generate_mistakes(self, examples: List[Tuple[str, str]]) -> List[Tuple[str,
8388 # Prepare request for logging
8489 provider_request = {
8590 "model" : self .model ,
86- "max_tokens" : 4096 ,
91+ "max_tokens" : self . max_tokens ,
8792 "messages" : [
8893 {"role" : "system" , "content" : self .system_prompt },
8994 {"role" : "user" , "content" : f"""
@@ -116,7 +121,7 @@ def generate_low_level_principles(self, mistakes: List[Tuple[str, str, str, str]
116121 # Prepare request for logging
117122 provider_request = {
118123 "model" : self .model ,
119- "max_tokens" : 4096 ,
124+ "max_tokens" : self . max_tokens ,
120125 "messages" : [
121126 {"role" : "system" , "content" : self .system_prompt },
122127 {"role" : "user" , "content" : f"""
@@ -152,7 +157,7 @@ def generate_high_level_principles(self) -> List[str]:
152157 # Prepare request for logging
153158 provider_request = {
154159 "model" : self .model ,
155- "max_tokens" : 4096 ,
160+ "max_tokens" : self . max_tokens ,
156161 "messages" : [
157162 {"role" : "system" , "content" : self .system_prompt },
158163 {"role" : "user" , "content" : f"""
@@ -185,7 +190,7 @@ def apply_principles(self, query: str) -> str:
185190 # Prepare request for logging
186191 provider_request = {
187192 "model" : self .model ,
188- "max_tokens" : 4096 ,
193+ "max_tokens" : self . max_tokens ,
189194 "messages" : [
190195 {"role" : "system" , "content" : self .system_prompt },
191196 {"role" : "user" , "content" : f"""
@@ -220,6 +225,6 @@ def solve(self, initial_query: str) -> str:
220225
221226 return self .apply_principles (initial_query )
222227
223- def leap (system_prompt : str , initial_query : str , client , model : str , request_id : str = None ) -> str :
224- leap_solver = LEAP (system_prompt , client , model , request_id )
228+ def leap (system_prompt : str , initial_query : str , client , model : str , request_config : dict = None , request_id : str = None ) -> str :
229+ leap_solver = LEAP (system_prompt , client , model , request_config = request_config , request_id = request_id )
225230 return leap_solver .solve (initial_query ), leap_solver .leap_completion_tokens
0 commit comments