From 852aa4fca7305857bf8fdee440491020ca25d8df Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Tue, 17 Sep 2024 20:24:16 +0000 Subject: [PATCH] Add profiler annotation for the decoderonly example --- examples/decoder_only_model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/decoder_only_model.py b/examples/decoder_only_model.py index 712423d79ad..5a050aad33d 100644 --- a/examples/decoder_only_model.py +++ b/examples/decoder_only_model.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from torch import nn +import torch_xla.debug.profiler as xp # the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core. @dataclass @@ -44,6 +45,7 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + @xp.trace_me("RMSNorm") def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) @@ -79,6 +81,7 @@ def __init__(self, config: DecoderOnlyConfig): self.num_heads * self.head_dim, self.hidden_size, bias=False) self.flash_attention_impl = None + @xp.trace_me("attention") def forward( self, hidden_states: torch.Tensor, @@ -153,6 +156,7 @@ def __init__(self, config: DecoderOnlyConfig): self.intermediate_size, self.hidden_size, bias=False) self.act_fn = F.silu + @xp.trace_me("MLP") def forward(self, x): # [B, S, H] -> [B, S, I] up_proj = self.up_proj(x) @@ -173,6 +177,7 @@ def __init__(self, config: DecoderOnlyConfig): self.input_layernorm = RMSNorm(config.hidden_size) self.post_attention_layernorm = RMSNorm(config.hidden_size) + @xp.trace_me("DecoderLayer") def forward( self, hidden_states: torch.Tensor, @@ -209,6 +214,7 @@ def __init__(self, config: DecoderOnlyConfig): self.norm = RMSNorm(config.hidden_size) self.output = nn.Linear(config.hidden_size, self.vocab_size, bias=False) + @xp.trace_me("DecoderOnlyModel") def forward( self, input_ids: torch.LongTensor = None,