Skip to content

Commit

Permalink
Default to functionalization disabled.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Sep 28, 2024
1 parent ecc0f5a commit bfa7ea1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
18 changes: 16 additions & 2 deletions benchmarks/benchmark_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def list_experiment_configs(self):
"torch_xla2": [None], # options only apply to torch_xla2
"test": ["eval", "train"],
"keep_model_data_on_cuda": [False],
"enable_functionalization": [False],
}

# Apply command line choices.
Expand All @@ -49,6 +50,10 @@ def list_experiment_configs(self):
config_choices["keep_model_data_on_cuda"] = [
self._args.keep_model_data_on_cuda
]
if self._args.enable_functionalization:
config_choices["enable_functionalization"] = [
self._args.enable_functionalization
]

# Expand experiment configs and add env vars.
logger.debug(f"Expand experiment configs")
Expand Down Expand Up @@ -136,6 +141,7 @@ def load_experiment(self,
batch_size = experiment_config.get("batch_size", self._args.batch_size)
torch_xla2 = experiment_config["torch_xla2"]
keep_model_data_on_cuda = experiment_config["keep_model_data_on_cuda"]
enable_functionalization = experiment_config["enable_functionalization"]
return BenchmarkExperiment(
accelerator=accelerator,
xla=xla,
Expand All @@ -144,14 +150,17 @@ def load_experiment(self,
torch_xla2=torch_xla2,
keep_model_data_on_cuda=keep_model_data_on_cuda,
test=test,
batch_size=batch_size)
batch_size=batch_size,
enable_functionalization=enable_functionalization,
)


class BenchmarkExperiment:

def __init__(self, accelerator: str, xla: Optional[str],
xla_flags: Optional[str], dynamo: str, torch_xla2: bool,
keep_model_data_on_cuda: bool, test: str, batch_size: str):
keep_model_data_on_cuda: bool, test: str, batch_size: str,
enable_functionalization: bool):
self.accelerator = accelerator
self.xla = xla
self.xla_flags = xla_flags
Expand All @@ -161,6 +170,7 @@ def __init__(self, accelerator: str, xla: Optional[str],
self.test = test
self.batch_size = batch_size
self.accelerator_model = get_accelerator_model(self.accelerator)
self.enable_functionalization = enable_functionalization

def update_process_env(self, process_env: Dict[str, str]):

Expand Down Expand Up @@ -192,6 +202,9 @@ def update_process_env(self, process_env: Dict[str, str]):
if self.xla_flags:
process_env["XLA_FLAGS"] = self.xla_flags

if not self.enable_functionalization:
process_env["XLA_DISABLE_FUNCTIONALIZATION"] = "1"

def get_device(self):
if self.torch_xla2:
# Initiate the model in CPU first for xla2. We will move the model to jax device later.
Expand Down Expand Up @@ -236,4 +249,5 @@ def to_dict(self):
d["keep_model_data_on_cuda"] = self.keep_model_data_on_cuda
d["test"] = self.test
d["batch_size"] = self.batch_size
d["enable_functionalization"] = self.enable_functionalization
return d
5 changes: 5 additions & 0 deletions benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,11 @@ def __str__(self):
help="""ID of the benchmark suite partition to be run. Used to divide CI
tasks""",
)
parser.add_argument(
"--enable-functionalization",
action="store_true",
help="Enable the functionalization layer by default",
)
parser.add_argument(
"--dry-run",
action="store_true",
Expand Down

0 comments on commit bfa7ea1

Please sign in to comment.