• Stars
    star
    195
  • Rank 198,371 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 4 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Summarization module based on KoBART

KoBART-summarization

Load KoBART

Download binary

import torch
from transformers import PreTrainedTokenizerFast
from transformers import BartForConditionalGeneration

tokenizer = PreTrainedTokenizerFast.from_pretrained('digit82/kobart-summarization')
model = BartForConditionalGeneration.from_pretrained('digit82/kobart-summarization')

text = """
1일 μ˜€ν›„ 9μ‹œκΉŒμ§€ μ΅œμ†Œ 20만3220λͺ…이 μ½”λ‘œλ‚˜19에 μ‹ κ·œ 확진됐닀. λ˜λ‹€μ‹œ λ™μ‹œκ°„λŒ€ μ΅œλ‹€ 기둝으둜, 사상 처음 20만λͺ…λŒ€μ— μ§„μž…ν–ˆλ‹€.
λ°©μ—­ λ‹Ήκ΅­κ³Ό μ„œμšΈμ‹œ λ“± 각 μ§€λ°©μžμΉ˜λ‹¨μ²΄μ— λ”°λ₯΄λ©΄ 이날 0μ‹œλΆ€ν„° μ˜€ν›„ 9μ‹œκΉŒμ§€ μ „κ΅­ μ‹ κ·œ ν™•μ§„μžλŠ” 총 20만3220λͺ…μœΌλ‘œ 집계됐닀.
κ΅­λ‚΄ μ‹ κ·œ ν™•μ§„μž μˆ˜κ°€ 20만λͺ…λŒ€λ₯Ό λ„˜μ–΄μ„  것은 이번이 μ²˜μŒμ΄λ‹€.
λ™μ‹œκ°„λŒ€ μ΅œλ‹€ 기둝은 μ§€λ‚œ 23일 μ˜€ν›„ 9μ‹œ κΈ°μ€€ 16만1389λͺ…μ΄μ—ˆλŠ”λ°, 이λ₯Ό 무렀 4만1831λͺ…μ΄λ‚˜ μ›ƒλŒμ•˜λ‹€. μ „λ‚  같은 μ‹œκ°„ κΈ°λ‘ν•œ 13만3481λͺ…보닀도 6만9739λͺ… λ§Žλ‹€.
ν™•μ§„μž 폭증은 3μ‹œκ°„ 전인 μ˜€ν›„ 6μ‹œ μ§‘κ³„μ—μ„œλ„ μ˜ˆκ²¬λλ‹€.
μ˜€ν›„ 6μ‹œκΉŒμ§€ μ΅œμ†Œ 17만8603λͺ…이 μ‹ κ·œ 확진돼 λ™μ‹œκ°„λŒ€ μ΅œλ‹€ 기둝(24일 13만8419λͺ…)을 κ°ˆμ•„μΉ˜μš΄ 데 이어 이미 직전 0μ‹œ κΈ°μ€€ μ—­λŒ€ μ΅œλ‹€ 기둝도 λ„˜μ–΄μ„°λ‹€. μ—­λŒ€ μ΅œλ‹€ 기둝은 μ§€λ‚œ 23일 0μ‹œ κΈ°μ€€ 17만1451λͺ…μ΄μ—ˆλ‹€.
17개 μ§€μžμ²΄λ³„λ‘œ 보면 μ„œμšΈ 4만6938λͺ…, κ²½κΈ° 6만7322λͺ…, 인천 1만985λͺ… λ“± μˆ˜λ„κΆŒμ΄ 12만5245λͺ…μœΌλ‘œ μ „μ²΄μ˜ 61.6%λ₯Ό μ°¨μ§€ν–ˆλ‹€. μ„œμšΈκ³Ό κ²½κΈ°λŠ” λͺ¨λ‘ λ™μ‹œκ°„λŒ€ κΈ°μ€€ μ΅œλ‹€λ‘œ, 처음으둜 각각 4만λͺ…κ³Ό 6만λͺ…을 λ„˜μ–΄μ„°λ‹€.
λΉ„μˆ˜λ„κΆŒμ—μ„œλŠ” 7만7975λͺ…(38.3%)이 λ°œμƒν–ˆλ‹€. 제주λ₯Ό μ œμ™Έν•œ λ‚˜λ¨Έμ§€ μ§€μ—­μ—μ„œ λͺ¨λ‘ λ™μ‹œκ°„λŒ€ μ΅œλ‹€λ₯Ό μƒˆλ‘œ 썼닀.
λΆ€μ‚° 1만890λͺ…, 경남 9909λͺ…, λŒ€κ΅¬ 6900λͺ…, 경뢁 6977λͺ…, 좩남 5900λͺ…, λŒ€μ „ 5292λͺ…, 전뢁 5150λͺ…, μšΈμ‚° 5141λͺ…, κ΄‘μ£Ό 5130λͺ…, 전남 4996λͺ…, 강원 4932λͺ…, 좩뢁 3845λͺ…, 제주 1513λͺ…, μ„Έμ’… 1400λͺ…이닀.
집계λ₯Ό λ§ˆκ°ν•˜λŠ” μžμ •κΉŒμ§€ μ‹œκ°„μ΄ λ‚¨μ•„μžˆλŠ” 만큼 2일 0μ‹œ κΈ°μ€€μœΌλ‘œ λ°œν‘œλ  μ‹ κ·œ ν™•μ§„μž μˆ˜λŠ” 이보닀 더 λŠ˜μ–΄λ‚  수 μžˆλ‹€. 이에 따라 μ΅œμ’… μ§‘κ³„λ˜λŠ” ν™•μ§„μž μˆ˜λŠ” 21만λͺ… μ•ˆνŒŽμ„ 기둝할 수 μžˆμ„ 전망이닀.
ν•œνŽΈ μ „λ‚  ν•˜λ£¨ μ„ λ³„μ§„λ£Œμ†Œμ—μ„œ 이뀄진 κ²€μ‚¬λŠ” 70만8763건으둜 검사 μ–‘μ„±λ₯ μ€ 40.5%λ‹€. μ–‘μ„±λ₯ μ΄ 40%λ₯Ό λ„˜μ€ 것은 이번이 μ²˜μŒμ΄λ‹€. ν™•μ‚°μ„Έκ°€ 계속 κ±°μ„Έμ§ˆ 수 μžˆλ‹€λŠ” μ–˜κΈ°λ‹€.
이날 0μ‹œ κΈ°μ€€ μ‹ κ·œ ν™•μ§„μžλŠ” 13만8993λͺ…μ΄μ—ˆλ‹€. 이틀 연속 13만λͺ…λŒ€λ₯Ό 이어갔닀.
"""

