diff --git a/benchmarks/benchmark_experiment.py b/benchmarks/benchmark_experiment.py index d7159c0bb5b..a82490d373b 100644 --- a/benchmarks/benchmark_experiment.py +++ b/benchmarks/benchmark_experiment.py @@ -27,6 +27,7 @@ def list_experiment_configs(self): "torch_xla2": [None], # options only apply to torch_xla2 "test": ["eval", "train"], "keep_model_data_on_cuda": [False], + "enable_functionalization": [False], } # Apply command line choices. @@ -49,6 +50,10 @@ def list_experiment_configs(self): config_choices["keep_model_data_on_cuda"] = [ self._args.keep_model_data_on_cuda ] + if self._args.enable_functionalization: + config_choices["enable_functionalization"] = [ + self._args.enable_functionalization + ] # Expand experiment configs and add env vars. logger.debug(f"Expand experiment configs") @@ -136,6 +141,7 @@ def load_experiment(self, batch_size = experiment_config.get("batch_size", self._args.batch_size) torch_xla2 = experiment_config["torch_xla2"] keep_model_data_on_cuda = experiment_config["keep_model_data_on_cuda"] + enable_functionalization = experiment_config["enable_functionalization"] return BenchmarkExperiment( accelerator=accelerator, xla=xla, @@ -144,14 +150,17 @@ def load_experiment(self, torch_xla2=torch_xla2, keep_model_data_on_cuda=keep_model_data_on_cuda, test=test, - batch_size=batch_size) + batch_size=batch_size, + enable_functionalization=enable_functionalization, + ) class BenchmarkExperiment: def __init__(self, accelerator: str, xla: Optional[str], xla_flags: Optional[str], dynamo: str, torch_xla2: bool, - keep_model_data_on_cuda: bool, test: str, batch_size: str): + keep_model_data_on_cuda: bool, test: str, batch_size: str, + enable_functionalization: bool): self.accelerator = accelerator self.xla = xla self.xla_flags = xla_flags @@ -161,6 +170,7 @@ def __init__(self, accelerator: str, xla: Optional[str], self.test = test self.batch_size = batch_size self.accelerator_model = get_accelerator_model(self.accelerator) + self.enable_functionalization = enable_functionalization def update_process_env(self, process_env: Dict[str, str]): @@ -192,6 +202,9 @@ def update_process_env(self, process_env: Dict[str, str]): if self.xla_flags: process_env["XLA_FLAGS"] = self.xla_flags + if not self.enable_functionalization: + process_env["XLA_DISABLE_FUNCTIONALIZATION"] = "1" + def get_device(self): if self.torch_xla2: # Initiate the model in CPU first for xla2. We will move the model to jax device later. @@ -236,4 +249,5 @@ def to_dict(self): d["keep_model_data_on_cuda"] = self.keep_model_data_on_cuda d["test"] = self.test d["batch_size"] = self.batch_size + d["enable_functionalization"] = self.enable_functionalization return d diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py index 815b00ec8bc..42f5e248293 100644 --- a/benchmarks/experiment_runner.py +++ b/benchmarks/experiment_runner.py @@ -832,6 +832,11 @@ def __str__(self): help="""ID of the benchmark suite partition to be run. Used to divide CI tasks""", ) + parser.add_argument( + "--enable-functionalization", + action="store_true", + help="Enable the functionalization layer by default", + ) parser.add_argument( "--dry-run", action="store_true",