Skip to content

Commit

Permalink
methods is struct not cell
Browse files Browse the repository at this point in the history
git-svn-id: https://pmtk3.googlecode.com/svn/trunk@2734 b6abd7f4-f95b-11de-aa3c-59de0406b4f5
  • Loading branch information
[email protected] committed Mar 16, 2011
1 parent e5d01a5 commit 5cdea2e
Show file tree
Hide file tree
Showing 4 changed files with 389 additions and 125 deletions.
82 changes: 82 additions & 0 deletions projects/sceneContext/obsModelTest.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@


%% Check the reasonableness of the local observation model

% Data
loadData('sceneContextSUN09', 'ismatfile', false)
load('SUN09data')


train = data.train;
test = data.test;
objectnames = data.names;

[Ntrain, Nobjects] = size(train.presence);
[Ntest, Nobjects2] = size(test.presence);

obstypes = {'gauss', 'quantize'};

for oo=1:numel(obstypes)
obstype = obstypes{oo};

labels = train.presence;
scores = train.detect_maxprob;
%[quantizedScores, discretizeParams] = discretizePMTK(scores, 10);
[obsmodel] = obsModelFit(labels, scores, obstype);


% we plot the distribution of scores for 2 classes

for c=[1 110]

% Empirical distributon
scores = train.detect_maxprob;
ndx=(train.presence(:,c)==1);
figure;
subplot(2,2,1)
[counts, bins]=hist(scores(ndx,c));
binstr =cellfun(@(b) sprintf('%2.1f', b), num2cell(bins), 'uniformoutput', false);
bar(counts); set(gca, 'xticklabel', binstr)
title(sprintf('%s present, m %5.3f, v %5.3f', ...
objectnames{c}, mean(scores(ndx,c)),var(scores(ndx,c))));

subplot(2,2,2)
[counts, bins] = hist(scores(~ndx,c));
binstr =cellfun(@(b) sprintf('%2.1f', b), num2cell(bins), 'uniformoutput', false);
bar(counts); set(gca, 'xticklabel', binstr)
title(sprintf('%s absent, m %5.3f, v %5.3f', ...
objectnames{c}, mean(scores(~ndx,c)), var(scores(~ndx,c))));

% Model distribution

switch obsmodel.obsType
case 'gauss'
xmin = min(scores(:,c));
xmax = max(scores(:,c));
xvals = linspace(xmin, xmax, 100);
mu = squeeze(obsmodel.mu(1,:,c));
Sigma = permute(obsmodel.Sigma(:,:,:,c), [3 4 1 2]);
p = gaussProb(xvals, mu(2), Sigma(2));
subplot(2,2,3)
plot(xvals, p, 'b-');
title(sprintf('model for %s presence', objectnames{c}))
subplot(2,2,4)
p = gaussProb(xvals, mu(1), Sigma(1));
plot(xvals, p, 'r:');
title(sprintf('model for %s absence', objectnames{c}))
case 'quantize'
% CPT(label, feature, node)
subplot(2,2,3)
bar(squeeze(obsmodel.CPT(2,:,c)))
title(sprintf('model for %s presence', objectnames{c}))
bins = obsmodel.discretizeParams.bins{c};
binstr =cellfun(@(b) sprintf('%2.1f', b), num2cell(bins), 'uniformoutput', false);
set(gca,'xticklabel', binstr)
subplot(2,2,4)
bar(squeeze(obsmodel.CPT(1,:,c)))
title(sprintf('model for %s absence', objectnames{c}))
set(gca,'xticklabel',binstr)
end
end

end
155 changes: 30 additions & 125 deletions projects/sceneContext/tagContextDemo3.m
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,24 @@


%% Models/ methods
methodNames = { 'dag' };
methodNames = { 'tree' };
%methodNames = { 'dag', 'mix10', 'mix15', 'tree' };


