Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored May 30, 2024
1 parent 6726cff commit e863340
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,6 @@ class PyLoweringContext {
// needed in xlacomputation for fori_loop/while_loop.
void BuildForiLoop(std::vector<at::Tensor> tensors,
std::vector<at::Tensor> additional_inputs_list = {}) {

// Get the backing XLA tensors from the output torch tensor handles
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
Expand All @@ -936,12 +935,14 @@ class PyLoweringContext {
lowering_ctx.AddResult(root);
}

// add dummy parameter to cond xlacomputation's input for xla::while requriement
// add dummy parameter to cond xlacomputation's input for xla::while
// requriement
if (GetNameString() == "condctx") {
xla::XlaBuilder* local_builder = lowering_ctx.builder();
int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size();
int64_t parameter_idx =
local_builder->GetProgramShape()->parameters_size();
int64_t additional_inputs_list_size = additional_inputs_list.size();
for (int64_t i = parameter_idx; i < additional_inputs_list_size ; i++) {
for (int64_t i = parameter_idx; i < additional_inputs_list_size; i++) {
XLATensorPtr xtensor = bridge::GetXlaTensor(additional_inputs_list[i]);
xla::Shape shape = xtensor->shape().get();
xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape,
Expand All @@ -950,12 +951,14 @@ class PyLoweringContext {
}
}

// add dummy parameter to body xlacomputation's input for xla::while requriement
// add dummy parameter to body xlacomputation's input for xla::while
// requriement
if (GetNameString() == "bodyctx" && additional_inputs_list.size() != 0) {
xla::XlaBuilder* local_builder = lowering_ctx.builder();
int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size();
int64_t parameter_idx =
local_builder->GetProgramShape()->parameters_size();
int64_t additional_inputs_list_size = additional_inputs_list.size();
for (int64_t i = parameter_idx; i < additional_inputs_list_size ; i++) {
for (int64_t i = parameter_idx; i < additional_inputs_list_size; i++) {
XLATensorPtr xtensor = bridge::GetXlaTensor(additional_inputs_list[i]);
xla::Shape shape = xtensor->shape().get();
xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape,
Expand Down

0 comments on commit e863340

Please sign in to comment.