Skip to content

Commit

Permalink
Support eager mode for multi-process training (#7327)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Jun 24, 2024
1 parent dcbd929 commit 222bbd8
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 7 deletions.
26 changes: 26 additions & 0 deletions examples/eager/train_decoder_only_eager_multi_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_decoder_only_base import TrainDecoderOnlyBase

import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm


class TrainDecoderXLADDP(TrainDecoderOnlyBase):

def run_optimizer(self):
# optimizer_step will call `optimizer.step()` and all_reduce the gradident
xm.optimizer_step(self.optimizer)


def _mp_fn(index):
import torch_xla
torch_xla.experimental.eager_mode(True)
xla_ddp = TrainDecoderXLADDP()
xla_ddp.start_training()


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
2 changes: 1 addition & 1 deletion examples/train_decoder_only_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self):
self.config = DecoderOnlyConfig()
self.batch_size = 16
self.seq_len = 512
self.num_steps = 300
self.num_steps = 200
self.num_epochs = 1
self.train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
# For the purpose of this example, we are going to use fake data.
Expand Down
40 changes: 40 additions & 0 deletions test/eager/test_eager_all_reduce_in_place.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch_xla

import torch_xla.core.xla_model as xm
import torch_xla.debug
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.debug.metrics as met


def _mp_fn(index):
import torch_xla
torch_xla.experimental.eager_mode(True)

device = torch_xla.device()

if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'):
return

ordinal_tensor_1 = torch.tensor([index], dtype=torch.float).to(device)
ordinal_tensor_2 = torch.tensor([index], dtype=torch.int32).to(device)
xm.wait_device_ops()
met.clear_all()

# all_reduce with list of tensor as input will be a inplace op. This is
# used by the optimizer_step.
xm.all_reduce(xm.REDUCE_SUM, [ordinal_tensor_1, ordinal_tensor_2])

xm.wait_device_ops()
assert met.metric_data("EagerOpExecuteTime")[0] == 1

num_device = torch_xla.runtime.global_runtime_device_count()
expected_sum = (num_device - 1) * num_device / 2
expected_1 = torch.tensor([(expected_sum)], dtype=torch.float)
expected_2 = torch.tensor([(expected_sum)], dtype=torch.int32)
assert torch.allclose(expected_1, ordinal_tensor_1.cpu())
assert torch.allclose(expected_2, ordinal_tensor_2.cpu())


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
run_test "$CDIR/eager/test_eager_with_torch_compile.py"
run_test "$CDIR/eager/test_eager_all_reduce_in_place.py"
}

# All the new xla op tests should go to run_xla_op_tests3
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ TPU_VERSION=$(python -c "import sys; sys.path.remove(''); import torch_xla; prin
if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then
python3 examples/eager/train_decoder_only_eager.py
python3 examples/eager/train_decoder_only_eager_with_compile.py
python3 examples/eager/train_decoder_only_eager_multi_process.py
fi
5 changes: 3 additions & 2 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ void XLATensor::SetIrValue(torch::lazy::Value ir_value, bool inplace) {
data()->is_cloned = false;
}

void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) {
void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value,
bool delay_eager_executation) {
auto xla_shape = shape();
if (xla_shape.get().element_type() != GetXlaShape(ir_value).element_type()) {
ir_value =
Expand All @@ -361,7 +362,7 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) {
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();

// in place update should also be triggered eagerly if configured
if (graph_executor->UseEagerMode()) {
if (graph_executor->UseEagerMode() && !delay_eager_executation) {
std::vector<XLATensorPtr> xtensors({c10::make_intrusive<XLATensor>(*this)});
graph_executor->ApplyEagerSync(xtensors);
}
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ class XLATensor : public torch::lazy::LazyTensor {
// TODO(alanwaketan): Reuse the upstream ones once Functionalization is done.
torch::lazy::Value GetIrValue() const;
void SetIrValue(torch::lazy::Value ir_value, bool inplace = true);
void SetInPlaceIrValue(torch::lazy::Value ir_value);
void SetInPlaceIrValue(torch::lazy::Value ir_value,
bool delay_eager_executation = false);

// TODO(alanwaketan): Reuse the upstream one once Functionalization is done.
std::optional<at::Tensor> CurrentTensorData() const;
Expand Down
22 changes: 19 additions & 3 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,26 @@ void all_reduce(const std::vector<XLATensorPtr>& inputs,
reduce_type, input_values, GetAllReduceToken(inputs.front()->GetDevice()),
scale, std::move(groups), pin_layout);
for (size_t i = 0; i < inputs.size(); ++i) {
inputs[i]->SetInPlaceIrValue(torch::lazy::Value(node, i));
// In eager mode we don't want to execute the IR for each tensor because
// that will execute the `all_reduce` x times.
inputs[i]->SetInPlaceIrValue(torch::lazy::Value(node, i),
/*delay_eager_executation=*/true);
}

XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
if (graph_executor->UseEagerMode()) {
// Execute the HLO that will run the `all_reduce` and in place update all
// tensors in one graph.
graph_executor->ApplyEagerSync(
const_cast<std::vector<XLATensorPtr>&>(inputs));
} else {
// all_reduce_token is to enforce the order of the cc ops. There is no point
// of setting it for eager mode since each cc op will be executed
// independently.
SetAllReduceToken(
inputs.front()->GetDevice(),
std::make_shared<torch::lazy::Value>(node, inputs.size()));
}
SetAllReduceToken(inputs.front()->GetDevice(),
std::make_shared<torch::lazy::Value>(node, inputs.size()));
}

std::pair<XLATensorPtr, torch::lazy::Value> reduce_scatter(
Expand Down

0 comments on commit 222bbd8

Please sign in to comment.