% We requre that fitting methods have this form
% model = fn(truth(N, D), features(N, D, :))
% where truth(n,d) in {0,1}

%{
fitMethods = {
@(labels, features) dgmFit(labels)
};
%}


%{
fitMethods = {
@(labels, features) noisyMixModelFit(labels, [], 1)
@(labels, features) treegmFit(labels)
};
%}


%{
Expand All @@ -49,7 +50,6 @@
infMethods = {
@(model, features, softev) argout(2, @noisyMixModelInferNodes, model, [], softev)
};
%}
infMethods = {
Expand All @@ -58,13 +58,18 @@
@(model, features, softev) argout(2, @noisyMixModelInferNodes, model, [], softev), ...
@(model, features, softev) argout(2, @treegmInferNodes, model, [], softev)
};
%}


infMethods = {
@(model, features, softev) argout(2, @treegmInferNodes, model, [], softev)
};


%{
logprobMethods = {
@(model, X) mixModelLogprob(model.mixmodel, X)
};
%}
logprobMethods = {
Expand All @@ -74,6 +79,11 @@
@(model, X) treegmLogprob(model, X)
};
%}

logprobMethods = {
@(model, X) treegmLogprob(model, X)
};


%% CV
Expand All @@ -88,7 +98,7 @@
detect_maxprob = [data.train.detect_maxprob; data.test.detect_maxprob];
N = size(presence, 1);
assert(N==Ntrain+Ntest);
Nfolds = 1;
Nfolds = 3;
if Nfolds == 1
% use original train/ test split
trainfolds{1} = 1:Ntrain;
Expand All @@ -102,83 +112,20 @@
train.detect_maxprob = detect_maxprob(trainfolds{fold}, :);
test.presence = presence(testfolds{fold}, :);
test.detect_maxprob = detect_maxprob(testfolds{fold}, :);


[Ntrain, Nobjects] = size(train.presence);
[Ntest, Nobjects2] = size(test.presence);

%% Train p(scores | labels)
%obstype = 'localev';
obstype = 'gauss';
%obstype = 'quantize';

labels = train.presence;
scores = train.detect_maxprob;
%[quantizedScores, discretizeParams] = discretizePMTK(scores, 10);
[obsmodel] = obsModelFit(labels, scores, obstype);


%% Check the reasonableness of the local observation model for class c
% note that p(score|label) is same for all models

%{
for c=[1 110]
% Empirical distributon
scores = train.detect_maxprob;
ndx=(train.presence(:,c)==1);
figure;
subplot(2,2,1)
[counts, bins]=hist(scores(ndx,c));
binstr =cellfun(@(b) sprintf('%2.1f', b), num2cell(bins), 'uniformoutput', false);
bar(counts); set(gca, 'xticklabel', binstr)
title(sprintf('%s present, m %5.3f, v %5.3f', ...
objectnames{c}, mean(scores(ndx,c)),var(scores(ndx,c))));
subplot(2,2,2)
[counts, bins] = hist(scores(~ndx,c));
binstr =cellfun(@(b) sprintf('%2.1f', b), num2cell(bins), 'uniformoutput', false);
bar(counts); set(gca, 'xticklabel', binstr)
title(sprintf('%s absent, m %5.3f, v %5.3f', ...
objectnames{c}, mean(scores(~ndx,c)), var(scores(~ndx,c))));
% Model distribution
switch obsmodel.obstype
case 'gauss'
xmin = min(scores(:,c));
xmax = max(scores(:,c));
xvals = linspace(xmin, xmax, 100);
mu = squeeze(obsmodel.mu(1,:,c));
Sigma = permute(obsmodel.Sigma(:,:,:,c), [3 4 1 2]);
p = gaussProb(xvals, mu(2), Sigma(2));
subplot(2,2,3)
plot(xvals, p, 'b-');
title(sprintf('model for %s presence', objectnames{c}))
subplot(2,2,4)
p = gaussProb(xvals, mu(1), Sigma(1));
plot(xvals, p, 'r:');
title(sprintf('model for %s absence', objectnames{c}))
case 'quantize'
% CPT(label, feature, node)
subplot(2,2,3)
bar(squeeze(obsmodel.CPT(2,:,c)))
title(sprintf('model for %s presence', objectnames{c}))
bins = obsmodel.discretizeParams.bins{c};
binstr =cellfun(@(b) sprintf('%2.1f', b), num2cell(bins), 'uniformoutput', false);
set(gca,'xticklabel', binstr)
subplot(2,2,4)
bar(squeeze(obsmodel.CPT(1,:,c)))
title(sprintf('model for %s absence', objectnames{c}))
set(gca,'xticklabel',binstr)
end
end
obstype = 'gauss';
[obsmode] = obsModelFit(labels, scores, obstype);

%}


