Skip to content

Commit 131169d

Browse files
authored
Merge branch 'main' into comet-ml-tracker-update
2 parents d78720d + a702364 commit 131169d

File tree

102 files changed

+3661
-721
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+3661
-721
lines changed

.github/workflows/build_docker_images.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,6 @@ jobs:
105105
- name: Build and Push GPU
106106
uses: docker/build-push-action@v4
107107
with:
108-
file: benchmarks/fp8/Dockerfile
108+
file: benchmarks/fp8/transformer_engine/Dockerfile
109109
push: true
110110
tags: huggingface/accelerate:gpu-fp8-transformerengine-nightly-${{ env.date }}

.github/workflows/gaudi1.yml

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
name: Gaudi1 tests (scheduled)
2+
3+
on:
4+
workflow_dispatch:
5+
schedule:
6+
- cron: "0 2 * * *"
7+
8+
concurrency:
9+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
10+
cancel-in-progress: true
11+
12+
jobs:
13+
run_gaudi1_tests:
14+
name: Test on Gaudi1
15+
runs-on:
16+
group: aws-dl1-24xlarge
17+
18+
container:
19+
image: docker://vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
20+
options: --runtime=habana --shm-size=64G --cap-add=sys_nice --env HABANA_VISIBLE_DEVICES=0,1
21+
env:
22+
OMPI_MCA_btl_vader_single_copy_mechanism: none
23+
PT_ENABLE_INT64_SUPPORT: 1
24+
PT_HPU_LAZY_MODE: 0
25+
RUN_SLOW: 1
26+
27+
steps:
28+
- name: HL-SMI (1)
29+
run: |
30+
hl-smi
31+
echo "HABANA_VISIBLE_DEVICES=${HABANA_VISIBLE_DEVICES}"
32+
echo "HABANA_VISIBLE_MODULES=${HABANA_VISIBLE_MODULES}"
33+
34+
- name: Extract HPU visible modules
35+
id: add-modules
36+
run: |
37+
export HABANA_VISIBLE_MODULES=$(hl-smi -Q module_id -f csv,noheader | tr '\n' ',' | sed 's/,$//')
38+
echo "HABANA_VISIBLE_MODULES=${HABANA_VISIBLE_MODULES}" >> $GITHUB_ENV
39+
40+
- name: HL-SMI (2)
41+
run: |
42+
hl-smi
43+
echo "HABANA_VISIBLE_DEVICES=${HABANA_VISIBLE_DEVICES}"
44+
echo "HABANA_VISIBLE_MODULES=${HABANA_VISIBLE_MODULES}"
45+
46+
- name: Checkout to Accelerate
47+
uses: actions/checkout@v4
48+
49+
- name: Install Accelerate with Transformers & DeepSpeed
50+
run: |
51+
pip install -e .[testing] \
52+
git+https://github.com/HabanaAI/[email protected] \
53+
git+https://github.com/huggingface/transformers.git@hpu-support
54+
55+
- name: Run CLI tests
56+
run: |
57+
make test_cli
58+
59+
- name: Run Core tests
60+
run: |
61+
make test_core
62+
63+
- name: Run Big Modeling tests
64+
run: |
65+
make test_big_modeling
66+
67+
- name: Run FSDP integration tests
68+
run: |
69+
make test_fsdp
70+
71+
- name: Run DeepSpeed integration tests
72+
run: |
73+
make test_deepspeed
74+
75+
- name: Run Examples tests
76+
run: |
77+
make test_examples

Makefile

+24-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ test_big_modeling:
2828

2929
test_core:
3030
python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py --ignore=./tests/deepspeed --ignore=./tests/test_big_modeling.py \
31-
--ignore=./tests/fsdp --ignore=./tests/test_cli.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_core.log",)
31+
--ignore=./tests/fsdp --ignore=./tests/tp --ignore=./tests/test_cli.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_core.log",)
3232