text = text.replace('\n', ' ')

raw_input_ids = tokenizer.encode(text)
input_ids = [tokenizer.bos_token_id] + raw_input_ids + [tokenizer.eos_token_id]

summary_ids = model.generate(torch.tensor([input_ids]),  num_beams=4,  max_length=512,  eos_token_id=1)
tokenizer.decode(summary_ids.squeeze().tolist(), skip_special_tokens=True)

'1일 0 9μ‹œκΉŒμ§€ μ΅œμ†Œ 20만3220λͺ…이 μ½”λ‘œλ‚˜19에 μ‹ κ·œ ν™•μ§„λ˜μ–΄ μ—­λŒ€ μ΅œλ‹€ 기둝을 κ°ˆμ•„μΉ˜μ› λ‹€.'

Requirements

pytorch>=1.10.0
transformers==4.16.2
pytorch-lightning==1.5.10
streamlit==1.2.0

Data

  • Dacon ν•œκ΅­μ–΄ λ¬Έμ„œ μƒμ„±μš”μ•½ AI κ²½μ§„λŒ€νšŒ 의 ν•™μŠ΅ 데이터λ₯Ό ν™œμš©ν•¨
  • ν•™μŠ΅ λ°μ΄ν„°μ—μ„œ μž„μ˜λ‘œ Train / Test 데이터λ₯Ό 생성함
  • 데이터 탐색에 μš©μ΄ν•˜κ²Œ tsv ν˜•νƒœλ‘œ 데이터λ₯Ό λ³€ν™˜ν•¨
  • Data ꡬ쑰
    • Train Data : 34,242
    • Test Data : 8,501
  • default둜 data/train.tsv, data/test.tsv ν˜•νƒœλ‘œ μ €μž₯함
news summary
λ‰΄μŠ€μ›λ¬Έ μš”μ•½λ¬Έ

How to Train

  • KoBART summarization fine-tuning
pip install -r requirements.txt

[use gpu]
python train.py  --gradient_clip_val 1.0  \
                 --max_epochs 50 \
                 --default_root_dir logs \
                 --gpus 1 \
                 --batch_size 4 \
                 --num_workers 4

