Skip to content

Commit

Permalink
separate plot
Browse files Browse the repository at this point in the history
  • Loading branch information
delta2323 committed Jan 10, 2016
1 parent ea80887 commit cf4d561
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 28 deletions.
17 changes: 17 additions & 0 deletions plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from matplotlib import cm
from matplotlib import pyplot
import numpy


def visualize2D(fig, ax, xs, ys, bins=200, xlabel='x', ylabel='y'):
H, xedges, yedges = numpy.histogram2d(xs, ys, bins)
H = numpy.rot90(H)
H = numpy.flipud(H)
Hmasked = numpy.ma.masked_where(H == 0, H)

ax.pcolormesh(xedges, yedges, Hmasked)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xlim(min(xs), max(xs))
ax.set_ylim(min(ys), max(ys))
fig.colorbar(pyplot.contourf(Hmasked))
18 changes: 8 additions & 10 deletions toy_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

import chainer
from chainer import functions as F
import matplotlib.pyplot as plt
from matplotlib import pyplot
import numpy
import six

import model
import plot


parser = argparse.ArgumentParser(description='HMC')
Expand All @@ -32,6 +33,8 @@
help='If true, rejection phase is introduced')
# others
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--visualize', default='visualize_hmc.png', type=str,
help='path to output file')
args = parser.parse_args()

n_batch = (args.N + args.batchsize - 1) // args.batchsize
Expand Down Expand Up @@ -112,12 +115,7 @@ def accept(p, theta, p_propose, theta_propose):
theta2_all[epoch * n_batch + i // args.batchsize] = theta[1]
print(epoch, theta, theta[0] * 2 + theta[1])

H, xedges, yedges = numpy.histogram2d(theta1_all, theta2_all, bins=200)
H = numpy.rot90(H)
H = numpy.flipud(H)
Hmasked = numpy.ma.masked_where(H == 0, H)
plt.pcolormesh(xedges, yedges, Hmasked)
plt.xlabel('x')
plt.ylabel('y')
cbar = plt.colorbar()
plt.savefig('visualize_hmc.png')
fig, axes = pyplot.subplots(ncols=1, nrows=1)
plot.visualize2D(fig, axes, theta1_all, theta2_all,
xlabel='theta1', ylabel='theta2')
fig.savefig(args.visualize)
16 changes: 7 additions & 9 deletions toy_sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import model
import stepsize
import plot


parser = argparse.ArgumentParser(description='SGHMC')
Expand All @@ -22,6 +23,8 @@
help='If true, initialize moment in each sample')
# others
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--visualize', default='visualize_sghmc.png', type=str,
help='path to output file')
args = parser.parse_args()


Expand Down Expand Up @@ -63,12 +66,7 @@ def update_theta(theta, p, eps):
if i == 0:
print(epoch, theta, theta[0] * 2 + theta[1])

H, xedges, yedges = numpy.histogram2d(theta1_all, theta2_all, bins=200)
H = numpy.rot90(H)
H = numpy.flipud(H)
Hmasked = numpy.ma.masked_where(H == 0, H)
plt.pcolormesh(xedges, yedges, Hmasked)
plt.xlabel('x')
plt.ylabel('y')
cbar = plt.colorbar()
plt.savefig('visualize_sghmc.png')
fig, axes = pyplot.subplots(ncols=1, nrows=1)
plot.visualize2D(fig, axes, theta1_all, theta2_all,
xlabel='theta1', ylabel='theta2')
fig.savefig(args.visualize)
15 changes: 6 additions & 9 deletions toy_sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
parser.add_argument('--epoch', default=1000, type=int, help='epoch num')
# others
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--visualize', default='visualize_hmc.png', type=str,
help='path to output file')
args = parser.parse_args()


Expand Down Expand Up @@ -63,12 +65,7 @@ def update(theta, x, epoch, eps):
if i == 0:
print(epoch, theta, theta[0] * 2 + theta[1])

H, xedges, yedges = numpy.histogram2d(theta1_all, theta2_all, bins=200)
H = numpy.rot90(H)
H = numpy.flipud(H)
Hmasked = numpy.ma.masked_where(H == 0, H)
plt.pcolormesh(xedges, yedges, Hmasked)
plt.xlabel('x')
plt.ylabel('y')
cbar = plt.colorbar()
plt.savefig('visualize_sgld.png')
fig, axes = pyplot.subplots(ncols=1, nrows=1)
plot.visualize2D(fig, axes, theta1_all, theta2_all,
xlabel='theta1', ylabel='theta2')
fig.savefig(args.visualize)

0 comments on commit cf4d561

Please sign in to comment.