Skip to content

Commit

Permalink
add gail cost tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 committed Jul 11, 2024
1 parent 956567f commit 8e7713f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 58 deletions.
85 changes: 49 additions & 36 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -10438,6 +10438,8 @@ def _create_mock_data_gail(self, batch=2, obs_dim=3, action_dim=4, device="cpu")
source={
"observation": obs,
"action": action,
"collector_action": action,
"collector_observation": obs,
},
device=device,
)
Expand All @@ -10455,6 +10457,8 @@ def _create_seq_mock_data_gail(
source={
"observation": obs,
"action": action,
"collector_action": action,
"collector_observation": obs,
},
device=device,
)
Expand All @@ -10478,23 +10482,26 @@ def test_gail_tensordict_keys(self):
)

@pytest.mark.parametrize("device", get_default_devices())
def test_gail_notensordict(self, device):
@pytest.mark.parametrize("use_grad_penalty", [True, False])
@pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
def test_gail_notensordict(self, device, use_grad_penalty, gp_lambda):
torch.manual_seed(self.seed)
discriminator = self._create_mock_discriminator(device=device)
loss_fn = DTLoss(discriminator)

expert_td = self._create_mock_data_gail(device=device)
collector_td = self._create_mock_data_gail(device=device)
expert_td.set(
loss_fn.tensor_keys.collector_observation, collector_td["observation"]
loss_fn = GAILLoss(
discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda
)
expert_td.set(loss_fn.tensor_keys.collector_action, collector_td["action"])

tensordict = self._create_mock_data_gail(device=device)

in_keys = self._flatten_in_keys(loss_fn.in_keys)
kwargs = dict(expert_td.flatten_keys("_").select(*in_keys))
kwargs = dict(tensordict.flatten_keys("_").select(*in_keys))

loss_val_td = loss_fn(tensordict)
if use_grad_penalty:
loss_val, _ = loss_fn(**kwargs)
else:
loss_val = loss_fn(**kwargs)

loss_val_td = loss_fn(expert_td)
loss_val = loss_fn(**kwargs)
torch.testing.assert_close(loss_val_td.get("loss"), loss_val)
# test select
loss_fn.select_out_keys("loss")
Expand All @@ -10510,25 +10517,27 @@ def test_gail_notensordict(self, device):
assert loss_discriminator == loss_val_td["loss"]

@pytest.mark.parametrize("device", get_available_devices())
def test_dt(self, device):
@pytest.mark.parametrize("use_grad_penalty", [True, False])
@pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
def test_gail(self, device, use_grad_penalty, gp_lambda):
torch.manual_seed(self.seed)
td = self._create_mock_data_dt(device=device)
td = self._create_mock_data_gail(device=device)

actor = self._create_mock_actor(device=device)
discriminator = self._create_mock_discriminator(device=device)

loss_fn = DTLoss(actor)
loss_fn = GAILLoss(
discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda
)
loss = loss_fn(td)
loss_transformer = loss["loss"]
loss_transformer.backward(retain_graph=True)
named_parameters = loss_fn.named_parameters()

for name, p in named_parameters:
if p.grad is not None and p.grad.norm() > 0.0:
assert "actor" in name
assert "alpha" not in name
assert "discriminator" in name
if p.grad is None:
assert "actor" not in name
assert "alpha" in name
assert "discriminator" not in name
loss_fn.zero_grad()

sum([loss_transformer]).backward()
Expand All @@ -10542,36 +10551,38 @@ def test_dt(self, device):
assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"

@pytest.mark.parametrize("device", get_available_devices())
def test_dt_state_dict(self, device):
def test_gail_state_dict(self, device):
torch.manual_seed(self.seed)

actor = self._create_mock_actor(device=device)
discriminator = self._create_mock_discriminator(device=device)

loss_fn = DTLoss(actor)
loss_fn = GAILLoss(discriminator)
sd = loss_fn.state_dict()
loss_fn2 = DTLoss(actor)
loss_fn2 = GAILLoss(discriminator)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("device", get_available_devices())
def test_seq_dt(self, device):
@pytest.mark.parametrize("use_grad_penalty", [True, False])
@pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
def test_seq_gail(self, device, use_grad_penalty, gp_lambda):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_dt(device=device)
td = self._create_seq_mock_data_gail(device=device)

actor = self._create_mock_actor(device=device)
discriminator = self._create_mock_discriminator(device=device)

loss_fn = DTLoss(actor)
loss_fn = GAILLoss(
discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda
)
loss = loss_fn(td)
loss_transformer = loss["loss"]
loss_transformer.backward(retain_graph=True)
named_parameters = loss_fn.named_parameters()

for name, p in named_parameters:
if p.grad is not None and p.grad.norm() > 0.0:
assert "actor" in name
assert "alpha" not in name
assert "discriminator" in name
if p.grad is None:
assert "actor" not in name
assert "alpha" in name
assert "discriminator" not in name
loss_fn.zero_grad()

sum([loss_transformer]).backward()
Expand All @@ -10585,19 +10596,21 @@ def test_seq_dt(self, device):
assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_dt_reduction(self, reduction):
@pytest.mark.parametrize("use_grad_penalty", [True, False])
@pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
def test_gail_reduction(self, reduction, use_grad_penalty, gp_lambda):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_dt(device=device)
actor = self._create_mock_actor(device=device)
loss_fn = DTLoss(actor, reduction=reduction)
td = self._create_mock_data_gail(device=device)
discriminator = self._create_mock_discriminator(device=device)
loss_fn = GAILLoss(discriminator, reduction=reduction)
loss = loss_fn(td)
if reduction == "none":
assert loss["loss"].shape == td["action"].shape
assert loss["loss"].shape == (td["observation"].shape[0], 1)
else:
assert loss["loss"].shape == torch.Size([])

Expand Down
65 changes: 43 additions & 22 deletions torchrl/objectives/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ class _AcceptedKeys:

discriminator_network: TensorDictModule
discriminator_network_params: TensorDictParams
target_discriminator_network: TensorDictModule
target_discriminator_network_params: TensorDictParams

out_keys = [
"loss",
"gp_loss",
]

def __init__(
self,
Expand All @@ -84,7 +91,7 @@ def __init__(
"discriminator_network",
create_target_params=False,
)
self.loss_function = torch.nn.BCELoss()
self.loss_function = torch.nn.BCELoss(reduction="none")
self.use_grad_penalty = use_grad_penalty
self.gp_lambda = gp_lambda

Expand All @@ -95,6 +102,8 @@ def _set_in_keys(self):
keys = set(keys)
keys.add(self.tensor_keys.expert_observation)
keys.add(self.tensor_keys.expert_action)
keys.add(self.tensor_keys.collector_observation)
keys.add(self.tensor_keys.collector_action)
self._in_keys = sorted(keys, key=str)

def _forward_value_estimator_keys(self, **kwargs) -> None:
Expand All @@ -114,6 +123,8 @@ def in_keys(self, values):
def out_keys(self):
if self._out_keys is None:
keys = ["loss"]
if self.use_grad_penalty:
keys.append("gp_loss")
self._out_keys = keys
return self._out_keys

Expand All @@ -126,9 +137,19 @@ def forward(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
"""Compute the GAIL discriminator loss."""
"""The forward method.
Computes the discriminator loss and gradient penalty if `use_grad_penalty` is set to True. If `use_grad_penalty` is set to True, the detached gradient penalty loss is also returned for logging purposes.
To see what keys are expected in the input tensordict and what keys are expected as output, check the
class's `"in_keys"` and `"out_keys"` attributes.
"""
device = self.discriminator_network.device
tensordict = tensordict.clone(False)
batch_size = tensordict.batch_size[0]
shape = tensordict.shape
if len(shape) > 1:
batch_size, seq_len = shape
else:
batch_size = shape[0]
collector_obs = tensordict.get(self.tensor_keys.collector_observation)
collector_act = tensordict.get(self.tensor_keys.collector_action)

Expand All @@ -144,15 +165,20 @@ def forward(
self.tensor_keys.expert_action: combined_act_inputs,
},
batch_size=[2 * batch_size],
device=device,
)

# create labels
fake_labels = torch.zeros((batch_size, 1), dtype=torch.float32).to(
tensordict.device
)
real_labels = torch.ones((batch_size, 1), dtype=torch.float32).to(
tensordict.device
)
# create
if len(shape) > 1:
fake_labels = torch.zeros((batch_size, seq_len, 1), dtype=torch.float32).to(
device
)
real_labels = torch.ones((batch_size, seq_len, 1), dtype=torch.float32).to(
device
)
else:
fake_labels = torch.zeros((batch_size, 1), dtype=torch.float32).to(device)
real_labels = torch.ones((batch_size, 1), dtype=torch.float32).to(device)

with self.discriminator_network_params.to_module(self.discriminator_network):
d_logits = self.discriminator_network(combined_inputs).get(
Expand All @@ -167,22 +193,18 @@ def forward(
collection_loss = self.loss_function(collection_preds, fake_labels)

loss = expert_loss + collection_loss
out = {"loss": loss}
out = {}
if self.use_grad_penalty:
obs = tensordict.get(self.tensor_keys.collector_observation)
acts = tensordict.get(self.tensor_keys.collector_action)
obs_e = tensordict.get(self.tensor_keys.expert_observation)
acts_e = tensordict.get(self.tensor_keys.expert_action)

obss_noise = (
torch.distributions.Uniform(0.0, 1.0)
.sample(obs_e.shape)
.to(tensordict.device)
torch.distributions.Uniform(0.0, 1.0).sample(obs_e.shape).to(device)
)
acts_noise = (
torch.distributions.Uniform(0.0, 1.0)
.sample(acts_e.shape)
.to(tensordict.device)
torch.distributions.Uniform(0.0, 1.0).sample(acts_e.shape).to(device)
)
obss_mixture = obss_noise * obs + (1 - obss_noise) * obs_e
acts_mixture = acts_noise * acts + (1 - acts_noise) * acts_e
Expand All @@ -195,6 +217,7 @@ def forward(
self.tensor_keys.expert_action: acts_mixture,
},
[],
device=device,
)

with self.discriminator_network_params.to_module(
Expand All @@ -208,9 +231,7 @@ def forward(
autograd.grad(
outputs=d_logits_mixture,
inputs=(obss_mixture, acts_mixture),
grad_outputs=torch.ones(
d_logits_mixture.size(), device=tensordict.device
),
grad_outputs=torch.ones(d_logits_mixture.size(), device=device),
create_graph=True,
retain_graph=True,
only_inputs=True,
Expand All @@ -223,8 +244,8 @@ def forward(
)

loss += gp_loss
out["gp_loss"] = gp_loss
out["gp_loss"] = gp_loss.detach()
loss = _reduce(loss, reduction=self.reduction)

out["loss"] = loss
td_out = TensorDict(out, [])
return td_out

0 comments on commit 8e7713f

Please sign in to comment.