From c0ca34ed26401ef30d688840e9a0aa500e54d57b Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Tue, 11 Nov 2025 04:23:01 +0000 Subject: [PATCH 1/4] adding a renderer for human simulator --- tinker_cookbook/renderers.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tinker_cookbook/renderers.py b/tinker_cookbook/renderers.py index 349c17c..45b6ecb 100644 --- a/tinker_cookbook/renderers.py +++ b/tinker_cookbook/renderers.py @@ -42,6 +42,7 @@ class TrainOnWhat(StrEnum): ALL_MESSAGES = "all_messages" ALL_TOKENS = "all_tokens" ALL_USER_AND_SYSTEM_MESSAGES = "all_user_and_system_messages" + ALL_BUT_FIRST_USER_AND_SYSTEM_MESSAGES = "all_but_first_user_and_system_messages" class Renderer: @@ -109,7 +110,10 @@ def build_supervised_example( - weights: a tensor of weights """ tokens_weights = [(token, 0) for token in start_tokens] + first_user_turn_ended = False for idx, message in enumerate(messages[:-1]): + if message["role"] == "assistant": + first_user_turn_ended = True ob_part, action_part, action_tail = render_message(idx, message) if train_on_what == TrainOnWhat.LAST_ASSISTANT_MESSAGE: tokens_weights.extend([(token, 0) for token in ob_part + action_part]) @@ -128,6 +132,10 @@ def build_supervised_example( tokens_weights += [(token, 0) for token in ob_part] is_user_or_system = message["role"] in ["user", "system"] tokens_weights += [(token, int(is_user_or_system)) for token in action_part] + elif train_on_what == TrainOnWhat.ALL_BUT_FIRST_USER_AND_SYSTEM_MESSAGES: + tokens_weights += [(token, 0) for token in ob_part] + action_weights = int((message["role"] in ["user", "system"]) and first_user_turn_ended) + tokens_weights += [(token, action_weights) for token in action_part] else: raise ValueError(f"Unknown train_on_what: {train_on_what}") ob_part, action_part, action_tail = render_message(len(messages) - 1, messages[-1]) From c274d2ae84c91451bae00b774cdb487856a147a0 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Wed, 12 Nov 2025 18:04:26 +0000 Subject: [PATCH 2/4] more general customization on what to train on --- tinker_cookbook/renderers.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tinker_cookbook/renderers.py b/tinker_cookbook/renderers.py index 45b6ecb..528f273 100644 --- a/tinker_cookbook/renderers.py +++ b/tinker_cookbook/renderers.py @@ -34,6 +34,7 @@ class Message(TypedDict): content: str tool_calls: NotRequired[list[ToolCall]] thinking: NotRequired[str] + trainable: NotRequired[bool] class TrainOnWhat(StrEnum): @@ -42,7 +43,7 @@ class TrainOnWhat(StrEnum): ALL_MESSAGES = "all_messages" ALL_TOKENS = "all_tokens" ALL_USER_AND_SYSTEM_MESSAGES = "all_user_and_system_messages" - ALL_BUT_FIRST_USER_AND_SYSTEM_MESSAGES = "all_but_first_user_and_system_messages" + CUSTOMIZED = "customized" class Renderer: @@ -102,6 +103,10 @@ def build_supervised_example( train_on_what: an enum that controls how the weights are assigned to the tokens. - TrainOnWhat.LAST_ASSISTANT_MESSAGE: only the last assistant message is used for training - TrainOnWhat.ALL_ASSISTANT_MESSAGES: all assistant messages are used for training + - TrainOnWhat.ALL_MESSAGES: all messages are used for training + - TrainOnWhat.ALL_TOKENS: all tokens are used for training + - TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: all user and system messages are used for training + - TrainOnWhat.CUSTOMIZED: each message has a trainable field, and the weights are assigned based on the trainable field messages: a list of messages to render. Returns: @@ -110,11 +115,14 @@ def build_supervised_example( - weights: a tensor of weights """ tokens_weights = [(token, 0) for token in start_tokens] - first_user_turn_ended = False for idx, message in enumerate(messages[:-1]): - if message["role"] == "assistant": - first_user_turn_ended = True ob_part, action_part, action_tail = render_message(idx, message) + + if train_on_what == TrainOnWhat.CUSTOMIZED: + assert "trainable" in message, "When using CUSTOMIZED train_on_what, each message must have a trainable field: True if loss is applied on this message, False otherwise" + else: + assert "trainable" not in message, "When using non-CUSTOMIZED train_on_what, each message must not have a trainable field. Either change train_on_what to CUSTOMIZED or remove the trainable field from the message" + if train_on_what == TrainOnWhat.LAST_ASSISTANT_MESSAGE: tokens_weights.extend([(token, 0) for token in ob_part + action_part]) elif train_on_what == TrainOnWhat.ALL_ASSISTANT_MESSAGES: @@ -132,10 +140,10 @@ def build_supervised_example( tokens_weights += [(token, 0) for token in ob_part] is_user_or_system = message["role"] in ["user", "system"] tokens_weights += [(token, int(is_user_or_system)) for token in action_part] - elif train_on_what == TrainOnWhat.ALL_BUT_FIRST_USER_AND_SYSTEM_MESSAGES: + elif train_on_what == TrainOnWhat.CUSTOMIZED: + message_weight = int(message["trainable"]) tokens_weights += [(token, 0) for token in ob_part] - action_weights = int((message["role"] in ["user", "system"]) and first_user_turn_ended) - tokens_weights += [(token, action_weights) for token in action_part] + tokens_weights += [(token, message_weight) for token in action_part] else: raise ValueError(f"Unknown train_on_what: {train_on_what}") ob_part, action_part, action_tail = render_message(len(messages) - 1, messages[-1]) From 3ba9f23b4ab897ffd646f5145d3897fc9c7cd82f Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Wed, 12 Nov 2025 18:11:05 +0000 Subject: [PATCH 3/4] n --- tinker_cookbook/renderers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tinker_cookbook/renderers.py b/tinker_cookbook/renderers.py index 528f273..f0854ca 100644 --- a/tinker_cookbook/renderers.py +++ b/tinker_cookbook/renderers.py @@ -119,9 +119,13 @@ def build_supervised_example( ob_part, action_part, action_tail = render_message(idx, message) if train_on_what == TrainOnWhat.CUSTOMIZED: - assert "trainable" in message, "When using CUSTOMIZED train_on_what, each message must have a trainable field: True if loss is applied on this message, False otherwise" + assert "trainable" in message, ( + "When using CUSTOMIZED train_on_what, each message must have a trainable field: True if loss is applied on this message, False otherwise" + ) else: - assert "trainable" not in message, "When using non-CUSTOMIZED train_on_what, each message must not have a trainable field. Either change train_on_what to CUSTOMIZED or remove the trainable field from the message" + assert "trainable" not in message, ( + "When using non-CUSTOMIZED train_on_what, each message must not have a trainable field. Either change train_on_what to CUSTOMIZED or remove the trainable field from the message" + ) if train_on_what == TrainOnWhat.LAST_ASSISTANT_MESSAGE: tokens_weights.extend([(token, 0) for token in ob_part + action_part]) From 408236c840e64b736dae87116dca3e3939947eed Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Wed, 12 Nov 2025 19:02:50 +0000 Subject: [PATCH 4/4] b --- tinker_cookbook/renderers.py | 61 +++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/tinker_cookbook/renderers.py b/tinker_cookbook/renderers.py index f0854ca..3a11029 100644 --- a/tinker_cookbook/renderers.py +++ b/tinker_cookbook/renderers.py @@ -116,8 +116,6 @@ def build_supervised_example( """ tokens_weights = [(token, 0) for token in start_tokens] for idx, message in enumerate(messages[:-1]): - ob_part, action_part, action_tail = render_message(idx, message) - if train_on_what == TrainOnWhat.CUSTOMIZED: assert "trainable" in message, ( "When using CUSTOMIZED train_on_what, each message must have a trainable field: True if loss is applied on this message, False otherwise" @@ -127,32 +125,39 @@ def build_supervised_example( "When using non-CUSTOMIZED train_on_what, each message must not have a trainable field. Either change train_on_what to CUSTOMIZED or remove the trainable field from the message" ) - if train_on_what == TrainOnWhat.LAST_ASSISTANT_MESSAGE: - tokens_weights.extend([(token, 0) for token in ob_part + action_part]) - elif train_on_what == TrainOnWhat.ALL_ASSISTANT_MESSAGES: - tokens_weights += [(token, 0) for token in ob_part] - # TODO: look at the previous action tail and its overlap with the current action part - # and put weight of 1 on those tokens too. - is_assistant = message["role"] == "assistant" - tokens_weights += [(token, int(is_assistant)) for token in action_part] - elif train_on_what == TrainOnWhat.ALL_MESSAGES: - tokens_weights += [(token, 0) for token in ob_part] - tokens_weights += [(token, 1) for token in action_part] - elif train_on_what == TrainOnWhat.ALL_TOKENS: - tokens_weights += [(token, 1) for token in ob_part + action_part] - elif train_on_what == TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: - tokens_weights += [(token, 0) for token in ob_part] - is_user_or_system = message["role"] in ["user", "system"] - tokens_weights += [(token, int(is_user_or_system)) for token in action_part] - elif train_on_what == TrainOnWhat.CUSTOMIZED: - message_weight = int(message["trainable"]) - tokens_weights += [(token, 0) for token in ob_part] - tokens_weights += [(token, message_weight) for token in action_part] - else: - raise ValueError(f"Unknown train_on_what: {train_on_what}") - ob_part, action_part, action_tail = render_message(len(messages) - 1, messages[-1]) - tokens_weights.extend([(token, 0) for token in ob_part]) - tokens_weights.extend([(token, 1) for token in action_part + action_tail]) + is_last_message = idx == len(messages) - 1 + is_assistant = message["role"] == "assistant" + is_user_or_system = message["role"] in ["user", "system"] + + # only apply weight to observation part if train_on_what is ALL_TOKENS + ob_part, action_part, action_tail = render_message(idx, message) + ob_weight = int(train_on_what == TrainOnWhat.ALL_TOKENS) + tokens_weights += [(token, ob_weight) for token in ob_part] + + action_tokens = action_part + # action tail is effectively the stop_token and the start token for the next turn + # e.g. \n\nUser: + if is_last_message: + action_tokens += action_tail + + match train_on_what: + case TrainOnWhat.LAST_ASSISTANT_MESSAGE: + action_has_weight = is_last_message and is_assistant + case TrainOnWhat.ALL_ASSISTANT_MESSAGES: + action_has_weight = is_assistant + case TrainOnWhat.ALL_MESSAGES: + action_has_weight = True + case TrainOnWhat.ALL_TOKENS: + action_has_weight = True + case TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: + action_has_weight = is_user_or_system + case TrainOnWhat.CUSTOMIZED: + action_has_weight = message.get("trainable", False) + case _: + raise ValueError(f"Unknown train_on_what: {train_on_what}") + + tokens_weights += [(token, int(action_has_weight)) for token in action_tokens] + tokens, weights = zip(*tokens_weights, strict=True) return torch.tensor(tokens), torch.tensor(weights)