8000 Bring back PartialState DeepSpeed (#22921) · githubhjs/transformers@8b12903 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8b12903

Browse files
authored
Bring back PartialState DeepSpeed (huggingface#22921)
* Bring back deepspeed integration * Branchname * Self-scheduled * newline * Use deepspeed env var * Remove comment * Del env var after partialstate
1 parent 4331923 commit 8b12903

File tree

1 file changed

+48
-62
lines changed
Expand file tree

1 file changed

+48
-62
lines changed

src/transformers/training_args.py

Lines changed: 48 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,84 +1544,70 @@ def _setup_devices(self) -> "torch.device":
15441544
self._n_gpu = 1
15451545
torch.cuda.set_device(device)
15461546
elif self.deepspeed:
1547-
# deepspeed inits torch.distributed internally
1548-
from .deepspeed import is_deepspeed_available
1549-
1550-
if not is_deepspeed_available():
1551-
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
1552-
import deepspeed
1553-
1554-
deepspeed.init_distributed(timeout=timedelta(seconds=self.ddp_timeout))
1555-
1556-
# workaround for setups like notebooks where the launcher can't be used,
1557-
# but deepspeed requires a dist env.
1558-
# env LOCAL_RANK could be set manually by the user, or via init_distributed if mpi4py is installed
1559-
self.local_rank = int(os.environ.get("LOCAL_RANK", "-1"))
1560-
1561-
device = torch.device("cuda", self.local_rank)
1547+
# Need to do similar for Accelerator init
1548+
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
1549+
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
1550+
del os.environ["ACCELERATE_USE_DEEPSPEED"]
15621551
self._n_gpu = 1
15631552
else:
15641553
self.distributed_state = PartialState(backend=self.xpu_backend)
15651554
self._n_gpu = 1
1566-
if not is_sagemaker_mp_enabled() and not self.deepspeed:
1555+
if not is_sagemaker_mp_enabled():
15671556
device = self.distributed_state.device
15681557
self.local_rank = self.distributed_state.local_process_index
15691558
if (
15701559
torch.distributed.is_available()
15711560
and torch.distributed.is_initialized()
1572-
and hasattr(self, "distributed_state")
15731561
and self.distributed_state.distributed_type == DistributedType.NO
15741562
):
15751563
logger.warning(
15761564
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
15771565
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
15781566
)
1579-
if not self.deepspeed:
1580-
if is_torch_tpu_available():
1581-
device = self.distributed_state.device
1582-
self._n_gpu = 0
1583-
elif is_sagemaker_dp_enabled():
1584-
self._n_gpu = 1
1585-
elif self.distributed_state.distributed_type == DistributedType.NO:
1586-
if self.use_mps_device:
1587-
if not torch.backends.mps.is_available():
1588-
if not torch.backends.mps.is_built():
1589-
raise AssertionError(
1590-
"MPS not available because the current PyTorch install was not "
1591-
"built with MPS enabled. Please install torch version >=1.12.0 on "
1592-
"your Apple silicon Mac running macOS 12.3 or later with a native "
1593-
"version (arm64) of Python"
1594-
)
1595-
else:
1596-
raise AssertionError(
1597-
"MPS not available because the current MacOS version is not 12.3+ "
1598-
"and/or you do not have an MPS-enabled device on this machine."
1599-
)
1567+
if is_torch_tpu_available():
1568+
device = self.distributed_state.device
1569+
self._n_gpu = 0
1570+
elif is_sagemaker_dp_enabled():
1571+
self._n_gpu = 1
1572+
elif self.distributed_state.distributed_type == DistributedType.NO:
1573+
if self.use_mps_device:
1574+
if not torch.backends.mps.is_available():
1575+
if not torch.backends.mps.is_built():
1576+
raise AssertionError(
1577+
"MPS not available because the current PyTorch install was not "
1578+
"built with MPS enabled. Please install torch version >=1.12.0 on "
1579+
"your Apple silicon Mac running macOS 12.3 or later with a native "
1580+
"version (arm64) of Python"
1581+
)
16001582
else:
1601-
if not version.parse(version.parse(torch.__version__).base_version) > version.parse("1.12.0"):
1602-
warnings.warn(
1603-
"We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing)"
1604-
" on your MacOS machine. It has major fixes related to model correctness and performance"
1605-
" improvements for transformer based models. Please refer to"
1606-
" https://github.com/pytorch/pytorch/issues/82707 for more details."
1607-
)
1608-
device = torch.device("mps")
1609-
self._n_gpu = 1
1610-
1583+
raise AssertionError(
1584+
"MPS not available because the current MacOS version is not 12.3+ "
1585+
"and/or you do not have an MPS-enabled device on this machine."
1586+
)
16111587
else:
1612-
# if n_gpu is > 1 we'll use nn.DataParallel.
1613-
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
1614-
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
1615-
# trigger an error that a device index is missing. Index 0 takes into account the
1616-
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
1617-
# will use the first GPU in that env, i.e. GPU#1
1618-
# device = self.distributed_state.device
1619-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1620-
# Sometimes the line in the postinit has not been run before we end up here, so just chec 628C king we're not at
1621-
# the default value.
1622-
self._n_gpu = torch.cuda.device_count()
1623-
if device.type == "cuda":
1624-
torch.cuda.set_device(device)
1588+
if not version.parse(version.parse(torch.__version__).base_version) > version.parse("1.12.0"):
1589+
warnings.warn(
1590+
"We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing)"
1591+
" on your MacOS machine. It has major fixes related to model correctness and performance"
1592+
" improvements for transformer based models. Please refer to"
1593+
" https://github.com/pytorch/pytorch/issues/82707 for more details."
1594+
)
1595+
device = torch.device("mps")
1596+
self._n_gpu = 1
1597+
1598+
else:
1599+
# if n_gpu is > 1 we'll use nn.DataParallel.
1600+
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
1601+
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
1602+
# trigger an error that a device index is missing. Index 0 takes into account the
1603+
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
1604+
# will use the first GPU in that env, i.e. GPU#1
1605+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1606+
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
1607+
# the default value.
1608+
self._n_gpu = torch.cuda.device_count()
1609+
if device.type == "cuda":
1610+
torch.cuda.set_device(device)
16251611
return device
16261612

16271613
@property
@@ -1664,7 +1650,7 @@ def parallel_mode(self):
16641650
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
16651651
elif is_sagemaker_dp_enabled():
16661652
return ParallelMode.SAGEMAKER_DATA_PARALLEL
1667-
elif self.deepspeed or self.distributed_state.distributed_type != DistributedType.NO:
1653+
elif hasattr(self, "distributed_state") and self.distributed_state.distributed_type != DistributedType.NO:
16681654
return ParallelMode.DISTRIBUTED
16691655
elif self.n_gpu > 1:
16701656
return ParallelMode.NOT_DISTRIBUTED

0 commit comments

Comments
 (0)
0