diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index 985405059..c9c2fffd4 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -53,7 +53,9 @@ class Option: label: Optional[Union[str, int]] = None schema_id: Optional[str] = None feature_schema_id: Optional[FeatureSchemaId] = None - options: Union[List["Classification"], List["PromptResponseClassification"]] = field(default_factory=list) + options: Union[List["Classification"], + List["PromptResponseClassification"]] = field( + default_factory=list) def __post_init__(self): if self.label is None: @@ -82,7 +84,9 @@ def asdict(self) -> Dict[str, Any]: "options": [o.asdict(is_subclass=True) for o in self.options] } - def add_option(self, option: Union["Classification", "PromptResponseClassification"]) -> None: + def add_option( + self, option: Union["Classification", + "PromptResponseClassification"]) -> None: if option.name in (o.name for o in self.options): raise InconsistentOntologyException( f"Duplicate nested classification '{option.name}' " @@ -140,7 +144,7 @@ class Type(Enum): class Scope(Enum): GLOBAL = "global" INDEX = "index" - + class UIMode(Enum): HOTKEY = "hotkey" SEARCHABLE = "searchable" @@ -155,7 +159,8 @@ class UIMode(Enum): schema_id: Optional[str] = None feature_schema_id: Optional[str] = None scope: Scope = None - ui_mode: Optional[UIMode] = None # How this classification should be answered (e.g. hotkeys / autocomplete, etc) + ui_mode: Optional[ + UIMode] = None # How this classification should be answered (e.g. hotkeys / autocomplete, etc) def __post_init__(self): if self.class_type == Classification.Type.DROPDOWN: @@ -187,7 +192,8 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: instructions=dictionary["instructions"], required=dictionary.get("required", False), options=[Option.from_dict(o) for o in dictionary["options"]], - ui_mode=cls.UIMode(dictionary["uiMode"]) if "uiMode" in dictionary else None, + ui_mode=cls.UIMode(dictionary["uiMode"]) + if "uiMode" in dictionary else None, schema_id=dictionary.get("schemaNodeId", None), feature_schema_id=dictionary.get("featureSchemaId", None), scope=cls.Scope(dictionary.get("scope", cls.Scope.GLOBAL))) @@ -206,7 +212,8 @@ def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: "schemaNodeId": self.schema_id, "featureSchemaId": self.feature_schema_id } - if (self.class_type == self.Type.RADIO or self.class_type == self.Type.CHECKLIST) and self.ui_mode: + if (self.class_type == self.Type.RADIO or + self.class_type == self.Type.CHECKLIST) and self.ui_mode: # added because this key does nothing for text so no point of including classification["uiMode"] = self.ui_mode.value if is_subclass: @@ -221,7 +228,8 @@ def add_option(self, option: Option) -> None: f"Duplicate option '{option.value}' " f"for classification '{self.name}'.") self.options.append(option) - + + @dataclass class ResponseOption(Option): """ @@ -239,7 +247,7 @@ class ResponseOption(Option): feature_schema_id: (str) options: (list) """ - + @classmethod def from_dict( cls, @@ -294,7 +302,7 @@ class PromptResponseClassification: schema_id: (str) feature_schema_id: (str) """ - + def __post_init__(self): if self.name is None: msg = ( @@ -314,7 +322,7 @@ def __post_init__(self): class Type(Enum): PROMPT = "prompt" - RESPONSE_TEXT= "response-text" + RESPONSE_TEXT = "response-text" RESPONSE_CHECKLIST = "response-checklist" RESPONSE_RADIO = "response-radio" @@ -332,15 +340,18 @@ class Type(Enum): @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: - return cls(class_type=cls.Type(dictionary["type"]), - name=dictionary["name"], - instructions=dictionary["instructions"], - required=True, # always required - options=[ResponseOption.from_dict(o) for o in dictionary["options"]], - character_min=dictionary.get("minCharacters", None), - character_max=dictionary.get("maxCharacters", None), - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None)) + return cls( + class_type=cls.Type(dictionary["type"]), + name=dictionary["name"], + instructions=dictionary["instructions"], + required=True, # always required + options=[ + ResponseOption.from_dict(o) for o in dictionary["options"] + ], + character_min=dictionary.get("minCharacters", None), + character_max=dictionary.get("maxCharacters", None), + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None)) def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: if self.class_type in self._REQUIRES_OPTIONS \ @@ -351,12 +362,13 @@ def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: "type": self.class_type.value, "instructions": self.instructions, "name": self.name, - "required": True, # always required + "required": True, # always required "options": [o.asdict() for o in self.options], "schemaNodeId": self.schema_id, "featureSchemaId": self.feature_schema_id } - if (self.class_type == self.Type.PROMPT or self.class_type == self.Type.RESPONSE_TEXT): + if (self.class_type == self.Type.PROMPT or + self.class_type == self.Type.RESPONSE_TEXT): if self.character_min: classification["minCharacters"] = self.character_min if self.character_max: @@ -488,7 +500,8 @@ class Ontology(DbObject): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._tools: Optional[List[Tool]] = None - self._classifications: Optional[Union[List[Classification],List[PromptResponseClassification]]] = None + self._classifications: Optional[Union[ + List[Classification], List[PromptResponseClassification]]] = None def tools(self) -> List[Tool]: """Get list of tools (AKA objects) in an Ontology.""" @@ -498,15 +511,20 @@ def tools(self) -> List[Tool]: ] return self._tools - def classifications(self) -> List[Union[Classification, PromptResponseClassification]]: + def classifications( + self) -> List[Union[Classification, PromptResponseClassification]]: """Get list of classifications in an Ontology.""" if self._classifications is None: self._classifications = [] for classification in self.normalized["classifications"]: - if "type" in classification and classification["type"] in PromptResponseClassification.Type._value2member_map_.keys(): - self._classifications.append(PromptResponseClassification.from_dict(classification)) + if "type" in classification and classification[ + "type"] in PromptResponseClassification.Type._value2member_map_.keys( + ): + self._classifications.append( + PromptResponseClassification.from_dict(classification)) else: - self._classifications.append(Classification.from_dict(classification)) + self._classifications.append( + Classification.from_dict(classification)) return self._classifications @@ -536,14 +554,22 @@ class OntologyBuilder: """ tools: List[Tool] = field(default_factory=list) - classifications: List[Union[Classification, PromptResponseClassification]] = field(default_factory=list) + classifications: List[Union[Classification, + PromptResponseClassification]] = field( + default_factory=list) + + def foo(self): + pass @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: classifications = [] for c in dictionary["classifications"]: - if "type" in c and c["type"] in PromptResponseClassification.Type._value2member_map_.keys(): - classifications.append(PromptResponseClassification.from_dict(c)) + if "type" in c and c[ + "type"] in PromptResponseClassification.Type._value2member_map_.keys( + ): + classifications.append( + PromptResponseClassification.from_dict(c)) else: classifications.append(Classification.from_dict(c)) return cls(tools=[Tool.from_dict(t) for t in dictionary["tools"]], @@ -554,11 +580,13 @@ def asdict(self) -> Dict[str, Any]: classifications = [] prompts = 0 for c in self.classifications: - if hasattr(c, "class_type") and c.class_type in PromptResponseClassification.Type: + if hasattr(c, "class_type" + ) and c.class_type in PromptResponseClassification.Type: if c.class_type == PromptResponseClassification.Type.PROMPT: prompts += 1 if prompts > 1: - raise ValueError("Only one prompt is allowed per ontology") + raise ValueError( + "Only one prompt is allowed per ontology") classifications.append(PromptResponseClassification.asdict(c)) else: classifications.append(Classification.asdict(c)) @@ -592,7 +620,10 @@ def add_tool(self, tool: Tool) -> None: f"Duplicate tool name '{tool.name}'. ") self.tools.append(tool) - def add_classification(self, classification: Union[Classification, PromptResponseClassification]) -> None: + def add_classification( + self, classification: Union[Classification, + PromptResponseClassification] + ) -> None: if classification.name in (c.name for c in self.classifications): raise InconsistentOntologyException( f"Duplicate classification name '{classification.name}'. ")