Skip to content

Commit

Permalink
fix type annotation for 3.8 (#7701)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Jul 17, 2024
1 parent 2a20d1e commit 9952e38
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 48 deletions.
91 changes: 45 additions & 46 deletions benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch._dynamo.utils as dynamo_utils
import tiers
import typing
from typing import Optional, Any, List, Dict, Sequence
import torch_xla.debug.metrics as met
from tqdm import tqdm
Expand Down Expand Up @@ -206,8 +207,8 @@ def generate_and_run_all_configs(self):

# TODO: Use `_unique_basename` instead.
def _get_config_fingerprint(
self, experiment_config: OrderedDict[str, Optional[StrOrBool]],
model_config: OrderedDict[str, Optional[StrOrBool]]) -> str:
self, experiment_config: typing.OrderedDict[str, Optional[StrOrBool]],
model_config: typing.OrderedDict[str, Optional[StrOrBool]]) -> str:
# Experiment `batch_size` may be altered by model in `set_up`, so we will
# ignore that.
return "-".join(
Expand Down Expand Up @@ -290,9 +291,9 @@ def run_single_config(self):
def run_once_and_gather_metrics(
self, benchmark_experiment: BenchmarkExperiment,
benchmark_model: BenchmarkModel,
experiment_config: OrderedDict[str, Optional[StrOrBool]],
model_config: OrderedDict[str,
Optional[StrOrBool]], repeat_iteration: int):
experiment_config: typing.OrderedDict[str, Optional[StrOrBool]],
model_config: typing.OrderedDict[str, Optional[StrOrBool]],
repeat_iteration: int):

# Prepare inputs.
self.reset_rng_state(benchmark_experiment)
Expand Down Expand Up @@ -463,8 +464,8 @@ def _synchronize(self, benchmark_experiment: BenchmarkExperiment):
##############################################################################

def _unique_basename(
self, experiment_config: OrderedDict[str, Optional[StrOrBool]],
model_config: OrderedDict[str, Optional[StrOrBool]]) -> str:
self, experiment_config: typing.OrderedDict[str, Optional[StrOrBool]],
model_config: typing.OrderedDict[str, Optional[StrOrBool]]) -> str:

def unique_basename_segment(x, max_len=32):
s = str(x).replace(" ", "")
Expand All @@ -483,14 +484,13 @@ def unique_basename_segment(x, max_len=32):
]
return "-".join(segments)

