• Stars
    star
    170
  • Rank 223,357 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated about 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

1. 描述

此为Kaggle平台上举办的Plant Pathology 2020 - CVPR-FGVC7 Competition第一名解决方案,该解决方案来自yelanlan并由nick重新实现。其与比赛中提交方案有两点不同:1.训练策略产生最终模型文件的选取方式 2.采用ensemble的策略。故此方案生成的分数和比赛中的分数会略有不同,但也足以获取第一名。

2. 如何进行复现

2.1 环境数据准备

安装依赖包

pip install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple

下载比赛原始数据集,并在data文件夹下进行解压。

2.2 运行代码

Step-1 : 用五折交叉验证训练模型。

python train.py --train_batch_size 32 --gpus 0 1

Step-2 : 产生soft-label用于后期自蒸馏的训练,你可以从这步获得名为"submission.csv"的文件,提交这个文件能获得第四名的成绩,具体分数为 public:0.97988 | private:0.98108。

python generate_soft_labels.py

Step-3 : 用soft-label和hard-label进行训练外加五折交叉验证训练模型。

python train.py --train_batch_size 32 --gpus 0 1 --soft_labels_filename soft_labels.csv --log_dir logs_submit_distill

Step-4 : 产生自蒸馏的模型预测的结果,你可以从这步获得名为"submission_distill.csv"的文件,提交这个文件能获得第三名的成绩,具体分数为 public:0.98422 | private:0.98135。

python generate_distlled_submission.py

Step-5 : 融合结果,提交融合后的文件能获得第一名的成绩,具体分数为 public:0.98354 | private:0.98253。

s1 = pd.read_csv("submission.csv")
s2 = pd.read_csv("submission_distill.csv")
s3 = s1.copy()
s3.iloc[:, 1:] = (s1.iloc[:, 1:] + s2.iloc[:, 1:]) / 2
s3.to_csv("submission_mix.csv", index=False)

3. 解题技术路径

3.1 背景介绍

赛题数据集

本次竞赛数据集包含1821张训练图像和1821张测试图像,每张图片有四种可能的标签(健康、锈病、痂并、同时拥有两种疾病),这四种类型的比例为6:1:6:6,是一个数据不平衡问题,且数据集中有一部分不准确标签。针对数据量少以及标签不准确的问题,我们采用了数据增强和知识蒸馏的问题来进行处理。

评价指标

赛题采用mean column-wise ROC AUC作为评价指标来衡量模型的性能,该指标为各类标签ROC AUC的平均值。

3.2 数据预处理

数据增强

由于竞赛数据集相对较小,直接使用原始数据进行训练会导致模型存在过拟合的风险,为了更好的增加模型鲁棒性,我们对数据集进行了一系列增强操作来扩充原始数据集。

随机光照增强

随机光照增强

随机对比度增强

随机对比度增强

上下翻转

上下翻转

左右翻转

左右翻转

随机旋转缩放

随机旋转缩放 此外还有一些高斯模糊等肉眼不容易区分的增强操作,这些操作可以极大的丰富训练数据集,让模型尽可能的学习更多的特征来增强其泛化程度。

from albumentations import (
    Compose,
    Resize,
    OneOf,
    RandomBrightness,
    RandomContrast,
    MotionBlur,
    MedianBlur,
    GaussianBlur,
    VerticalFlip,
    HorizontalFlip,
    ShiftScaleRotate,
    Normalize,
)
train_transform = Compose(
    [
        Resize(height=image_size[0], width=image_size[1]),
        OneOf([RandomBrightness(limit=0.1, p=1), RandomContrast(limit=0.1, p=1)]),
        OneOf([MotionBlur(blur_limit=3), MedianBlur(blur_limit=3), GaussianBlur(blur_limit=3),], p=0.5,),
        VerticalFlip(p=0.5),
        HorizontalFlip(p=0.5),
        ShiftScaleRotate(
            shift_limit=0.2,
            scale_limit=0.2,
            rotate_limit=20,
            interpolation=cv2.INTER_LINEAR,
            border_mode=cv2.BORDER_REFLECT_101,
            p=1,
        ),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
    ]
)

3.3 模型选型

对于模型,我们采用了seresnext50的模型框架,其中前缀se为squeeze and excitation的过程,这个过程的原理为通过控制scale的大小,把重要的特征增强,不重要的特征减弱,和attention原理相同,目的是让提取的特征指向性更强,从而能更好的对FGVC(Fine-Grained Visual Categorization 细粒度视觉分类)任务中精细的特征做识别。 se_resnext50_32x4d

3.4 训练策略