3333
test_cli:
3434
python -m pytest -s -v ./tests/test_cli.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_cli.log",)
@@ -39,6 +39,9 @@ test_deepspeed:
3939
test_fsdp:
4040
python -m pytest -s -v ./tests/fsdp $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_fsdp.log",)
4141

42+
test_tp:
43+
python -m pytest -s -v ./tests/tp $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_tp.log",)
44+
4245
# Since the new version of pytest will *change* how things are collected, we need `deepspeed` to
4346
# run after test_core and test_cli
4447
test:
@@ -47,13 +50,14 @@ test:
4750
$(MAKE) test_big_modeling
4851
$(MAKE) test_deepspeed
4952
$(MAKE) test_fsdp
53+
$(MAKE) test_tp
5054

5155
test_examples:
5256
python -m pytest -s -v ./tests/test_examples.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_examples.log",)
5357

5458
# Broken down example tests for the CI runners
5559
test_integrations:
56-
python -m pytest -s -v ./tests/deepspeed ./tests/fsdp $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_integrations.log",)
60+
python -m pytest -s -v ./tests/deepspeed ./tests/fsdp ./tests/tp $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_integrations.log",)
5761

5862
test_example_differences:
5963
python -m pytest -s -v ./tests/test_examples.py::ExampleDifferenceTests $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_example_diff.log",)
@@ -70,3 +74,21 @@ test_prod:
7074