[use gpu]
python train.py  --gradient_clip_val 1.0  \
                 --max_epochs 50 \
                 --default_root_dir logs \
                 --strategy ddp \
                 --gpus 2 \
                 --batch_size 4 \
                 --num_workers 4

[use cpu]
python train.py  --gradient_clip_val 1.0  \
                 --max_epochs 50 \
                 --default_root_dir logs \
                 --strategy ddp \
                 --batch_size 4 \
                 --num_workers 4

Generation Sample

Text
1 Label νƒœμ™•μ˜ 'μ„±λ‹Ή νƒœμ™•μ•„λ„ˆμŠ€ λ©”νŠΈλ‘œ'λͺ¨λΈν•˜μš°μŠ€λŠ” μ΄ˆμ—­μ„ΈκΆŒ μž…μ§€μ™€ λ³€ν™”ν•˜λŠ” λΌμ΄ν”„μŠ€νƒ€μΌμ— 맞좘 ν˜μ‹ ν‰λ©΄μœΌλ‘œ μ˜€ν”ˆ 당일뢀터 κ΄€λžŒκ°μ˜ 쀄이 μ΄μ–΄μ§€λ©΄μ„œ κ΄€λžŒκ°μ˜ ν˜Έν‰μ„ λ°›μ•˜λ‹€.
1 koBART μ•„νŒŒνŠΈ λΆ„μ–‘μ‹œμž₯이 μ‹€μˆ˜μš”μž μ€‘μ‹¬μœΌλ‘œ λ°”λ€Œλ©΄μ„œ μ΄ˆμ—­μ„ΈκΆŒ μž…μ§€μ™€ λ³€ν™”ν•˜λŠ” λΌμ΄ν”„μŠ€νƒ€μΌμ— 맞좘 ν˜μ‹ ν‰λ©΄μ΄ μ•„νŒŒνŠΈ 선택에 λ―ΈμΉ˜λŠ” 영ν–₯λ ₯이 컀지고 μžˆλŠ” κ°€μš΄λ°, νƒœμ™•μ΄ μ§€λ‚œ 22일 κ³΅κ°œν•œ β€˜μ„±λ‹Ή νƒœμ™•μ•„λ„ˆμŠ€ λ©”νŠΈλ‘œβ€™ λͺ¨λΈν•˜μš°μŠ€λ₯Ό 찾은 방문객듀은 합리적인 뢄양가와 μ€‘λ„κΈˆλ¬΄μ΄μž λ“±μ˜ 뢄양쑰건도 μ‹€μˆ˜μš”μžμ—κ²Œ μœ λ¦¬ν•΄ 높은 μ²­μ•½κ²½μŸλ₯ μ„ κΈ°λŒ€ν–ˆλ‹€.
Text
2 Label 광주지방ꡭ세청은 'μƒμƒν•˜κ³  ν¬μš©ν•˜λŠ” μ„Έμ •κ΅¬ν˜„μ„ μœ„ν•œ' ν˜μ‹ μ„±μž₯ κΈ°μ—… 세정지원 μ„€λͺ…νšŒλ₯Ό μ—΄μ–΄ μ—¬λŸ¬ 세정지원 μ œλ„λ₯Ό μ•ˆλ‚΄ν•˜κ³  κΈ°μ—… ν˜„μž₯의 μ• λ‘œ, κ±΄μ˜μ‚¬ν•­μ„ κ²½μ²­ν•˜λ©° κΈ°μ—… λ§žμΆ€ν˜• μ„Έμ •μ„œλΉ„μŠ€λ₯Ό μ œκ³΅ν•  것을 μ•½μ†ν–ˆλ‹€.
2 koBART 17일 광주지방ꡭ세청은 정뢀광주지방합동청사 3μΈ΅ μ„Έλ―Έλ‚˜μ‹€μ—μ„œ ν˜μ‹ μ„±μž₯ κ²½μ œμ •μ±…μ„ μ„Έμ •μ°¨μ›μ—μ„œ λ’·λ°›μΉ¨ν•˜κΈ° μœ„ν•΄ λ‹€μ–‘ν•œ 세정지원 μ œλ„λ₯Ό μ•ˆλ‚΄ν•˜λŠ” λ™μ‹œμ— κΈ°μ—… ν˜„μž₯의 μ• λ‘œΒ·κ±΄μ˜μ‚¬ν•­μ„ κ²½μ²­ν•˜κΈ° μœ„ν•΄ β€˜μƒμƒν•˜κ³  ν¬μš©ν•˜λŠ” μ„Έμ •κ΅¬ν˜„μ„ μœ„ν•œβ€™ ν˜μ‹ μ„±μž₯ κΈ°μ—… 세정지원 μ„€λͺ…νšŒλ₯Ό μ—΄μ–΄ μ£Όλͺ©μ„ λŒμ—ˆλ‹€.'
Text
3 Label μ‹ μš©λ³΄μ¦κΈ°κΈˆ λ“± 3개 기관은 31일 μ„œμšΈ 쀑ꡬ 기업은행 λ³Έμ μ—μ„œ 졜근 κ²½μ˜μ— 어렀움을 κ²ͺλŠ” μ†Œμƒκ³΅μΈ λ“±μ˜ κΈˆμœ΅λΉ„μš© 뢀담을 쀄이고 μ„œλ―Όκ²½μ œμ— ν™œλ ₯을 μ£ΌκΈ° μœ„ν•΄ 'μ†Œμƒκ³΅μΈ. μžμ˜μ—…μž νŠΉλ³„ κΈˆμœ΅μ§€μ› μ—…λ¬΄ν˜‘μ•½'을 μ²΄κ²°ν–ˆλ‹€κ³  μ „ν–ˆμœΌλ©° μ§€μ›λŒ€μƒμ€ ν•„μš”ν•œ 쑰건을 κ°–μΆ˜ μˆ˜μΆœμ€‘μ†ŒκΈ°μ—…, μœ λ§μ°½μ—…κΈ°μ—… 등이닀.
3 koBART 졜근 κ²½μ˜μ• λ‘œλ₯Ό κ²ͺκ³  μžˆλŠ” μ†Œμƒκ³΅μΈκ³Ό μžμ˜μ—…μžμ˜ κΈˆμœ΅λΉ„μš© 뢀담을 μ™„ν™”ν•˜κ³  μ„œλ―Όκ²½μ œμ˜ ν™œλ ₯을 μ œκ³ ν•˜κΈ° μœ„ν•΄ μ‹ μš©λ³΄μ¦κΈ°κΈˆΒ·κΈ°μˆ λ³΄μ¦κΈ°κΈˆΒ·μ‹ μš©λ³΄μ¦μž¬λ‹¨ μ€‘μ•™νšŒΒ·κΈ°μ—…μ€ν–‰μ€ 31일 μ„œμšΈ 쀑ꡬ 기업은행 λ³Έμ μ—μ„œ β€˜μ†Œμƒκ³΅μΈΒ·μžμ˜μ—…μž νŠΉλ³„ κΈˆμœ΅μ§€μ› μ—…λ¬΄ν˜‘μ•½β€™μ„ μ²΄κ²°ν–ˆλ‹€.

