function [bmask,uniqidx,nc,nn,totalarea,iso]=mcxsvmc(vol, varargin)
%
% Format:
%    newvol=mcxsvmc(vol)
%       or
%    [bmask,uniqidx,nn,nc,totalarea,iso]=mcxsvmc(vol)
%    [bmask,uniqidx,nn,nc,totalarea,iso]=mcxsvmc(vol,'option1',v1,'option2',v2,...)
%
% Preprocessing voxelated labels to find and incorporate curved boundary
% information and improve accuracy
%
% Author: Qianqian Fang <q.fang at neu.edu>
%
% Input:
%    vol: a 3D volume with integer labels, a label 0 is assumed to be
%         background and tissue labels are non-zero positive numbers
%    options (optional): one can add additional 'name', value pairs to
%         specify user options, these options include
%         'debug': [0], if set to 1, print example surfaces for debugging
%         'smoothing': [1], if set to 1, do a Gaussian smoothing for the
%                 labels before extracting isosurface using marching cube
%         'kernelsize': [5], the 3D gaussian kernel size
%         'kernelstd': [2], the 3D gaussian kernel standard-deviation(sigma)
%         'threshold': [0.5], default threshold when extracting isosurfaces
%         'debugpatch': [1], ID of the patch to plot
%         'bmask': [], a pre-computed mask of the same size as vol; a
%                 voxel of non-zero value prevents adding new boundary info
%         'curveonly': [1] if set to 1, remove voxels where patch normals
%                 are align with x/y/z/ axes.
%
% Output:
%    bmask: an Nx2 array, where N is the length of uniqidx, with the 1st
%         column records the lower-valued label ID, and the 2nd column
%         records the higher-valued label ID in a mixed label voxel
%    uniqidx: the 1-D index of all voxels that are made of mixed labels
%    nc: nc is a Nx3 array, where N is the length of uniqidx, representing
%         the reference point of the intravoxel interface in a mixed label 
%         voxel
%    nn: nn is a Nx3 array, where N is the length of uniqidx, representing
%         the normalized normal vector in a mixed label voxel, pointing
%         from low to high label numbers
%    totalarea: an Nx1 array, where N is the length of uniqidx, representing
%         the total cross-sectional area of the boundary in each mixed
%         label voxel
%    iso: a struct, iso.vertices denotes the nodes and iso.faces denotes
%         the triangle patches of the combined isosurfaces from all labels.
%
%    if a single output is given, bmask, nn, nc, are merged into a single
%    4-D volume, where the indices in the 1st dimension represents
%    1-2: bmask, 3-5: nn, 6-8: nc
%
% Example:
%    [xi,yi,zi]=ndgrid(0.5:59.5,0.5:59.5,0.5:59.5);
%    dist=(xi-30).*(xi-30)+(yi-30).*(yi-30)+(zi-30).*(zi-30);
%    vol=zeros(size(dist));
%    vol(dist<10*10)=1;
%    vol(20:40,20:40,1:30)=2;
%    nuvox=mcxnuvoxel(vol);
%
% Dependency:
%    this function depends on the Iso2Mesh toolbox (http://iso2mesh.sf.net)
%
% This function is part of Monte Carlo eXtreme (MCX) URL: http://mcx.space
%
% License: GNU General Public License version 3, please read LICENSE.txt for details
%

%% parse user options
opt=varargin2struct(varargin{:});
dodebug=jsonopt('debug',0,opt);
ksize=jsonopt('kernelsize',3,opt);
kstd=jsonopt('kernelstd',1,opt);
level=jsonopt('threshold',0.5,opt);
debugpatch=jsonopt('debugpatch',1,opt);
bmask=jsonopt('bmask',NaN(size(vol)),opt);
dosmooth=jsonopt('smoothing',1,opt);
curveonly=jsonopt('curveonly',1,opt);

bmask2=zeros(size(vol));

%% read unique labels

labels=sort(unique(vol(:)));
% labels(labels==0)=[];

if(length(labels)>255)
    error('MCX currently supports up to 255 labels for this function');
end

%% loop over unique labels in ascending order

iso=struct('vertices',[],'faces',[]);

for i=1:length(labels)
    % convert each label into a binary mask, smooth it, then extract the
    % isosurface using marching cube algorithm (matlab builtin)
    if(dosmooth)
        volsmooth=smooth3(double(vol==labels(i)),'g',ksize,kstd);
    else
        volsmooth=(vol==labels(i));
    end
    [xi,yi,zi]=ndgrid(1:size(volsmooth,1),1:size(volsmooth,2),1:size(volsmooth,3));
    fv0=isosurface(xi,yi,zi,volsmooth,level);
    if(isempty(fv0.vertices))
        continue;
    end

    % get the containing voxel linear id
    c0=meshcentroid(fv0.vertices,fv0.faces);
    voxid=sub2ind(size(vol),floor(c0(:,1))+1,floor(c0(:,2))+1,floor(c0(:,3))+1);
    % identify unique voxels
    uniqidx=unique(voxid);
    bmask2(uniqidx)=labels(i);
    
    % find new boundary voxels that are not covered by previous levelsets
    uniqidx=uniqidx(isnan(bmask(uniqidx)));
    goodpatchidx=ismember(voxid,uniqidx);
    
    % merge surface patches located inside new boundary voxels
    iso.faces=[iso.faces; size(iso.vertices,1)+fv0.faces(goodpatchidx==1,:)];
    iso.vertices=[iso.vertices; fv0.vertices];
    
    % label those voxels as covered
    bmask(uniqidx)=labels(i);
end

%% handle uniform domains
if(isempty(iso.vertices))
    uniqidx=[];
    nn=[];
    totalarea=[];
    return;
end

%% get the voxel mapping for the final combined isosurface
[iso.vertices,iso.faces]=removeisolatednode(iso.vertices,iso.faces);
c0=meshcentroid(iso.vertices,iso.faces);
voxid=sub2ind(size(vol),floor(c0(:,1))+1,floor(c0(:,2))+1,floor(c0(:,3))+1);
[uniqidx, vox2patch]=unique(voxid);
    
[cc1,cc2]=histc(voxid,uniqidx);
cc=cc1(cc2);

if(dodebug)
    plotmesh(iso.vertices,iso.faces,'facealpha',0.4,'facecolor','b','edgealpha',0.2)
    disp(max(iso.vertices)-min(iso.vertices))
end

%% obtain the low and high labels in all mix-label voxels
bmask=[bmask(uniqidx),bmask2(uniqidx)];
bmask(bmask(:,1)==bmask(:,2),2)=0;

%% computing total area and normal vector for each boundary voxel
areas=elemvolume(iso.vertices,iso.faces);
normals=surfacenorm(iso.vertices,iso.faces);
centroids=meshcentroid(iso.vertices,iso.faces);

totalarea=zeros(size(vol));
maxvid=max(voxid);
totalarea(1:maxvid)=accumarray(voxid,areas); % total areas of cross-sections in each boundary voxel
% totalarea_unique=totalarea(uniqidx);

%% compute weighted average surface centroid per boundary voxel;
nc=zeros([3 size(vol)]);
nc(1,1:maxvid)=accumarray(voxid,centroids(:,1).*areas); % total normal_x of cross-sections in each boundary voxel
nc(2,1:maxvid)=accumarray(voxid,centroids(:,2).*areas); % total normal_y of cross-sections in each boundary voxel
nc(3,1:maxvid)=accumarray(voxid,centroids(:,3).*areas); % total normal_z of cross-sections in each boundary voxel
for i=1:3
    nc(i,uniqidx)=nc(i,uniqidx)./totalarea(uniqidx)';
end
nc=nc(:,uniqidx)';

%% creating weighted average surface patch normals per boundary voxel
nn=zeros([3 size(vol)]);
nn(1,1:maxvid)=accumarray(voxid,normals(:,1).*areas); % total normal_x of cross-sections in each boundary voxel
nn(2,1:maxvid)=accumarray(voxid,normals(:,2).*areas); % total normal_y of cross-sections in each boundary voxel
nn(3,1:maxvid)=accumarray(voxid,normals(:,3).*areas); % total normal_z of cross-sections in each boundary voxel
nnlen=sqrt(sum(nn.*nn,1));
for i=1:3
    nn(i,uniqidx)=nn(i,uniqidx)./nnlen(uniqidx)';
end
nn=nn(:,uniqidx)';

%% remove x/y/z oriented 
if(curveonly)
    [ixc,iyc]=find((nc-floor(nc))<1e-6);
    [ixn,iyn]=find(abs(nn)==1);
    ix=intersect(ixc,ixn);
    boxmask=zeros(size(vol));
    boxmask(uniqidx(ix))=1;
    boxmask=smooth3(boxmask,'b',3);
    boxmask=(boxmask>0);
    ix=find(boxmask(uniqidx)==1);

    nn(ix,:)=[];
    nc(ix,:)=[];
    uniqidx(ix)=[];
    bmask(ix,:)=[];
    totalarea(ix)=[];
end

%% discretize nn and nc vector components to 0-255 gray-scale numbers

nc=nc-floor(nc);
nc=floor(nc*255);
nn=min(floor((nn+1)*255/2),254);

%% assemble the final volume
if(nargout==1)
    newvol=zeros([8,size(vol)]);
    newvol(1,:,:,:)=vol;
    newvol(1:2,uniqidx)=bmask';
    newvol(3:5,uniqidx)=nc';
    newvol(6:8,uniqidx)=nn';
    bmask=newvol;
end

%% plotting for verification

if(dodebug)
    figure;

    pcidx=find(cc>1);         % find 4-patch that blong to the same voxel (in patch idx)
    pidx0=pcidx(debugpatch);          % pick one such patch group to debug
    idx1=voxid(pidx0);        % find the corresponding voxel id idx1 for the patch pidx0
    patid=voxid==idx1;  % these patches are within voxel linear id idx1

    testmask=zeros(size(volsmooth));
    testmask(idx1)=1;
    [no3,fc3]=binsurface(testmask,4);

    figure;
    plotmesh(no3,fc3,'facealpha',0.3,'facecolor','none')
    hold on;
    plotmesh(iso.vertices,iso.faces(patid,:),'facealpha',0.3,'facecolor','b','edgealpha',0.2)
    disp(nn(cc2(pidx0),:))           % normal in voxel id idx1
    plotmesh([c0(pidx0,:); c0(pidx0,:)+nn(cc2(pidx0),:)],'ro-');   % plot centroid of pidx0-th patch to the normal in voxel id idx1
end

