Skip to content

Commit

Permalink
Update export.py (#6576)
Browse files Browse the repository at this point in the history
  • Loading branch information
wang2yn84 committed Feb 21, 2024
1 parent ee341ee commit 8c65f09
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions experimental/torch_xla2/torch_xla2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ def _extract_states_from_exported_program(exported_model):
for name in exported_model.graph_signature.lifted_tensor_constants:
param_buffer_values.append(exported_model.tensor_constants[name])

return param_buffer_values
return param_and_buffer_keys, param_buffer_values


def exported_program_to_jax(exported_program):
def exported_program_to_jax(exported_program, export_raw: bool = False):
"""returns a pytree of jax arrays(state), and
a callable(func) that is jax function.
Expand All @@ -207,7 +207,7 @@ def exported_program_to_jax(exported_program):
if DEBUG:
print(exported_program.graph_module.code)

states = _extract_states_from_exported_program(exported_program)
names, states = _extract_states_from_exported_program(exported_program)

def _extract_args(args, kwargs):
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
Expand All @@ -227,5 +227,8 @@ def func(states, inputs):
res = res[num_mutations:]
return res

if export_raw:
return names, states, func

states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
return states, func

0 comments on commit 8c65f09

Please sign in to comment.