Model Performance

  • Test Data κΈ°μ€€μœΌλ‘œ rouge scoreλ₯Ό μ‚°μΆœν•¨
  • Score μ‚°μΆœ 방법은 Dacon ν•œκ΅­μ–΄ λ¬Έμ„œ μƒμ„±μš”μ•½ AI κ²½μ§„λŒ€νšŒ metric을 ν™œμš©ν•¨
rouge-1 rouge-2 rouge-l
Precision 0.515 0.351 0.415
Recall 0.538 0.359 0.440
F1 0.505 0.340 0.415

Demo

  • ν•™μŠ΅ν•œ model binary μΆ”μΆœ μž‘μ—…μ΄ ν•„μš”ν•¨
    • pytorch-lightning binary --> huggingface binary둜 μΆ”μΆœ μž‘μ—… ν•„μš”
    • hparams의 κ²½μš°μ—λŠ” ./logs/tb_logs/default/version_0/hparams.yaml νŒŒμΌμ„ ν™œμš©
    • model_binary 의 κ²½μš°μ—λŠ” ./logs/kobart_summary-model_chp μ•ˆμ— μžˆλŠ” .ckpt νŒŒμΌμ„ ν™œμš©
    • λ³€ν™˜ μ½”λ“œλ₯Ό μ‹€ν–‰ν•˜λ©΄ ./kobart_summary 에 model binary κ°€ μΆ”μΆœ 됨
 python get_model_binary.py --hparams hparam_path --model_binary model_binary_path
  • streamlit을 ν™œμš©ν•˜μ—¬ Demo μ‹€ν–‰
streamlit run infer.py

drawing

Reference