Skip to content

Commit 7402b0e

Browse files
committed
yoco init
1 parent 50c5700 commit 7402b0e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+4513
-4
lines changed

YOCO/README.md

+166-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,168 @@
1-
# YOCO
1+
# You Only Cache Once: Decoder-Decoder Architectures for Large Language Models
22

3-
- May 2024: Code release
4-
- May 2024: release preprint [YOCO](https://arxiv.org/abs/)
3+
## Approach
4+
<div align="center">
5+
<img src="./imgs/arch.png" width=60%/>
6+
</div>
57

6-
## Getting Started
8+
<div align="center">
9+
<img src="./imgs/inference.png" width=50%/>
10+
</div>
11+
12+
## Performance
13+
### Harness Eval
14+
Training with 1T Tokens:
15+
| **Model** | **Arc-c** | **Arc-e** | **BoolQ** | **Hellaswag**$^*$ | **OBQA** | **PIQA** | **Winogrande** | **SciQ** | **Avg** |
16+
|----------------------------|-----------|-----------|-----------|-------------------|----------|----------|----------------|----------|---------|
17+
| OpenLLaMA-3B-v2 | 0.339 | 0.676 | 0.657 | **0.700** | 0.260 | 0.767 | 0.629 | 0.924 | 0.619 |
18+
| StableLM-base-alpha-3B-v2 | 0.324 | 0.673 | 0.646 | 0.686 | 0.264 | 0.760 | 0.621 | 0.921 | 0.612 |
19+
| StableLM-3B-4E1T | --- | 0.666 | --- | --- | --- | **0.768**| 0.632 | 0.914 | --- |
20+
| YOCO-3B | **0.379** | **0.731** | 0.645 | 0.689 | **0.298**| 0.763 | 0.639 | 0.924 | **0.634**|
21+
22+
Training with 1.6T Tokens:
23+
| **Model** | **Arc-c** | **Arc-e** | **BoolQ** | **Hellaswag**$^*$ | **OBQA** | **PIQA** | **Winogrande** | **SciQ** | **Avg** |
24+
|----------------------------|-----------|-----------|-----------|-------------------|----------|----------|----------------|----------|---------|
25+
| StableLM-3B-4E1T | --- | 0.688 | --- | --- | --- | 0.762 | 0.627 | 0.913 | --- |
26+
| YOCO-3B | 0.396 | 0.733 | **0.644** | 0.698 | 0.300 | 0.764 | 0.631 | 0.921 | 0.636 |
27+
| YOCO-3B-1M | **0.413** | **0.747** | 0.638 | **0.705** | 0.300 | **0.773**| **0.651** | **0.932**| **0.645**|
28+
### Needle In A Haystack
29+
<div align="center">
30+
<img src="./imgs/1m_retrieval.png"/>
31+
</div>
32+
33+
### Multi-Needle Eval
34+
| **Model** | **Size** | **N=1** | **N=2** | **N=4** | **N=8** |
35+
|-------------------------|----------|---------|---------|---------|---------|
36+
| GPT-4-128K | -- | 1.00 | 1.00 | 0.98 | 1.00 |
37+
| MiniCPM-128K | 2.4B | 1.00 | 1.00 | 0.54 | 0.56 |
38+
| ChatGLM3-128K | 6B | 0.94 | 0.72 | 0.52 | 0.44 |
39+
| YaRN-Mistral-128K | 7B | 0.02 | 0.12 | 0.08 | 0.20 |
40+
| LWM-1M-text | 7B | 1.00 | 0.90 | 0.76 | 0.62 |
41+
| YOCO-3B-1M | 3B | 0.98 | 0.98 | 0.84 | 0.56 |
42+
43+
## Setup
44+
45+
To install the required packages, use the following command:
46+
47+
```bash
48+
pip install -r requirements.txt
49+
```
50+
51+
Besides normal packages, [Apex](https://github.com/NVIDIA/apex) and [Flash-Attention](https://github.com/Dao-AILab/flash-attention) should be installed seperately following their offcial guidences.
52+
53+
## Harness Eval
54+
55+
To evaluate models in Harness-Eval, the script is as follows in ```scripts/eval_task.sh```:
56+
```bash
57+
cd fairseq/
58+
TASK='harness_boolq'
59+
60+
torchrun --master-port=29505 --nproc_per_node=1 validate.py \
61+
--data-dir ../harness_data/ \
62+
--criterion harness_eval \
63+
--task harness_eval \
64+
--batch-size 4 \
65+
--eval-data ${TASK} \
66+
--log-format simple --log-interval 10 \
67+
--bf16 \
68+
--tokenizer-pad-to-multiple 8 \
69+
--arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /path_to_ckpt/YOCO-3B-1M/checkpoint.pth --yoco-model /path_to_ckpt/YOCO-3B-1M --tokens-per-sample 4096
70+
```
71+
72+
## Needle In A Haystack Evaluation
73+
Our model uses city-number pairs for long sequence evaluation. To get the results at a certain maximal length, the script is as follows in ```scripts/eval_needle.sh```:
74+
```bash
75+
cd fairseq/
76+
torchrun --master-port=29504 --nproc_per_node=1 validate.py \
77+
--task pseudo \
78+
--criterion needle_haystack \
79+
--batch-size 1 \
80+
--max-epoch 1 \
81+
--no-save \
82+
--tiktoken-model cl100k_base \
83+
--bf16 \
84+
--arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /path_to_ckpt/YOCO-3B-1M/checkpoint.pth --yoco-model /path_to_ckpt/YOCO-3B-1M --tokens-per-sample 1048576 --interval 1048576
85+
```
86+
87+
To run Multi-Needle experiments, replace ```--criterion needle_haystack``` with ```--criterion multi_needle --needle-num {num}```.
88+
89+
## Pretraining From Scratch
90+
To support distributed training, our implementation is based on infinibatch to read data iteratively. The overall data directory should be organized as follows:
91+
```
92+
Data/
93+
├── json/
94+
│ ├── train.json
95+
│ └── CC.json
96+
│ └── StarCoder.json
97+
│ └── ...
98+
├── shard/
99+
│ ├── CC/
100+
│ │ ├── 00000.jsonl
101+
│ │ ├── 00001.jsonl
102+
│ │ └── ...
103+
│ └── StarCoder/
104+
│ ├── 00000.jsonl
105+
│ ├── 00001.jsonl
106+
│ └── ...
107+
```
108+
109+
We recommend that each sharded data files contains no more than 10K lines with one json dict per line, and jsonl file, such as ```Data/shard/CC/00000.jsonl```, should be in the format like this:
110+
```json
111+
{"text": "File 1 is here..."}
112+
{"text": "File 2 is here..."}
113+
...
114+
```
115+
116+
Then, for each source, a JSON file preserves all the paths of the jsonl files. Take ```Data/json/CC.json``` for example:
117+
```json
118+
[
119+
"/path_to_data/Data/shard/CC/00000.jsonl",
120+
"/path_to_data/Data/shard/CC/00001.jsonl",
121+
...
122+
]
123+
```
124+
125+
Finally, ```train.json``` records all sources' information and sampling ratio:
126+
```json
127+
[
128+
{
129+
"name": "CC",
130+
"weight": 0.5
131+
},
132+
{
133+
"name": "StarCoder",
134+
"weight": 0.2
135+
},
136+
...
137+
]
138+
```
139+
140+
```scripts/train.sh```:
141+
```bash
142+
cd fairseq/
143+
torchrun --nproc-per-node=1 train.py /path_to_data \
144+
--save-interval-updates 5000 \
145+
--no-epoch-checkpoints \
146+
--arch yoco_base \
147+
--criterion cross_entropy \
148+
--task gpt \
149+
--tokens-per-sample 2048 \
150+
--tokenizer-pad-to-multiple 8 \
151+
--pad-to-max-len \
152+
--optimizer adam --adam-betas "(0.9, 0.95)" \
153+
--adam-eps 1e-06 \
154+
--clip-norm 2.0 \
155+
--lr 0.00015 \
156+
--lr-scheduler polynomial_decay \
157+
--warmup-updates 50 \
158+
--weight-decay 0.05 \
159+
--batch-size 1 \
160+
--model-parallel-size 1 \
161+
--update-freq 1 \
162+
--batch-read-ahead 1000 \
163+
--total-num-update 300000 \
164+
--log-format simple --log-interval 10 --disable-validation \
165+
--tiktoken-model cl100k_base \
166+
--save-interval-updates 5000 \
167+
--bf16 # bf16 is encouraged in pre-training
168+
```

YOCO/imgs/1m_retrieval.png

42.1 KB
Loading

YOCO/imgs/arch.png

53.5 KB
Loading

YOCO/imgs/inference.png

98.3 KB
Loading

YOCO/requirements.txt

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
torch>=2.2.0
2+
triton>=2.2.0
3+
numpy==1.23.0
4+
fairscale
5+
tiktoken
6+
sentencepiece
7+
ninja
8+
boto3
9+
iopath
10+
git+https://github.com/sunyt32/fairseq.git@moe3#egg=fairseq
11+
git+https://github.com/shumingma/infinibatch.git#egg=infinibatch
12+
git+https://github.com/microsoft/torchscale.git#egg=torchscale

YOCO/scripts/eval_needle.sh

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
cd yoco/
2+
torchrun --master-port=29504 --nproc_per_node=1 validate.py \
3+
--task pseudo \
4+
--criterion multi_needle --needle-num 4 \
5+
--batch-size 1 \
6+
--max-epoch 1 \
7+
--no-save \
8+
--tiktoken-model cl100k_base \
9+
--bf16 \
10+
--arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /data/yutao/ckpt_opensource/YOCO-3B-1M/checkpoint.pth --yoco-model /data/yutao/ckpt_opensource/YOCO-3B-1M --tokens-per-sample 1048576 --interval 1048576
11+

YOCO/scripts/eval_task.sh

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
TASK='harness_boolq'
2+
# TASK='hendrycksTest-abstract_algebra'
3+
4+
cd yoco/
5+
torchrun --master-port=29505 --nproc_per_node=1 validate.py \
6+
--data-dir ../harness_data/ \
7+
--criterion harness_eval \
8+
--task harness_eval \
9+
--batch-size 4 \
10+
--eval-data ${TASK} \
11+
--log-format simple --log-interval 10 \
12+
--bf16 \
13+
--tokenizer-pad-to-multiple 8 \
14+
--arch yoco_3b_new --tiktoken-model cl100k_base --load-ckpt /data/yutao/ckpt_opensource/YOCO-3B-1M/checkpoint.pth --yoco-model /data/yutao/ckpt_opensource/YOCO-3B-1M --tokens-per-sample 4096
15+
# --arch llama_from_ckpt --llama-model /data/yutao/llama/llama-2-7b --load-ckpt /data/yutao/llama/llama-2-7b/consolidated.00.pth --tokens-per-sample 4096
16+
17+

YOCO/scripts/train.sh

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
cd yoco/
2+
torchrun --master-port=29501 --nproc-per-node=1 train.py /mnt/nlcredstone/shaohanh/data/redstone_v4_21_config \
3+
--save-interval-updates 5000 \
4+
--no-epoch-checkpoints \
5+
--arch yoco_base \
6+
--criterion cross_entropy \
7+
--task gpt \
8+
--tokens-per-sample 2048 \
9+
--tokenizer-pad-to-multiple 8 \
10+
--pad-to-max-len \
11+
--optimizer adam --adam-betas "(0.9, 0.95)" \
12+
--adam-eps 1e-06 \
13+
--clip-norm 2.0 \
14+
--lr 0.00015 \
15+
--lr-scheduler polynomial_decay \
16+
--warmup-updates 50 \
17+
--weight-decay 0.05 \
18+
--batch-size 1 \
19+
--model-parallel-size 1 \
20+
--update-freq 1 \
21+
--batch-read-ahead 1000 \
22+
--total-num-update 300000 \
23+
--log-format simple --log-interval 10 --disable-validation \
24+
--tiktoken-model cl100k_base \
25+
--no-save \
26+
--bf16 \
27+

YOCO/yoco/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) 2022 Microsoft
2+
# Licensed under The MIT License [see LICENSE for details]

YOCO/yoco/criterions/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import importlib
2+
import os
3+
4+
# automatically import any Python files in the criterions/ directory
5+
for file in sorted(os.listdir(os.path.dirname(__file__))):
6+
if file.endswith(".py") and not file.startswith("_"):
7+
file_name = file[: file.find(".py")]
8+
importlib.import_module("criterions." + file_name)

YOCO/yoco/criterions/harness_eval.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
from fairseq import metrics
5+
from fairseq.criterions import FairseqCriterion, register_criterion
6+
from fairseq.dataclass import FairseqDataclass
7+
8+
9+
@register_criterion("harness_eval", dataclass=FairseqDataclass)
10+
class HarnessEvalCriterion(FairseqCriterion):
11+
def __init__(self, cfg, task):
12+
super().__init__(task)
13+
14+
def forward(self, model, sample, reduce=True):
15+
"""Compute the loss for the given sample.
16+
17+
Returns a tuple with three elements:
18+
1) the loss
19+
2) the sample size, which is used as the denominator for the gradient
20+
3) logging outputs to display while training
21+
"""
22+
model.eval()
23+
net_output, _ = model(sample["net_input"]["src_tokens"])
24+
net_output = net_output[:, :-1, :]
25+
targets = sample["net_input"]["src_tokens"][:, 1:]
26+
loss_mask = sample["net_input"]["gpt_loss_mask"][:, 1:]
27+
label_length = sample["net_input"]["label_length"]
28+
loss = F.cross_entropy(
29+
net_output.float().reshape(-1, net_output.size(-1)),
30+
targets.reshape(-1),
31+
reduction="none",
32+
ignore_index=self.padding_idx,
33+
).reshape(targets.size(0), -1)
34+
loss = loss * loss_mask.int()
35+
loss_norm = loss.sum(-1) / label_length.float()
36+
loss = loss.sum(-1)
37+
38+
option_num = self.task.harness_task.class_num
39+
labels = sample["targets"].view(-1)
40+
41+
assert sample["targets"].size(0) % option_num == 0
42+
sample_size = sample["ntokens"]
43+
44+
pred_label = torch.argmin(loss.view(-1, option_num), dim=1)
45+
pred_norm_label = torch.argmin(loss_norm.view(-1, option_num), dim=1)
46+
target_label = labels.view(-1, option_num)[:, 0]
47+
48+
logging_output = {}
49+
50+
logging_output.update(
51+
{
52+
"loss": 0,
53+
"nsentences": pred_label.size(0),
54+
"sample_size": pred_label.size(0),
55+
"ncorrect": (pred_label == target_label).sum().item(),
56+
"ncorrect_norm": (pred_norm_label == target_label).sum().item(),
57+
}
58+
)
59+
60+
return loss, sample_size, logging_output
61+
62+
@staticmethod
63+
def reduce_metrics(logging_outputs) -> None:
64+
"""Aggregate logging outputs from data parallel training."""
65+
loss = sum(log.get("loss", 0) for log in logging_outputs)
66+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
67+
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
68+
ncorrect_norm = sum(log.get("ncorrect_norm", 0) for log in logging_outputs)
69+
metrics.log_scalar(
70+
"loss", loss / nsentences, nsentences, round=3
71+
)
72+
metrics.log_scalar(
73+
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=2
74+
)
75+
metrics.log_scalar(
76+
"accuracy_norm", 100.0 * ncorrect_norm / nsentences, nsentences, round=2
77+
)
78+
79+
@staticmethod
80+
def logging_outputs_can_be_summed() -> bool:
81+
"""
82+
Whether the logging outputs returned by `forward` can be summed
83+
across workers prior to calling `reduce_metrics`. Setting this
84+
to True will improves distributed training speed.
85+
"""
86+
return True

0 commit comments

Comments
 (0)