このプロジェクトは、ResNetモデルを使用して積雪画像の分類タスクを行います。半教師あり学習方法を採用しており、まずは教師ありの三分類を行い、その後、各クラス内でより細かな予測を行います。
model.py
: ResNetモデルの定義を含み、複数のResNetバリアント(ResNet18, 34, 50, 101など)をサポートしますtrain.py
: トレーニングスクリプトで、データの読み込み、モデルのトレーニングなどの主要機能を含みますpredict.py
: 単一画像予測スクリプトload_weights.py
: モデルの重みをロードするツールdata_set/
: データセットディレクトリharmo/
: 3つのクラスを含む積雪画像(未公開されましたので、こちらで削除いたしました。)- No Snow
- Snow Coverage <50%
- Snow Coverage ≥50%
- Python 3.x
- PyTorch
- torchvision
- tqdm
- CUDA(or mps, for acceleration)
-
dataset:
- データセットをクラス別に
data_set/harmo
置いといてください。
- データセットをクラス別に
-
Supervised learning part:
python train.py --num_classes 3 --vis-interval 5 --device mps --weights checkpoint/resnet34-pre.pth
-
Semi-supervised learning part:
python train.py --num_classes 9 --vis-interval 5 --device mps --weights checkpoint/resnet34-XX.pth
Note: こちらの
checkpoint/resnet34-XX.pth
は、前の教師あり学習で得られたweightを使った方がいいです。 例えばpython train.py --num_classes 9 --vis-interval 5 --device mps --weights checkpoint/ResNet34-6-v3.pth
-
Prediction:
- Single image prediction:
あるいは
python predict.py --weights [path-to-weight] --img-path [path-to-image] --device mps --visualize
Note: --img-dir for batch prediction or --img-path for single image predictionpython predict.py --weights [path-to-weight] --img-path [path-to-image] --device mps
- Single image prediction:
このプロジェクトでは、ResNet34をベースモデルとして使用し、transfer learningを使いました:
- Load the pre-trained weights
- Replace the final fully connected layer to fit the 3-class classification task
- Fine-tune the model
- --num_classes: モデルの分類クラス数を指定します(デフォルトは3クラス)。
- --epochs: トレーニングのエポック数(デフォルトは10)。
- --batch-size: バッチサイズ(デフォルトは32)。
- --lr: 学習率(デフォルトは1e-4)。
- --wd: 重み減衰(デフォルトは5e-2)。
- --version: モデルのバージョン(デフォルトは2)。
- --data-path: トレーニングデータのパス(デフォルトはdata_set/harmo)。
- --weights: 初期重みのパス(デフォルトはcheckpoint/resnet34-pre.pth)。
- --freeze-layers: 最後の層以外を凍結するかどうか(デフォルトはFalse)。
- --device: 使用するデバイス(CPU、CUDA、MPS)。
- --semi-supervised: 半教師あり学習を有効にするかどうか(オプション)。
- --unlabeled-data-path: 半教師あり学習用のラベルなしデータのパス(オプション)。
- --consistency-weight: 一貫性損失の重み(オプション)。
- --visualize: 可視化を有効にするかどうか(オプション)。
- --vis-interval: 可視化の間隔(エポック単位)。
- トレーニングで、学習率やバッチサイズなどのハイパーパラメータを調整できます。
- モデルの重みは
checkpoint
に保存されます。