Skip to content

Commit d53daec

Browse files
authored
avoid importing optuna in dspy import (#8258)
1 parent b5d66a6 commit d53daec

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

dspy/teleprompt/mipro_optimizer_v2.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
import sys
55
import textwrap
66
import time
7+
from typing import TYPE_CHECKING
78
from collections import defaultdict
89
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
910

1011
import numpy as np
11-
import optuna
12-
from optuna.distributions import CategoricalDistribution
1312

1413
import dspy
1514
from dspy.evaluate.evaluate import Evaluate
@@ -26,6 +25,9 @@
2625
set_signature,
2726
)
2827

28+
if TYPE_CHECKING:
29+
import optuna
30+
2931
logger = logging.getLogger(__name__)
3032

3133
# Constants
@@ -496,6 +498,7 @@ def _optimize_prompt_parameters(
496498
minibatch_full_eval_steps: int,
497499
seed: int,
498500
) -> Optional[Any]:
501+
import optuna
499502
# Run optimization
500503
optuna.logging.set_verbosity(optuna.logging.WARNING)
501504
logger.info("==> STEP 3: FINDING OPTIMAL PROMPT PARAMETERS <==")
@@ -727,7 +730,7 @@ def _select_and_insert_instructions_and_demos(
727730
candidate_program: Any,
728731
instruction_candidates: Dict[int, List[str]],
729732
demo_candidates: Optional[List],
730-
trial: optuna.trial.Trial,
733+
trial: "optuna.trial.Trial",
731734
trial_logs: Dict,
732735
trial_num: int,
733736
) -> List[str]:
@@ -756,6 +759,7 @@ def _select_and_insert_instructions_and_demos(
756759
return chosen_params, raw_chosen_params
757760

758761
def _get_param_distributions(self, program, instruction_candidates, demo_candidates):
762+
from optuna.distributions import CategoricalDistribution
759763
param_distributions = {}
760764

761765
for i in range(len(instruction_candidates)):
@@ -780,10 +784,11 @@ def _perform_full_evaluation(
780784
score_data,
781785
best_score: float,
782786
best_program: Any,
783-
study: optuna.Study,
787+
study: "optuna.Study",
784788
instruction_candidates: List,
785789
demo_candidates: List,
786790
):
791+
import optuna
787792
logger.info(f"===== Trial {trial_num + 1} / {adjusted_num_trials} - Full Evaluation =====")
788793

789794
# Identify best program to evaluate fully

dspy/teleprompt/teleprompt_optuna.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import optuna
2-
31
from dspy.evaluate.evaluate import Evaluate
42
from dspy.teleprompt.teleprompt import Teleprompter
53

@@ -55,6 +53,7 @@ def objective(self, trial):
5553
return score
5654

5755
def compile(self, student, *, teacher=None, max_demos, trainset, valset=None):
56+
import optuna
5857
self.trainset = trainset
5958
self.valset = valset or trainset
6059
self.student = student.reset_copy()

0 commit comments

Comments
 (0)