Skip to content

Commit

Permalink
Infer type of tensor for uniform noise
Browse files Browse the repository at this point in the history
  • Loading branch information
shariqiqbal2810 committed Feb 28, 2018
1 parent 7c29b11 commit b57dc2b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
2 changes: 1 addition & 1 deletion algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def update(self, sample, agent_i, parallel=False, logger=None):
# correct since it removes the assumption of a deterministic policy for
# DDPG. Regardless, discrete policies don't seem to learn properly without it.
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True, cuda=(self.pol_dev == 'gpu'))
curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
else:
curr_pol_out = curr_agent.policy(obs[agent_i])
curr_pol_vf_in = curr_pol_out
Expand Down
14 changes: 5 additions & 9 deletions utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,19 @@ def onehot_from_logits(logits, eps=0.0):
enumerate(torch.rand(logits.shape[0]))])

# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def sample_gumbel(shape, eps=1e-20, cuda=True):
def sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor):
"""Sample from Gumbel(0, 1)"""
if cuda:
tens_type = torch.cuda.FloatTensor
else:
tens_type = torch.FloatTensor
U = Variable(tens_type(*shape).uniform_(), requires_grad=False)
return -torch.log(-torch.log(U + eps) + eps)

# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def gumbel_softmax_sample(logits, temperature, cuda=True):
def gumbel_softmax_sample(logits, temperature):
""" Draw a sample from the Gumbel-Softmax distribution"""
y = logits + sample_gumbel(logits.shape, cuda=cuda)
y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data))
return F.softmax(y / temperature, dim=1)

# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
def gumbel_softmax(logits, temperature=1.0, hard=False, cuda=True):
def gumbel_softmax(logits, temperature=1.0, hard=False):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
Expand All @@ -89,7 +85,7 @@ def gumbel_softmax(logits, temperature=1.0, hard=False, cuda=True):
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y = gumbel_softmax_sample(logits, temperature, cuda=cuda)
y = gumbel_softmax_sample(logits, temperature)
if hard:
y_hard = onehot_from_logits(y)
y = (y_hard - y).detach() + y
Expand Down

0 comments on commit b57dc2b

Please sign in to comment.