7175
test_rest:
7276
python -m pytest -s -v ./tests/test_examples.py::FeatureExamplesTests -k "not by_step and not by_epoch" $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_rest.log",)
77+
78+
# For developers to prepare a release
79+
prepare_release:
80+
rm -rf dist build
81+
python setup.py bdist_wheel sdist
82+
83+
# Make sure this is ran in a fresh venv of some form
84+
install_test_release:
85+
pip uninstall accelerate -y
86+
pip install -i https://testpypi.python.org/pypi --extra-index-url https://pypi.org/simple accelerate$(if $(version),==$(version),)
87+
88+
# Run as `make target=testpypi upload_release`
89+
upload_release:
90+
@if [ "$(target)" != "testpypi" ] && [ "$(target)" != "pypi" ]; then \
91+
echo "Error: target must be either 'testpypi' or 'pypi'"; \
92+
exit 1; \
93+
fi
94+
twine upload dist/* -r $(target)

benchmarks/fp8/torchao/Dockerfile

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
FROM nvcr.io/nvidia/pytorch:24.07-py3
2+
3+
RUN pip install transformers evaluate datasets
4+
RUN git clone https://github.com/huggingface/accelerate.git
5+
6+
RUN cd accelerate && \
7+
pip install -e . && \
8+
cd benchmarks/fp8
9+
10+
RUN /bin/bash
11+
12+

benchmarks/fp8/torchao/README.md

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# FP8 Benchmarks
2+
3+
Comparing and running [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8) FP8 with accelerate
4+
5+
## Overview
6+
7+
This repo provides scripts which compare native `torchao` model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following:
8+
9+
* Single GPU training (`non_distributed.py`)
10+
* Multi-GPU training via DistributedDataParallelism (`ddp.py`)
11+
* Fully Sharded Data Parallelism (`fsdp.py`)
12+
* DeepSpeed ZeRO 1-3 (`deepspeed.py`)
13+
14+
To run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `torchao` manually.
15+
16+
## Running:
17+
18+
There are official Docker images located at `huggingface/accelerate:gpu-fp8-torchao-nightly` which can be used.
19+
20+
You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed.
21+
22+
For single GPU, run it via `python`:
23+
24+
```bash
25+
python non_distributed.py
26+
```
27+
28+
For the rest, run it via `accelerate launch`:
29+
30+
```bash
31+
accelerate launch ddp.py # or distrib_deepspeed.py, ddp.py
32+
```

benchmarks/fp8/torchao/ddp.py

+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This script tests to ensure that `accelerate` performs at the same level as raw `torchao`.
17+
18+
This particular script verifies this for DDP training.
19+
"""
20+
21+
from functools import partial
22+
23+
import evaluate
24+
import torch
25+
from fp8_utils import get_training_utilities
26+
from torch.nn.parallel import DistributedDataParallel as DDP
27+
from torchao.float8 import convert_to_float8_training
28+
29+
from accelerate import Accelerator
30+
from accelerate.state import AcceleratorState
31+
from accelerate.utils import AORecipeKwargs, set_seed
32+
33+
34+
MODEL_NAME = "bert-base-cased"
35+
METRIC = evaluate.load("glue", "mrpc")
36+
37+
38+
def evaluate_model(model, dataloader, metric, accelerator=None):
39+
"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
40+
model.eval()
41+
for step, batch in enumerate(dataloader):
42+
with torch.no_grad():
43+
outputs = model(**batch)
44+
predictions = outputs.logits.argmax(dim=-1)
45+
references = batch["labels"]
46+
if accelerator is not None and accelerator.num_processes > 1:
47+
predictions, references = accelerator.gather_for_metrics((predictions, references))
48+
metric.add_batch(predictions=predictions, references=references)
49+
return metric.compute()
50+
51+
52+
def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):
53+
if isinstance(module, torch.nn.Linear):
54+
if module.in_features % 16 != 0 or module.out_features % 16 != 0:
55+
return False
56+
# For stability reasons, we skip the first and last linear layers
57+
# Otherwise can lead to the model not training or converging properly
58+
if fqn in (first_layer_name, last_layer_name):
59+
return False
60+
return True
61+
62+
63+
def train_baseline():
64+
set_seed(42)
65+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
66+
first_linear = None
67+
last_linear = None
68+
for name, module in model.named_modules():
69+
if isinstance(module, torch.nn.Linear):
70+
if first_linear is None:
71+
first_linear = name
72+
last_linear = name
73+
func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
74+
accelerator = Accelerator()
75+
device = accelerator.device
76+
model.to(device)
77+
78+
convert_to_float8_training(model, module_filter_fn=func)
79+
80+
# Convert the model to DDP
81+
device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index
82+
model = DDP(model, device_ids=device_ids, output_device=output_device)
83+
84+
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
85+
model.train()
86+
87+
for batch in train_dataloader:
88+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
89+
batch = batch.to(device)
90+
outputs = model(**batch)
91+
loss = outputs.loss
92+
loss.backward()
93+
optimizer.step()
94+
optimizer.zero_grad()
95+
lr_scheduler.step()
96+
97+
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
98+
99+
assert (
100+
trained_model_results["accuracy"] > base_model_results["accuracy"]
101+
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
102+
assert (
103+
trained_model_results["f1"] > base_model_results["f1"]
104+
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
105+
106+
return base_model_results, trained_model_results
107+
108+
109+
def train_integration():
110+
AcceleratorState()._reset_state(True)
111+
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()])
112+
set_seed(42)
113+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
114+
MODEL_NAME, accelerator=accelerator
115+
)
116+
117+
model, optimizer = accelerator.prepare(model, optimizer)
118+
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
119+
model.train()
120+
121+
for batch in train_dataloader:
122+
outputs = model(**batch)
123+
loss = outputs.loss
124+
accelerator.backward(loss)
125+
optimizer.step()
126+
optimizer.zero_grad()
127+
lr_scheduler.step()
128+
129+
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
130+
131+
assert (
132+
trained_model_results["accuracy"] > base_model_results["accuracy"]
133+
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
134+
assert (
135+
trained_model_results["f1"] > base_model_results["f1"]
136+
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
137+
138+
return base_model_results, trained_model_results
139+
140+
141+
if __name__ == "__main__":
142+
baseline_not_trained, baseline_trained = train_baseline()
143+
accelerator_not_trained, accelerator_trained = train_integration()
144+
145+
assert (
146+
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
147+
), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
148+
assert (
149+
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
150+
), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
151+
assert (
152+
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
153+
), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
154+
assert (
155+
baseline_trained["f1"] == accelerator_trained["f1"]
156+
), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
157+
158+
torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)