def _get_results_file_path(self,
experiment_config: OrderedDict[
str, Optional[StrOrBool]],
model_config: OrderedDict[str,
Optional[StrOrBool]],
partial_name: str,
ext: Optional[str] = "txt",
sub_dirname: Optional[str] = None) -> str:
def _get_results_file_path(
self,
experiment_config: typing.OrderedDict[str, Optional[StrOrBool]],
model_config: typing.OrderedDict[str, Optional[StrOrBool]],
partial_name: str,
ext: Optional[str] = "txt",
sub_dirname: Optional[str] = None) -> str:
is_dir = ext is None
model_name = model_config["model_name"]
basename = self._unique_basename(experiment_config, model_config)
Expand All @@ -507,38 +507,38 @@ def _get_results_file_path(self,

return path

def _get_results_dir_path(self,
experiment_config: OrderedDict[str,
Optional[StrOrBool]],
model_config: OrderedDict[str, Optional[StrOrBool]],
partial_name: str,
sub_dirname: Optional[str] = None) -> str:
def _get_results_dir_path(
self,
experiment_config: typing.OrderedDict[str, Optional[StrOrBool]],
model_config: typing.OrderedDict[str, Optional[StrOrBool]],
partial_name: str,
sub_dirname: Optional[str] = None) -> str:
return self._get_results_file_path(
experiment_config,
model_config,
partial_name,
ext=None,
sub_dirname=sub_dirname)

def _save_results_file(self,
text: str,
experiment_config: OrderedDict[str,
Optional[StrOrBool]],
model_config: OrderedDict[str, Optional[StrOrBool]],
partial_name: str,
ext: str = "txt",
sub_dirname: Optional[str] = None,
mode: str = "w"):
def _save_results_file(
self,
text: str,
experiment_config: typing.OrderedDict[str, Optional[StrOrBool]],
model_config: typing.OrderedDict[str, Optional[StrOrBool]],
partial_name: str,
ext: str = "txt",
sub_dirname: Optional[str] = None,
mode: str = "w"):
path = self._get_results_file_path(experiment_config, model_config,
partial_name, ext, sub_dirname)
with open(path, mode, encoding="utf-8") as f:
f.write(text)

def _save_results(
self,
experiment_config: OrderedDict[str, Optional[StrOrBool]],
model_config: OrderedDict[str, Optional[StrOrBool]],
metrics: OrderedDict[str, Any],
experiment_config: typing.OrderedDict[str, Optional[StrOrBool]],
model_config: typing.OrderedDict[str, Optional[StrOrBool]],
metrics: typing.OrderedDict[str, Any],
verification_result: Optional[VerificationResult] = VerificationResult(
VerificationCode.CANNOT_PROCEED_WITH_VERIFICATION)):
results = OrderedDict()
Expand All @@ -558,11 +558,11 @@ def _save_results(
# Helpers to dump and analyze the PyTorch profile, PyTorch/XLA metrics, etc. #
##############################################################################

def _dump_pytorch_profile(self, profile: Optional[torch.profiler.profile],
experiment_config: OrderedDict[str,
Optional[StrOrBool]],
model_config: OrderedDict[str, Optional[StrOrBool]],
repeat_iteration: int):
def _dump_pytorch_profile(
self, profile: Optional[torch.profiler.profile],
experiment_config: typing.OrderedDict[str, Optional[StrOrBool]],
model_config: typing.OrderedDict[str, Optional[StrOrBool]],
repeat_iteration: int):
assert profile is not None, "Expect PyTorch profile"

# Dump PyTorch trace.
Expand Down Expand Up @@ -653,11 +653,10 @@ def get_xla_cpu_fallback_ops(met):
metrics["inductor_ops"] = dict()
metrics["inductor_ops"][op_name] = extract_prof_info(event)

def _dump_dynamo_counters(self,
experiment_config: OrderedDict[str,
Optional[StrOrBool]],
model_config: OrderedDict[str, Optional[StrOrBool]],
repeat_iteration: int):
def _dump_dynamo_counters(
self, experiment_config: typing.OrderedDict[str, Optional[StrOrBool]],
model_config: typing.OrderedDict[str, Optional[StrOrBool]],
repeat_iteration: int):
text = f"{json.dumps(dynamo_utils.counters)}\n"
self._save_results_file(
text,
Expand All @@ -667,9 +666,9 @@ def _dump_dynamo_counters(self,
sub_dirname=str(repeat_iteration))

def _dump_pytorch_xla_metrics(
self, experiment_config: OrderedDict[str, Optional[StrOrBool]],
model_config: OrderedDict[str,
Optional[StrOrBool]], repeat_iteration: int):
self, experiment_config: typing.OrderedDict[str, Optional[StrOrBool]],
model_config: typing.OrderedDict[str, Optional[StrOrBool]],
repeat_iteration: int):
text = met.metrics_report()
assert isinstance(text, str)
self._save_results_file(
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def gmm_xla(
rhs: torch.Tensor,
group_sizes: torch.Tensor,
# pytorch custom op does not allow tuple type, use list instead
tiling: Optional[list[int]] = [512, 512, 512]):
tiling: Optional[List[int]] = [512, 512, 512]):
assert len(tiling) == 3, "tiling must be a list with 3 integers"
assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]"
assert rhs.dim(
Expand All @@ -932,7 +932,7 @@ def gmm_xla(
def gmm_non_xla(lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
tiling: Optional[list[int]] = [512, 512, 512]):
tiling: Optional[List[int]] = [512, 512, 512]):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if lhs.device != torch.device("meta"):
Expand Down

0 comments on commit 9952e38

Please sign in to comment.