Skip to content

Commit

Permalink
Use Tuple instead of tuple in scan.py (#8067)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei committed Sep 25, 2024
1 parent 3c7daa2 commit 1717e25
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torch_xla/experimental/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from typing import Callable, TypeVar
from typing import Callable, TypeVar, Tuple

import torch
from torch.utils._pytree import tree_map, tree_iter
Expand All @@ -15,10 +15,10 @@


def scan(
fn: Callable[[Carry, X], tuple[Carry, Y]],
fn: Callable[[Carry, X], Tuple[Carry, Y]],
init: Carry,
xs: X,
) -> tuple[Carry, Y]:
) -> Tuple[Carry, Y]:
"""Apply a function over leading dimension of tensors while carrying along state.
This is similar to the JAX `jax.lax.scan` function found in [1].
Expand Down

0 comments on commit 1717e25

Please sign in to comment.