我们采用了adam + cycle learning rate的学习策略进行模型的训练,这种学习策略通常不会有太多过拟合,也不需要仔细的调参,非常推荐大家尝试。 cycle_learning_rate

3.5 误差分析

误差分析是迭代深度学习中十分重要的一个环节,当训练的模型完成后,如何对模型的性能进行改进才是提分的关键点,我们通过对热力图的方式将模型对标签关键的识别部位提取出来,这样就能很清晰的知道模型主要看到了哪些部位才将图片识别为对应的类型,当我们把识别错误的图片拿出来分析后,就可以知道数据增强的改进点和网络训练的改进点。 error_analysis_1 error_analysis_2 error_analysis_3

3.6 自蒸馏

由于疾病之间有些较难被区别,从而导致标签中存在一些不准确的情况,这给训练增加了一定的难度,我们的模型很可能被这些不准确的标签给误导,为了应对这种情况的出现,我们采用了自蒸馏的方式来解决该问题,我们训练了五折模型,然后将五折的验证集组成out-of-fold文件,最后将out-of-fold的结果和groundtruth label按3:7混合作为训练新模型的标签,简单的描述即为给每个软化了之前的标签,给与每个标签一定的概率,进而减少了模型训练的难度。 self_distill

3.7 模型预测

在最后提交成绩的阶段,我们采用了TTA(Test Time Augmentation)的策略,对预测样本都做了一定的数据增强,然后对这些增强的预测值做了平均加权,这也给我的结果起到了一定的提升。 tta

More Repositories

1

SoloPi

SoloPi 自动化测试工具
Java
5,736
star
2

alipay-easysdk

