diff --git a/libs/labelbox/src/labelbox/schema/tool_building/classification.py b/libs/labelbox/src/labelbox/schema/tool_building/classification.py index 62fee2dda..a6f4ebd1d 100644 --- a/libs/labelbox/src/labelbox/schema/tool_building/classification.py +++ b/libs/labelbox/src/labelbox/schema/tool_building/classification.py @@ -76,14 +76,15 @@ class UIMode(Enum): None # How this classification should be answered (e.g. hotkeys / autocomplete, etc) ) attributes: Optional[FeatureSchemaAttributes] = None + is_likert_scale: bool = False def __post_init__(self): if self.name is None: msg = ( - "When creating the Classification feature, please use “name” " + 'When creating the Classification feature, please use "name" ' "for the classification schema name, which will be used when " "creating annotation payload for Model-Assisted Labeling " - "Import and Label Import. “instructions” is no longer " + 'Import and Label Import. "instructions" is no longer ' "supported to specify classification schema name." ) if self.instructions is not None: @@ -119,6 +120,7 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> "Classification": ] if dictionary.get("attributes") else None, + is_likert_scale=dictionary.get("isLikertScale", False), ) def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: @@ -138,6 +140,9 @@ def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: if self.attributes is not None else None, } + if self.class_type == self.Type.RADIO and self.is_likert_scale: + # is_likert_scale is only applicable to RADIO classifications + classification["isLikertScale"] = self.is_likert_scale if ( self.class_type == self.Type.RADIO or self.class_type == self.Type.CHECKLIST @@ -159,6 +164,9 @@ def add_option(self, option: "Option") -> None: f"Duplicate option '{option.value}' " f"for classification '{self.name}'." ) + # Auto-assign position if not set + if option.position is None: + option.position = len(self.options) self.options.append(option) @@ -178,6 +186,7 @@ class Option: schema_id: (str) feature_schema_id: (str) options: (list) + position: (int) - Position of the option, auto-assigned starting from 0 """ value: Union[str, int] @@ -187,6 +196,7 @@ class Option: options: Union[ List["Classification"], List["PromptResponseClassification"] ] = field(default_factory=list) + position: Optional[int] = None def __post_init__(self): if self.label is None: @@ -203,16 +213,20 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> "Option": Classification.from_dict(o) for o in dictionary.get("options", []) ], + position=dictionary.get("position", None), ) def asdict(self) -> Dict[str, Any]: - return { + result = { "schemaNodeId": self.schema_id, "featureSchemaId": self.feature_schema_id, "label": self.label, "value": self.value, "options": [o.asdict(is_subclass=True) for o in self.options], } + if self.position is not None: + result["position"] = self.position + return result def add_option( self, option: Union["Classification", "PromptResponseClassification"] @@ -268,10 +282,10 @@ class PromptResponseClassification: def __post_init__(self): if self.name is None: msg = ( - "When creating the Classification feature, please use “name” " + 'When creating the Classification feature, please use "name" ' "for the classification schema name, which will be used when " "creating annotation payload for Model-Assisted Labeling " - "Import and Label Import. “instructions” is no longer " + 'Import and Label Import. "instructions" is no longer ' "supported to specify classification schema name." ) if self.instructions is not None: diff --git a/libs/labelbox/tests/unit/test_unit_ontology.py b/libs/labelbox/tests/unit/test_unit_ontology.py index 137e4fd1f..ca0984aa2 100644 --- a/libs/labelbox/tests/unit/test_unit_ontology.py +++ b/libs/labelbox/tests/unit/test_unit_ontology.py @@ -294,3 +294,33 @@ def test_classification_using_instructions_instead_of_name_shows_warning(): def test_classification_without_name_raises_error(): with pytest.raises(ValueError): Classification(class_type=Classification.Type.TEXT) + + +@pytest.mark.parametrize( + "class_type, is_likert_scale, should_include", + [ + (Classification.Type.RADIO, True, True), + (Classification.Type.RADIO, False, False), + (Classification.Type.CHECKLIST, True, False), + (Classification.Type.TEXT, True, False), + ], +) +def test_is_likert_scale_serialization( + class_type, is_likert_scale, should_include +): + c = Classification( + class_type=class_type, name="test", is_likert_scale=is_likert_scale + ) + if class_type in Classification._REQUIRES_OPTIONS: + c.add_option(Option(value="option1")) + result = c.asdict() + assert ("isLikertScale" in result) == should_include + + +def test_option_position_auto_assignment(): + c = Classification(class_type=Classification.Type.RADIO, name="test") + o1, o2 = Option(value="first"), Option(value="second") + c.add_option(o1) + c.add_option(o2) + assert o1.position == 0 and o2.position == 1 + assert c.asdict()["options"][0]["position"] == 0