|
1 |
| -# YOCO |
| 1 | +# You Only Cache Once: Decoder-Decoder Architectures for Large Language Models |
2 | 2 |
|
3 |
| -- May 2024: Code release |
4 |
| -- May 2024: release preprint [YOCO](https://arxiv.org/abs/) |
| 3 | +## Approach |
| 4 | +<div align="center"> |
| 5 | + <img src="./imgs/arch.png" width=60%/> |
| 6 | +</div> |
5 | 7 |
|
6 |
| -## Getting Started |
| 8 | +<div align="center"> |
| 9 | + <img src="./imgs/inference.png" width=50%/> |
| 10 | +</div> |
| 11 | + |
| 12 | +## Performance |
| 13 | +### Harness Eval |
| 14 | +Training with 1T Tokens: |
| 15 | +| **Model** | **Arc-c** | **Arc-e** | **BoolQ** | **Hellaswag**$^*$ | **OBQA** | **PIQA** | **Winogrande** | **SciQ** | **Avg** | |
| 16 | +|----------------------------|-----------|-----------|-----------|-------------------|----------|----------|----------------|----------|---------| |
| 17 | +| OpenLLaMA-3B-v2 | 0.339 | 0.676 | 0.657 | **0.700** | 0.260 | 0.767 | 0.629 | 0.924 | 0.619 | |
| 18 | +| StableLM-base-alpha-3B-v2 | 0.324 | 0.673 | 0.646 | 0.686 | 0.264 | 0.760 | 0.621 | 0.921 | 0.612 | |
| 19 | +| StableLM-3B-4E1T | --- | 0.666 | --- | --- | --- | **0.768**| 0.632 | 0.914 | --- | |
| 20 | +| YOCO-3B | **0.379** | **0.731** | 0.645 | 0.689 | **0.298**| 0.763 | 0.639 | 0.924 | **0.634**| |
| 21 | + |
| 22 | +Training with 1.6T Tokens: |
| 23 | +| **Model** | **Arc-c** | **Arc-e** | **BoolQ** | **Hellaswag**$^*$ | **OBQA** | **PIQA** | **Winogrande** | **SciQ** | **Avg** | |
| 24 | +|----------------------------|-----------|-----------|-----------|-------------------|----------|----------|----------------|----------|---------| |
| 25 | +| StableLM-3B-4E1T | --- | 0.688 | --- | --- | --- | 0.762 | 0.627 | 0.913 | --- | |
| 26 | +| YOCO-3B | 0.396 | 0.733 | **0.644** | 0.698 | 0.300 | 0.764 | 0.631 | 0.921 | 0.636 | |
| 27 | +| YOCO-3B-1M | **0.413** | **0.747** | 0.638 | **0.705** | 0.300 | **0.773**| **0.651** | **0.932**| **0.645**| |
| 28 | +### Needle In A Haystack |
| 29 | +<div align="center"> |
| 30 | + <img src="./imgs/1m_retrieval.png"/> |
| 31 | +</div> |
| 32 | + |
| 33 | +### Multi-Needle Eval |
| 34 | +| **Model** | **Size** | **N=1** | **N=2** | **N=4** | **N=8** | |
| 35 | +|-------------------------|----------|---------|---------|---------|---------| |
| 36 | +| GPT-4-128K | -- | 1.00 | 1.00 | 0.98 | 1.00 | |
| 37 | +| MiniCPM-128K | 2.4B | 1.00 | 1.00 | 0.54 | 0.56 | |
| 38 | +| ChatGLM3-128K | 6B | 0.94 | 0.72 | 0.52 | 0.44 | |
| 39 | +| YaRN-Mistral-128K | 7B | 0.02 | 0.12 | 0.08 | 0.20 | |
| 40 | +| LWM-1M-text | 7B | 1.00 | 0.90 | 0.76 | 0.62 | |
| 41 | +| YOCO-3B-1M | 3B | 0.98 | 0.98 | 0.84 | 0.56 | |
| 42 | + |
| 43 | +## Setup |
| 44 | + |
| 45 | +To install the required packages, use the following command: |
| 46 | + |
| 47 | +```bash |
| 48 | +pip install -r requirements.txt |
| 49 | +``` |
| 50 | + |
| 51 | +Besides normal packages, [Apex](https://github.com/NVIDIA/apex) and [Flash-Attention](https://github.com/Dao-AILab/flash-attention) should be installed seperately following their offcial guidences. |
| 52 | + |
| 53 | +## Harness Eval |
| 54 | + |
| 55 | +To evaluate models in Harness-Eval, the script is as follows in ```scripts/eval_task.sh```: |
| 56 | +```bash |
| 57 | +cd fairseq/ |
| 58 | +TASK='harness_boolq' |
| 59 | + |
| 60 | +torchrun --master-port=29505 --nproc_per_node=1 validate.py \ |
| 61 | + --data-dir ../harness_data/ \ |
| 62 | + --criterion harness_eval \ |
| 63 | + --task harness_eval \ |
| 64 | + --batch-size 4 \ |
| 65 | + --eval-data ${TASK} \ |
| 66 | + --log-format simple --log-interval 10 \ |
| 67 | + --bf16 \ |
| 68 | + --tokenizer-pad-to-multiple 8 \ |
| 69 | + --arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /path_to_ckpt/YOCO-3B-1M/checkpoint.pth --yoco-model /path_to_ckpt/YOCO-3B-1M --tokens-per-sample 4096 |
| 70 | +``` |
| 71 | + |
| 72 | +## Needle In A Haystack Evaluation |
| 73 | +Our model uses city-number pairs for long sequence evaluation. To get the results at a certain maximal length, the script is as follows in ```scripts/eval_needle.sh```: |
| 74 | +```bash |
| 75 | +cd fairseq/ |
| 76 | +torchrun --master-port=29504 --nproc_per_node=1 validate.py \ |
| 77 | + --task pseudo \ |
| 78 | + --criterion needle_haystack \ |
| 79 | + --batch-size 1 \ |
| 80 | + --max-epoch 1 \ |
| 81 | + --no-save \ |
| 82 | + --tiktoken-model cl100k_base \ |
| 83 | + --bf16 \ |
| 84 | + --arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /path_to_ckpt/YOCO-3B-1M/checkpoint.pth --yoco-model /path_to_ckpt/YOCO-3B-1M --tokens-per-sample 1048576 --interval 1048576 |
| 85 | +``` |
| 86 | + |
| 87 | +To run Multi-Needle experiments, replace ```--criterion needle_haystack``` with ```--criterion multi_needle --needle-num {num}```. |
| 88 | + |
| 89 | +## Pretraining From Scratch |
| 90 | +To support distributed training, our implementation is based on infinibatch to read data iteratively. The overall data directory should be organized as follows: |
| 91 | +``` |
| 92 | +Data/ |
| 93 | +├── json/ |
| 94 | +│ ├── train.json |
| 95 | +│ └── CC.json |
| 96 | +│ └── StarCoder.json |
| 97 | +│ └── ... |
| 98 | +├── shard/ |
| 99 | +│ ├── CC/ |
| 100 | +│ │ ├── 00000.jsonl |
| 101 | +│ │ ├── 00001.jsonl |
| 102 | +│ │ └── ... |
| 103 | +│ └── StarCoder/ |
| 104 | +│ ├── 00000.jsonl |
| 105 | +│ ├── 00001.jsonl |
| 106 | +│ └── ... |
| 107 | +``` |
| 108 | + |
| 109 | +We recommend that each sharded data files contains no more than 10K lines with one json dict per line, and jsonl file, such as ```Data/shard/CC/00000.jsonl```, should be in the format like this: |
| 110 | +```json |
| 111 | +{"text": "File 1 is here..."} |
| 112 | +{"text": "File 2 is here..."} |
| 113 | +... |
| 114 | +``` |
| 115 | + |
| 116 | +Then, for each source, a JSON file preserves all the paths of the jsonl files. Take ```Data/json/CC.json``` for example: |
| 117 | +```json |
| 118 | +[ |
| 119 | + "/path_to_data/Data/shard/CC/00000.jsonl", |
| 120 | + "/path_to_data/Data/shard/CC/00001.jsonl", |
| 121 | + ... |
| 122 | +] |
| 123 | +``` |
| 124 | + |
| 125 | +Finally, ```train.json``` records all sources' information and sampling ratio: |
| 126 | +```json |
| 127 | +[ |
| 128 | + { |
| 129 | + "name": "CC", |
| 130 | + "weight": 0.5 |
| 131 | + }, |
| 132 | + { |
| 133 | + "name": "StarCoder", |
| 134 | + "weight": 0.2 |
| 135 | + }, |
| 136 | + ... |
| 137 | +] |
| 138 | +``` |
| 139 | + |
| 140 | + ```scripts/train.sh```: |
| 141 | +```bash |
| 142 | +cd fairseq/ |
| 143 | +torchrun --nproc-per-node=1 train.py /path_to_data \ |
| 144 | + --save-interval-updates 5000 \ |
| 145 | + --no-epoch-checkpoints \ |
| 146 | + --arch yoco_base \ |
| 147 | + --criterion cross_entropy \ |
| 148 | + --task gpt \ |
| 149 | + --tokens-per-sample 2048 \ |
| 150 | + --tokenizer-pad-to-multiple 8 \ |
| 151 | + --pad-to-max-len \ |
| 152 | + --optimizer adam --adam-betas "(0.9, 0.95)" \ |
| 153 | + --adam-eps 1e-06 \ |
| 154 | + --clip-norm 2.0 \ |
| 155 | + --lr 0.00015 \ |
| 156 | + --lr-scheduler polynomial_decay \ |
| 157 | + --warmup-updates 50 \ |
| 158 | + --weight-decay 0.05 \ |
| 159 | + --batch-size 1 \ |
| 160 | + --model-parallel-size 1 \ |
| 161 | + --update-freq 1 \ |
| 162 | + --batch-read-ahead 1000 \ |
| 163 | + --total-num-update 300000 \ |
| 164 | + --log-format simple --log-interval 10 --disable-validation \ |
| 165 | + --tiktoken-model cl100k_base \ |
| 166 | + --save-interval-updates 5000 \ |
| 167 | + --bf16 # bf16 is encouraged in pre-training |
| 168 | +``` |
0 commit comments