Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scan and apply_layers #7901

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

scan and apply_layers #7901

wants to merge 1 commit into from

Conversation

tengyifei
Copy link
Collaborator

Add the lowering of scan to HLO While op.

Introduce apply_layers which can sequentially apply a bunch of layers
using scan underneath.

Beef up unit tests including linear layers and decoders.

@JackCaoG
Copy link
Collaborator

======================================================================
ERROR: test_decoder_model (__main__.ApplyLayersTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/__w/xla/xla/pytorch/xla/test/test_apply_layers.py", line 77, in test_decoder_model
    from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel  # type:ignore
ModuleNotFoundError: No module named 'decoder_only_model'

you can't just import it, you need to setup import dir correctly. Take a look at https://github.com/pytorch/xla/blob/master/test/dynamo/test_dynamo_dynamic_shape.py#L1-L6

@tengyifei
Copy link
Collaborator Author

@JackCaoG ty. i followed your example and got it working.

test/test_apply_layers.py Outdated Show resolved Hide resolved
test/test_apply_layers.py Outdated Show resolved Hide resolved
test/test_apply_layers.py Outdated Show resolved Hide resolved
test/test_operations.py Outdated Show resolved Hide resolved
import json
hlo_json = json.loads(ctx.hlo_json())
num_parameters = len(hlo_json["hostProgramShape"]["parameters"])
self.assertEqual(len(mapping), num_parameters)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you expect both value to be 10?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not. It looks like some integer values (e.g. values <= 2) are shared when you put multiple copies into the HLO, but values above 2 are not shared. So we don't necessarily get 10. In any case, the precise number of parameters seems to be an implementation detail that we can't reliably test.

@@ -1077,7 +1076,9 @@ class PyLoweringContext {
at::ScalarType dtype =
MaybeUpcastToHostTorchType(literal.shape().element_type());
at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype);
results[param_ids[i]] = input;
std::optional param_id = lowering_ctx.GetParameterId(device_data[i]);
XLA_CHECK(param_id.has_value());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would it not has value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When GetParameterId receives a BackendData that is not a parameter in this lowering context, it will return std::nullopt. However, this loop is only iterating over parameters (line 1071, const std::vector<torch::lazy::BackendDataPtr>& device_data = lowering_ctx.GetParametersData();), so we will expect all BackendData there to have an ID. Seems good to enforce this invariant.

return input_data

# Extract and stack the parameters into a pytree.
params = [_extract_weights_dict(layer) for layer in layers]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if it is a dropout layer that parameters are more than just tensors?

Copy link
Collaborator Author

@tengyifei tengyifei Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a dropout layer that references tensors other than model parameters (for example, the dropout probability), then those tensors will be captured as an additional HLO parameter to the XlaComputation object. As implemented now, apply_layers and scan will trace the first layer, and then use the same captured tensor for subsequent layers. This will be a problem if the user passes different dropout probabilities for say a sequence of dropout layers -- we'll instead incorrectly just keep using the first dropout's probability. I'll have to dig deeper and find a solution for this.

If there's a layer that references things other than tensors, then either that thing (e.g. a bool field) will impact the traced HLO computation, in which case I need to add a verification that all layers trace to equivalent computations. Or that thing won't impact the traced computation, in which case it won't matter to us.

example_layer = deepcopy(next(iter(layers)))

# Hollow out the weights and biases in the example layer.
example_layer = example_layer.to_empty(device=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this not going to impact the cloned arg?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify this question -- I thought to_empty is going to destroy the value inside example_layer, so I deepcopy it before to backup.

Comment on lines +187 to +190
fn_output_carry_pytree, fn_output_y_pytree = flat_fn(*(fake_carry + fake_x))

# Later we'll use `fn_output_carry_spec` etc to turn flattened outputs back to a PyTree.
fn_output_carry, fn_output_carry_spec = tree_flatten(fn_output_carry_pytree)
assert fn_output_carry_spec == carry_spec
fn_output_y, fn_output_y_spec = tree_flatten(fn_output_y_pytree)
flat_y_len = len(fn_output_y)
fn_outputs = fn_output_carry + fn_output_y
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if there are in place updates to the tensor but it is not being return from the function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this and we give the wrong answer: https://github.com/tengyifei/playground/blob/master/scan_with_in_place_updates.ipynb

In the notebook, I wrote an approach to detect and prevent in place updates like that. TLDR is we'll have to trace every forward of each layer and verify that they're the same.


def step_fn(grad_carry, pytree: Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]):
grad_y, carry, x = pytree
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a typo?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so -- pytree is a tuple of the output grad at current step (grad_y), carry at the current step (carry), and input at current step (x)

carry, carry_history, ys = _scan_impl(fn, init, xs)
flat_carry_history, carry_spec = tree_flatten(carry_history)
flat_xs, xs_spec = tree_flatten(xs)
ctx.save_for_backward(*flat_carry_history, *flat_xs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are flat_carry_history flat_xs always everything we need to save for the backward?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. If the user fn references other tensors, they'll be captured as additional inputs to the HLO computation. I'll update the documentation to mention that we'll explicitly checkpoint fn in this iteration, and add the right barriers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the barriers won't be needed. This is the recommendation from JAX, which says you don't want to wrap inputs into a barrier if the checkpointed function is to be used in a scan.

Comment on lines +400 to +401
outputs = fn(*detached_inputs)
output_carry, output_y = outputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait... you are retracing the fwd here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. I'm not sure if the CSE pass can combine this with the same fn in the fwd pass. If it can't, then the bwd of scan will be slower.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ehh, what you do here is pretty much force the gradident accumulation(through most likely cancel by the CSE since there is no optimization barrier), this sounds like a bad idea

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the barriers won't be needed. This is the recommendation from JAX, which says you don't want to wrap inputs into a barrier if the checkpointed function is to be used in a scan.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@miladm
Copy link
Collaborator

miladm commented Sep 20, 2024

@tengyifei is this PR a 2.5 candidate?

@tengyifei
Copy link
Collaborator Author

@miladm yes, I'd like to backport this to 2.5 after addressing the comments etc.

@tengyifei tengyifei removed the request for review from alanwaketan September 25, 2024 18:15
@tengyifei tengyifei force-pushed the yifeit/scan branch 2 times, most recently from ddd01a4 to ea640ab Compare September 25, 2024 19:25
@tengyifei tengyifei force-pushed the yifeit/scan branch 3 times, most recently from 348120e to 2f868fd Compare September 27, 2024 03:11
Add the lowering of scan to HLO While op.

Introduce apply_layers which can sequentially apply a bunch of layers
using scan underneath.

Beef up unit tests including linear layers and decoders.

add regression test for parameter_id_tensor_mapping

add test_apply_layers.py to test shell scripts

correctly import decoder model from examples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants