Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/honeynet/droidbot
Browse files Browse the repository at this point in the history
  • Loading branch information
connglli committed Mar 16, 2021
2 parents 1858210 + 58fab7c commit 6ae1535
Showing 1 changed file with 211 additions and 58 deletions.
269 changes: 211 additions & 58 deletions droidbot/input_policy2.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,39 @@
import sys
import json
import logging
import random
import time
import collections
import spacy
import copy
import numpy as np
from abc import abstractmethod
import logging
import random
import time
import math

import numpy as np
import spacy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from .input_event import InputEvent, KeyEvent, IntentEvent, TouchEvent, UIEvent, KillAppEvent
from .input_event import KeyEvent, IntentEvent, TouchEvent, UIEvent, KillAppEvent
from .input_policy import UtgBasedInputPolicy
from .device_state import DeviceState
from .utg import UTG
from .utils import lazy_property

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
DEBUG = True
ACTION_INEFFECTIVE = 'ineffective'

CLOSER_ACTION_ENCOURAGEMENT = 0.1
RANDOM_EXPLORE_PROB = 0.4
RANDOM_EXPLORE_PROB = 0.3
N_ACTIONS_TRAINING = 32

MAX_NUM_STEPS_OUTSIDE = 3
MAX_NUM_STEPS_OUTSIDE_KILL = 5
MAX_NAV_STEPS = 10


class UIEmbedModel(nn.Module):
class UIEmbedLSTM(nn.Module):
def __init__(self):
super().__init__()
input_size = 68
self.text_encoder = TextEncoder(method='spacy')
input_size = 18 + self.text_encoder.embed_size
embed_size = 100
output_size = 50
self.lstm = nn.LSTM(
Expand All @@ -49,43 +46,10 @@ def __init__(self):
self.fc = nn.Linear(embed_size, output_size)
# self.fc = nn.Linear(input_size, output_size)

def forward(self, x):
emb, _ = self.lstm(x)
return F.normalize(self.fc(emb))
# return F.normalize(self.fc(x))

def encode_state(self, state, views):
return torch.stack([self._encode_view(state, view) for view in views])

class Memory:
def __init__(self, utg, app):
self.utg = utg
self.app = app
self.known_states = collections.OrderedDict()
self.known_transitions = collections.OrderedDict()
self.nlp = spacy.load("en_core_web_md")
self.model = UIEmbedModel()

def _memorize_state(self, state):
if state.get_app_activity_depth(self.app) != 0:
return None
if state.state_str not in self.known_states:
views = state.views
views_str = [view['view_str'] for view in views]
views_enc = torch.stack([self._encode_view(view, state) for view in views])
embedder = self.model
embedder.eval()
with torch.no_grad():
views_emb = self.model(views_enc.unsqueeze(0))
views_emb = views_emb.detach().cpu()[0]
self.known_states[state.state_str] = {
'state': state,
'views': views,
'views_str': views_str,
'views_enc': views_enc,
'views_emb': views_emb
}
return self.known_states[state.state_str]

def _encode_view(self, view, state):
def _encode_view(self, state, view):
# print(view)
view_children = view['children'] if 'children' in view else []
is_parent = 1 if len(view_children) > 0 else -1
Expand All @@ -110,7 +74,7 @@ def _encode_view(self, view, state):
size = view_w * view_h
wh_ratio = view_w / (view_h + 0.0001)
wh_ratio = min(wh_ratio, 10)
text_emb = self.nlp(text=view_text).vector[:50] if view_text else np.zeros(50)
text_emb = self.text_encoder.encode(view_text)
encoding = np.concatenate([np.array([
is_parent, is_text, is_password, visible,
l, r, t, b, size, wh_ratio,
Expand All @@ -119,6 +83,194 @@ def _encode_view(self, view, state):
]), text_emb])
return torch.Tensor(encoding)

def forward(self, state_encs):
state_encs = pad_sequence(state_encs, batch_first=True)
emb, _ = self.lstm(state_encs)
return F.normalize(self.fc(emb))
# return F.normalize(self.fc(x))


class TextEncoder:
def __init__(self, method='spacy'):
self.method = method
self.embed_size = -1
if method == 'spacy':
self.nlp = spacy.load("en_core_web_md")
self.embed_size = 300
if method == 'bert':
from transformers import BertTokenizer, BertModel
from torch.nn import TransformerEncoder, TransformerEncoderLayer
self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
self.text_encoder = BertModel.from_pretrained('bert-base-multilingual-cased')
self.embed_size = 768

def encode(self, text):
if not text:
return np.zeros(self.embed_size)
if self.method == 'spacy':
doc = self.nlp(text)
return doc.vector
if self.method == 'bert':
encoding = self.tokenizer([text], return_tensors='pt', padding=True, truncation=True)
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
text_encoder_out = self.text_encoder(input_ids, attention_mask=attention_mask)
text_emb = text_encoder_out['pooler_output'][0]
return text_emb.detach().cpu().numpy()


class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps

def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias


class UIEmbedTransformer(nn.Module):

def __init__(self, nhid=50, nhead=2, nlayers=2, dropout=0.8):
super().__init__()
self.model_type = 'Transformer'
from torch.nn import TransformerEncoder, TransformerEncoderLayer
self.pos_max = 100
self.text_encoder = TextEncoder(method='spacy')
self.x_position_embeddings = nn.Embedding(self.pos_max, nhid)
self.y_position_embeddings = nn.Embedding(self.pos_max, nhid)
self.h_position_embeddings = nn.Embedding(self.pos_max, nhid)
self.w_position_embeddings = nn.Embedding(self.pos_max, nhid)
dim_feedforward = 256
encoder_layers = TransformerEncoderLayer(nhid, nhead, dim_feedforward, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.meta2hid = nn.Linear(12, nhid)
self.text2hid = nn.Linear(self.text_encoder.embed_size, nhid)
self.layer_norm = BertLayerNorm(nhid)
self.dropout = nn.Dropout(dropout)

def forward(self, state_encs):
state_encs, attn_mask = self.encode_state_batch(state_encs)
output = self.transformer_encoder(state_encs, src_key_padding_mask=attn_mask)
output = output.permute(1, 0, 2)
return output

def encode_state(self, state, views):
meta_enc = torch.stack([self._encode_view_meta(state, view) for view in views])
pos_enc = torch.stack([self._encode_view_pos(state, view) for view in views])
text_enc = torch.stack([self._encode_view_text(state, view) for view in views])
return meta_enc, pos_enc, text_enc

def encode_state_batch(self, state_encs):
embs = []
for state_enc in state_encs:
meta_enc, pos_enc, text_enc = state_enc
# pos_enc = self.pos_encoder(src)
meta_emb = self.meta2hid(meta_enc)
l_emb = self.x_position_embeddings(pos_enc[:, 0])
r_emb = self.x_position_embeddings(pos_enc[:, 1])
t_emb = self.y_position_embeddings(pos_enc[:, 2])
b_emb = self.y_position_embeddings(pos_enc[:, 3])
w_emb = self.w_position_embeddings(pos_enc[:, 4])
h_emb = self.h_position_embeddings(pos_enc[:, 5])
pos_emb = l_emb + r_emb + t_emb + b_emb + w_emb + h_emb
text_emb = self.text2hid(text_enc)
# emb = torch.cat([meta_emb, pos_emb, text_emb], dim=1)
emb = meta_emb + pos_emb + text_emb
emb = self.layer_norm(emb)
emb = self.dropout(emb)
embs.append(emb)
embs_pad = pad_sequence(embs, batch_first=False)
attn_mask = embs_pad.sum(axis=2).t() == 0
return embs_pad, attn_mask

def _encode_view_meta(self, state, view):
# print(view)
view_children = view['children'] if 'children' in view else []
is_parent = 1 if len(view_children) > 0 else -1
view_text = view['text'] if 'text' in view else None
is_text = 1 if view_text and len(view_text) > 0 else -1
enabled = 1 if 'enabled' in view and view['enabled'] else -1
visible = 1 if 'visible' in view and view['visible'] else -1
clickable = 1 if 'clickable' in view and view['clickable'] else -1
long_clickable = 1 if 'long_clickable' in view and view['long_clickable'] else -1
checkable = 1 if 'checkable' in view and view['checkable'] else -1
checked = 1 if 'checked' in view and view['checked'] else -1
selected = 1 if 'selected' in view and view['selected'] else -1
editable = 1 if 'editable' in view and view['editable'] else -1
is_password = 1 if 'is_password' in view and view['is_password'] else -1
scrollable = 1 if 'scrollable' in view and view['scrollable'] else -1
meta_enc = np.array([
is_parent, is_text, is_password, visible,
enabled, checked, selected,
clickable, long_clickable, checkable, editable, scrollable
])
return torch.Tensor(meta_enc)

def _encode_view_pos(self, state, view):
screen_w = state.width
screen_h = state.height
[[l,t], [r,b]] = view['bounds'] if 'bounds' in view else [[0,0], [0,0]]
pos_max = self.pos_max - 1
l, r, t, b = int(pos_max*l/screen_w), int(pos_max*r/screen_w), int(pos_max*t/screen_h), int(pos_max*b/screen_h)
if l > r:
tmp = l
l = r
r = tmp
if t > b:
tmp = t
t = b
b = tmp
l = max(0, min(pos_max, l))
r = max(0, min(pos_max, r))
t = max(0, min(pos_max, t))
b = max(0, min(pos_max, b))
w = abs(l - r)
h = abs(t - b)
return torch.LongTensor(np.array([l, r, t, b, w, h]))

def _encode_view_text(self, state, view):
view_text = view['text'] if 'text' in view else None
emb = self.text_encoder.encode(view_text)
return torch.Tensor(emb)


class Memory:
def __init__(self, utg, app):
self.utg = utg
self.app = app
self.known_states = collections.OrderedDict()
self.known_transitions = collections.OrderedDict()
self.model = UIEmbedTransformer()

def _memorize_state(self, state):
if state.get_app_activity_depth(self.app) != 0:
return None
if state.state_str not in self.known_states:
views = state.views
views_str = [view['view_str'] for view in views]
state_enc = self.model.encode_state(state, views)
embedder = self.model
embedder.eval()
with torch.no_grad():
# state_encs = self.model.encode_state_batch([state_enc])
views_emb = self.model.forward([state_enc])
views_emb = views_emb.detach().cpu()[0]
self.known_states[state.state_str] = {
'state': state,
'views': views,
'views_str': views_str,
'state_enc': state_enc,
'views_emb': views_emb
}
return self.known_states[state.state_str]

def save_transition(self, action, from_state, to_state):
if not from_state or not to_state:
return
Expand Down Expand Up @@ -176,7 +328,7 @@ def encode_action_pairs(self, action_strs=None):
return

state_strs = [self.known_transitions[action_str]['from_state'].state_str for action_str in action_strs]
state_encs = [self.known_states[state_str]['views_enc'] for state_str in state_strs]
state_encs = [self.known_states[state_str]['state_enc'] for state_str in state_strs]
action_pairs = []
for i, action_str1 in enumerate(action_strs):
state_str1 = self.known_transitions[action_str1]['from_state'].state_str
Expand Down Expand Up @@ -213,7 +365,7 @@ def action_info_str(action_info):

def train_model(self):
embedder = self.model
optimizer = torch.optim.Adam(embedder.parameters(), lr=1e-2)
optimizer = torch.optim.Adam(embedder.parameters(), lr=1e-3)
n_iterations = 10

def compute_loss(ele_embed, action_pairs):
Expand Down Expand Up @@ -250,8 +402,7 @@ def train():
embedder.train()
action_strs = self._select_transitions_for_training(size=N_ACTIONS_TRAINING)
state_encs, action_pairs = self.encode_action_pairs(action_strs)
state_encs = pad_sequence(state_encs, batch_first=True)
ele_embed = embedder(state_encs)
ele_embed = embedder.forward(state_encs)

loss = compute_loss(ele_embed, action_pairs)
optimizer.zero_grad()
Expand All @@ -268,9 +419,9 @@ def train():
# update embedding
with torch.no_grad():
embedder.eval()
state_encs = [v['views_enc'] for k,v in self.known_states.items()]
state_encs_pad = pad_sequence(state_encs, batch_first=True)
ele_embed = embedder(state_encs_pad)
state_encs = [v['state_enc'] for k,v in self.known_states.items()]
# state_encs_pad = pad_sequence(state_encs, batch_first=True)
ele_embed = embedder(state_encs)
ele_embed = ele_embed.detach().cpu()
for i, (k,v) in enumerate(self.known_states.items()):
self.known_states[k]['views_emb'] = ele_embed[i]
Expand Down Expand Up @@ -323,6 +474,8 @@ def generate_event_based_on_utg(self):
self.memory.save_transition(self.last_event, self.last_state, self.current_state)
except Exception as e:
self.logger.warning(f'failed to save transition: {e}')
import traceback
traceback.print_exc()
if self.action_count % self.num_actions_train == 0:
self.memory.train_model()
# self.logger.info(f'we have {len(self.memory.known_transitions)} transitions now')
Expand Down

0 comments on commit 6ae1535

Please sign in to comment.