• Stars
    star
    210
  • Rank 187,585 (Top 4 %)
  • Language
    Python
  • Created over 1 year ago
  • Updated 2 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Codes and Data for Scaling Relationship on Learning Mathematical Reasoning with Large Language Models

Scaling Relationship on Learning Mathematical Reasoning with Large Language Models

The code and data used for reproducing results of Scaling Relationship on Learning Mathematical Reasoning with Large Language Models and Query and Response Augmentation Cannot Help Out-of-domain Math Reasoning Generalization.

  • [2023.10] We have a new paper that investigates the scaling of in-domain and out-of-domain generalization on augmented math problems.
  • [2023.9] Paper updated with more details on 65B and 70B models.
Setting 7B 7B-2 13B 13B-2 33B 65B 70B-2
ICL-8shot 11.0/18.1 14.6/- 17.8/29.3 28.7/- 35.6/53.1 50.9/69.7 56.8/-
SFT 35.9/48.7 41.6/55.4 43.0/55.2 50.0/61.7 54.6/- 59.3/- 63.2/-
RFT k=100 41.7/52.7 47.5/58.7 49.1/59.9 54.8/65.4 54.5/- - -
RFT-U13B 49.3/61.8 50.3/65.6 52.1/66.2 55.4/69.1 56.5/- 59.0/- 62.3/-
RFT-U33B 49.1/61.6 51.2/64.1 51.4/66.3 55.3/69.1 57.9/- 59.7/- 64.8/-

Metrics are maj1@1 and maj1@100.

Findings from the paper

fig1 fig2

SFT Training

If you cannot reproduce our results, please try using Transformers <= 4.29 and test with batch size=1.

Use train_xb.sh for fine-tuning LLaMA and LLaMA-2.

bash train_xb.sh ./data/train_use.jsonl SAVE_PATH 3

RFT Inference

LLaMA 7B / 13B

bash group_sample_7b_13b.sh SAVE_PATH

LLaMA 30B

bash group_sample_30b.sh SAVE_PATH

Filter reasoning paths

python collect_rejection_sampling.py

RFT Training

For RFT using LLaMA-7B/7B-2/13B/13B-2/33B generated samples with k=100.

bash train_xb.sh ./data/rft/llama_yb.jsonl SAVE_PATH 3

For RFT using U13B.

bash train_xb.sh ./data/rft/u13b.jsonl SAVE_PATH 3

For RFT using U33B.

bash train_xb.sh ./data/rft/u33b.jsonl SAVE_PATH 3

Evaluation

We use greedy decode for the test set.

For evaluate 7B/13B models:

bash test_7b_13b.sh SAVE_PATH

For evaluate 30B models:

bash single_test_30b.sh SAVE_PATH 0 ./data/test_jsonl.sh

For evaluate 65B / 70B models:

bash single_test_65b.sh SAVE_PATH 0,1 ./data/test_jsonl.sh

Use eval.py to obtain the scores, and it also supports maj1@K.

GPU Usage

7B / 13B 33B 65B / 70B
SFT / RFT 8 16 32
Minimal Inference 1 1 2
Group Inference 8 8 8

Checkpoints

7B 7B2 13B 13B2 33B
RFT k = 100 OFA-Sys/gsm8k-rft-llama7b-sample100
RFT U13B OFA-Sys/gsm8k-rft-llama7b-u13b OFA-Sys/gsm8k-rft-llama7b2-u13b OFA-Sys/gsm8k-rft-llama13b-u13b OFA-Sys/gsm8k-rft-llama13b2-u13b
RFT U33B OFA-Sys/gsm8k-rft-llama33b-u33b

Citation

@misc{yuan2023scaling,
      title={Scaling Relationship on Learning Mathematical Reasoning with Large Language Models}, 
      author={Zheng Yuan and Hongyi Yuan and Chengpeng Li and Guanting Dong and Keming Lu and Chuanqi Tan and Chang Zhou and Jingren Zhou},
      year={2023},
      eprint={2308.01825},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
@misc{li2023query,
      title={Query and Response Augmentation Cannot Help Out-of-domain Math Reasoning Generalization}, 
      author={Chengpeng Li and Zheng Yuan and Guanting Dong and Keming Lu and Jiancan Wu and Chuanqi Tan and Xiang Wang and Chang Zhou},
      year={2023},
      eprint={2310.05506},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}