Skip to content

Commit

Permalink
add MRF mean field
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed May 28, 2017
1 parent d1b3fe2 commit 3aedcb4
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 0 deletions.
11 changes: 11 additions & 0 deletions chapter08/betheEnergy.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function lnZ = betheEnergy(A, nodePot, edgePot, nodeBel, edgeBel)
% Compute Bethe free energy
% TBD: deal with log(0) for entropy
edgePot = reshape(edgePot,[],size(edgePot,3));
edgeBel = reshape(edgeBel,[],size(edgeBel,3));
Ex = dot(nodeBel,nodePot,1);
Exy = dot(edgeBel,edgePot,1);
Hx = -dot(nodeBel,log(nodeBel),1);
Hxy = -dot(edgeBel,log(edgeBel),1);
d = full(sum(logical(A),1));
lnZ = -sum(Ex)-sum(Exy)-sum((d-1).*Hx)+sum(Hxy);
76 changes: 76 additions & 0 deletions chapter08/demo.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
clear; close all;
% load letterA.mat;
% X = A;
load letterX.mat
%% Original image
epoch = 50;
J = 1; % ising parameter
sigma = 1; % noise level

img = double(X);
img = sign(img-mean(img(:)));

figure;
subplot(2,3,1);
imagesc(img);
title('Original image');
axis image;
colormap gray;
%% Noisy image
y = img + sigma*randn(size(img)); % noisy signal
subplot(2,3,2);
imagesc(y);
title('Noisy image');
axis image;
colormap gray;
%% Mean Field
[A, nodePot, edgePot] = im2mrf(y, sigma, J);
[nodeBel, edgeBel, lnZ] = meanField(A, nodePot, edgePot, epoch);
lnZ0 = gibbsEnergy(nodePot, edgePot, nodeBel, edgeBel);
lnZ1 = betheEnergy(A, nodePot, edgePot, nodeBel, edgeBel);
maxdiff(lnZ0, lnZ(end))
maxdiff(lnZ0, lnZ1)

subplot(2,3,3);
imagesc(reshape(nodeBel(1,:),size(img)));
title('MF');
axis image;
colormap gray;
%% Belief Propagation
% [nodeBel,edgeBel] = belProp(A, nodePot, edgePot, epoch);
%
% [nodeBel0,edgeBel0] = belProp0(A, nodePot, edgePot, epoch);
% maxdiff(nodeBel,nodeBel0)
% maxdiff(edgeBel,edgeBel0)
%
% subplot(2,3,4);
% imagesc(reshape(nodeBel(1,:),size(img)));
% title('BP');
% axis image;
% colormap gray;
% %% Expectation Propagation
% [nodeBel,edgeBel] = expProp(A, nodePot, edgePot, epoch);
%
% lnZ0 = betheEnergy(A, nodePot, edgePot, nodeBel, edgeBel);
%
% [nodeBel0,edgeBel0] = expProp0(A, nodePot, edgePot, epoch);
% maxdiff(nodeBel,nodeBel0)
% maxdiff(edgeBel,edgeBel0)
%
% subplot(2,3,5);
% imagesc(reshape(nodeBel(1,:),size(img)));
% title('EP');
% axis image;
% colormap gray;
% %% EP-BP
% [nodeBel,edgeBel] = expBelProp(A, nodePot, edgePot, epoch);
%
% [nodeBel0,edgeBel0] = expBelProp0(A, nodePot, edgePot, epoch);
% maxdiff(nodeBel,nodeBel0)
% maxdiff(edgeBel,edgeBel0)
%
% subplot(2,3,6);
% imagesc(reshape(nodeBel(1,:),size(img)));
% title('EBP');
% axis image;
% colormap gray;
9 changes: 9 additions & 0 deletions chapter08/gibbsEnergy.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
function lnZ = gibbsEnergy(nodePot, edgePot, nodeBel, edgeBel)
% Compute Gibbs free energy
% TBD: deal with log(0) for entropy
edgePot = reshape(edgePot,[],size(edgePot,3));
edgeBel = reshape(edgeBel,[],size(edgeBel,3));
Ex = dot(nodeBel,nodePot,1);
Exy = dot(edgeBel,edgePot,1);
Hx = dot(nodeBel,log(nodeBel),1);
lnZ = -(sum(Ex)+sum(Exy)+sum(Hx));
20 changes: 20 additions & 0 deletions chapter08/im2mrf.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
function [A, nodePot, edgePot] = im2mrf(im, sigma, J)
% Convert a image to Ising MRF with distribution p(x)=exp(-sum(nodePot)-sum(edgePot)-lnZ)
% Input:
% im: row x col image
% sigma: variance of Gaussian node potential
% J: parameter of Ising edge
% Output:
% nodePot: 2 x n node potential
% edgePot: 2 x 2 x m edge potential

A = lattice(size(im));
[s,t,e] = find(tril(A));
nEdge = numel(e);
e(:) = 1:nEdge;
A = sparse([s;t],[t;s],[e;e]);

z = [1;-1];
y = reshape(im,1,[]);
nodePot = (y-z).^2/(2*sigma^2);
edgePot = repmat(-J*(z*z'),[1, 1, nEdge]);
Binary file added chapter08/letterX.mat
Binary file not shown.
38 changes: 38 additions & 0 deletions chapter08/meanField.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
function [nodeBel, edgeBel, lnZ] = meanField(A, nodePot, edgePot, epoch)
% Mean field for MRF
% Assuming egdePot is symmetric
% Input:
% A: n x n adjacent matrix of undirected graph, where value is edge index
% nodePot: k x n node potential
% edgePot: k x k x m edge potential
% Output:
% nodeBel: k x n node belief
% edgeBel: k x k x m edge belief
% L: variational lower bound
% Written by Mo Chen ([email protected])
tol = 0;
if nargin < 4
epoch = 10;
tol = 1e-4;
end
lnZ = -inf(1,epoch+1);
[nodeBel,L] = softmax(-nodePot,1); % init nodeBel
for iter = 1:epoch
for i = 1:numel(L)
[~,j,e] = find(A(i,:)); % neighbors
np = nodePot(:,i);
[lnp ,lnz] = lognormexp(-np-reshape(edgePot(:,:,e),2,[])*reshape(nodeBel(:,j),[],1));
p = exp(lnp);
L(i) = -dot(p,lnp+np)+lnz; %
nodeBel(:,i) = p;
end
lnZ(iter+1) = sum(L)/2;
if abs(lnZ(iter+1)-lnZ(iter))/abs(lnZ(iter)) < tol; break; end
end
lnZ = lnZ(2:iter);

[s,t,e] = find(tril(A));
edgeBel = zeros(size(edgePot));
for l = 1:numel(e)
edgeBel(:,:,e(l)) = nodeBel(:,s(l))*nodeBel(:,t(l))';
end

0 comments on commit 3aedcb4

Please sign in to comment.