ํญ๋ชฉ | ๋ด์ฉ |
---|---|
ํ๋ก์ ํธ ์ฃผ์ | ๋ฌธ์ฅ ์ ๋จ์ด(Entity)์ ๋ํ ์์ฑ๊ณผ ๊ด๊ณ๋ฅผ ์์ธกํ๋ ๊ด๊ณ ์ถ์ถ(Relation Extraction) ์ ๋ฌด๋ฅผ ์ํ |
ํ๋ก์ ํธ ๊ตฌํ ๋ด์ฉ | โข ๋จ์ด ๊ฐ ๊ด๊ณ๋ฅผ ์๋ฏธํ๋ 30๊ฐ ๋ผ๋ฒจ ๊ฐ๊ฐ์ ๋ํด, subject์ object๊ฐ ํด๋น ํด๋์ค์ ์ํ ํ๋ฅ ์ ์์ธก โข ํ๊ฐ ์งํ๋ก๋ 1) no-relation class๋ฅผ ์ ์ธํ micro F1 score, 2) ๋ชจ๋ ํด๋์ค์ ๋ํ AUPRC๊ฐ ์ฌ์ฉ |
์งํ ๊ธฐ๊ฐ | 2024๋ 1์ 3์ผ ~ 2024๋ 1์ 18์ผ |
Notion, Slack, Zoom ์ ํตํด ํ์๋ฅผ ์งํํ์ผ๋ฉฐ, ์ฝ๋์ ๊ฒฝ์ฐ๋ ๋ชจ๋ํํ์ฌ Makeํ์ผ๋ก ์๋ํํ์ฌ ๊ด๋ฆฌํ์ต๋๋ค. ์ด๋, Github์ ํตํด ์ฝ๋ ๊ณต์ ๋ฅผ ์งํํ์ผ๋ฉฐ, Wandb๋ฅผ ์ด์ฉํด ์ค์๊ฐ์ผ๋ก ์คํ์ ๊ด๋ฆฌํ์ต๋๋ค.
๊ตฌํฌ์ฐฌ | ๊น๋ฏผ์ | ์์ ๋ฆผ | ์คํ์ฐ | ์ด์์ | ์ต์์ง |
---|---|---|---|---|---|
์ ๋ฐ์ ์ธ ํ๋ก์ ํธ ๊ณผ์ ์ ๋ชจ๋ ๊ฒฝํํ ์ ์๋๋ก ๋ถ์ ์ ํ์ฌ ํ์ ์ ์งํํ์ผ๋ฉฐ, ์ด๊ธฐ ๊ฐ๋ฐ ํ๊ฒฝ ๊ตฌ์ถ ๋จ๊ณ๋ถํฐ Github์ ์ฐ๋ํ์ฌ ์ธ๋ถ task ๋ง๋ค issue์ branch๋ฅผ ์์ฑํ์ฌ ๊ฐ์์ task๋ฅผ ์ํํ๋ค. ์ด๋ฅผ ํตํด ์ฝ๋ ๋ฒ์ ๊ด๋ฆฌ์ ์ฝ๋ ๊ณต์ ๊ฐ ์ํํ๊ฒ ์ด๋ฃจ์ด์ง ์ ์์๋ค.
์ด๋ฆ | ์ญํ |
---|---|
๊ตฌํฌ์ฐฌ | ํ๊ฒฝ ์ค์ ๋ฐ ํ ํ๋ฆฟ ๊ด๋ฆฌ, ๋ฒ ์ด์ค๋ผ์ธ ์ฝ๋ ๋ฆฌํฉํ ๋ง, ๋ฐ์ดํฐ ๋ผ๋ฒจ ๊ฒ์, ๋ชจ๋ธ ์ํคํ ์ณ ๊ฐ์ (3-classification, LSTM layer), Github issue ๊ด๋ฆฌ ๋ฐ PR merge, ํ์ดํผํ๋ผ๋ฏธํฐ ํ๋ ์คํ, ๊ฒ์ฆ ๋ฐ์ดํฐ์ ๋น์จ ๋ณ๊ฒฝ ์คํ, ์ฌ์ฉ์ฑ ๊ฐ์ ๋ฐ ์ฌํ๋ถ์์ฉ ์ ํธ ์ ์ย |
๊น๋ฏผ์ย | ํ๊ฒฝ ์ค์ ๋ฐ ํ ํ๋ฆฟ ๊ด๋ฆฌ, ๋ฒ ์ด์ค๋ผ์ธ ๋ชจ๋ธ ํ์ ๋ฐ ์คํ, ๋ชจ๋ธ ์ํคํ ์ณ ๊ฐ์ (LSTM layer, Focal Loss ์ ์ฉ ๋ฐ ์คํ), ์ ์ฒ๋ฆฌ(query ์ถ๊ฐ), ๊ฒ์ฆ ๋ฐ์ดํฐ์ ์์ฑ ๋ฐ ์ฝ๋ ์์ฑ, ์ ๋ฐ์ ์ธ ๋ฐฉ๋ฒ๋ก ์ ๋ฆฌ, ์คํ ๊ฒฐ๊ณผ ์ฌํ ๋ถ์ย |
์์ ๋ฆผย | ๋ฒ ์ด์ค๋ผ์ธ ๋ชจ๋ธ ์คํ, ์ฌ์ ์กฐ์ฌ ๋ฐ ๋ฐฉ๋ฒ๋ก ์ ๋ฆฌ, ๋ชจ๋ธ ๊ฐ์ ๊ด๋ จ ๋ ผ๋ฌธ ๋ฐ์ (entity marking, TAPT), ์ ์ฒ๋ฆฌ ๋ฐ ์ฆ๊ฐ ์๋(entity marking, ์ญ๋ฒ์ญ, ํ์ ์ ๊ฑฐ, MLM), ๋ชจ๋ธ๋ง ์คํ(3-classification)ย |
์คํ์ฐย | ์ ๋ฐ์ ์ธ ๋ฐฉ๋ฒ๋ก ์ ๋ฆฌ ๋ฐ ๋ ธ์ ๊ด๋ฆฌ, WandB ๊ด๋ฆฌ, EDA ๋ฐ ๋ฐ์ดํฐ ๋ถ์, ์ ์ฒ๋ฆฌ ๋ฐ ์ฆ๊ฐ ์๋(label reverse, MLM์ฆ๊ฐ, ์์ค ์คํ์ ํ ํฐ ์ถ๊ฐ), tokenizer ๋ถ์, ๋ชจ๋ธ ์ํคํ ์ณ ๊ฐ์ (LSTM layer), ์คํ ๊ฒฐ๊ณผ ์ฌํ ๋ถ์ ์ฝ๋ ์์ฑ ๋ฐ ๊ฒฐ๊ณผ ๋น๊ตย |
์ด์์ย | ๋ฒ ์ด์ค๋ผ์ธ ๋ชจ๋ธ ์คํ ๋ฐ ๋น๊ต, ๋ชจ๋ธ ์ํคํ ์ณ ๊ฐ์ (3-classification, Huggingface Roberta ์ฝ๋ ๋ถ์ ๋ฐ LSTM layer ์ถ๊ฐ), ์ ์ฒ๋ฆฌ ๋ฐ ์ฆ๊ฐ ์๋(์ญ๋ฒ์ญ ๋ฐ entity ์ถ์ถ), ํ์ดํผํ๋ผ๋ฏธํฐ ํ๋ ์คํ, ๋ฐ์ดํฐ ๋ผ๋ฒจ ๋ถ์ ๋ฐ ์คํ ๊ฒฐ๊ณผ ์ฌํ ๋ถ์ย |
์ต์์งย | ๋ชจ๋ธ ์ํคํ ์ณ ๊ฐ์ (3-classification, Early Stopping, Trainer ์ฝ๋ ๋ถ์), ์ ์ฒ๋ฆฌ ๋ฐ ์ฆ๊ฐ ์๋(data copy, backtranslation), ๊ฒ์ฆ ๋ฐ์ดํฐ์ ์์ฑ ๋ฐ ์ฝ๋ ์์ฑ, ๋ฐ์ดํฐ ๋ถ์ ๋ฐ ์คํ ๊ฒฐ๊ณผ ์ฌํ ๋ถ์, ์์๋ธ ์ฝ๋ ์์ฑย |
- ์ ์ฒด ๋ฐ์ดํฐ์ ๋ํ ํต๊ณ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
train.csv
: ์ด 32470๊ฐtest_data.csv
: ์ด 7765๊ฐ (์ ๋ต ๋ผ๋ฒจ blind = 100์ผ๋ก ์์ ํํ)

