Skip to content

Commit 9978130

Browse files
Merge pull request #26 from semiotic-ai/ip
ip
2 parents d303bac + e1979ab commit 9978130

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

graphdoc/graphdoc/eval/doc_generator_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def forward(self, database_schema: str) -> dict[str, Any]:
6161
trace=True,
6262
client=self.mlflow_helper.mlflow_client,
6363
expirement_name=self.mlflow_experiment_name,
64-
api_key="temp",
64+
logging_id="temp",
6565
)
6666
# TODO: let's decide if this is how we want to handle this in the future.
6767
# Alternatively, we could return the documented schema from forward,

graphdoc/graphdoc/modules/doc_generator_module.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,21 @@ def document_full_schema(
279279
trace: bool = False,
280280
client: Optional[mlflow.MlflowClient] = None,
281281
expirement_name: Optional[str] = None,
282-
api_key: Optional[str] = None,
282+
logging_id: Optional[str] = None,
283283
) -> dspy.Prediction:
284284
"""Given a database schema, parse out the underlying components and document on
285285
a per-component basis.
286286
287287
:param database_schema: The database schema to generate documentation for.
288288
:type database_schema: str
289+
:param trace: Whether to trace the generation.
290+
:type trace: bool
291+
:param client: The mlflow client.
292+
:type client: mlflow.MlflowClient
293+
:param expirement_name: The name of the experiment.
294+
:type expirement_name: str
295+
:param logging_id: The id to use for logging. Maps back to the user request.
296+
:type logging_id: str
289297
:return: The generated documentation.
290298
:rtype: dspy.Prediction
291299
@@ -296,8 +304,8 @@ def document_full_schema(
296304
raise ValueError("client must be provided if trace is True")
297305
if expirement_name is None:
298306
raise ValueError("expirement_name must be provided if trace is True")
299-
if api_key is None:
300-
raise ValueError("api_key must be provided if trace is True")
307+
if logging_id is None:
308+
raise ValueError("logging_id must be provided if trace is True")
301309

302310
# check that the graphql is valid
303311
try:
@@ -323,7 +331,7 @@ def document_full_schema(
323331
# TODO: we should have better type handling, but we check at the top
324332
trace_name="document_full_schema",
325333
inputs={"database_schema": database_schema},
326-
attributes={"api_key": api_key},
334+
attributes={"logging_id": logging_id},
327335
)
328336
log.info("created trace: " + str(root_trace))
329337

0 commit comments

Comments
 (0)