Skip to content

Commit 0ae37c5

Browse files
authored
Merge pull request #47 from codelion/fix-bugs
Track and persist training history per class
2 parents f68b3c9 + b2baef0 commit 0ae37c5

File tree

9 files changed

+623
-41
lines changed

9 files changed

+623
-41
lines changed

.github/workflows/test.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name: Run Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- uses: actions/checkout@v3
15+
16+
- name: Set up Python 3.12
17+
uses: actions/setup-python@v4
18+
with:
19+
python-version: '3.12'
20+
21+
- name: Install dependencies
22+
run: |
23+
python -m pip install --upgrade pip
24+
pip install -e .
25+
pip install pytest pytest-cov psutil
26+
27+
- name: Run tests
28+
run: |
29+
pytest tests/ -v --cov=adaptive_classifier --cov-report=xml --cov-report=term
30+
31+
- name: Upload coverage to Codecov
32+
uses: codecov/codecov-action@v3
33+
with:
34+
file: ./coverage.xml
35+
flags: unittests
36+
name: codecov-umbrella
37+
fail_ci_if_error: false

README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,55 @@ This real-world evaluation demonstrates that adaptive classification can signifi
388388
- [RAGTruth: A Hallucination Corpus for Developing Trustworthy Retrieval-Augmented Language Models](https://arxiv.org/abs/2401.00396)
389389
- [LettuceDetect: A Hallucination Detection Framework for RAG Applications](https://arxiv.org/abs/2502.17125)
390390

391+
## Order Dependency in Online Learning
392+
393+
When using the adaptive classifier for true online learning (adding examples incrementally), be aware that the order in which examples are added can affect predictions. This is inherent to incremental neural network training.
394+
395+
### The Challenge
396+
397+
```python
398+
# These two scenarios may produce slightly different models:
399+
400+
# Scenario 1
401+
classifier.add_examples(["fish example"], ["aquatic"])
402+
classifier.add_examples(["bird example"], ["aerial"])
403+
404+
# Scenario 2
405+
classifier.add_examples(["bird example"], ["aerial"])
406+
classifier.add_examples(["fish example"], ["aquatic"])
407+
```
408+
409+
While we've implemented sorted label ID assignment to minimize this effect, the neural network component still learns incrementally, which can lead to order-dependent behavior.
410+
411+
### Solution: Prototype-Only Predictions
412+
413+
For applications requiring strict order independence, you can configure the classifier to rely solely on prototype-based predictions:
414+
415+
```python
416+
# Configure to use only prototypes (order-independent)
417+
config = {
418+
'prototype_weight': 1.0, # Use only prototypes
419+
'neural_weight': 0.0 # Disable neural network contribution
420+
}
421+
422+
classifier = AdaptiveClassifier("bert-base-uncased", config=config)
423+
```
424+
425+
With this configuration:
426+
- Predictions are based solely on similarity to class prototypes (mean embeddings)
427+
- Results are completely order-independent
428+
- Trade-off: May have slightly lower accuracy than the hybrid approach
429+
430+
### Best Practices
431+
432+
1. **For maximum consistency**: Use prototype-only configuration
433+
2. **For maximum accuracy**: Accept some order dependency with the default hybrid approach
434+
3. **For production systems**: Consider batching updates and retraining periodically if strict consistency is required
435+
4. **Model selection matters**: Some models (e.g., `google-bert/bert-large-cased`) may produce poor embeddings for single words. For better results with short inputs, consider:
436+
- `bert-base-uncased`
437+
- `sentence-transformers/all-MiniLM-L6-v2`
438+
- Or any model specifically trained for semantic similarity
439+
391440
## Citation
392441

393442
If you use this library in your research, please cite:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setup(
1717
name="adaptive-classifier",
18-
version="0.0.14",
18+
version="0.0.15",
1919
author="codelion",
2020
author_email="codelion@okyasoft.com",
2121
description="A flexible, adaptive classification system for dynamic text classification",

src/adaptive_classifier/classifier.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666

6767
# Statistics
6868
self.train_steps = 0
69+
self.training_history = {} # Track cumulative training examples per class
6970

7071
# Strategic classification components
7172
self.strategic_cost_function = None
@@ -87,19 +88,24 @@ def add_examples(self, texts: List[str], labels: List[str]):
8788
new_classes = set(labels) - set(self.label_to_id.keys())
8889
is_adding_new_classes = len(new_classes) > 0
8990

90-
# Update label mappings
91-
for label in new_classes:
91+
# Update label mappings - sort new classes alphabetically for consistent IDs
92+
for label in sorted(new_classes):
9293
idx = len(self.label_to_id)
9394
self.label_to_id[label] = idx
9495
self.id_to_label[idx] = label
9596

9697
# Get embeddings for all texts
9798
embeddings = self._get_embeddings(texts)
9899

99-
# Add examples to memory
100+
# Add examples to memory and update training history
100101
for text, embedding, label in zip(texts, embeddings, labels):
101102
example = Example(text, label, embedding)
102103
self.memory.add_example(example, label)
104+
105+
# Update training history
106+
if label not in self.training_history:
107+
self.training_history[label] = 0
108+
self.training_history[label] += 1
103109

104110
# Special handling for new classes
105111
if is_adding_new_classes:
@@ -118,6 +124,9 @@ def add_examples(self, texts: List[str], labels: List[str]):
118124
# Strategic training step if enabled
119125
if self.strategic_mode and self.train_steps % self.config.strategic_training_frequency == 0:
120126
self._perform_strategic_training()
127+
128+
# Ensure FAISS index is up to date after adding examples
129+
self.memory._rebuild_index()
121130

122131
def _train_new_classes(self, old_head: Optional[nn.Module], new_classes: Set[str]):
123132
"""Train the model with focus on new classes while preserving old class knowledge."""
@@ -317,17 +326,21 @@ def _predict_regular(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
317326
# Combine predictions with adjusted weights
318327
combined_scores = {}
319328

320-
# Use neural predictions more for recent classes
329+
# Use training history to determine weights
321330
for label, score in proto_preds:
322-
if label in self.memory.examples and len(self.memory.examples[label]) < 10:
323-
# For newer classes (fewer examples), trust neural predictions more
331+
# Check training history instead of current storage
332+
trained_examples = self.training_history.get(label, 0)
333+
if trained_examples < 10:
334+
# For newer classes (fewer training examples), trust neural predictions more
324335
weight = 0.3 # Lower prototype weight for new classes
325336
else:
326337
weight = 0.7 # Higher prototype weight for established classes
327338
combined_scores[label] = score * weight
328339

329340
for label, score in head_preds:
330-
if label in self.memory.examples and len(self.memory.examples[label]) < 10:
341+
# Use training history for neural weights too
342+
trained_examples = self.training_history.get(label, 0)
343+
if trained_examples < 10:
331344
weight = 0.7 # Higher neural weight for new classes
332345
else:
333346
weight = 0.3 # Lower neural weight for established classes
@@ -414,6 +427,7 @@ def _save_pretrained(
414427
'label_to_id': self.label_to_id,
415428
'id_to_label': {str(k): v for k, v in self.id_to_label.items()},
416429
'train_steps': self.train_steps,
430+
'training_history': self.training_history, # Save cumulative training counts
417431
'config': self.config.to_dict()
418432
}
419433

@@ -569,6 +583,9 @@ def _from_pretrained(
569583
int(k): v for k, v in config_dict['id_to_label'].items()
570584
}
571585
classifier.train_steps = config_dict['train_steps']
586+
587+
# Restore training history with backward compatibility
588+
classifier.training_history = config_dict.get('training_history', {})
572589

573590
# Load tensors
574591
tensors = load_file(model_path / "model.safetensors")
@@ -600,6 +617,13 @@ def _from_pretrained(
600617
classifier._initialize_adaptive_head()
601618
classifier.adaptive_head.load_state_dict(adaptive_head_params)
602619

620+
# Backward compatibility: estimate training history if not present
621+
if not classifier.training_history:
622+
for label, examples in saved_examples.items():
623+
# Estimate based on saved examples (default saves 5, typical training uses 100+)
624+
# Using 20x multiplier as reasonable estimate
625+
classifier.training_history[label] = len(examples) * 20
626+
603627
return classifier
604628

605629
def _generate_model_card(self) -> str:
@@ -754,20 +778,15 @@ def _initialize_adaptive_head(self):
754778
).to(self.device)
755779

756780
def _get_embeddings(self, texts: List[str]) -> List[torch.Tensor]:
757-
"""Get embeddings for input texts with improved caching."""
758-
# Sort texts for consistent tokenization
759-
sorted_indices = list(range(len(texts)))
760-
sorted_indices.sort(key=lambda i: texts[i])
761-
sorted_texts = [texts[i] for i in sorted_indices]
762-
781+
"""Get embeddings for input texts."""
763782
# Temporarily set model to eval mode
764783
was_training = self.model.training
765784
self.model.eval()
766785

767786
# Get embeddings
768787
with torch.no_grad():
769788
inputs = self.tokenizer(
770-
sorted_texts,
789+
texts,
771790
max_length=self.config.max_length,
772791
truncation=True,
773792
padding=True,
@@ -784,12 +803,8 @@ def _get_embeddings(self, texts: List[str]) -> List[torch.Tensor]:
784803
if was_training:
785804
self.model.train()
786805

787-
# Restore original order
788-
original_order = [0] * len(texts)
789-
for i, idx in enumerate(sorted_indices):
790-
original_order[idx] = embeddings[i].cpu()
791-
792-
return original_order
806+
# Return embeddings as list
807+
return [emb.cpu() for emb in embeddings]
793808

794809
def get_example_statistics(self) -> Dict[str, Any]:
795810
"""Get statistics about stored examples and model state."""

0 commit comments

Comments
 (0)