Alipay Easy SDK for multi-language(java、c#、php、ts etc.) allows you to enjoy a minimalist programming experience and quickly access the various high-frequency capabilities of the Alipay Open Platform.
Java
1,099
star
3

agentUniverse

agentUniverse is a LLM multi-agent framework that allows developers to easily build multi-agent applications.
Python
799
star
4

alipay-sdk-java-all

支付宝开放平台 Alipay SDK for Java
Java
521
star
5

alipay-sdk-nodejs-all

支付宝开放平台 Alipay SDK for Node.js
TypeScript
408
star
6

mPaaS

mPaaS Demo 合集,mPaaS 是源自于支付宝的移动开发平台。The collection of demos for mPaaS components. mPaaS is the Mobile Development Platform which oriented from Alipay.
C
323
star
7

ant-application-security-testing-benchmark

xAST评价体系,让安全工具不再“黑盒”. The xAST evaluation benchmark makes security tools no longer a "black box".
Java
301
star
8

PainlessInferenceAcceleration

Python
283
star
9

alipay-sdk-python-all

支付宝开放平台 Alipay SDK for Python
Python
268
star
10

Owfuzz

Owfuzz: a WiFi protocol fuzzing tool
C
216
star
11

alipay-sdk-net-all

支付宝开放平台 Alipay SDK for .NET
C#
200
star
12

antcloud-node-stack

蚂蚁金融科技官方 Node 技术栈脚本
JavaScript
159
star
13

financial_evaluation_dataset

Python
154
star
14

rdf-file

Rdf-File是一个处理结构化文本文件的工具组件
Java
149
star
15

alipay-sdk-php-all

支付宝开放平台 Alipay SDK for PHP
PHP
146
star
16

SOFAStack

SOFAStack™ (Scalable Open Financial Architecture Stack) is a collection of cloud native middleware components, which are designed to build distributed systems with high performance and reliability, and have been fully validated by mission-critical financial business scenarios.
139
star
17

Ant-Multi-Modal-Framework

Research Code for Multimodal-Cognition Team in Ant Group
Python
117
star
18

vsag

vsag is a vector indexing library used for similarity search.
C++
115
star
19

Pyraformer

Python
100
star
20

container-observability-service

Simplify Kubernetes applications operation with one-stop observability services, including resource delivery SLO,root cause diagnoses and container lifecycle tracing and more.
Go
88
star
21

ios-malicious-bithunter

iOS Malicious Bit Hunter is a malicious plug-in detection engine for iOS applications. It can analyze the head of the macho file of the injected dylib dynamic library based on runtime. If you are interested in other programs of the author, please visit https://github.com/SecurityLife
C
83
star
22

goldfish

A development framework for Alipay Mini Program.
TypeScript
80
star
23

SQLFlow

SQLFlow is a bridge that connects a SQL engine, e.g. MySQL, Hive, SparkSQL or SQL Server, with TensorFlow and other machine learning toolkits. SQLFlow extends the SQL language to enable model training, prediction and inference.
73
star
24

KnowledgeGraphEmbeddingsViaPairedRelationVectors_PairRE

Python
61
star
25

Antchain-MPC

Antchain-MPC is a library of MPC (Multi-Parties Computation)
Terra
57
star
26

VCSL

Video Copy Segment Localization (VCSL) dataset and benchmark [CVPR2022]
Python
49
star
27

StructuredLM_RTDT

A library for building hierarchical text representation and corresponding downstream applications.
Python
48
star
28

RJU_Ant_QA

The RJUA-QA (RenJi hospital department of Urology and Antgroup collaborative Question and Answer dataset) is an innovative medical urology specialty QA inference dataset.
47
star
29

Z-RareCharacterSolution

TypeScript
45
star
30

quic-lb

nginx-quic-lb is an implementation of ietf-quic-lb, based on nginx-release-1.18.0, you can see the detailed code in this pull request
C
41
star
31

jpmml-sparkml-lightgbm

JPMML-SparkML plugin for converting LightGBM-Spark models to PMML
Java
41
star
32

PASE

C
41
star
33

global-open-sdk-java

Ant global gateway SDK
Java
35
star
34

container-auto-tune

Container Auto Tune is an intelligent parameter tuning product that helps developers, operators automatically adjust the application, analyzes JVM reasonable configuration parameters through intelligent algorithms.Please visit the official site for the quick start guide and documentation.
Java
32
star
35

promo-mini-component

支付宝营销玩法小程序组件库
JavaScript
31
star
36

private_llm

Python
28
star
37

tls13-sm-spec

IETF Internet-Draft (I-D) of Chinese cipher suites in TLSv1.3 and related documentation.
Makefile
27
star
38

microservice_system_twin_graph_based_anomaly_detection

Python
26
star
39

mobile-agent

Python
26
star
40

alipay-intellij-plugin

Intellij IDEA Plugin
20
star
41

character-js

TypeScript
19
star
42

global-open-sdk-php

Ant global gateway SDK
PHP
17
star
43

ams-java-sdk

AMS Java binding
Java
13
star
44

antchain-openapi-prod-sdk

PHP
10
star
45

Pattern-Based-Compression

High-Ratio Compression for Machine-Generated Data
C
10
star
46

global-open-sdk-python

Ant global gateway SDK
Python
10
star
47

PC2-NoiseofWeb

Noise of Web (NoW) is a challenging noisy correspondence learning (NCL) benchmark containing 100K image-text pairs for robust image-text matching/retrieval models.
Python
9
star
48

YiJian-Community

YiJian-Comunity: a full-process automated large model safety evaluation tool designed for academic research
Python
9
star
49

AOP-Based-Runtime-Security-Analysis-Toolkit

TypeScript
8
star
50

ant-application-security-testing-benchmark-nodejs

JavaScript
8
star
51

agentUniverse-Guides

8
star
52

RGSL

Python
8
star
53

POA

Python
7
star
54

payment-code-widget

A lightweight library provides UI widgets to display payment code in mobile applications. The dimension of the payment code is optimal and scanner-friendly.
Java
6
star
55

TDEER

Code For TDEER: An Efficient Translating Decoding Schema for Joint Extraction of Entities and Relations (EMNLP 2021)
Python
5
star
56

Finite_State_Autoregressive_Entropy_Coding

Python
5
star
57

ComBERT

4
star
58

Parameter_Inference_Efficient_PIE

Python
4
star
59

NMCDR

Python
4
star
60

A2-efficient-automated-attacker-for-boosting-adversarial-training

Python
4
star
61

global-open-sdk-dotnet

C#
3
star
62

tldk

This is a fork of FDio/tldk.
C
3
star
63

DUPLEX

Python
3
star
64

antchain-openapi-util-sdk

C#
3
star
65

Automatic_AI_Model_Greenness_Track_Toolkit

JavaScript
3
star
66

style-tokenizer

Python
3
star
67

Timestep-aware-SentenceEmbedding-and-AcmeCoverage

Python
2
star
68

ATTEMPT_Pre-training_with_Aspect-Content_Text_Mutual_Prediction

Python
2
star
69

hypro_tpp

Python
1
star
70

BehaviorAugmentedRelevanceModel

Implementation and data of the paper "Beyond Semantics: Learning a Behavior Augmented Relevance Model with Self-supervised Learning" in CIKM'23.
Python
1
star
71

A-Knowledge-augmented-Method-DiKGRS

Python
1
star
72

Analogic-Reasoning-Augmented-Large-Language-Model

Python
1
star
73

MMDL-based-Data-Augmentation-with-Domain-Knowledge-for-Time-Series-Classification

This repository contains the official implementation for the paper: MMDL-based Data Augmentation with Domain Knowledge for Time Series Classification.
Python
1
star