-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Make Teleprompt.compile more typing friendly.
#9013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| import logging | ||
| import random | ||
| from typing import Callable | ||
| from typing import Callable, TypeVar | ||
|
|
||
| import dspy | ||
| from dspy.primitives.example import Example | ||
|
|
@@ -14,20 +14,22 @@ | |
| ) | ||
| from dspy.teleprompt.random_search import BootstrapFewShotWithRandomSearch | ||
| from dspy.teleprompt.teleprompt import Teleprompter | ||
| from dspy.primitives import Module | ||
|
|
||
| M = TypeVar("M", bound=Module) | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class BetterTogether(Teleprompter): | ||
|
|
||
| STRAT_SEP = " -> " | ||
|
|
||
| def __init__(self, | ||
| def __init__( | ||
| self, | ||
| metric: Callable, | ||
| prompt_optimizer: Teleprompter | None = None, | ||
| weight_optimizer: Teleprompter | None = None, | ||
| seed: int | None = None, | ||
| ): | ||
| ): | ||
| if not dspy.settings.experimental: | ||
| raise ValueError("This is an experimental optimizer. Set `dspy.settings.experimental` to `True` to use it.") | ||
|
|
||
|
|
@@ -37,7 +39,9 @@ def __init__(self, | |
| # a BootstrapFinetune without a metric, say, if there aren't labels | ||
| # available for the training data. Should this be noted somewhere? | ||
| # TODO: We should re-consider if the metric should be required. | ||
| self.prompt_optimizer = prompt_optimizer if prompt_optimizer else BootstrapFewShotWithRandomSearch(metric=metric) | ||
| self.prompt_optimizer = ( | ||
| prompt_optimizer if prompt_optimizer else BootstrapFewShotWithRandomSearch(metric=metric) | ||
| ) | ||
| self.weight_optimizer = weight_optimizer if weight_optimizer else BootstrapFinetune(metric=metric) | ||
|
|
||
| is_supported_prompt = isinstance(self.prompt_optimizer, BootstrapFewShotWithRandomSearch) | ||
|
|
@@ -52,11 +56,11 @@ def __init__(self, | |
|
|
||
| def compile( | ||
| self, | ||
| student: Module, | ||
| student: M, | ||
| trainset: list[Example], | ||
| strategy: str = "p -> w -> p", | ||
| valset_ratio = 0.1, | ||
| ) -> Module: | ||
| valset_ratio=0.1, | ||
| ) -> M: | ||
| # TODO: We could record acc on a different valset to pick the best | ||
| # strategy within the provided strategy | ||
| logger.info("Validating the strategy") | ||
|
|
@@ -91,10 +95,9 @@ def _run_strategies(self, parsed_strategy, student, trainset, valset_ratio) -> M | |
| launched_flag = False | ||
|
|
||
| for ind, step_code in enumerate(parsed_strategy): | ||
| current_strategy = self.STRAT_SEP.join(parsed_strategy[:ind + 1]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not include unrelated changes
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Must be my formatter. Let me see if I can undo the other formatting changes. |
||
| current_strategy = self.STRAT_SEP.join(parsed_strategy[: ind + 1]) | ||
| logger.info( | ||
| f"\n########## Step {ind + 1} of {len(parsed_strategy)} - Strategy " | ||
| f"'{current_strategy}' ##########" | ||
| f"\n########## Step {ind + 1} of {len(parsed_strategy)} - Strategy " f"'{current_strategy}' ##########" | ||
| ) | ||
|
|
||
| logger.info("Shuffling the trainset...") | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,6 +1,6 @@ | ||||||||||||||||||||
| import logging | ||||||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||||
| from typing import Any, Callable | ||||||||||||||||||||
| from typing import Any, Callable, TypeVar | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import dspy | ||||||||||||||||||||
| from dspy.adapters.base import Adapter | ||||||||||||||||||||
|
|
@@ -16,6 +16,8 @@ | |||||||||||||||||||
|
|
||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| M = TypeVar("M", bound=Module) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| class FinetuneTeleprompter(Teleprompter): | ||||||||||||||||||||
| def __init__( | ||||||||||||||||||||
|
|
@@ -57,9 +59,7 @@ def __init__( | |||||||||||||||||||
| self.exclude_demos = exclude_demos | ||||||||||||||||||||
| self.num_threads = num_threads | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def compile( | ||||||||||||||||||||
| self, student: Module, trainset: list[Example], teacher: Module | list[Module] | None = None | ||||||||||||||||||||
| ) -> Module: | ||||||||||||||||||||
| def compile(self, student: M, trainset: list[Example], teacher: Module | list[Module] | None = None) -> M: | ||||||||||||||||||||
| # TODO: Print statements can be converted to logger.info if we ensure | ||||||||||||||||||||
| # that the default DSPy logger logs info level messages in notebook | ||||||||||||||||||||
| # environments. | ||||||||||||||||||||
|
Comment on lines
+62
to
65
|
||||||||||||||||||||
| def compile(self, student: M, trainset: list[Example], teacher: Module | list[Module] | None = None) -> M: | |
| # TODO: Print statements can be converted to logger.info if we ensure | |
| # that the default DSPy logger logs info level messages in notebook | |
| # environments. | |
| def compile(self, student: M, trainset: list[Example], **kwargs) -> M: | |
| # TODO: Print statements can be converted to logger.info if we ensure | |
| # that the default DSPy logger logs info level messages in notebook | |
| # environments. | |
| teacher = kwargs.get('teacher', None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method requires at least 3 positional arguments, whereas overridden Teleprompter.compile requires 2.