Skip to content

Commit

Permalink
add annotations and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisyrniu committed Apr 19, 2021
1 parent 676ebe4 commit 0e0fe26
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 99 deletions.
160 changes: 106 additions & 54 deletions magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,46 @@ def __init__(self, args):

dropout = 0
negative_slope = 0.2

# initialize sub-processors
self.sub_processor1 = GraphAttention(args.hid_size, args.gat_hid_size, dropout=dropout, negative_slope=negative_slope, num_heads=args.gat_num_heads, self_loop_type=args.self_loop_type1, average=False, normalize=args.first_gat_normalize)
self.sub_processor2 = GraphAttention(args.gat_hid_size*args.gat_num_heads, args.hid_size, dropout=dropout, negative_slope=negative_slope, num_heads=args.gat_num_heads_out, self_loop_type=args.self_loop_type2, average=True, normalize=args.second_gat_normalize)
if args.use_gconv_encoder:
self.gconv_encoder = GraphAttention(args.hid_size, args.gconv_encoder_out_size, dropout=dropout, negative_slope=negative_slope, num_heads=args.ge_num_heads, self_loop_type=1, average=True, normalize=args.gconv_gat_normalize)
# initialize the gat encoder for the Scheduler
if args.use_gat_encoder:
self.gat_encoder = GraphAttention(args.hid_size, args.gat_encoder_out_size, dropout=dropout, negative_slope=negative_slope, num_heads=args.ge_num_heads, self_loop_type=1, average=True, normalize=args.gat_encoder_normalize)

self.encoder = nn.Linear(args.obs_size, args.hid_size)
self.obs_encoder = nn.Linear(args.obs_size, args.hid_size)

self.init_hidden(args.batch_size)
self.lstm_cell= nn.LSTMCell(args.hid_size, args.hid_size)