- column 1: ์ํ ์์ id
- column 2: sentence.
- column 3: subject_entity
- column 4: object_entity
- column 5: label
- column 6: ์ํ ์ถ์ฒ
๋ถ๋ฅ | ๋ด์ฉ |
---|---|
๋ชจ๋ธ | โข ์คํํ ๋ชจ๋ธ : ์ต์ข
์ ์ผ๋ก klue/roberta-large ์ฌ์ฉklue/roberta-large , klue/roberta-base , klue/roberta-small , monologg/koelectra-base-v3-discriminator , snunlp/KR-ELECTRA-discriminator , beomi/kcbert , xlm/roberta-large , kykim/bert-kor-base , kykim/electra-kor-base โข LSTM layer ์ถ๊ฐ : Classification ๋จ๊ณ์์ LSTM layer๋ฅผ ์ถ๊ฐํด์ค์ผ๋ก์จ ์ผ๋ถ ํ ํฐ์ ๊ฒฐ๊ณผ ๋ฒกํฐ๋ง์ ์ฌ์ฉํ๋ ๊ธฐ์กด ๊ตฌ์กฐ ๊ฐ์ , ๋ฌธ์ฅ ์ ์ฒด ๋ฒกํฐ๋ฅผ ํ์ฉํ ์ ์๋ LSTM layer๋ฅผ ์ถ๊ฐ |
๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ | โข (Typed) Entity Marker : entity์ ์์น ์ ๋ณด๋ฅผ marker๋ก ์ ๊ณตํ๊ณ entity์ ์ ํ์ ์ ๊ณตํด์ ํ์ต ์ฑ๋ฅ ํฅ์์ ์๋ โข ๋ฐ์ดํฐ Query ์ถ๊ฐํ๊ธฐ : BERT์ QA Task ํ์ต ๋ฐฉ์์ ์ ์ฉํ๊ณ ์ ํจ- sentence ์ ๋ถ๋ถ์ ์ง๋ฌธ ํํ์ ์ฟผ๋ฆฌ ์ถ๊ฐ (์์ : [SUB]์ [OBJ]์ ๊ด๊ณ๋ ๋ฌด์์ธ๊ฐ? [SEP] [sentence] [SEP]) โข Source ์คํ์ ํ ํฐ ์ถ๊ฐ : ์์ค๋ณ ํ๊ฒ๊ฐ์ ๋ถํฌ๊ฐ ๋ค๋ฅธ ๊ฒ์ ํ์ธ, ์ฟผ๋ฆฌ๋ฌธ ์์ 3๊ฐ์ง ์์ค ์คํ์ ํ ํฐ์ ์ถ๊ฐํด์ค - [W_PED],[W_TR], [POL] โข ํ์ ์ ๊ฑฐ : ํ ํฐ ๊ฒฐ๊ณผ์ UNK ์ต์ํ๋ฅผ ์ํจ. ๊ฐ์ฅ ๋ง์ด UNK๋ก ํ ํฐํ๋์๋ ํ์์ด ์ ๊ฑฐ |
๋ฐ์ดํฐ ์ฆ๊ฐ ๋ฐ ์กฐ์ | โขย Label Reverse ์ฆ๊ฐ : ์๋ก ์์ถฉ๋๋ ์๋ฏธ์ ๋ผ๋ฒจ๊ณผ, subject์ object๋ฅผ ๋ฐ๊ฟ๋ ๊ด์ฐฎ์ ๋ผ๋ฒจ์ ๊ฒฝ์ฐ subject์ object๋ฅผ ๋ฐ๋๋ก swapํ์ฌ ๋ฐ์ดํฐ ์ฆ๊ฐ, 10939๊ฐ์ ๋ฐ์ดํฐ ์ฆ๊ฐ โข Back-Translation ์ฆ๊ฐ : GoogleTrans ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ํ์ฉํด ๋ฌธ์ฅ์ ์์ด๋ก ๋ฒ์ญํ ํ, ์ด๋ฅผ ๋ค์ ํ๊ตญ์ด๋ก ๋ฒ์ญํ์ฌ ๋ฐ์ดํฐ ์ฆ๊ฐ โข MLM ์ฆ๊ฐ : BERT ๊ธฐ๋ฐ ๋ชจ๋ธ๋ค์ MLM ํ์ต ๋ฐฉ์์์ ์ฐฉ์, [MASK] ๋ถ๋ถ์ด ๊ธฐ์กด ๋ฌธ์ฅ๊ณผ ๋ค๋ฅธ ์๋ก์ด token์ผ๋ก ํจ๋ฌํ๋ ์ด์ง ๋ ๊ฒ์์ ๊ฐ์ ํ๊ณ ,์ฆ๊ฐ์ ํ์ฉ |
์ํคํ ์ณ ๋ณด์ | 1. ๊ณผ์ ํฉ ๋ฐฉ์ง โข Early Stopping : patience ์กฐ์ โข Hyperparameter Tuning : epoch, learning_rate, batch_size, load_best_model ๋ฑ 2. ๋ถํฌ ๋ถ๊ท ํ ํด๊ฒฐ โข binning ๋ชจ๋ธ๋ง โข ํน์ ๋ผ๋ฒจ ์ฆ๊ฐ ์๋ ๋ฐ no_relation ๋ผ๋ฒจ undersamping โข source๋ณ ๋ถ๊ท ํ ํด์ ์๋ โข Loss Function ๋ณ๊ฒฝ (Focal Loss) |
๊ฒ์ฆ ์ ๋ต | โข 9:1, 8:2, 95:5 ๋น์จ๊ณผ random, stratify์ ๋ฐฉ์์ผ๋ก valid set ์์ฑํด์ ํ๊ฐ โข ์ต์ข ์ ์ผ๋ก ๋ฆฌ๋๋ณด๋์ ์ ์ถํ์ฌ ๋ชจ๋ธ ์ฑ๋ฅ ๊ฒ์ฆ โข Valid set์ ๋ํ predict ๊ฐ๊ณผ ์ ๋ต๊ฐ์ ๋น๊ตํ๋ difference.csv ํ์ผ ๋ฐ ํํธ๋งต์ ์์ฑํ์ฌ ์ ์ฑํ๊ฐ |
์์๋ธ ๋ฐฉ๋ฒ | โข ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ์ ๋ชจ๋ธ๋ง ๊ธฐ๋ฒ, ์ฆ๊ฐ ๋ฐ์ดํฐ ์ ์ฉ ํ ํ์ตํ ๋ชจ๋ธ ์ค ๊ฐ์ฅ ์ฑ๋ฅ์ด ์ข์ ๋ชจ๋ธ 10๊ฐ๋ฅผ ์ ์ ํ์ฌ soft voting ์์๋ธ์ ์งํ โข ์ฑ๋ฅ์ด ์ข์ ๋ชจ๋ธ๋ค ์ค ์ต๋ํ ๋ค์ํ b์ข ๋ฅ์ ๋ชจ๋ธ๊ณผ ์ฌ๋ฌ ๋ฐ์ดํฐ์ ์ด ์์ด๋๋ก Soft Voting, Weighted Voting ์งํ โข ์ฑ๋ฅ ๊ฐ์ : micro f1 75.1084(๋จ์ผ๋ชจ๋ธ ์ต๊ณ ) โ76.4576 (์์๋ธ) |
๐ฆlevel2-klue-nlp-04
โฃ code
โ โฃ custom_robertamodel.py
โ โฃ dict_label_to_num.pkl
โ โฃ dict_num_to_label.pkl
โ โฃ focal.py
โ โฃ focal_loss.py
โ โฃ heatmap.py
โ โฃ inference.py
โ โฃ load_data.py
โ โฃ metrics.py
โ โฃ modify_path.py
โ โฃ split_valid_random.py
โ โฃ split_valid_stratify.py
โ โฃ train.py
โ โ train_source.py
โฃ config
โ โ default_config.yaml
โฃ dataset
โ โฃ test
โ โ โ test_data.csv
โ โฃ train
โ โ โ train.csv
โฃ model_ensemble
โ โฃ ensemble.py
โ โฃ ensemble_model.py
โ โ utils.py
โฃ utils
โ โฃ add_query.py
โ โฃ add_source_token.py
โ โ preprocessing.py
โฃ prediction
โ โ sample_submission.csv
โฃ .gitignore
โฃ README.md
โฃ Makefile
โฃ read_config.sh
โฃ setup.cfg
โฃ pyproject.toml
โ requirements.txt
dataset
์ ํ์ ๋๋ ํ ๋ฆฌ์ธtest
์test.csv
,train
์train.csv
ํ์ผ์ ์ค๋นํ๋ค.code
๋๋ ํ ๋ฆฌ๋ก ์ด๋ํ๊ณsplit_valid_*.py
ํ์ผ์ ์คํํ์ฌ validation ๋ฐ์ดํฐ๋ฅผ ์์ฑํ๋ค.config
๋๋ ํ ๋ฆฌ์default_config.yaml
์ ๋ณต์ฌํ์ฌconfig.yaml
ํ์ผ์ ์์ฑํ๊ณ ์ํ๋ Hyperpatameter๋ฅผ ์ค์ ํ๋ค.- ์์ ๋๋ ํ ๋ฆฌ์ธ
level2-klue-nlp-04
์ผ๋ก ์ด๋ํ์ฌmake run
์ ์ ๋ ฅํ๋ฉด ํ์ต๊ณผ ํจ๊ป ์ถ๋ก ์ด ์๋ฃ๋๋ค.- โconfigโ๋ก ์์ํ๋ ํ์ผ์ ๋ง๋ค๊ณ
make all
์ ์ ๋ ฅํ๋ฉด ์ฌ๋ฌ config ํ์ผ์ ์์ํ๋ค.
-
์์
โฃ config โ โฃ config.yaml โ โฃ config2.yaml โ โฃ config3.yaml โ โฃ config4.yaml โ โ default_config.yaml
- โconfigโ๋ก ์์ํ๋ ํ์ผ์ ๋ง๋ค๊ณ