From 2b54a45737ee1cdeb2119fd73f2cb383d798d35e Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Fri, 12 Jul 2024 14:09:07 -0500 Subject: [PATCH 01/13] added project and ontology creation for prompt response projects --- libs/labelbox/src/labelbox/__init__.py | 2 + libs/labelbox/src/labelbox/client.py | 119 ++++++++-- libs/labelbox/src/labelbox/schema/ontology.py | 219 ++++++++++++++++-- .../src/labelbox/schema/ontology_kind.py | 28 ++- libs/labelbox/src/labelbox/schema/project.py | 11 +- libs/labelbox/tests/integration/conftest.py | 134 ++++++++++- ...test_prompt_response_generation_project.py | 164 +++++++++++++ .../test_response_creation_project.py | 24 ++ 8 files changed, 670 insertions(+), 31 deletions(-) create mode 100644 libs/labelbox/tests/integration/test_prompt_response_generation_project.py create mode 100644 libs/labelbox/tests/integration/test_response_creation_project.py diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py index c2e26f64b..bcd1f9fee 100644 --- a/libs/labelbox/src/labelbox/__init__.py +++ b/libs/labelbox/src/labelbox/__init__.py @@ -23,6 +23,8 @@ from labelbox.schema.asset_attachment import AssetAttachment from labelbox.schema.webhook import Webhook from labelbox.schema.ontology import Ontology, OntologyBuilder, Classification, Option, Tool, FeatureSchema +from labelbox.schema.ontology import PromptResponseClassification +from labelbox.schema.ontology import ResponseOption from labelbox.schema.role import Role, ProjectRole from labelbox.schema.invite import Invite, InviteLimit from labelbox.schema.data_row_metadata import DataRowMetadataOntology, DataRowMetadataField, DataRowMetadata, DeleteDataRowMetadata diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 1764ee034..afcfcbada 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -41,7 +41,7 @@ from labelbox.schema.model_config import ModelConfig from labelbox.schema.model_run import ModelRun from labelbox.schema.ontology import Ontology, DeleteFeatureFromOntologyResult -from labelbox.schema.ontology import Tool, Classification, FeatureSchema +from labelbox.schema.ontology import Tool, Classification, FeatureSchema, PromptResponseClassification from labelbox.schema.organization import Organization from labelbox.schema.project import Project from labelbox.schema.quality_mode import QualityMode, BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS, \ @@ -874,6 +874,99 @@ def create_offline_model_evaluation_project(self, **kwargs) -> Project: kwargs.pop("data_row_count", None) return self._create_project(**kwargs) + + @overload + def create_prompt_response_generation_project(self, + dataset_name: str, + dataset_id: str = None, + data_row_count: int = 100, + **kwargs) -> Project: + pass + + @overload + def create_prompt_response_generation_project(self, + dataset_id: str, + dataset_name: str = None, + data_row_count: int = 100, + **kwargs) -> Project: + pass + + def create_prompt_response_generation_project(self, + dataset_id: Optional[str] = None, + dataset_name: Optional[str] = None, + data_row_count: int = 100, + **kwargs) -> Project: + """ + Use this method exclusively to create a prompt and response generation project. + + Args: + dataset_name: When creating a new dataset, pass the name + dataset_id: When using an existing dataset, pass the id + data_row_count: The number of data row assets to use for the project + **kwargs: Additional parameters to pass to the the create_project method + Returns: + Project: The created project + + Examples: + >>> client.create_prompt_response_generation_project(name=project_name, dataset_name="new data set", project_kind=MediaType.LLMPromptResponseCreation) + >>> This creates a new dataset with a default number of rows (100), creates new project and assigns a batch of the newly created datarows to the project. + + >>> client.create_prompt_response_generation_project(name=project_name, dataset_name="new data set", data_row_count=10, project_kind=MediaType.LLMPromptCreation) + >>> This creates a new dataset with 10 data rows, creates new project and assigns a batch of the newly created datarows to the project. + + >>> client.create_prompt_response_generation_project(name=project_name, dataset_id="clr00u8j0j0j0", project_kind=MediaType.LLMPromptCreation) + >>> This creates a new project, and adds 100 datarows to the dataset with id "clr00u8j0j0j0" and assigns a batch of the newly created data rows to the project. + + >>> client.create_prompt_response_generation_project(name=project_name, dataset_id="clr00u8j0j0j0", data_row_count=10, project_kind=MediaType.LLMPromptCreation) + >>> This creates a new project, and adds 100 datarows to the dataset with id "clr00u8j0j0j0" and assigns a batch of the newly created 10 data rows to the project. + + """ + if not dataset_id and not dataset_name: + raise ValueError( + "dataset_name or data_set_id must be present and not be an empty string." + ) + if data_row_count <= 0: + raise ValueError("data_row_count must be a positive integer.") + + if dataset_id: + append_to_existing_dataset = True + dataset_name_or_id = dataset_id + else: + append_to_existing_dataset = False + dataset_name_or_id = dataset_name + + if "media_type" in kwargs and kwargs.get("media_type") not in [MediaType.LLMPromptCreation, MediaType.LLMPromptResponseCreation]: + raise ValueError( + "media_type must be either LLMPromptCreation or LLMPromptResponseCreation" + ) + + kwargs["dataset_name_or_id"] = dataset_name_or_id + kwargs["append_to_existing_dataset"] = append_to_existing_dataset + kwargs["data_row_count"] = data_row_count + + kwargs.pop("editor_task_type", None) + + return self._create_project(**kwargs) + + def create_response_creation_project(self, **kwargs) -> Project: + """ + Creates a project for response creation. + Args: + **kwargs: Additional parameters to pass see the create_project method + Returns: + Project: The created project + """ + kwargs[ + "media_type"] = MediaType.Text # Only Text is supported + kwargs[ + "editor_task_type"] = EditorTaskType.ResponseCreation.value # Special editor task type for offline model evaluation + + # The following arguments are not supported for offline model evaluation + kwargs.pop("dataset_name_or_id", None) + kwargs.pop("append_to_existing_dataset", None) + kwargs.pop("data_row_count", None) + + return self._create_project(**kwargs) def _create_project(self, **kwargs) -> Project: auto_audit_percentage = kwargs.get("auto_audit_percentage") @@ -1189,11 +1282,13 @@ def create_ontology_from_feature_schemas( name (str): Name of the ontology feature_schema_ids (List[str]): List of feature schema ids corresponding to top level tools and classifications to include in the ontology - media_type (MediaType or None): Media type of a new ontology. NOTE for chat evaluation, we currently foce media_type to Conversational + media_type (MediaType or None): Media type of a new ontology. ontology_kind (OntologyKind or None): set to OntologyKind.ModelEvaluation if the ontology is for chat evaluation, leave as None otherwise. Returns: The created Ontology + + NOTE for chat evaluation, we currently force media_type to Conversational and for response creation, we force media_type to Text. """ tools, classifications = [], [] for feature_schema_id in feature_schema_ids: @@ -1209,10 +1304,13 @@ def create_ontology_from_feature_schemas( f"Tool `{tool}` not in list of supported tools.") elif 'type' in feature_schema.normalized: classification = feature_schema.normalized['type'] - try: + if classification in Classification.Type._value2member_map_.keys(): Classification.Type(classification) classifications.append(feature_schema.normalized) - except ValueError: + elif classification in PromptResponseClassification.Type._value2member_map_.keys(): + PromptResponseClassification.Type(classification) + classifications.append(feature_schema.normalized) + else: raise ValueError( f"Classification `{classification}` not in list of supported classifications." ) @@ -1222,15 +1320,7 @@ def create_ontology_from_feature_schemas( ) normalized = {'tools': tools, 'classifications': classifications} - if ontology_kind and ontology_kind is OntologyKind.ModelEvaluation: - if media_type is None: - media_type = MediaType.Conversational - else: - if media_type is not MediaType.Conversational: - raise ValueError( - "For chat evaluation, media_type must be Conversational." - ) - + # validation for ontology_kind and media_type is done within self.create_ontology return self.create_ontology(name=name, normalized=normalized, media_type=media_type, @@ -1424,7 +1514,7 @@ def create_ontology(self, Returns: The created Ontology - NOTE caller of this method is expected to set media_type to Conversational if ontology_kind is ModelEvaluation + NOTE for chat evaluation, we currently force media_type to Conversational and for response creation, we force media_type to Text. """ media_type_value = None @@ -1435,6 +1525,7 @@ def create_ontology(self, raise get_media_type_validation_error(media_type) if ontology_kind and OntologyKind.is_supported(ontology_kind): + media_type = OntologyKind.evaluate_ontology_kind_with_media_type(ontology_kind, media_type) editor_task_type_value = EditorTaskTypeMapper.to_editor_task_type( ontology_kind, media_type).value elif ontology_kind: diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index f6c758faa..af3e0b6a3 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -221,6 +221,183 @@ def add_option(self, option: Option) -> None: f"Duplicate option '{option.value}' " f"for classification '{self.name}'.") self.options.append(option) + +@dataclass +class ResponseOption: + """ + An option is a possible answer within a PromptResponseClassification response object in + a Project's ontology. + + To instantiate, only the "value" parameter needs to be passed in. + + Example(s): + option = ResponseOption(value = "Response Option Example") + + Attributes: + value: (str) + schema_id: (str) + feature_schema_id: (str) + options: (list) + """ + value: Union[str, int] + label: Optional[Union[str, int]] = None + schema_id: Optional[str] = None + feature_schema_id: Optional[FeatureSchemaId] = None + options: List["Classification"] = field(default_factory=list) + + def __post_init__(self): + if self.label is None: + self.label = self.value + + @classmethod + def from_dict( + cls, + dictionary: Dict[str, + Any]) -> Dict[Union[str, int], Union[str, int]]: + return cls(value=dictionary["value"], + label=dictionary["label"], + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + options=[ + PromptResponseClassification.from_dict(o) + for o in dictionary.get("options", []) + ]) + + def asdict(self) -> Dict[str, Any]: + return { + "schemaNodeId": self.schema_id, + "featureSchemaId": self.feature_schema_id, + "label": self.label, + "value": self.value, + "options": [o.asdict(is_subclass=True) for o in self.options] + } + + def add_option(self, option: 'Classification') -> None: + if option.name in (o.name for o in self.options): + raise InconsistentOntologyException( + f"Duplicate nested classification '{option.name}' " + f"for option '{self.label}'") + self.options.append(option) + + +@dataclass +class PromptResponseClassification: + """ + + A PromptResponseClassification to be added to a Project's ontology. The + classification is dependent on the PromptResponseClassification Type. + + To instantiate, the "class_type" and "name" parameters must + be passed in. + + The "options" parameter holds a list of Response Option objects. This is not + necessary for some Classification types, such as RESPONSE_TEXT or PROMPT. To see which + types require options, look at the "_REQUIRES_OPTIONS" class variable. + + Example(s): + >>> classification = PromptResponseClassification( + >>> class_type = PromptResponseClassification.Type.Prompt, + >>> character_min = 1, + >>> character_max = 1 + >>> name = "Prompt Classification Example") + + >>> classification_two = PromptResponseClassification( + >>> class_type = PromptResponseClassification.Type.RESPONSE_RADIO, + >>> name = "Second Example") + + >>> classification_two.add_option(ResponseOption( + >>> value = "Option Example")) + + Attributes: + class_type: (Classification.Type) + name: (str) + instructions: (str) + required: (bool) + options: (list) + character_min: (int) + character_max: (int) + ui_mode: (Classification.UIMode) + schema_id: (str) + feature_schema_id: (str) + scope: (Classification.Scope) + """ + + def __post_init__(self): + if self.name is None: + msg = ( + "When creating the Classification feature, please use “name” " + "for the classification schema name, which will be used when " + "creating annotation payload for Model-Assisted Labeling " + "Import and Label Import. “instructions” is no longer " + "supported to specify classification schema name.") + if self.instructions is not None: + self.name = self.instructions + warnings.warn(msg) + else: + raise ValueError(msg) + else: + if self.instructions is None: + self.instructions = self.name + + class Type(Enum): + PROMPT = "prompt" + RESPONSE_TEXT= "response-text" + RESPONSE_CHECKLIST = "response-checklist" + RESPONSE_RADIO = "response-radio" + + _REQUIRES_OPTIONS = {Type.RESPONSE_CHECKLIST, Type.RESPONSE_RADIO} + + class_type: Type + name: Optional[str] = None + instructions: Optional[str] = None + required: bool = True + options: List[ResponseOption] = field(default_factory=list) + character_min: Optional[int] = None + character_max: Optional[int] = None + schema_id: Optional[str] = None + feature_schema_id: Optional[str] = None + + @classmethod + def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: + return cls(class_type=cls.Type(dictionary["type"]), + name=dictionary["name"], + instructions=dictionary["instructions"], + required=True, # always required + options=[ResponseOption.from_dict(o) for o in dictionary["options"]], + character_min=dictionary.get("minCharacters", None), + character_max=dictionary.get("maxCharacters", None), + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None)) + + def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: + if self.class_type in self._REQUIRES_OPTIONS \ + and len(self.options) < 1: + raise InconsistentOntologyException( + f"Response Classification '{self.name}' requires options.") + classification = { + "type": self.class_type.value, + "instructions": self.instructions, + "name": self.name, + "required": True, # always required + "options": [o.asdict() for o in self.options], + "schemaNodeId": self.schema_id, + "featureSchemaId": self.feature_schema_id + } + if (self.class_type == self.Type.PROMPT or self.class_type == self.Type.RESPONSE_TEXT): + if self.character_min: + classification["minCharacters"] = self.character_min + if self.character_max: + classification["maxCharacters"] = self.character_max + if is_subclass: + return classification + return classification + + def add_option(self, option: ResponseOption) -> None: + if option.value in (o.value for o in self.options): + raise InconsistentOntologyException( + f"Duplicate option '{option.value}' " + f"for response classification '{self.name}'.") + self.options.append(option) @dataclass @@ -338,7 +515,7 @@ class Ontology(DbObject): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._tools: Optional[List[Tool]] = None - self._classifications: Optional[List[Classification]] = None + self._classifications: Optional[Union[List[Classification],List[PromptResponseClassification]]] = None def tools(self) -> List[Tool]: """Get list of tools (AKA objects) in an Ontology.""" @@ -348,13 +525,15 @@ def tools(self) -> List[Tool]: ] return self._tools - def classifications(self) -> List[Classification]: + def classifications(self) -> List[Union[Classification, PromptResponseClassification]]: """Get list of classifications in an Ontology.""" if self._classifications is None: - self._classifications = [ - Classification.from_dict(classification) - for classification in self.normalized['classifications'] - ] + self._classifications = [] + for classification in self.normalized["classifications"]: + if "type" in classification and classification["type"] in PromptResponseClassification.Type._value2member_map_.keys(): + self._classifications.append(PromptResponseClassification.from_dict(classification)) + else: + self._classifications.append(Classification.from_dict(classification)) return self._classifications @@ -384,21 +563,35 @@ class OntologyBuilder: """ tools: List[Tool] = field(default_factory=list) - classifications: List[Classification] = field(default_factory=list) + classifications: List[Union[Classification, PromptResponseClassification]] = field(default_factory=list) @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: + classifications = [] + for c in dictionary["classifications"]: + if ["type"] in c and c["type"] in PromptResponseClassification.Type: + classifications.append(PromptResponseClassification.from_dict(c)) + else: + classifications.append(Classification.from_dict(c)) return cls(tools=[Tool.from_dict(t) for t in dictionary["tools"]], - classifications=[ - Classification.from_dict(c) - for c in dictionary["classifications"] - ]) + classifications=classifications) def asdict(self) -> Dict[str, Any]: self._update_colors() + classifications = [] + prompts = 0 + for c in self.classifications: + if hasattr(c, "class_type") and c.class_type in PromptResponseClassification.Type: + if c.class_type == PromptResponseClassification.Type.PROMPT: + prompts += 1 + if prompts > 1: + raise ValueError("Only one prompt is allowed per ontology") + classifications.append(PromptResponseClassification.asdict(c)) + else: + classifications.append(Classification.asdict(c)) return { "tools": [t.asdict() for t in self.tools], - "classifications": [c.asdict() for c in self.classifications] + "classifications": classifications } def _update_colors(self): @@ -426,7 +619,7 @@ def add_tool(self, tool: Tool) -> None: f"Duplicate tool name '{tool.name}'. ") self.tools.append(tool) - def add_classification(self, classification: Classification) -> None: + def add_classification(self, classification: Union[Classification, PromptResponseClassification]) -> None: if classification.name in (c.name for c in self.classifications): raise InconsistentOntologyException( f"Duplicate classification name '{classification.name}'. ") diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py index d31feda12..b0809d093 100644 --- a/libs/labelbox/src/labelbox/schema/ontology_kind.py +++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py @@ -7,9 +7,9 @@ class OntologyKind(Enum): """ OntologyKind is an enum that represents the different types of ontologies - At the moment it is only limited to ModelEvaluation """ ModelEvaluation = "MODEL_EVALUATION" + ResponseCreation = "RESPONSE_CREATION" Missing = None @classmethod @@ -21,6 +21,30 @@ def get_ontology_kind_validation_error(cls, ontology_kind): return TypeError(f"{ontology_kind}: is not a valid ontology kind. Use" f" any of {OntologyKind.__members__.items()}" " from OntologyKind.") + + @staticmethod + def evaluate_ontology_kind_with_media_type(ontology_kind, + media_type: MediaType = None) -> MediaType: + + if ontology_kind and ontology_kind is OntologyKind.ModelEvaluation: + if media_type is None: + media_type = MediaType.Conversational + else: + if media_type is not MediaType.Conversational: + raise ValueError( + "For chat evaluation, media_type must be Conversational." + ) + + elif ontology_kind == OntologyKind.ResponseCreation: + if media_type is None: + media_type = MediaType.Text + else: + if media_type is not MediaType.Text: + raise ValueError( + "For response creation, media_type must be Text." + ) + + return media_type class EditorTaskType(Enum): @@ -69,6 +93,8 @@ def map_to_editor_task_type(onotology_kind: OntologyKind, media_type: MediaType) -> EditorTaskType: if onotology_kind == OntologyKind.ModelEvaluation and media_type == MediaType.Conversational: return EditorTaskType.ModelChatEvaluation + elif onotology_kind == OntologyKind.ResponseCreation and media_type == MediaType.Text: + return EditorTaskType.ResponseCreation else: return EditorTaskType.Missing diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index ee3a29c41..c9215a61f 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -149,6 +149,13 @@ def is_chat_evaluation(self) -> bool: True if this project is a live chat evaluation project, False otherwise """ return self.media_type == MediaType.Conversational and self.editor_task_type == EditorTaskType.ModelChatEvaluation + + def is_prompt_response(self) -> bool: + """ + Returns: + True if this project is a prompt response project, False otherwise + """ + return self.media_type in [MediaType.LLMPromptResponseCreation, MediaType.LLMPromptCreation] or EditorTaskType.ResponseCreation def is_auto_data_generation(self) -> bool: return (self.upload_type == UploadType.Auto) # type: ignore @@ -829,9 +836,9 @@ def setup(self, labeling_frontend, labeling_frontend_options) -> None: "labeling_frontend parameter will not be used to create a new labeling frontend." ) - if self.is_chat_evaluation(): + if self.is_chat_evaluation() or self.is_prompt_response(): warnings.warn(""" - This project is a live chat evaluation project. + This project is a live chat evaluation project or prompt and response generation project. Editor was setup automatically. """) return diff --git a/libs/labelbox/tests/integration/conftest.py b/libs/labelbox/tests/integration/conftest.py index b639b20df..6333be300 100644 --- a/libs/labelbox/tests/integration/conftest.py +++ b/libs/labelbox/tests/integration/conftest.py @@ -17,7 +17,7 @@ from labelbox import Dataset, DataRow from labelbox import LabelingFrontend -from labelbox import OntologyBuilder, Tool, Option, Classification, MediaType +from labelbox import OntologyBuilder, Tool, Option, Classification, MediaType, PromptResponseClassification, ResponseOption from labelbox.orm import query from labelbox.pagination import PaginatedCollection from labelbox.schema.annotation_import import LabelImport @@ -342,6 +342,129 @@ def _upload_invalid_data_rows_for_dataset(dataset: Dataset): return _upload_invalid_data_rows_for_dataset +@pytest.fixture +def prompt_response_generation_project_with_new_dataset(client: Client, rand_gen, request): + """fixture is parametrize and needs project_type in request""" + media_type = request.param + prompt_response_project = client.create_prompt_response_generation_project(name=f"{media_type.value}-{rand_gen(str)}", + dataset_name=f"{media_type.value}-{rand_gen(str)}", + data_row_count=1, + media_type=media_type) + + yield prompt_response_project + + prompt_response_project.delete() + +@pytest.fixture +def prompt_response_generation_project_with_dataset_id(client: Client, dataset, rand_gen, request): + """fixture is parametrized and needs project_type in request""" + media_type = request.param + prompt_response_project = client.create_prompt_response_generation_project(name=f"{media_type.value}-{rand_gen(str)}", + dataset_id=dataset.uid, + data_row_count=1, + media_type=media_type) + + yield prompt_response_project + + prompt_response_project.delete() + +@pytest.fixture +def response_creation_project(client: Client, rand_gen): + project_name = f"response-creation-project-{rand_gen(str)}" + project = client.create_response_creation_project(name=project_name) + + yield project + + project.delete() + +@pytest.fixture +def prompt_response_features(rand_gen): + + prompt_text = PromptResponseClassification(class_type=PromptResponseClassification.Type.PROMPT, + name=f"{rand_gen(str)}-prompt text") + + response_radio = PromptResponseClassification(class_type=PromptResponseClassification.Type.RESPONSE_RADIO, + name=f"{rand_gen(str)}-response radio classification", + options=[ + ResponseOption(value=f"{rand_gen(str)}-first radio option answer"), + ResponseOption(value=f"{rand_gen(str)}-second radio option answer"), + ]) + + response_checklist = PromptResponseClassification(class_type=PromptResponseClassification.Type.RESPONSE_CHECKLIST, + name=f"{rand_gen(str)}-response checklist classification", + options=[ + ResponseOption(value=f"{rand_gen(str)}-first checklist option answer"), + ResponseOption(value=f"{rand_gen(str)}-second checklist option answer"), + ]) + + response_text_with_char = PromptResponseClassification(class_type=PromptResponseClassification.Type.RESPONSE_TEXT, + name=f"{rand_gen(str)}-response text with character min and max", + character_min = 1, + character_max = 10) + + response_text = PromptResponseClassification(class_type=PromptResponseClassification.Type.RESPONSE_TEXT, + name=f"{rand_gen(str)}-response text") + + nested_response_radio = PromptResponseClassification(class_type=PromptResponseClassification.Type.RESPONSE_RADIO, + name=f"{rand_gen(str)}-nested response radio classification", + options=[ + ResponseOption(f"{rand_gen(str)}-first_radio_answer", + options=[ + PromptResponseClassification( + class_type=PromptResponseClassification.Type.RESPONSE_RADIO, + name=f"{rand_gen(str)}-sub_radio_question", + options=[ResponseOption(f"{rand_gen(str)}-first_sub_radio_answer")]) + ]) + ]) + yield { + "prompts": [prompt_text], + "responses": [response_text, response_radio, response_checklist, response_text_with_char, nested_response_radio] + } + +@pytest.fixture +def prompt_response_ontology(client: Client, rand_gen, prompt_response_features, request): + """fixture is parametrize and needs project_type in request""" + + project_type = request.param + if project_type == MediaType.LLMPromptCreation: + ontology_builder = OntologyBuilder( + tools=[], + classifications=prompt_response_features["prompts"]) + elif project_type == MediaType.LLMPromptResponseCreation: + ontology_builder = OntologyBuilder( + tools=[], + classifications=prompt_response_features["prompts"] + prompt_response_features["responses"]) + else: + ontology_builder = OntologyBuilder( + tools=[], + classifications=prompt_response_features["responses"] + ) + + ontology_name = f"prompt-response-{rand_gen(str)}" + + if project_type in MediaType: + ontology = client.create_ontology( + ontology_name, + ontology_builder.asdict(), + media_type=project_type) + else: + ontology = client.create_ontology( + ontology_name, + ontology_builder.asdict(), + media_type=MediaType.Text, + ontology_kind=OntologyKind.ResponseCreation + ) + yield ontology + + featureSchemaIds = [feature["featureSchemaId"] for feature in ontology.normalized["classifications"]] + + try: + client.delete_unused_ontology(ontology.uid) + for featureSchemaId in featureSchemaIds: + client.delete_unused_feature_schema(featureSchemaId) + except Exception as e: + print(f"Failed to delete ontology {ontology.uid}: {str(e)}") + @pytest.fixture def chat_evaluation_ontology(client, rand_gen): @@ -559,6 +682,15 @@ def offline_conversational_data_row(initial_dataset): return data_row +@pytest.fixture +def response_data_row(initial_dataset): + text_asset = { + "row_data": "response sample text" + } + data_row = initial_dataset.create_data_row(text_asset) + + return data_row + @pytest.fixture() def conversation_data_row(initial_dataset, rand_gen): diff --git a/libs/labelbox/tests/integration/test_prompt_response_generation_project.py b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py new file mode 100644 index 000000000..20d42d92c --- /dev/null +++ b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py @@ -0,0 +1,164 @@ +import pytest +from unittest.mock import patch + +from labelbox import MediaType +from labelbox.schema.ontology_kind import OntologyKind +from labelbox.exceptions import MalformedQueryException + +@pytest.mark.parametrize( + "prompt_response_ontology, prompt_response_generation_project_with_new_dataset", + [ + (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), + (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + ], + indirect=True +) +def test_prompt_response_generation_ontology_project( + client, prompt_response_ontology, + prompt_response_generation_project_with_new_dataset, + response_data_row, rand_gen): + + ontology = prompt_response_ontology + + assert ontology + assert ontology.name + + for classification in ontology.classifications(): + assert classification.schema_id + assert classification.feature_schema_id + + project = prompt_response_generation_project_with_new_dataset + + project.connect_ontology(ontology) + + assert project.labeling_frontend().name == "Editor" + assert project.ontology().name == ontology.name + + with pytest.raises( + ValueError, + match="Cannot create batches for auto data generation projects"): + project.create_batch( + rand_gen(str), + [response_data_row.uid], # sample of data row objects + ) + + with pytest.raises( + ValueError, + match="Cannot create batches for auto data generation projects"): + with patch('labelbox.schema.project.MAX_SYNC_BATCH_ROW_COUNT', + new=0): # force to async + + project.create_batch( + rand_gen(str), + [response_data_row.uid + ], # sample of data row objects + ) + +@pytest.mark.parametrize( + "prompt_response_ontology, prompt_response_generation_project_with_dataset_id", + [ + (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), + (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + ], + indirect=True +) +def test_prompt_response_generation_ontology_project_with_existing_dataset( + prompt_response_ontology, + prompt_response_generation_project_with_dataset_id): + ontology = prompt_response_ontology + + project = prompt_response_generation_project_with_dataset_id + assert project + project.connect_ontology(ontology) + + assert project.labeling_frontend().name == "Editor" + assert project.ontology().name == ontology.name + + +@pytest.fixture +def classification_json(): + classifications = [{ + 'featureSchemaId': None, + 'kind': 'Prompt', + 'minCharacters': 2, + 'maxCharacters': 10, + 'name': 'prompt text', + 'instructions': 'prompt text', + 'required': True, + 'schemaNodeId': None, + "scope": "global", + 'type': 'prompt', + 'options': [] + }, { + 'featureSchemaId': None, + 'kind': 'ResponseCheckboxQuestion', + 'name': 'response checklist', + 'instructions': 'response checklist', + 'options': [{'featureSchemaId': None, + 'kind': 'ResponseCheckboxOption', + 'label': 'response checklist option', + 'schemaNodeId': None, + 'position': 0, + 'value': 'option_1'}], + 'required': True, + 'schemaNodeId': None, + "scope": "global", + 'type': 'response-checklist' + }, { + 'featureSchemaId': None, + 'kind': 'ResponseText', + 'maxCharacters': 10, + 'minCharacters': 1, + 'name': 'response text', + 'instructions': 'response text', + 'required': True, + 'schemaNodeId': None, + "scope": "global", + 'type': 'response-text', + 'options': [] + } + ] + + return classifications + + +@pytest.fixture +def features_from_json(client, classification_json): + classifications = classification_json + features = {client.create_feature_schema(t) for t in classifications if t} + + yield features + + for f in features: + client.delete_unused_feature_schema(f.uid) + + +@pytest.fixture +def ontology_from_feature_ids(client, features_from_json): + feature_ids = {f.uid for f in features_from_json} + ontology = client.create_ontology_from_feature_schemas( + name="test-prompt_response_creation{rand_gen(str)}", + feature_schema_ids=feature_ids, + media_type=MediaType.LLMPromptResponseCreation + ) + + yield ontology + + client.delete_unused_ontology(ontology.uid) + + +def test_ontology_create_feature_schema(ontology_from_feature_ids, + features_from_json, classification_json): + created_ontology = ontology_from_feature_ids + feature_schema_ids = {f.uid for f in features_from_json} + classifications_normalized = created_ontology.normalized['classifications'] + classifications = classification_json + + for classification in classifications: + generated_tool = next( + c for c in classifications_normalized if c['name'] == classification['name']) + assert generated_tool['schemaNodeId'] is not None + assert generated_tool['featureSchemaId'] in feature_schema_ids + assert generated_tool['type'] == classification['type'] + assert generated_tool['name'] == classification['name'] + assert generated_tool['required'] == classification['required'] diff --git a/libs/labelbox/tests/integration/test_response_creation_project.py b/libs/labelbox/tests/integration/test_response_creation_project.py new file mode 100644 index 000000000..76ba12d54 --- /dev/null +++ b/libs/labelbox/tests/integration/test_response_creation_project.py @@ -0,0 +1,24 @@ +from labelbox.schema.project import Project +import pytest + +from labelbox.schema.ontology_kind import OntologyKind + +@pytest.mark.parametrize("prompt_response_ontology", [OntologyKind.ResponseCreation], indirect=True) +def test_create_response_creation_project(client, rand_gen, + response_creation_project, + prompt_response_ontology, + response_data_row): + project: Project = response_creation_project + assert project + + ontology = prompt_response_ontology + project.connect_ontology(ontology) + + assert project.labeling_frontend().name == "Editor" + assert project.ontology().name == ontology.name + + batch = project.create_batch( + rand_gen(str), + [response_data_row.uid], # sample of data row objects + ) + assert batch \ No newline at end of file From 61411bcdb1c6c20fc59563ce5712c2a8deebc773 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Fri, 12 Jul 2024 14:21:25 -0500 Subject: [PATCH 02/13] test_fixes --- libs/labelbox/src/labelbox/schema/ontology.py | 2 +- libs/labelbox/src/labelbox/schema/ontology_kind.py | 4 ++-- libs/labelbox/src/labelbox/schema/project.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index af3e0b6a3..144594ef0 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -569,7 +569,7 @@ class OntologyBuilder: def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: classifications = [] for c in dictionary["classifications"]: - if ["type"] in c and c["type"] in PromptResponseClassification.Type: + if "type" in c and c["type"] in PromptResponseClassification.Type: classifications.append(PromptResponseClassification.from_dict(c)) else: classifications.append(Classification.from_dict(c)) diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py index b0809d093..27764f784 100644 --- a/libs/labelbox/src/labelbox/schema/ontology_kind.py +++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import Optional, Union from labelbox.schema.media_type import MediaType @@ -24,7 +24,7 @@ def get_ontology_kind_validation_error(cls, ontology_kind): @staticmethod def evaluate_ontology_kind_with_media_type(ontology_kind, - media_type: MediaType = None) -> MediaType: + media_type: Union[MediaType, None]) -> MediaType: if ontology_kind and ontology_kind is OntologyKind.ModelEvaluation: if media_type is None: diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index c9215a61f..a07d8e19c 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -155,7 +155,7 @@ def is_prompt_response(self) -> bool: Returns: True if this project is a prompt response project, False otherwise """ - return self.media_type in [MediaType.LLMPromptResponseCreation, MediaType.LLMPromptCreation] or EditorTaskType.ResponseCreation + return self.media_type == MediaType.LLMPromptResponseCreation or self.media_type == MediaType.LLMPromptCreation or self.editor_task_type == EditorTaskType.ResponseCreation def is_auto_data_generation(self) -> bool: return (self.upload_type == UploadType.Auto) # type: ignore From fcaa75a3f2d2a784f0817f5d9c6ad29e5e34d342 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Fri, 12 Jul 2024 14:25:09 -0500 Subject: [PATCH 03/13] test_fixes --- libs/labelbox/src/labelbox/schema/ontology.py | 2 +- libs/labelbox/src/labelbox/schema/ontology_kind.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index 144594ef0..cd8a0480c 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -569,7 +569,7 @@ class OntologyBuilder: def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: classifications = [] for c in dictionary["classifications"]: - if "type" in c and c["type"] in PromptResponseClassification.Type: + if "type" in c and c["type"] in PromptResponseClassification.Type._value2member_map_.keys(): classifications.append(PromptResponseClassification.from_dict(c)) else: classifications.append(Classification.from_dict(c)) diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py index 27764f784..ae223a6fa 100644 --- a/libs/labelbox/src/labelbox/schema/ontology_kind.py +++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py @@ -24,7 +24,7 @@ def get_ontology_kind_validation_error(cls, ontology_kind): @staticmethod def evaluate_ontology_kind_with_media_type(ontology_kind, - media_type: Union[MediaType, None]) -> MediaType: + media_type: Union[MediaType, None]) -> Union[MediaType, None]: if ontology_kind and ontology_kind is OntologyKind.ModelEvaluation: if media_type is None: From c6c26d929f791a250a35117f95c716db553dbcc8 Mon Sep 17 00:00:00 2001 From: Gabe <33893811+Gabefire@users.noreply.github.com> Date: Wed, 17 Jul 2024 09:41:08 -0500 Subject: [PATCH 04/13] Typo --- libs/labelbox/src/labelbox/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index afcfcbada..19a3f3f0e 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -959,9 +959,9 @@ def create_response_creation_project(self, **kwargs) -> Project: kwargs[ "media_type"] = MediaType.Text # Only Text is supported kwargs[ - "editor_task_type"] = EditorTaskType.ResponseCreation.value # Special editor task type for offline model evaluation + "editor_task_type"] = EditorTaskType.ResponseCreation.value # Special editor task type for response creation projects - # The following arguments are not supported for offline model evaluation + # The following arguments are not supported for response creation projects kwargs.pop("dataset_name_or_id", None) kwargs.pop("append_to_existing_dataset", None) kwargs.pop("data_row_count", None) From 7b533eae680d58206c392b12c4348b297c9ebf70 Mon Sep 17 00:00:00 2001 From: Gabe <33893811+Gabefire@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:44:20 -0500 Subject: [PATCH 05/13] Update libs/labelbox/src/labelbox/schema/ontology_kind.py Co-authored-by: Val Brodsky --- libs/labelbox/src/labelbox/schema/ontology_kind.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py index ae223a6fa..074dc7db0 100644 --- a/libs/labelbox/src/labelbox/schema/ontology_kind.py +++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py @@ -24,7 +24,7 @@ def get_ontology_kind_validation_error(cls, ontology_kind): @staticmethod def evaluate_ontology_kind_with_media_type(ontology_kind, - media_type: Union[MediaType, None]) -> Union[MediaType, None]: + media_type: Optional[MediaType]) -> Union[MediaType, None]: if ontology_kind and ontology_kind is OntologyKind.ModelEvaluation: if media_type is None: From b1a45f99b1464a53d5da02f24309fb370aec6270 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:50:18 -0500 Subject: [PATCH 06/13] added suggestions --- libs/labelbox/src/labelbox/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 19a3f3f0e..bdd3b1e81 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -885,8 +885,8 @@ def create_prompt_response_generation_project(self, @overload def create_prompt_response_generation_project(self, - dataset_id: str, dataset_name: str = None, + dataset_id: str, data_row_count: int = 100, **kwargs) -> Project: pass @@ -903,7 +903,7 @@ def create_prompt_response_generation_project(self, dataset_name: When creating a new dataset, pass the name dataset_id: When using an existing dataset, pass the id data_row_count: The number of data row assets to use for the project - **kwargs: Additional parameters to pass to the the create_project method + **kwargs: Additional parameters to pass see the create_project method Returns: Project: The created project @@ -917,7 +917,7 @@ def create_prompt_response_generation_project(self, >>> client.create_prompt_response_generation_project(name=project_name, dataset_id="clr00u8j0j0j0", project_kind=MediaType.LLMPromptCreation) >>> This creates a new project, and adds 100 datarows to the dataset with id "clr00u8j0j0j0" and assigns a batch of the newly created data rows to the project. - >>> client.create_prompt_response_generation_project(name=project_name, dataset_id="clr00u8j0j0j0", data_row_count=10, project_kind=MediaType.LLMPromptCreation) + >>> client.create_prompt_response_generation_project(name=project_name, dataset_id="clr00u8j0j0j0", data_row_count=10, project_kind=MediaType.LLMPromptResponseCreation) >>> This creates a new project, and adds 100 datarows to the dataset with id "clr00u8j0j0j0" and assigns a batch of the newly created 10 data rows to the project. """ From 47a5099760e71401d499536db7a0023d40c7739f Mon Sep 17 00:00:00 2001 From: Gabe <33893811+Gabefire@users.noreply.github.com> Date: Tue, 23 Jul 2024 19:50:47 -0500 Subject: [PATCH 07/13] Update client.py --- libs/labelbox/src/labelbox/client.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index bdd3b1e81..4e1e1eb1d 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -875,21 +875,6 @@ def create_offline_model_evaluation_project(self, **kwargs) -> Project: return self._create_project(**kwargs) - @overload - def create_prompt_response_generation_project(self, - dataset_name: str, - dataset_id: str = None, - data_row_count: int = 100, - **kwargs) -> Project: - pass - - @overload - def create_prompt_response_generation_project(self, - dataset_name: str = None, - dataset_id: str, - data_row_count: int = 100, - **kwargs) -> Project: - pass def create_prompt_response_generation_project(self, dataset_id: Optional[str] = None, From c04a3a00c47f9e60e0f2e6dbdef581a4bd0a10c6 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Tue, 23 Jul 2024 19:56:42 -0500 Subject: [PATCH 08/13] provide validation for dataset_name and id --- libs/labelbox/src/labelbox/client.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 4e1e1eb1d..53083bed4 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -891,7 +891,9 @@ def create_prompt_response_generation_project(self, **kwargs: Additional parameters to pass see the create_project method Returns: Project: The created project - + + NOTE: Only a dataset_name or dataset_id should be included + Examples: >>> client.create_prompt_response_generation_project(name=project_name, dataset_name="new data set", project_kind=MediaType.LLMPromptResponseCreation) >>> This creates a new dataset with a default number of rows (100), creates new project and assigns a batch of the newly created datarows to the project. @@ -908,8 +910,14 @@ def create_prompt_response_generation_project(self, """ if not dataset_id and not dataset_name: raise ValueError( - "dataset_name or data_set_id must be present and not be an empty string." + "dataset_name or dataset_id must be present and not be an empty string." ) + + if dataset_id and dataset_name: + raise ValueError( + "Only provide a dataset_name or dataset_id, not both." + ) + if data_row_count <= 0: raise ValueError("data_row_count must be a positive integer.") From 17de92ec33ff68084482b6010fb4363da212f1e4 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 24 Jul 2024 09:02:28 -0500 Subject: [PATCH 09/13] added suggestion --- .../src/labelbox/schema/ontology_kind.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py index 074dc7db0..607b60f33 100644 --- a/libs/labelbox/src/labelbox/schema/ontology_kind.py +++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py @@ -26,23 +26,18 @@ def get_ontology_kind_validation_error(cls, ontology_kind): def evaluate_ontology_kind_with_media_type(ontology_kind, media_type: Optional[MediaType]) -> Union[MediaType, None]: - if ontology_kind and ontology_kind is OntologyKind.ModelEvaluation: - if media_type is None: - media_type = MediaType.Conversational - else: - if media_type is not MediaType.Conversational: - raise ValueError( - "For chat evaluation, media_type must be Conversational." - ) - - elif ontology_kind == OntologyKind.ResponseCreation: + ontology_to_media = { + OntologyKind.ModelEvaluation: (MediaType.Conversational, "For chat evaluation, media_type must be Conversational."), + OntologyKind.ResponseCreation: (MediaType.Text, "For response creation, media_type must be Text.") + } + + if ontology_kind in ontology_to_media: + expected_media_type, error_message = ontology_to_media[ontology_kind] + if media_type is None: - media_type = MediaType.Text + media_type = expected_media_type else: - if media_type is not MediaType.Text: - raise ValueError( - "For response creation, media_type must be Text." - ) + raise ValueError(error_message) return media_type From bc43a403aa66e88e5e2c882835c8728e64adfb0f Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 24 Jul 2024 09:04:51 -0500 Subject: [PATCH 10/13] added small adjustment --- libs/labelbox/src/labelbox/schema/ontology_kind.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py index 607b60f33..e8b4475ae 100644 --- a/libs/labelbox/src/labelbox/schema/ontology_kind.py +++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py @@ -34,7 +34,7 @@ def evaluate_ontology_kind_with_media_type(ontology_kind, if ontology_kind in ontology_to_media: expected_media_type, error_message = ontology_to_media[ontology_kind] - if media_type is None: + if media_type is None or media_type == expected_media_type: media_type = expected_media_type else: raise ValueError(error_message) From 0f17a5fded9439e97252ebf1104deebab4a53317 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 24 Jul 2024 09:10:39 -0500 Subject: [PATCH 11/13] added feedback --- libs/labelbox/src/labelbox/schema/ontology.py | 42 +------------------ 1 file changed, 2 insertions(+), 40 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index cd8a0480c..fd96e235f 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -223,7 +223,7 @@ def add_option(self, option: Option) -> None: self.options.append(option) @dataclass -class ResponseOption: +class ResponseOption(Option): """ An option is a possible answer within a PromptResponseClassification response object in a Project's ontology. @@ -239,45 +239,7 @@ class ResponseOption: feature_schema_id: (str) options: (list) """ - value: Union[str, int] - label: Optional[Union[str, int]] = None - schema_id: Optional[str] = None - feature_schema_id: Optional[FeatureSchemaId] = None - options: List["Classification"] = field(default_factory=list) - - def __post_init__(self): - if self.label is None: - self.label = self.value - - @classmethod - def from_dict( - cls, - dictionary: Dict[str, - Any]) -> Dict[Union[str, int], Union[str, int]]: - return cls(value=dictionary["value"], - label=dictionary["label"], - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - options=[ - PromptResponseClassification.from_dict(o) - for o in dictionary.get("options", []) - ]) - - def asdict(self) -> Dict[str, Any]: - return { - "schemaNodeId": self.schema_id, - "featureSchemaId": self.feature_schema_id, - "label": self.label, - "value": self.value, - "options": [o.asdict(is_subclass=True) for o in self.options] - } - - def add_option(self, option: 'Classification') -> None: - if option.name in (o.name for o in self.options): - raise InconsistentOntologyException( - f"Duplicate nested classification '{option.name}' " - f"for option '{self.label}'") - self.options.append(option) + pass @dataclass From 78ccbe5e93ccb9071d9d83f82a9af8123f8e4aba Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 24 Jul 2024 09:48:08 -0500 Subject: [PATCH 12/13] fixed one method --- libs/labelbox/src/labelbox/schema/ontology.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index fd96e235f..5896d8e68 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -53,7 +53,7 @@ class Option: label: Optional[Union[str, int]] = None schema_id: Optional[str] = None feature_schema_id: Optional[FeatureSchemaId] = None - options: List["Classification"] = field(default_factory=list) + options: Union[List["Classification"], List["PromptResponseClassification"]] = field(default_factory=list) def __post_init__(self): if self.label is None: @@ -82,7 +82,7 @@ def asdict(self) -> Dict[str, Any]: "options": [o.asdict(is_subclass=True) for o in self.options] } - def add_option(self, option: 'Classification') -> None: + def add_option(self, option: Union["Classification", "PromptResponseClassification"]) -> None: if option.name in (o.name for o in self.options): raise InconsistentOntologyException( f"Duplicate nested classification '{option.name}' " @@ -239,7 +239,20 @@ class ResponseOption(Option): feature_schema_id: (str) options: (list) """ - pass + + @classmethod + def from_dict( + cls, + dictionary: Dict[str, + Any]) -> Dict[Union[str, int], Union[str, int]]: + return cls(value=dictionary["value"], + label=dictionary["label"], + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + options=[ + PromptResponseClassification.from_dict(o) + for o in dictionary.get("options", []) + ]) @dataclass From e9f263bae9ebcd9cb4bf6ef2fbe84ff5545d5240 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 24 Jul 2024 09:49:06 -0500 Subject: [PATCH 13/13] fix typo on doc string --- libs/labelbox/src/labelbox/schema/ontology.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index 5896d8e68..985405059 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -291,10 +291,8 @@ class PromptResponseClassification: options: (list) character_min: (int) character_max: (int) - ui_mode: (Classification.UIMode) schema_id: (str) feature_schema_id: (str) - scope: (Classification.Scope) """ def __post_init__(self):