Skip to content

Commit

Permalink
noisy mixture model demo works
Browse files Browse the repository at this point in the history
git-svn-id: https://pmtk3.googlecode.com/svn/trunk@2724 b6abd7f4-f95b-11de-aa3c-59de0406b4f5
  • Loading branch information
[email protected] committed Mar 10, 2011
1 parent 533dc5c commit f638f68
Show file tree
Hide file tree
Showing 8 changed files with 306 additions and 121 deletions.
55 changes: 0 additions & 55 deletions demos/noisyBinaryMixDemo.m

This file was deleted.

106 changes: 106 additions & 0 deletions demos/noisyMixDemo.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@

% Z -> Xj -> Yj
% Z is mixture node
% Xj is binary tag
% Yj is response of tag detector

% We generate some synthetic correlated binary data
% consisting of 10s and 01s

setSeed(0);
D = 4; % num bits/ nodes
mixWeights = normalize(ones(1,D));
N2 = 20;
X1 = repmat([1 0 0 1], N2, 1);
X2 = repmat([0 1 1 0], N2, 1);
X = [X1; X2];
X = X+1;
N = 2*N2;

options = {'maxIter', 10, 'verbose', true};
K = 2;
[model, loglikHist] = mixModelFit(X, K, 'discrete', options{:});

%{
This works as expected: in cluster 1,
the first feature (col 1) is most likely 1,
the second feature (col 2) is most likely 0
etc
> squeeze(model.cpd.T(1,:,:)) % soft version of 0 1 1 0 pattern
ans =
0.0455 0.9545 0.9545 0.0455
0.9545 0.0455 0.0455 0.9545
and vice versa for cluster 2
squeeze(model.cpd.T(2,:,:)) % soft version of 1 0 0 1 pattern
ans =
0.9545 0.0455 0.0455 0.9545
0.0455 0.9545 0.9545 0.0455
%}

% Now generate noisy function of X

setSeed(0);

% mu(c,j)
% Make off bits be -ve, on bits be +ve
mu = [-1 -2 -3 -4;
1 2 3 4];

sigma = 2;
Y = zeros(N, D);

for j=1:D
for c=1:2
ndx = find(X(:,j)==c);
Y(ndx,j) = sigma*randn(numel(ndx),1)+mu(c,j);
end
end

[model2] = noisyMixModelFit(X, Y, K);

%{
Mix model Z->X is same as before (modulo label switching)
squeeze(model2.mixmodel.cpd.T(1,:,:)) % soft version of 1 0 0 1 pattern
squeeze(model2.mixmodel.cpd.T(2,:,:)) % soft version of 0 1 1 0 pattern
Obs model
for j=1:D
mu2(:,j) = model2.obsmodel.localCPDs{j}.mu';
end
mu2 =
-0.8839 -2.1571 -2.9969 -4.0548
1.1921 2.0855 2.5789 3.8201
%}

[pZ, pX] = noisyMixModelInfer(model2, Y);

% Check that each row assigned to 'correct' cluster
[pmax, Zhat] = max(pZ,[],2);
% Zhat(1:20) = 1 for 1st cluster
% Zhat(21:40) = 2

figure; imagesc(Y); colorbar; title('raw data Y (score of detector)');
figure; imagesc(pZ); colorbar; title('soft clustering pZ')

% Now check that we infered correct tag, marginalizing out cluster
Xhat = zeros(N, D);
for j=1:D
[pmax, Xhat(:, j)] = max(pX(:, :, j), [], 2); % pX is N*2*D
end
% Xhat recovers X
figure; imagesc(X); colorbar; title('true tag');
figure; imagesc(Xhat); colorbar; title('est tag');

137 changes: 109 additions & 28 deletions matlabTools/graphics/graphviz.m
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ function graphviz(adj, varargin)
% removeIsolatedNodes - [1]
% filename - name for output file [default foo.pdf]
% directed - [default: determine from symmetry of adj]
%
% landscape - [default 1]
%
%
% Examples
% graphviz(adj, 'filename', 'foo') % default tmp
% graphviz(adj, 'labels', {'a','bbbb','c','d','e'})
% graphviz(adj, 'removeIsolatedNodes', 0, 'removeSelfLoops', 0);
%
% Written by Kevin Murphy, Leon Peshkin, et al
% 8 March 2011
% Written by Kevin Murphy, Leon Peshkin, Mark Schmidt, et al
% Last updated 8 March 2011



