diff --git a/train.py b/train.py index 1401ccb969b4..fa04089d4c0c 100644 --- a/train.py +++ b/train.py @@ -409,7 +409,7 @@ def lf(x): imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) # Forward - with torch.cuda.amp.autocast(amp): + with torch.amp.autocast("cuda", amp): pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size if RANK != -1: