Skip to content

Commit

Permalink
VizConsoleAdded
Browse files Browse the repository at this point in the history
  • Loading branch information
fliu2 committed Aug 9, 2017
1 parent 8ad8fa2 commit 0e4bfbd
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 145 deletions.
2 changes: 1 addition & 1 deletion examples/MnistDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def find_data(self,args={'type':"train",'id':{"$lt":500}}):
print "training reading"
print t6

return d,l
return d.astype(np.float32),l.astype(np.int32)

def generator_data(self,batch_size=20,args={'type':"train",'id':{"$lt":500}}):

Expand Down
104 changes: 104 additions & 0 deletions examples/VizConsole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import dash
from dash.dependencies import Input,Output
import dash_core_components as dcc
import dash_html_components as html
import plotly.graph_objs as go
import pandas as pd
import pymongo
import seed

db=seed.db







df=pd.DataFrame(db.queryValLog({}))




print df.columns.values.tolist()

print df.groupby('studyID')


print df.studyID.unique()


app=dash.Dash()





app.layout=html.Div(children=[
html.H1(children="hello dash"),

html.Div(children='''
Dash: a Weba pplication framework for python
'''),

dcc.Input(id='my-id', value='initial value', type="text"),
html.P(),

html.Div(id='my-div'),


##for date visuzlaiton, server side rendering#
## html.Img(id='my-img',src=""),


html.P(children="SelecViz"),

dcc.Dropdown(
options=[
{'label': 'New York City', 'value': 'NYC'},
{'label': u'Montreal', 'value': 'MTL'},
{'label': 'San Francisco', 'value': 'SF'}
],
value='MTL',
multi=True

),

dcc.Graph(id='example-graph',
figure={
'data':[
go.Scatter(
x=df[df['studyID'] == i]['epoch'],
y=df[df['studyID'] == i]['acc'],
text=df[df['studyID'] == i]['time'],
mode='markers',
opacity=0.7,
marker={
'size': 10,
'line': {'width': 0.5, 'color': 'white'}
},
name=i
) for i in df.studyID.unique()
],
'layout':{
'title':'Dash visulization'
}


}

)

])


@app.callback(
Output(component_id='my-div', component_property='children'),
[Input(component_id='my-id', component_property='value')]
)
def update_output_div(input_value):
print "server call back"
return "{}".format(input_value)

if __name__ =='__main__':
app.run_server(debug=True)
Empty file removed examples/__init__.py
Empty file.
17 changes: 13 additions & 4 deletions examples/mnistDemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,26 @@
mndb=seed.mn
db=seed.db

d, l = mndb.find_data({'type': 'train'})
dv,lv = mndb.find_data({'type':'val'})



def run_one(name_test):
global d, l, c, f, fn
d, l = mndb.find_data({'type': 'train'})
print d.shape
c, f, fn = db.load_model_architecture({'name': 'mlp'})
db.studyID =name_test
m1 = Model(fn, name_test, False)
m1.fit(10, d, l, 128, [DBLogger(db, m1)])
m1.fit(100, dv, lv, 128, [DBLogger(db, m1)],dv,lv)




if __name__ == "__main__":
run_one('run3')
run_one('run4')


run_one()



Expand Down
5 changes: 5 additions & 0 deletions examples/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,24 @@ def add_model():

def init_data():
X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 784))

X_train = np.asarray(X_train, dtype=np.float32)
y_train = np.asarray(y_train, dtype=np.int32)

X_val = np.asarray(X_val, dtype=np.float32)
y_val = np.asarray(y_val, dtype=np.int32)

X_test = np.asarray(X_test, dtype=np.float32)
y_test = np.asarray(y_test, dtype=np.int32)

print('X_train.shape', X_train.shape)
print('y_train.shape', y_train.shape)
print('X_val.shape', X_val.shape)
print('y_val.shape', y_val.shape)
print('X_test.shape', X_test.shape)
print('y_test.shape', y_test.shape)
print('X %s y %s' % (X_test.dtype, y_test.dtype))

mn.import_data(X_train, y_train, {'type': 'train'})
mn.import_data(X_val, y_val, {'type': 'val'})
mn.import_data(X_test, y_test, {'type': 'test'})
Expand Down
3 changes: 2 additions & 1 deletion mytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def build_network(name,reuse=False):

#n1.Params=w

n1.fit(5,d,l,128,[DBLogger(p2,n1)])
n1.fit(5,d,l,128,[DBLogger(p2,n1)],d[0:128],l[0:128])


#p.del_params({})
#p.del_train_log({})
Expand Down
1 change: 1 addition & 0 deletions tensorlab/Logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def on_epoch_end(self, epoch, logs={}):
w=self.model.Params
fid=self.db.save_params(w,logs)
logs.update({'params':fid})

self.db.valid_log(logs)
def on_batch_begin(self, batch,logs={}):
self.t=time.time()
Expand Down
9 changes: 5 additions & 4 deletions tensorlab/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def fit(self,n_epoch, X_train, y_train,batch_size, callback=[],X_val=None,y_val=
c.on_epoch_begin(epoch,{})
w=0

for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train,batch_size, shuffle=True):
for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):


for c in callback:
c.on_batch_begin(w)
Expand All @@ -87,9 +88,8 @@ def fit(self,n_epoch, X_train, y_train,batch_size, callback=[],X_val=None,y_val=
val_args={}
if X_val is not None:

feed_dict = {self.x: X_val, self.y_: y_val}
[dc]=self.sess.run([self.loss],feed_dict=feed_dict)
val_args.update({'acc':dc})
vc=self.validate(X_val,y_val)
val_args.update({'acc':vc})

for c in callback:
c.on_epoch_end(epoch,val_args)
Expand Down Expand Up @@ -159,6 +159,7 @@ def inference(self, X):

def validate(self, X,Y):
feed_dict={self.x:X,self.y_:Y}
feed_dict.update(self.network.all_drop) # enable dropout or dropconnect layers
[lo]=self.sess.run([self.loss],feed_dict=feed_dict)
return lo

Expand Down
Loading

0 comments on commit 0e4bfbd

Please sign in to comment.