8000 feature(whl): add SIL policy by kxzxvbk · Pull Request #675 · opendilab/DI-engine · GitHub
[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add SIL policy #675

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
add test file
  • Loading branch information
‘whl’ committed Jun 6, 2023
commit 25b94143f074ea23c1b8df31ea896b88824a9298
26 changes: 26 additions & 0 deletions ding/rl_utils/tests/test_sil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
import torch
from ding.rl_utils import sil_data, sil_error

random_weight = torch.rand(4) + 1
weight_args = [None, random_weight]


@pytest.mark.unittest
@pytest.mark.parametrize('weight, ', weight_args)
def test_a2c(weight):
B, N = 4, 32
logit = torch.randn(B, N).requires_grad_(True)
action = torch.randint(0, N, size=(B, ))
value = torch.randn(B).requires_grad_(True)
adv = torch.rand(B)
return_ = torch.randn(B) * 2
data = sil_data(logit, action, value, adv, return_, weight)
loss = sil_error(data)
assert all([l.shape == tuple() for l in loss])
assert logit.grad is None
assert value.grad is None
total_loss = sum(loss)
total_loss.backward()
assert isinstance(logit.grad, torch.Tensor)
assert isinstance(value.grad, torch.Tensor)
28DE
0