%% Training p(labels, scores)
%% Training p(labels)

Nmethods = numel(methodNames);
models = cell(1, Nmethods);
Expand All @@ -194,50 +141,7 @@
models{m} = fitMethods{m}(labels, scores);
end

keyboard

%{
if isfield(models{1}, 'mixmodel') && models{1}.mixmodel.nmix==1
assert(approxeq(priorProb, models{1}.mixmodel.cpd.T(1,2,:)))
priorProb(1:5)
squeeze(models{1}.mixmodel.cpd.T(1,2,1:5))
end
%}


% Fit a depnet to the labels
model = depnetFit(data.train.presence, 'nodeNames', data.names)
% folder = fileparts(which(mfilename())
folder = '/home/kpmurphy/Dropbox/figures';
% for some reason, the directed graph is much more readable
graphviz(model.G, 'labels', model.nodeNames, 'directed', 1, ...
'filename', fullfile(folder, 'SUN09depnet'));


%{
% Visualize tree
folder = fileparts(which(mfilename())
folder = '/home/kpmurphy/Dropbox/figures';
% for some reason, the directed graph is much more readable
graphviz(model.edge_weights, 'labels', train.names, 'directed', 1, ...
'filename', fullfile(folder, 'SUN09treeNeg'));
%}

%{
% Visualize mix model
model = models{2};
K = model.mixmodel.nmix;
[nr,nc] = nsubplots(K);
figure;
for k=1:K
T = squeeze(model.mixmodel.cpd.T(k,2,:));
subplot(nr, nc, k)
bar(T);
[probs, perm] = sort(T, 'descend');
memberNames = sprintf('%s,', train.names{perm(1:5)})
title(sprintf('%5.3f, %s', model.mixmodel.mixWeight(k), memberNames))
end
%}

%% Probability of labels
% See if the models help with p(y(1:T))
Expand Down Expand Up @@ -279,20 +183,21 @@
for n=1:Ntest
frame = n;
if (n==1) || (mod(n,500)==0), fprintf('testing image %d of %d\n', n, Ntest); end

%{
% needs database of images
img = imread(fullfile(HOMEIMAGES, test.folder{frame}, test.filename{frame}));
figure(1); clf; image(img)
trueObjects = sprintf('%s,', test.names{find(test.presence(frame,:))});
title(trueObjects)
%}

softev = softevBatch(:,:,n); % Nstates * Nnodes * 1
[presence_indep(n,:)] = features(n, :);
%[presence_indep(n,:)] = softev(2, :);
bel = infMethods{m}(models{m}, features(n,:), softev);
presence_model(n,:,m) = bel(2,:);


% visualize predictions - % needs database of images
if fold==1 && (n <= 3)
img = imread(fullfile(HOMEIMAGES, test.folder{frame}, test.filename{frame}));
figure(1); clf; image(img)
trueObjects = sprintf('%s,', test.names{find(test.presence(frame,:))});
title(trueObjects)
end

end
end

Expand Down
Loading

0 comments on commit 5cdea2e

Please sign in to comment.