• Stars
    star
    126
  • Rank 284,543 (Top 6 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

PyTorch implementation of some text classification models (HAN, fastText, BiLSTM-Attention, TextCNN, Transformer) | 文本分类

Text Classification

PyTorch re-implementation of some text classificaiton models.

 

Supported Models

Train the following models by editing model_name item in config files (here are some example config files). Click the link of each for details.

 

Requirements

First, make sure your environment is installed with:

  • Python >= 3.5

Then install requirements:

pip install -r requirements.txt

 

Dataset

Currently, the following datasets proposed in this paper are supported:

  • AG News
  • DBpedia
  • Yelp Review Polarity
  • Yelp Review Full
  • Yahoo Answers
  • Amazon Review Full
  • Amazon Review Polarity

All of them can be download here (Google Drive). Click here for details of these datasets.

You should download and unzip them first, then set their path (dataset_path) in your config files. If you would like to use other datasets, they may have to be stored in the same format as the above mentioned datasets.

 

Pre-trained Word Embeddings

If you would like to use pre-trained word embeddings (like GloVe), just set emb_pretrain to True and specify the path to pre-trained vectors (emb_folder and emb_filename) in your config files. You could also choose to fine-tune word embeddings or not with by editing fine_tune_embeddings item.

Or if you want to randomly initialize the embedding layer's weights, set emb_pretrain to False and specify the embedding size (embed_size).

 

Preprocess

Although torchtext can be used to preprocess data easily, it loads all data in one go and occupies too much memory and slows down the training speed, expecially when the dataset is big.

Therefore, here I preprocess the data manually and store them locally first (where configs/test.yaml is the path to your config file):

python preprocess.py --config configs/example.yaml 

Then I load data dynamically using PyTorch's Dataloader when training (see datasets/dataloader.py).

The preprocessing including encoding and padding sentences and building word2ix map. This may takes a little time, but in this way, the training can occupy less memory (which means we can have a large batch size) and take less time. For example, I need 4.6 minutes (on RTX 2080 Ti) to train a fastText model on Yahoo Answers dataset for an epoch using torchtext, but only 41 seconds using Dataloader.

torchtext.py is the script for loading data via torchtext, you can try it if you have interests.

 

Train

To train a model, just run:

python train.py --config configs/example.yaml

If you have enabled the tensorboard (tensorboard: True in config files), you can visualize the losses and accuracies during training by:

tensorboard --logdir=<your_log_dir>

 

Test

Test a checkpoint and compute accuracy on test set:

python test.py --config configs/example.yaml

 

Classify

To predict the category for a specific sentence:

First edit the following items in classify.py:

checkpoint_path = 'str: path_to_your_checkpoint'

# pad limits
# only makes sense when model_name == 'han'
sentence_limit_per_doc = 15
word_limit_per_sentence = 20
# only makes sense when model_name != 'han'
word_limit = 200

Then, run:

python classify.py

 

Performance

Here I report the test accuracy (%) and training time per epoch (on RTX 2080 Ti) of each model on various datasets. Model parameters are not carefully tuned, so better performance can be achieved by some parameter tuning.

Model AG News DBpedia Yahoo Answers
Hierarchical Attention Network 92.7 (45s) 98.2 (70s) 74.5 (2.7m)
fastText 91.6 (8s) 97.9 (25s) 66.7 (41s)
Bi-LSTM + Attention 92.0 (50s) 99.0 (105s) 73.5 (3.4m)
TextCNN 92.2 (24s) 98.5 (100s) 72.8 (4m)
Transformer 92.2 (60s) 98.6 (8.2m) 72.5 (14.5m)

 

Notes

  • The load_embeddings method (in utils/embedding.py) would try to create a cache for loaded embeddings under folder dataset_output_path. This dramatically speeds up the loading time the next time.
  • Only the encoder part of Transformer is used.

 

License

MIT

 

Acknowledgement

This project is based on sgrvinod/a-PyTorch-Tutorial-to-Text-Classification.

More Repositories

1

playground-macos

My portfolio website simulating macOS's GUI, developed with React and UnoCSS.
TypeScript
2,900
star
2

Speech-Emotion-Recognition

Speech emotion recognition implemented in Keras (LSTM, CNN, SVM, MLP) | 语音情感识别
Python
648
star
3

vuepress-theme-gungnir

A blog theme for VuePress 2.
TypeScript
314
star
4

Speech-and-Text

Speech to text (PocketSphinx, Iflytex API, Baidu API) and text to speech (pyttsx3) | 语音转文字(PocketSphinx、百度 API、科大讯飞 API)和文字转语音(pyttsx3)
Python
247
star
5

oh-vue-icons

A Vue component for importing inline SVG icons from different popular icon packs easily.
JavaScript
205
star
6

pcalg-py

Implement PC algorithm in Python | PC 算法的 Python 实现
Python
88
star
7

Fishmail

奇怪的摸鱼工具增加了:装作在 Gmail 上查邮件的样子看知乎摸鱼,从而降低我上班摸鱼时的不安全感
Vue
65
star
8

Just-a-Cube

A rubik's cube solver | 魔方还原(层先法 + Two-phase)
JavaScript
54
star
9

oh-my-cv

Write your curriculum vitae in Markdown online.
TypeScript
51
star
10

blog.zxh.io

My blog 🧐, powered by VuePress 2, themed by Gungnir.
CSS
36
star
11

flint

A toy deep learning framework implemented in pure Numpy from scratch. Aka homemade PyTorch lol.
Python
30
star
12

jekyll-theme-gungnir

A blog theme for Jekyll.
SCSS
17
star
13

what-if

My messy notebook, built with VuePress 2.
TypeScript
17
star
14

renovamen.github.io

My personal website 🤔
TypeScript
14
star
15

Legend-of-Zeld

Spring 2018 User Interface Interaction - A web page for "The Legend of Zelda: Breath of the Wild" | 塞尔达传说:荒野之息
CSS
12
star
16

Operating-Systems

Spring 2018 Operating Systems - Assignments (Elevator Scheduling, Memory Management and File Management) | 操作系统课程项目(电梯调度,内存管理,文件管理)
Java
11
star
17

KG-Application-Papers

Paper list about application of Knowledge Graph | 知识图谱的应用相关论文
10
star
18

midgard

Hey adventurer! Why not help me with retrieving my curriculum vitae fragments back! | 来帮我找简历吗勇士!
TypeScript
10
star
19

Image-Captioning

PyTorch re-implementation of some papers on image captioning | 图像描述
Python
7
star
20

metallic

A clean, lightweight and modularized PyTorch meta-learning library.
Python
7
star
21

wordle-helper

Help you solve the Wordle puzzles when your vocabulary failes you.
TypeScript
5
star
22

Ratom

Spring 2019 Introduction to Cybersecurity - A simple cross platform rat (remote access trojan)
Python
4
star
23

Gomoku

Fall 2016 C Programming - A Gomoku AI based on Minimax Algorithm with Alpha-beta pruning | 基于极大极小值搜索 + Alpha-beta 剪枝的五子棋人工智障
C
3
star
24

AntiFood

深夜放毒反击系统(基于酷 Q 和 NoneBot)
Python
3
star
25

OI-ACM

My solutions for OI / ACM problems - 退役蒟蒻的人生回顾
C++
3
star
26

Just-Movies

Spring 2019 Web System and Technology - Assigment: Assignment: A movie web app using Django and Vue | Django + Vue 实现的电影 Web 应用
Vue
3
star
27

gitbook-plugin-katex

Math typesetting using KaTex into Gitbook. Update Katex to the latest version (0.11.1) and support single '$' for inline math.
JavaScript
2
star
28

HUAJI_OS

Spring 2018 Operating System - Final Project | 操作系统课程设计,修改了 ORANGES 的源码
C
2
star
29

torchop

A collection of some attention / convolution operators implemented using PyTorch.
Python
2
star
30

Galaxy-Voyager

Spring 2019 Web System and Technology - Final Project: A game
JavaScript
2
star
31

alkaid

PyTorch reinforcement learning toolbox.
Python
1
star
32

Stupid-Torch

一个沙雕安卓手电筒 app,在有光的地方亮,没有光的地方绝对不亮
Java
1
star
33

Renovamen

1
star
34

Just-a-Cat

Fall 2018 SOA and Web Services - Individual Assignment: An Android app integrating some Web APIs
Java
1
star
35

Byzantine

Implement the Om(n, m) algorithm for solving the Byzantine generals problem | 拜占庭口头消息算法的 Python 实现
Python
1
star
36

StockBot

A chatbot based on Rasa NLU aims to provide stock and weather information implemented by Python | 基于 Rasa NLU 的提供股票和天气信息的聊天机器人
Jupyter Notebook
1
star