Expand All @@ -41,12 +41,28 @@ function graphviz(adj, varargin)
graphviz(adj, 'removeIsolatedNodes', 1, 'removeSelfLoops', 0);
%}

%
%{
Generated .dot file looks liek this
graph G {
center = 1;
size="10,10";
1 [ label = "field" ];
2 [ label = "sea" ];
1 -- 11 [dir=none];
2 -- 8 [dir=none];
}
Syntax is defined here
http://www.graphviz.org/content/dot-language
[labels, removeSelfLoops, removeIsolatedNodes, filename, directed] = ...
%}

[labels, removeSelfLoops, removeIsolatedNodes, filename, directed, landscape] = ...
process_options(varargin, ...
'labels', [], 'removeSelfLoops', 1, ...
'removeIsolatedNodes', 0, 'filename','tmp', 'directed', []);
'removeIsolatedNodes', 0, 'filename','tmp', 'directed', [], 'landscape', 1);

if removeSelfLoops
adj = setdiag(adj, 0);
Expand Down Expand Up @@ -85,15 +101,36 @@ function graphviz(adj, varargin)
end
end

adj = double(adj > 0); % make sure it is a binary matrix cast to double type
% Since we want to handle weighted graphs, we don't binarize
%adj = double(adj > 0); % make sure it is a binary matrix cast to double type

%{
u = unique(adj);
if ~isequal(u, [0 1])
% weighted graph
edgeSign = sign(adj);
Nbins = 10;
% quantize edge weights
h = hist(abs(adj(:)), Nbins);
edgeWeights = zeros(size(adj));
for hi=1:numel(h)
ndx=find(abs(adj)==h(hi));
edgeWeights(ndx) = hi;
end
else
edgeSign = 1;
edgeWeights = adj;
end
%}

tmpDOTfile = sprintf('%s.dot', filename);

graph_to_dot(adj, 'directed', directed, ...
'filename', tmpDOTfile, 'node_label', labels);

cmd = sprintf('dot -Tps %s -o %s.ps', tmpDOTfile, filename);
status = system(cmd);

dot_to_ps(tmpDOTfile, sprintf('%s.ps', filename), landscape);


if ~isempty(filename)
cmd = sprintf('ps2pdf %s.ps %s.pdf', filename, filename);
Expand All @@ -109,15 +146,52 @@ function graphviz(adj, varargin)
error(sprintf('error executing %s', cmd));
end

%{
cmd = sprintf('rm %s.dot', filename);
status = system(cmd);
if status ~= 0
error(sprintf('error executing %s', cmd));
end
%}
end

end

function status = dot_to_ps(dotname, outname, landscape)


