diff --git a/ovos_padatious/intent.py b/ovos_padatious/intent.py index 588ac7b..804de91 100644 --- a/ovos_padatious/intent.py +++ b/ovos_padatious/intent.py @@ -74,7 +74,7 @@ def train(self, train_data): tokens = set([token for sent in train_data.my_sents(self.name) for token in sent if token.startswith('{')]) self.pos_intents = [PosIntent(i, self.name) for i in tokens] - - self.simple_intent.train(train_data) + success = self.simple_intent.train(train_data) for i in self.pos_intents: i.train(train_data) + return success diff --git a/ovos_padatious/intent_manager.py b/ovos_padatious/intent_manager.py index 9c11ac1..28e34e9 100644 --- a/ovos_padatious/intent_manager.py +++ b/ovos_padatious/intent_manager.py @@ -48,6 +48,14 @@ def calc_intents(self, query: str, entity_manager) -> List[MatchData]: List[MatchData]: A list of matches sorted by confidence. """ sent = tokenize(query) + if not self.objects: + return [] + if len(self.objects) == 1: + try: + match = self.objects[0].match(sent, entity_manager).detokenize() + return [match] + except: + return [] def match_intent(intent): start_time = time.monotonic() diff --git a/ovos_padatious/opm.py b/ovos_padatious/opm.py index 0a6aa96..f8adfc1 100644 --- a/ovos_padatious/opm.py +++ b/ovos_padatious/opm.py @@ -258,7 +258,7 @@ class PadatiousPipeline(ConfidenceMatcherPipeline): def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, config: Optional[Dict] = None, - engine_class: Optional[PadatiousEngine] = IntentContainer): + engine_class: Optional[PadatiousEngine] = None): super().__init__(bus, config) self.lock = RLock() @@ -273,8 +273,8 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, self.conf_med = self.config.get("conf_med") or 0.8 self.conf_low = self.config.get("conf_low") or 0.5 - if engine_class is None and self.config.get("domain_engine"): - engine_class = DomainIntentContainer + engine_class = engine_class or DomainIntentContainer if self.config.get("domain_engine") else IntentContainer + LOG.info(f"Padatious class: {engine_class.__name__}") self.remove_punct = self.config.get("cast_to_ascii", False) use_stemmer = self.config.get("stem", False) diff --git a/ovos_padatious/pos_intent.py b/ovos_padatious/pos_intent.py index 2b407b2..4374a9d 100644 --- a/ovos_padatious/pos_intent.py +++ b/ovos_padatious/pos_intent.py @@ -83,6 +83,7 @@ def from_file(cls, prefix, token): i.load(prefix) return self - def train(self, train_data): + def train(self, train_data) -> bool: for i in self.edges: i.train(train_data) + return True diff --git a/ovos_padatious/simple_intent.py b/ovos_padatious/simple_intent.py index d03117d..6d20946 100644 --- a/ovos_padatious/simple_intent.py +++ b/ovos_padatious/simple_intent.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os.path from fann2 import libfann as fann from ovos_utils.log import LOG @@ -62,7 +63,7 @@ def configure_net(self): self.net.set_train_stop_function(fann.STOPFUNC_BIT) self.net.set_bit_fail_limit(0.1) - def train(self, train_data): + def train(self, train_data) -> bool: for sent in train_data.my_sents(self.name): self.ids.add_sent(sent) @@ -72,6 +73,10 @@ def train(self, train_data): n_pos = len(list(train_data.my_sents(self.name))) n_neg = len(list(train_data.other_sents(self.name))) + if not n_neg or not n_pos: + LOG.error(f"not enough samples to learn intent: pos {n_pos} / neg {n_neg}") + return False + def add(vec, out): inputs.append(self.vectorize(vec)) outputs.append([out]) @@ -126,15 +131,22 @@ def calc_weight(w): return pow(len(w), 3.0) if self.net.get_bit_fail() == 0: break LOG.debug(f"Training {self.name} finished!") + return True def save(self, prefix): - prefix += '.intent' + if not self.net: + raise RuntimeError(f"intent not yet trained! '{prefix}.net'") + if not prefix.endswith(".intent"): + prefix += '.intent' self.net.save(str(prefix + '.net')) # Must have str() self.ids.save(prefix) @classmethod def from_file(cls, name, prefix): - prefix += '.intent' + if not prefix.endswith(".intent"): + prefix += '.intent' + if not os.path.isfile(str(prefix + '.net')): + raise FileNotFoundError(f"intent not yet trained! '{prefix}.net'") self = cls(name) self.net = fann.neural_net() if not self.net.create_from_file(str(prefix + '.net')): # Must have str() diff --git a/ovos_padatious/trainable.py b/ovos_padatious/trainable.py index 37d70c4..994f24d 100644 --- a/ovos_padatious/trainable.py +++ b/ovos_padatious/trainable.py @@ -33,7 +33,7 @@ def save_hash(self, prefix): f.write(self.hash) @abstractmethod - def train(self, data): + def train(self, data) -> bool: pass @abstractmethod diff --git a/ovos_padatious/training_manager.py b/ovos_padatious/training_manager.py index bc99bad..560dd43 100644 --- a/ovos_padatious/training_manager.py +++ b/ovos_padatious/training_manager.py @@ -33,10 +33,12 @@ def _train_and_save(obj: Trainable, cache: str, data: TrainData, print_updates: data (TrainData): Training data. print_updates (bool): Whether to print updates during training. """ - obj.train(data) - obj.save(cache) - if print_updates: - LOG.debug(f'Saving {obj.name} to cache ({cache})') + if obj.train(data): + obj.save(cache) + if print_updates: + LOG.debug(f'Saving {obj.name} to cache ({cache})') + else: + LOG.debug(f'Failed to train {obj.name}') class TrainingManager: @@ -72,37 +74,43 @@ def add(self, name: str, lines: List[str], reload_cache: bool = False, must_trai reload_cache (bool): Whether to force reload of cache if it exists. must_train (bool): Whether training is required for the new intent/entity. """ + hash_fn = join(self.cache, name + '.hash') + min_ver = splitext(ovos_padatious.__version__)[0] + if not isfile(hash_fn): + must_train = True + if not must_train: - LOG.debug(f"Loading {name} from intent cache") - self.objects.append(self.cls.from_file(name=name, folder=self.cache)) + try: # .net file renamed/deleted for some reason + LOG.debug(f"Loading '{name}' from intent cache") + self.objects.append(self.cls.from_file(name=name, folder=self.cache)) + except: + LOG.debug(f"Regenerating cache for intent: {name}") + # general case: load resource (entity or intent) to training queue # or if no change occurred to memory data structures + old_hsh = None + new_hsh = lines_hash([min_ver] + lines) + + if isfile(hash_fn): + with open(hash_fn, 'rb') as g: + old_hsh = g.read() + if old_hsh != new_hsh: + LOG.debug(f"{name} training data changed! retraining") else: - hash_fn = join(self.cache, name + '.hash') - old_hsh = None - min_ver = splitext(ovos_padatious.__version__)[0] - new_hsh = lines_hash([min_ver] + lines) - - if isfile(hash_fn): - with open(hash_fn, 'rb') as g: - old_hsh = g.read() - if old_hsh != new_hsh: - LOG.debug(f"{name} training data changed! retraining") - else: - LOG.debug(f"First time training '{name}") - - retrain = reload_cache or old_hsh != new_hsh - if not retrain: - try: - LOG.debug(f"Loading {name} from intent cache") - self.objects.append(self.cls.from_file(name=name, folder=self.cache)) - except Exception as e: - LOG.error(f"Failed to load intent from cache: {name} - {str(e)}") - retrain = True - if retrain: - LOG.debug(f"Queuing {name} for training") - self.objects_to_train.append(self.cls(name=name, hsh=new_hsh)) - self.train_data.add_lines(name, lines) + LOG.debug(f"First time training '{name}'") + + retrain = reload_cache or old_hsh != new_hsh + if not retrain: + try: + LOG.debug(f"Loading {name} from intent cache") + self.objects.append(self.cls.from_file(name=name, folder=self.cache)) + except Exception as e: + LOG.error(f"Failed to load intent from cache: {name} - {str(e)}") + retrain = True + if retrain: + LOG.debug(f"Queuing {name} for training") + self.objects_to_train.append(self.cls(name=name, hsh=new_hsh)) + self.train_data.add_lines(name, lines) def load(self, name: str, file_name: str, reload_cache: bool = False) -> None: """