#
#
# Model3GibbsSampler.R
#
# Andrew Brown
# 2 March 2015
#
# This script contains the Gibbs sampler for implementing Model 3 in our Bayesian
# mult. testing model paper. "Model 3" is the one with weights assigned so that
# points with no neighbors can be included in the joint specification. Note that
# whether or not to inculde \rho is optional, as setting \rho = 1 still results
# in a nonsingular precision matrix. As such, the rho input is an option in this
# code. If absent, rho is set to 1 by default. 
#
# Input: W - The incidence (adjacency) matrix containing the neighborhood 
#              structure for the data
#        y - The vector of observed values
#        p.hyper - The scalar hyperparameter in the p prior for the model
#        burn.in - The scalar number of burn-in iterations to perform
#        mu.init - The vector of initial mu values
#        p.init - The scalar initial value for p
#        sig2.init - The scalar initial value for the noise variance
#        xi.init - The scalar initial value for the signal/noise variance ratio
#        rho.init - The scalar initial value for the propriety parameter in the
#                   CAR structure
#
# Output: out - The list containing 2000 sample draws for each of the parameters
#               in the model.
#
Model3GS<- function(W, y, p.hyper, burn.in, mu.init, p.init, 
                    sig2.init, xi.init, rho.init, diff= 1) {
  
  # The eigenvalue computation is only necessary if we include rho ~= 0 in the 
  # model. D_w is needed, regardless, so we calculate it.
  # Calculate the scaled adjacency matrix and find its eigenvalues. This involves
  # calculating D_w^(-.5) and taking D_w^(-.5) %*% W %*% D_w^(-.5).
  d.w= W
  d.w[]= 0
  
  num.nabors<- vector(length= dim(W)[1])
  for (i in 1:dim(W)[1]){
    
    num.nabors[i]= sum(W[i, ])
    
  }
  rm(i)
  
  diag(d.w)= num.nabors + diff  # <- This is the main difference between Model 2
                             # and Model 3. D_w is redifined here, and thus can
                             # be used as normal in the remainder of the code.
  
  
  if (!missing(rho.init)) {
    
    # Scale the adjacency matrix. I'm hoping this will be faster than doing
    # matrix multiplication in R
    w.star <- W
    for (i in 1:dim(w.star)[1]){
      
      for (j in 1:dim(w.star)[2]) {
        
        w.star[i,j] <- w.star[i,j]/(sqrt(num.nabors[i]+diff)*sqrt(num.nabors[j]+diff))
        
      }
      
    }  
    
    lambdas= eigen(w.star, symmetric= TRUE, only.values= TRUE)$values
    
    rho.low= 1/min(lambdas)
    rho.high= 1/max(lambdas)
    
    #rho.low= .9  # <- Try an informative prior on rho
    
  }  # End if
  
  
  
  ################################################################################
  
  
  
  # Now we proceed with the algorithm.
  
  
  # 1. Construct initial estimates for mu, sigma^2, p, xi, and rho. gamma will be
  #    drawn as the first step of the sampler. 
  mu= mu.init  
  p= p.init
  sig2= sig2.init
  xi= xi.init  # Xi is the scale factor between sigma^2 and tau^2
  
  # By default, rho= 1. Else, include it in the Gibbs updates
  if (!missing(rho.init)) {
    
    rho= rho.init  # We know that "high" values of rho are needed to induce 
    # reasonable spatial interaction
    
  } else {
    
    rho= 1
    
  }  # end if
  
  
  M= length(y)
  gam= vector(length= M)
  
  for (iter in 1:burn.in){
  
    
    # 2a. Draw gamma 
    for (indx in 1:M){
      
      num= ((1-p)/p)*exp(-1*(mu[indx]^2 - 2*y[indx]*mu[indx])/(2*sig2))
      p.star= num/(num + 1)
      
      gam[indx]= rbinom(1, 1, p.star)
      
    }
    rm(indx, p.star)
    
    # Check for NaNs
    if(is.nan(sum(gam))) {
      
      print('Uh oh!')
      break
      
    }
    
    
    
    # 2b. Draw sigma2 
    d.w.qf= sum((num.nabors+diff)*mu^2)
    w.qf= colSums(as.matrix(mu)*(W %*% (as.matrix(mu))))  # <- found this on
            # stackoverflow
    
    sig2.shp2= 2/(sum((y - gam*mu)^2) + 
                              (1/xi)*(d.w.qf - rho*w.qf))
    
    temp= rgamma(1, shape= M, scale= sig2.shp2)
    sig2= 1/temp
    
    rm(temp,sig2.shp2)
    
    
    
    
    # 2c. Draw p (Same as in Model 2)
    p.shp1= M - sum(gam) + p.hyper
    p.shp2= sum(gam) + 1
    
    p= rbeta(1, p.shp1, p.shp2)
    
    
    
    
    # 2d. Draw xi using rejection sampling.
    xi.shp1= M/2
    xi.rate= (d.w.qf - rho*w.qf)/(2*sig2)
    
    xi.star= 1/rgamma(1, shape= xi.shp1, rate= xi.rate)
    
    dec.rule= 4*xi.star/((xi.star+1)^2)
    u= runif(1)
    
    while(u > dec.rule){
      
      xi.star= 1/rgamma(1, shape= xi.shp1, rate= xi.rate)
      dec.rule= 4*xi.star/((xi.star+1)^2)
      u= runif(1)
      
    }  # End while loop
    
    
    xi= xi.star
    rm(dec.rule, xi.star, u)
    #xi= 1.5
    
    
    # The rho update is only done if we include it as a parameter in the model    
    if (!missing(rho.init)) {
      
      # 2e. Draw rho. This is done with slice sampling. 
      
      # To make the slice sampler more efficient, I'm (attempting) to use to 
      # 'doubling' procedure described by Neal (2003, AOS). 
      e= rexp(1)
      
      log.h.rho= (1/2)*sum(log(1 - rho*lambdas)) - ((d.w.qf - rho*w.qf)/(2*xi*sig2))
      
      z= log.h.rho - e  
      
      # Determine an interval from which to draw rho.star. 
      w= .001
      
      u= runif(1)
      l= max(rho - w*u, rho.low)
      r= min(l + w, rho.high)
      
      k= 15
      while (k > 0 && ( (z < (1/2)*sum(log(1 - l*lambdas)) - 
                         ((d.w.qf - l*w.qf)/(2*xi*sig2)) ) 
             || (z < (1/2)*sum(log(1 - r*lambdas)) - 
                   ((d.w.qf - r*w.qf)/(2*xi*sig2))))) {
        
        v= runif(1)
        if (v < 0.5) {
          
          l= max(l - (r - l), rho.low)  # We have to truncate the interval if it
          # runs up against the boundary
          
        } else {
          
          r= min(r + (r - l), rho.high)
          
        }  # End if-else
        
        k= k-1
        
      }  # End while loop
      
      
      # Sampele rho.star uniformly from (l,r), shrinking each time a rho.star is
      # drawn that is not in A = set of acceptable candidate values.
      l.bar= l
      r.bar= r
      repeat {
        
        u= runif(1)
        rho.star= l.bar + u*(r.bar - l.bar)
        
        # Determine whether rho.star is acceptable
        l.hat= l.bar
        r.hat= r.bar
        d= FALSE
        
        accept= TRUE
        while (r.hat - l.hat > 1.1*w) {
          
          m= (l.hat + r.hat)/2
          
          if ((rho < m && rho.star >= m) || (rho > m && rho.star <= m)) {
            
            d= TRUE
            
          }  # End if
          
          
          if (rho.star < m) {
            
            r.hat= m
            
          } else {
            
            l.hat= m
            
          }  # End if-else
          
          
          # Acceptable?
          if (d && ((z >= (1/2)*sum(log(1 - l.hat*lambdas)) - 
                       ((d.w.qf - l.hat*w.qf)/(2*xi*sig2))) && 
                      (z >= (1/2)*sum(log(1 - r.hat*lambdas)) - 
                         ((d.w.qf - r.hat*w.qf)/(2*xi*sig2))))) {
            
            accept= FALSE
            break  # You can stop the while loop here if you find that rho.star
            # is unacceptable.
            
          }  # End if
          
        }  # End while loop
        
        # Check both conditions before deciding whether or not to shrink the
        # interval of possible draws
        if (z < (1/2)*sum(log(1 - rho.star*lambdas)) - 
              ((d.w.qf - rho.star*w.qf)/(2*xi*sig2)) && accept) {
          
          break  # This is what we need to break out of the loop and keep the value
          
        }  # End if
        
        # Otherwise, shrink the interval and try again
        if (rho.star < rho) {
          
          l.bar= rho.star
          
        } else {
          
          r.bar= rho.star
          
        }  # End if-else
        
      }  # End repeat loop
      
      # After all of this mess, we should end up with an acceptable update for rho
      rho= rho.star
      #rho= 0.998  
      
    }  # end if

    
  
    # 2f. Draw mu
    for (indx in 1:M){
      
      ngbrs= W[indx, ]

      w.i.dot= d.w[indx, indx]
      
      mu.mean= (gam[indx]*y[indx] + (rho/xi)*sum(ngbrs*mu))/(gam[indx] + w.i.dot/xi)  
      mu.var= sig2/(gam[indx] + w.i.dot/xi)
      
      mu[indx]= rnorm(1, mean= mu.mean, sd= sqrt(mu.var))
      
    }  # End loop over mu
    rm(indx, ngbrs, w.i.dot)
    
    
    # Save every 100 iterations
    if (iter %% 10 == 0){
      
      print(iter)
      
    }
  #   
  #   if (iter %% 100 == 0){
  #     
  #     save(iter, y, M, gam, sig2, p, xi, rho, mu, file= "currentState.rda")
  #     
  #   } # End if
  
  }  # End burn-in loop
  rm(iter)
  




  # Repeat 2a - 2f until convergence.
  
  # Run an additional 5,000 iteration, saving every 5th draw to (i) improve 
  # convergence, and (ii) save space
  sig2.draws= vector(length= 2000)
  xi.draws= vector(length= 2000)
  gam.draws= matrix(nrow= 2000, ncol= length(gam))
  p.draws= vector(length= 2000)
  rho.draws= vector(length= 2000)
  mu.draws= matrix(nrow= 2000, ncol= length(mu))
  
  # Sampling loop
  for (iter in 1:10000){
    
    if (iter %% 100 == 0) {
      
      print(iter)
      
    } # End if
    
    if (is.nan(sig2)) {
      
      print('sig2 is NaN')
      break
      
    }
    
    # 2a. Draw gamma
    for (indx in 1:M){
      
      num= (1-p)*exp((-1/(2*sig2))*(y[indx] - mu[indx])^2)
      denom= num + p*exp((-1/(2*sig2))*y[indx]^2)
      
      if (is.nan(num)) {
        
        print('num is NaN')
        break
        
      }
      
      p.star= num/denom
      
      gam[indx]= rbinom(1, 1, p.star)
      
    }
    rm(indx, p.star)
    
    if (is.nan(num)) {
      
      break
      
    }
    
    # Check for NaNs
    if(sum(is.nan(gam) > 0) || (sum(is.na(gam) > 0))) {
      
      print('Uh oh!')
      break
      
    }
    
    
    # 2b. Draw sigma2
    d.w.qf= sum((num.nabors+diff)*mu^2)
    w.qf= colSums(as.matrix(mu)*(W %*% (as.matrix(mu))))  # <- found this on
      # stackoverflow
    
    sig2.shp2= 2/(sum((y - gam*mu)^2) + 
                    (1/xi)*(d.w.qf - rho*w.qf))
    
    temp= rgamma(1, shape= M, scale= sig2.shp2)
    sig2= 1/temp
    
    rm(temp,sig2.shp2)
    
    
    
    
    # 2c. Draw p
    p.shp1= M - sum(gam) + p.hyper
    p.shp2= sum(gam) + 1
    
    p= rbeta(1, p.shp1, p.shp2)
    
    if (is.nan(p)) {
      
      print('p is NaN')
      break
      
    }
    
    
    # 2d. Draw xi using rejection sampling. 
    xi.shp1= M/2
    xi.rate= (d.w.qf - rho*w.qf)/(2*sig2)
    
    xi.star= 1/rgamma(1, shape= xi.shp1, rate= xi.rate)
    
    dec.rule= 4*xi.star/((xi.star+1)^2)
    u= runif(1)
    
    while(u > dec.rule){
      
      xi.star= 1/rgamma(1, shape= xi.shp1, rate= xi.rate)
      dec.rule= dec.rule= 4*xi.star/((xi.star+1)^2)
    u= runif(1)
    
    }  # End while loop
    
    xi= xi.star
    rm(dec.rule, xi.star, u)
    #xi= 1.5
    
    
    
    # The rho update is only done if we include it as a parameter in the model    
    if (!missing(rho.init)) {
      
      # 2e. Draw rho. This is done with slice sampling. See my notes about it.
      
      # To make the slice sampler more efficient, I'm (attempting) to use to 
      # 'doubling' procedure described by Neal (2003, AOS). 
      e= rexp(1)
      
      log.h.rho= (1/2)*sum(log(1 - rho*lambdas)) - ((d.w.qf - rho*w.qf)/(2*xi*sig2))
      
      z= log.h.rho - e  
      
      # Determine an interval from which to draw rho.star. 
      w= .001
      
      u= runif(1)
      l= max(rho - w*u, rho.low)
      r= min(l + w, rho.high)
      
      k= 15
      while (k > 0 && ( (z < (1/2)*sum(log(1 - l*lambdas)) - 
                         ((d.w.qf - l*w.qf)/(2*xi*sig2)) ) 
             || (z < (1/2)*sum(log(1 - r*lambdas)) - 
                   ((d.w.qf - r*w.qf)/(2*xi*sig2))))) {
        
        v= runif(1)
        if (v < 0.5) {
          
          l= max(l - (r - l), rho.low)  # We have to truncate the interval if it
          # runs up against the boundary
          
        } else {
          
          r= min(r + (r - l), rho.high)
          
        }  # End if-else
        
        k= k-1
        
      }  # End while loop
      
      # Sampele rho.star uniformly from (l,r), shrinking each time a rho.star is
      # drawn that is not in A = set of acceptable candidate values.
      l.bar= l
      r.bar= r
      repeat {
        
        u= runif(1)
        rho.star= l.bar + u*(r.bar - l.bar)
        
        # Determine whether rho.star is acceptable
        l.hat= l.bar
        r.hat= r.bar
        d= FALSE
        
        accept= TRUE
        while (r.hat - l.hat > 1.1*w) {
          
          m= (l.hat + r.hat)/2
          
          if ((rho < m && rho.star >= m) || (rho > m && rho.star <= m)) {
            
            d= TRUE
            
          }  # End if
          
          
          if (rho.star < m) {
            
            r.hat= m
            
          } else {
            
            l.hat= m
            
          }  # End if-else
          
          
          # Acceptable?
          if (d && ((z >= (1/2)*sum(log(1 - l.hat*lambdas)) - 
                       ((d.w.qf - l.hat*w.qf)/(2*xi*sig2))) && 
                      (z >= (1/2)*sum(log(1 - r.hat*lambdas)) - 
                         ((d.w.qf - r.hat*w.qf)/(2*xi*sig2))))) {
            
            accept= FALSE
            break  # You can stop the while loop here if you find that rho.star
            # is unacceptable.
            
          }  # End if
          
        }  # End while loop
        
        # Check both conditions before deciding whether or not to shrink the
        # interval of possible draws
        if (z < (1/2)*sum(log(1 - rho.star*lambdas)) - 
              ((d.w.qf - rho.star*w.qf)/(2*xi*sig2)) && accept) {
          
          break  # This is what we need to break out of the loop and keep the value
          
        }  # End if
        
        # Otherwise, shrink the interval and try again
        if (rho.star < rho) {
          
          l.bar= rho.star
          
        } else {
          
          r.bar= rho.star
          
        }  # End if-else
        
      }  # End repeat loop
      
      # After all of this mess, we should end up with an acceptable update for rho
      rho= rho.star
      #rho= 0.998  
      
    }  # end if
    
    
    
    # 2f. Draw mu
    for (indx in 1:M){
      
      ngbrs= W[indx, ]
      # w.i.dot= sum(ngbrs)  # This is another place where the difference 
                             # between the two models shows up
      w.i.dot= d.w[indx, indx]
      
      mu.mean= (gam[indx]*y[indx] + (rho/xi)*sum(ngbrs*mu))/(gam[indx] + w.i.dot/xi)  
      mu.var= sig2/(gam[indx] + w.i.dot/xi)
      
      mu[indx]= rnorm(1, mean= mu.mean, sd= sqrt(mu.var))
      
    }  # End loop over mu
    rm(indx, ngbrs, w.i.dot)
    
    
    # Save every 5th iteration
    if (iter %% 5 == 0) {
      
      sig2.draws[iter/5]= sig2
      xi.draws[iter/5]= xi
      gam.draws[iter/5, ]= gam
      p.draws[iter/5]= p
      rho.draws[iter/5]= rho
      mu.draws[iter/5, ]= mu
      
    }  # End if
    
  }  # End sampling loop
  
  # Return a of all the sample draws (if rho.init is missing, then rho.draws is
  # just a vector of zeros.)
  out= list(sig2= sig2.draws, xi= xi.draws, gam= gam.draws, p= p.draws,
            rho= rho.draws, mu= mu.draws)

  return(out)

}  # End function Model2GS