% Useful options:
% -Glandscape (outputs in landscape mode)
% -Gconcentrate (merges two-way edges into one way edge, displays
% parallel edges in different way)
% -Gratio=.707 (changes to A4 landscape aspect ratio, other options are "fill", "compress",
% "expand", "auto")
% -Ncolor="blue" (changes node outlines to blue)
% -Ecolor="red" (changes edges to red)
% -Earrowsize=2 (changes size of arrows)
% -Nstyle="filled" -Nfillcolor="#ddddff" (make nodes light blue)
% -Nfontsize=32 (change font size to 32pt)
% -Gnodesep=0.125 (make nodes twice as close
% -Nshape="box" (change node shape to box)
%
% Details here:
% http://www.graphviz.org/doc/info/attrs.html

%opts = ' -Gconcentrate -Gratio=.707 -Ncolor="blue" -Ecolor="green" -Earrowsize=2 -Nstyle="filled" -Nfillcolor="#ddddff" -Nfontsize=40 ';
opts = '-Gconcentrate -Gratio=0.707 -Earrowsize=2 -Nfontsize=50 -Epenwidth=5';
if landscape
opts = strcat(opts,' -Glandscape ');
end

%cmd = strcat('C:\temp\graphviz-2.8\bin\dot ',opts,' -T ps -o graphVizIt.ps graphVizIt.txt ')
cmd = sprintf('dot %s -Tps %s -o %s', opts, dotname, outname);
%cmd = sprintf('dot %s -Tpng %s -o %s', opts, dotname, outname);
status = system(cmd);
if status ~= 0
error(sprintf('error executing %s', cmd));
end
end


function graph_to_dot(adj, varargin)

Expand All @@ -133,11 +207,10 @@ function graph_to_dot(adj, varargin)
% 'leftright' - 1 means layout left-to-right, 0 means top-to-bottom [0]
% 'directed' - 1 means use directed arcs, 0 means undirected [1]
%
% For details on dotty, See http://www.research.att.com/sw/tools/graphviz
%
% by Dr. Leon Peshkin, Jan 2004 inspired by Kevin Murphy's BNT
% pesha @ ai.mit.edu /~pesha


% Details here:
% http://www.graphviz.org/doc/info/attrs.html

node_label = []; arc_label = []; % set default args
width = 10; height = 10;
leftright = 0; directed = 1; filename = 'tmp.dot';
Expand All @@ -159,20 +232,8 @@ function graph_to_dot(adj, varargin)
end
if directed
fprintf(fid, 'digraph G {\n');
arctxt = '->';
if isempty(arc_label)
labeltxt = '';
else
labeltxt = '[label="%s"]';
end
else
fprintf(fid, 'graph G {\n');
arctxt = '--';
if isempty(arc_label)
labeltxt = '[dir=none]';
else
labeltext = '[label="%s",dir=none]';
end
end
fprintf(fid, 'center = 1;\n');
fprintf(fid, 'size=\"%d,%d\";\n', width, height);
Expand All @@ -187,15 +248,35 @@ function graph_to_dot(adj, varargin)
fprintf(fid, '%d [ label = "%s" ];\n', node, node_label{node});
end
end
edgeformat = strcat(['%d ',arctxt,' %d ',labeltxt,';\n']);

for node1 = 1:Nnds % process ARCs
if directed
arcs = find(adj(node1,:)); % children(adj, node);
arctype = '->';
else
arcs = find(adj(node1,node1+1:Nnds)) + node1; % remove duplicate arcs
arctype = '--';
end
for node2 = arcs
fprintf(fid, edgeformat, node1, node2);
arcargs = {};
if ~isempty(arc_label)
arcargs{end+1} = sprintf('label="%s"', arc_label(node1, node2));
end
if ~directed
arcargs{end+1} = sprintf('dir=none');
end
if adj(node1, node2) < 0
%fprintf('%d to %d value %5.3f\n', node1, node2, adj(node1, node2));
arcargs{end+1} = sprintf('style="dotted"');
arcargs{end+1} = sprintf('color=red');
end
arctxt = sprintf('%s,', arcargs{:});
if length(arctxt)>1
arctxt = arctxt(1:end-1); % remove final comma
end
arctxt = sprintf('[%s]', arctxt);
edgestr = sprintf('%d %s %d %s;\n', node1, arctype, node2, arctxt);
fprintf(fid, edgestr);
end
end
fprintf(fid, '}');
Expand Down
9 changes: 9 additions & 0 deletions projects/sceneContext/treeContextDemo.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@

model = treegmFit(train.presence_truth, train.maxscores, 'gauss');

%{
folder = fileparts(which(mfilename())
folder = '/home/kpmurphy/Dropbox/figures';
% for some reason, the directed graph is much more readable
graphviz(model.edge_weights, 'labels', train.names, 'directed', 1, ...
'filename', fullfile(folder, 'SUN09treeNeg'));
%}


%% Check the reasonableness of the local observation model for class c
for c=[1 110]
ndx=(train.presence_truth(:,c)==1);
Expand Down
Loading

0 comments on commit f638f68

Please sign in to comment.