# initialize mlp layers for the sub-schedulers
if not args.first_graph_complete:
if args.use_gconv_encoder:
self.hard_attn1 = nn.Sequential(
nn.Linear(args.gconv_encoder_out_size*2, args.gconv_encoder_out_size//2),
if args.use_gat_encoder:
self.sub_scheduler_mlp1 = nn.Sequential(
nn.Linear(args.gat_encoder_out_size*2, args.gat_encoder_out_size//2),
nn.ReLU(),
nn.Linear(args.gconv_encoder_out_size//2, args.gconv_encoder_out_size//2),
nn.Linear(args.gat_encoder_out_size//2, args.gat_encoder_out_size//2),
nn.ReLU(),
nn.Linear(args.gconv_encoder_out_size//2, 2))
nn.Linear(args.gat_encoder_out_size//2, 2))
else:
self.hard_attn1 = nn.Sequential(
self.sub_scheduler_mlp1 = nn.Sequential(
nn.Linear(self.hid_size*2, self.hid_size//2),
nn.ReLU(),
nn.Linear(self.hid_size//2, self.hid_size//8),
nn.ReLU(),
nn.Linear(self.hid_size//8, 2))

if args.learn_second_graph and not args.second_graph_complete:
if args.use_gconv_encoder:
self.hard_attn2 = nn.Sequential(
nn.Linear(args.gconv_encoder_out_size*2, args.gconv_encoder_out_size//2),
if args.use_gat_encoder:
self.sub_scheduler_mlp2 = nn.Sequential(
nn.Linear(args.gat_encoder_out_size*2, args.gat_encoder_out_size//2),
nn.ReLU(),
nn.Linear(args.gconv_encoder_out_size//2, args.gconv_encoder_out_size//2),
nn.Linear(args.gat_encoder_out_size//2, args.gat_encoder_out_size//2),
nn.ReLU(),
nn.Linear(args.gconv_encoder_out_size//2, 2))
nn.Linear(args.gat_encoder_out_size//2, 2))
else:
self.hard_attn2 = nn.Sequential(
self.sub_scheduler_mlp2 = nn.Sequential(
nn.Linear(self.hid_size*2, self.hid_size//2),
nn.ReLU(),
nn.Linear(self.hid_size//2, self.hid_size//8),
Expand All @@ -71,70 +75,100 @@ def __init__(self, args):
if args.message_decoder:
self.message_decoder = nn.Linear(args.hid_size, args.hid_size)

# initialise weights as 0
if args.comm_init == 'zeros':
# initialize weights as 0
if args.learn_second_graph and args.comm_init == 'zeros':
if args.message_encoder:
self.message_encoder.weight.data.zero_()
if args.message_decoder:
self.message_decoder.weight.data.zero_()
if not args.first_graph_complete:
self.hard_attn1.apply(self.init_linear)
self.sub_scheduler_mlp1.apply(self.init_linear)
if not args.second_graph_complete:
self.hard_attn2.apply(self.init_linear)
self.sub_scheduler_mlp2.apply(self.init_linear)

# initialize the action head (in practice, one action head is used)
self.action_heads = nn.ModuleList([nn.Linear(2*args.hid_size, o)
for o in args.naction_heads])

# initialize the value head
self.value_head = nn.Linear(2 * self.hid_size, 1)

self.tanh = nn.Tanh()


def forward(self, x, info={}):
"""
Forward function of MAGIC (two rounds of communication)
x, hidden_state, cell_state = self.forward_state_encoder(x)
Arguments:
x (list): a list for the input of the communication protocol [observations, (previous hidden states, previous cell states)]
observations (tensor): the observations for all agents [1 (batch_size) * n * obs_size]
previous hidden/cell states (tensor): the hidden/cell states from the previous time steps [n * hid_size]
Returns:
action_out (list): a list of tensors of size [1 (batch_size) * n * num_actions] that represent output policy distributions
value_head (tensor): estimated values [n * 1]
next hidden/cell states (tensor): next hidden/cell states [n * hid_size]
"""

# n: number of agents

obs, extras = x

# encoded_obs: [1 (batch_size) * n * hid_size]
encoded_obs = self.obs_encoder(obs)
hidden_state, cell_state = extras

batch_size = x.size()[0]
batch_size = encoded_obs.size()[0]
n = self.nagents

num_agents_alive, agent_mask = self.get_agent_mask(batch_size, info)

hidden_state, cell_state = self.lstm_cell(x.squeeze(), (hidden_state, cell_state))
# if self.args.comm_mask_zero == True, block the communiction (can also comment out the protocol to make training faster)
if self.args.comm_mask_zero:
agent_mask *= torch.zeros(n, 1)

hidden_state, cell_state = self.lstm_cell(encoded_obs.squeeze(), (hidden_state, cell_state))

# comm: [n * hid_size]
comm = hidden_state
if self.args.message_encoder:
comm = self.message_encoder(comm)

# Mask communcation from dead agents (only effective in Traffic Junction)
# mask communcation from dead agents (only effective in Traffic Junction)
comm = comm * agent_mask
comm_ori = comm.clone()

# sub-scheduler 1
# if args.first_graph_complete == True, sub-scheduler 1 will be disabled
if not self.args.first_graph_complete:
if self.args.use_gconv_encoder:
if self.args.use_gat_encoder:
adj_complete = self.get_complete_graph(agent_mask)
encoded_state1 = self.gconv_encoder(comm, adj_complete)
adj1 = self.get_adj_matrix(self.hard_attn1, encoded_state1, agent_mask, self.args.directed)
encoded_state1 = self.gat_encoder(comm, adj_complete)
adj1 = self.sub_scheduler(self.sub_scheduler_mlp1, encoded_state1, agent_mask, self.args.directed)
else:
adj1 = self.get_adj_matrix(self.hard_attn1, comm, agent_mask, self.args.directed)
adj1 = self.sub_scheduler(self.sub_scheduler_mlp1, comm, agent_mask, self.args.directed)
else:
adj1 = self.get_complete_graph(agent_mask)


# sub-processor 1
comm = F.elu(self.sub_processor1(comm, adj1))

# sub-scheduler 2
if self.args.learn_second_graph and not self.args.second_graph_complete:
if self.args.use_gconv_encoder:
if self.args.use_gat_encoder:
encoded_state2 = encoded_state1
adj2 = self.get_adj_matrix(self.hard_attn2, encoded_state2, agent_mask, self.args.directed)
adj2 = self.sub_scheduler(self.sub_scheduler_mlp2, encoded_state2, agent_mask, self.args.directed)
else:
adj2 = self.get_adj_matrix(self.hard_attn2, comm_ori, agent_mask, self.args.directed)
adj2 = self.sub_scheduler(self.sub_scheduler_mlp2, comm_ori, agent_mask, self.args.directed)
elif not self.args.learn_second_graph and not self.args.second_graph_complete:
adj2 = adj1
else:
adj2 = self.get_complete_graph(agent_mask)

# sub-processor 2
comm = self.sub_processor2(comm, adj2)

# Mask communication to dead agents (only effective in Traffic Junction)
# mask communication to dead agents (only effective in Traffic Junction)
comm = comm * agent_mask

if self.args.message_decoder:
Expand All @@ -149,6 +183,14 @@ def forward(self, x, info={}):
return action_out, value_head, (hidden_state.clone(), cell_state.clone())

def get_agent_mask(self, batch_size, info):
"""
Function to generate agent mask to mask out inactive agents (only effective in Traffic Junction)
Returns:
num_agents_alive (int): number of active agents
agent_mask (tensor): [n, 1]
"""

n = self.nagents

if 'alive_mask' in info:
Expand All @@ -162,50 +204,60 @@ def get_agent_mask(self, batch_size, info):

return num_agents_alive, agent_mask

def forward_state_encoder(self, x):
hidden_state, cell_state = None, None

x, extras = x
x = self.encoder(x)

hidden_state, cell_state = extras

return x, hidden_state, cell_state


def init_linear(self, m):
"""
Function to initialize the parameters in nn.Linear as o
"""
if type(m) == nn.Linear:
m.weight.data.fill_(0.)
m.bias.data.fill_(0.)

def init_hidden(self, batch_size):
# dim 0 = num of layers * num of direction
"""
Function to initialize the hidden states and cell states
"""
return tuple(( torch.zeros(batch_size * self.nagents, self.hid_size, requires_grad=True),
torch.zeros(batch_size * self.nagents, self.hid_size, requires_grad=True)))


def get_adj_matrix(self, hard_attn_model, hidden_state, agent_mask, directed=True):
# hidden_state size: n * hid_size
def sub_scheduler(self, sub_scheduler_mlp, hidden_state, agent_mask, directed=True):
"""
Function to perform a sub-scheduler
Arguments:
sub_scheduler_mlp (nn.Sequential): the MLP layers in a sub-scheduler
hidden_state (tensor): the encoded messages input to the sub-scheduler [n * hid_size]
agent_mask (tensor): [n * 1]
directed (bool): decide if generate directed graphs
Return:
adj (tensor): a adjacency matrix which is the communication graph [n * n]
"""

# hidden_state: [n * hid_size]
n = self.args.nagents
hid_size = hidden_state.size(-1)
# hard_attn_input size: n * n * (2*hid_size)
# hard_attn_input: [n * n * (2*hid_size)]
hard_attn_input = torch.cat([hidden_state.repeat(1, n).view(n * n, -1), hidden_state.repeat(n, 1)], dim=1).view(n, -1, 2 * hid_size)
# hard_attn_output size: n * n * 2
# hard_attn_output: [n * n * 2]
if directed:
hard_attn_output = F.gumbel_softmax(hard_attn_model(hard_attn_input), hard=True)
hard_attn_output = F.gumbel_softmax(sub_scheduler_mlp(hard_attn_input), hard=True)
else:
hard_attn_output = F.gumbel_softmax(0.5*hard_attn_model(hard_attn_input)+0.5*hard_attn_model(hard_attn_input.permute(1,0,2)), hard=True)
# hard_attn_output size: n * n * 1
hard_attn_output = F.gumbel_softmax(0.5*sub_scheduler_mlp(hard_attn_input)+0.5*sub_scheduler_mlp(hard_attn_input.permute(1,0,2)), hard=True)
# hard_attn_output: [n * n * 1]
hard_attn_output = torch.narrow(hard_attn_output, 2, 1, 1)
# agent_mask and its transpose size: n * n
# agent_mask and agent_mask_transpose: [n * n]
agent_mask = agent_mask.expand(n, n)
agent_mask_transpose = agent_mask.transpose(0, 1)
# adj size: n * n
# adj: [n * n]
adj = hard_attn_output.squeeze() * agent_mask * agent_mask_transpose

return adj

def get_complete_graph(self, agent_mask):
"""
Function to generate a complete graph, and mask it with agent_mask
"""
n = self.args.nagents
adj = torch.ones(n, n)
agent_mask = agent_mask.expand(n, n)
Expand Down
27 changes: 15 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

parser = argparse.ArgumentParser(description='Multi-Agent Graph Attention Communication')

# training
parser.add_argument('--num_epochs', default=100, type=int,
help='number of training epochs')
parser.add_argument('--epoch_size', type=int, default=10,
Expand All @@ -51,15 +52,15 @@
parser.add_argument('--ge_num_heads', default=4, type=int,
help='number of heads in the gat encoder')
parser.add_argument('--first_gat_normalize', action='store_true', default=False,
help='if normalize first gat layer')
help='if normalize the coefficients in the first gat layer of the message processor')
parser.add_argument('--second_gat_normalize', action='store_true', default=False,
help='if normilize second gat layer')
parser.add_argument('--gconv_gat_normalize', action='store_true', default=False,
help='if normilize gconv gat layer')
parser.add_argument('--use_gconv_encoder', action='store_true', default=False,
help='if use gconv encoder before learning the first graph')
parser.add_argument('--gconv_encoder_out_size', default=64, type=int,
help='hidden size of output of the gconv encoder')
help='if normilize the coefficients in the second gat layer of the message proccessor')
parser.add_argument('--gat_encoder_normalize', action='store_true', default=False,
help='if normilize the coefficients in the gat encoder (they have been normalized if the input graph is complete)')
parser.add_argument('--use_gat_encoder', action='store_true', default=False,
help='if use gat encoder before learning the first graph')
parser.add_argument('--gat_encoder_out_size', default=64, type=int,
help='hidden size of output of the gat encoder')
parser.add_argument('--first_graph_complete', action='store_true', default=False,
help='if the first graph is set to a complete graph')
parser.add_argument('--second_graph_complete', action='store_true', default=False,
Expand All @@ -75,20 +76,22 @@
parser.add_argument('--mean_ratio', default=0, type=float,
help='how much coooperative to do? 1.0 means fully cooperative')
parser.add_argument('--detach_gap', default=10000, type=int,
help='detach hidden state and cell state for rnns at this interval.'
+ ' Default 10000 (very high)')
help='detach hidden state and cell state for rnns at this interval')
parser.add_argument('--comm_init', default='uniform', type=str,
help='how to initialise comm weights [uniform|zeros]')
parser.add_argument('--advantages_per_action', default=False, action='store_true',
help='Whether to multipy log porb for each chosen action with advantages')
help='whether to multipy log porb for each chosen action with advantages')
parser.add_argument('--comm_mask_zero', action='store_true', default=False,
help="Whether block the communication")


# optimization
parser.add_argument('--gamma', type=float, default=1.0,
help='discount factor')
parser.add_argument('--tau', type=float, default=1.0,
help='gae (remove?)')
parser.add_argument('--seed', type=int, default=-1,
help='random seed. Pass -1 for random seed') # TODO: works in thread?
help='random seed. Pass -1 for random seed')
parser.add_argument('--normalize_rewards', action='store_true', default=False,
help='normalize rewards in each batch')
parser.add_argument('--lrate', type=float, default=0.001,
Expand Down
1 change: 1 addition & 0 deletions plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def parse_plot(files, term='Reward'):
# plt.title('GFootball {} {}'.format(sys.argv[2], term))

files = glob.glob(sys.argv[1] + "*")
# filter out files with ".pt"
files = list(filter(lambda x: x.find(".pt") == -1, files))

# 'Epoch'/ 'Steps-taken'
Expand Down
Loading

0 comments on commit 0e0fe26

Please sign in to comment.