diff --git a/tinker_cookbook/renderers.py b/tinker_cookbook/renderers.py index 349c17c..3a11029 100644 --- a/tinker_cookbook/renderers.py +++ b/tinker_cookbook/renderers.py @@ -34,6 +34,7 @@ class Message(TypedDict): content: str tool_calls: NotRequired[list[ToolCall]] thinking: NotRequired[str] + trainable: NotRequired[bool] class TrainOnWhat(StrEnum): @@ -42,6 +43,7 @@ class TrainOnWhat(StrEnum): ALL_MESSAGES = "all_messages" ALL_TOKENS = "all_tokens" ALL_USER_AND_SYSTEM_MESSAGES = "all_user_and_system_messages" + CUSTOMIZED = "customized" class Renderer: @@ -101,6 +103,10 @@ def build_supervised_example( train_on_what: an enum that controls how the weights are assigned to the tokens. - TrainOnWhat.LAST_ASSISTANT_MESSAGE: only the last assistant message is used for training - TrainOnWhat.ALL_ASSISTANT_MESSAGES: all assistant messages are used for training + - TrainOnWhat.ALL_MESSAGES: all messages are used for training + - TrainOnWhat.ALL_TOKENS: all tokens are used for training + - TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: all user and system messages are used for training + - TrainOnWhat.CUSTOMIZED: each message has a trainable field, and the weights are assigned based on the trainable field messages: a list of messages to render. Returns: @@ -110,29 +116,48 @@ def build_supervised_example( """ tokens_weights = [(token, 0) for token in start_tokens] for idx, message in enumerate(messages[:-1]): - ob_part, action_part, action_tail = render_message(idx, message) - if train_on_what == TrainOnWhat.LAST_ASSISTANT_MESSAGE: - tokens_weights.extend([(token, 0) for token in ob_part + action_part]) - elif train_on_what == TrainOnWhat.ALL_ASSISTANT_MESSAGES: - tokens_weights += [(token, 0) for token in ob_part] - # TODO: look at the previous action tail and its overlap with the current action part - # and put weight of 1 on those tokens too. - is_assistant = message["role"] == "assistant" - tokens_weights += [(token, int(is_assistant)) for token in action_part] - elif train_on_what == TrainOnWhat.ALL_MESSAGES: - tokens_weights += [(token, 0) for token in ob_part] - tokens_weights += [(token, 1) for token in action_part] - elif train_on_what == TrainOnWhat.ALL_TOKENS: - tokens_weights += [(token, 1) for token in ob_part + action_part] - elif train_on_what == TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: - tokens_weights += [(token, 0) for token in ob_part] - is_user_or_system = message["role"] in ["user", "system"] - tokens_weights += [(token, int(is_user_or_system)) for token in action_part] + if train_on_what == TrainOnWhat.CUSTOMIZED: + assert "trainable" in message, ( + "When using CUSTOMIZED train_on_what, each message must have a trainable field: True if loss is applied on this message, False otherwise" + ) else: - raise ValueError(f"Unknown train_on_what: {train_on_what}") - ob_part, action_part, action_tail = render_message(len(messages) - 1, messages[-1]) - tokens_weights.extend([(token, 0) for token in ob_part]) - tokens_weights.extend([(token, 1) for token in action_part + action_tail]) + assert "trainable" not in message, ( + "When using non-CUSTOMIZED train_on_what, each message must not have a trainable field. Either change train_on_what to CUSTOMIZED or remove the trainable field from the message" + ) + + is_last_message = idx == len(messages) - 1 + is_assistant = message["role"] == "assistant" + is_user_or_system = message["role"] in ["user", "system"] + + # only apply weight to observation part if train_on_what is ALL_TOKENS + ob_part, action_part, action_tail = render_message(idx, message) + ob_weight = int(train_on_what == TrainOnWhat.ALL_TOKENS) + tokens_weights += [(token, ob_weight) for token in ob_part] + + action_tokens = action_part + # action tail is effectively the stop_token and the start token for the next turn + # e.g. \n\nUser: + if is_last_message: + action_tokens += action_tail + + match train_on_what: + case TrainOnWhat.LAST_ASSISTANT_MESSAGE: + action_has_weight = is_last_message and is_assistant + case TrainOnWhat.ALL_ASSISTANT_MESSAGES: + action_has_weight = is_assistant + case TrainOnWhat.ALL_MESSAGES: + action_has_weight = True + case TrainOnWhat.ALL_TOKENS: + action_has_weight = True + case TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: + action_has_weight = is_user_or_system + case TrainOnWhat.CUSTOMIZED: + action_has_weight = message.get("trainable", False) + case _: + raise ValueError(f"Unknown train_on_what: {train_on_what}") + + tokens_weights += [(token, int(action_has_weight)) for token in action_tokens] + tokens, weights = zip(*tokens_weights, strict=True) return torch.tensor(tokens), torch.tensor(weights)