Grace Chen1, Zoey Qu1, Michael Zhao1
1 New York University
This project implements a hierarchical news classification system that leverages BERT embeddings with a two-level cascade architecture. The system first classifies news articles into broad categories (Level 1) and then further classifies them into specific subcategories (Level 2). By addressing hierarchical classification challenges through innovative confidence management and error correction strategies, our approach provides more accurate and robust news categorization, enhancing both computational efficiency and practical interpretability.
- Adaptive Confidence Thresholds: Dynamically adjusts confidence thresholds based on historical accuracy to mitigate error propagation
- Data Balancing Strategy: Implements intelligent upsampling for under-represented subcategories to improve classification of rare classes
- Hierarchical Error Correction: Introduces a feedback mechanism between classification levels, allowing low-confidence primary classifications to be corrected
- Differentiated Training Parameters: Optimizes hyperparameters separately for each classification level
config.py
: Configuration parameters for models and trainingdata.py
: Data loading, preprocessing, and balancing functionsthresholds.py
: Implementation of adaptive threshold mechanismtrain_l1.py
: Training script for Level-1 classifiertrain_l2.py
: Training script for Level-2 classifierspredict.py
: Model inference and hierarchical classificationapp.py
: Streamlit web application for interactive classification
- Python 3.7+
- PyTorch >= 1.9.0
- Transformers >= 4.12.0
- Pandas >= 1.3.0
- NumPy >= 1.20.0
- Scikit-learn >= 0.24.0
- Datasets >= 1.11.0
- Evaluate >= 0.2.0
- Streamlit >= 1.8.0
# Install dependencies
pip install -r requirements.txt
# Train Level-1 classifier
python train_l1.py
# Train Level-2 classifiers
python train_l2.py
# Run interactive prediction
python predict.py
# Launch the web interface
streamlit run app.py
The Streamlit application provides an interactive interface for classifying news articles and visualizing confidence scores for both levels.