"""C 2021 Bence Becsy
MCMC for CW fast likelihood (w/ Neil Cornish and Matthew Digman)"""
import numpy as np
from numba import njit,prange
from numpy.random import uniform
import QuickCW.CWFastPrior as CWFastPrior
import QuickCW.const_mcmc as cm
from QuickCW.QuickCorrectionUtils import check_merged,correct_intrinsic,correct_extrinsic_array
from QuickCW.QuickFisherHelpers import safe_reset_swap,get_FLI_mem
from time import perf_counter
################################################################################
#
#UPDATE INTRINSIC PARAMETERS AND RECALCULATE FILTERS
#
################################################################################
#version using multiple try mcmc (based on Table 6 of https://vixra.org/pdf/1712.0244v3.pdf)
#@profile
[docs]def do_intrinsic_update_mt(mcc, itrb):
"""do the intrinsic update using the multiple try mcmc algorithm
:param mcc: MCMCChain onject
:param itrb: Index within saved values (as opposed to block index itri or overall index itrn)
:return mcc.FLI_swap: FastLikeInfo object
"""
Npsr = mcc.x0s[0].Npsr
Ts = mcc.chain_params.Ts
for j in range(mcc.n_chain):
assert mcc.FLIs[j].get_lnlikelihood(mcc.x0s[j]) == mcc.log_likelihood[j,itrb]
#print('k',j,mcc.log_likelihood[j,itrb])
#print("EXT")
for j in range(mcc.n_chain):
#print(mcc.FLIs[j].get_lnlikelihood(mcc.x0s[j]),mcc.log_likelihood[j,itrb])
assert mcc.FLIs[j].get_lnlikelihood(mcc.x0s[j]) == mcc.log_likelihood[j,itrb]
mcc.FLIs[j].validate_consistent(mcc.x0s[j])
#save MMs and NN so we can revert them if the chain is rejected
FLI_mem_save = get_FLI_mem(mcc.FLIs[j])
samples_current = np.copy(mcc.samples[j,itrb,:])
#print('0',mcc.FLIs[j].get_lnlikelihood(mcc.x0s[j]),mcc.log_likelihood[j,itrb])
#should already be at this value
mcc.x0s[j].validate_consistent(samples_current)
mcc.x0s[j].update_params(samples_current)
total_weight = (mcc.chain_params.dist_jump_weight + mcc.chain_params.rn_jump_weight + mcc.chain_params.gwb_jump_weight +
mcc.chain_params.common_jump_weight + mcc.chain_params.all_jump_weight)
which_jump = np.random.choice(5, p=[mcc.chain_params.dist_jump_weight/total_weight,
mcc.chain_params.rn_jump_weight/total_weight,
mcc.chain_params.gwb_jump_weight/total_weight,
mcc.chain_params.common_jump_weight/total_weight,
mcc.chain_params.all_jump_weight/total_weight])
#replace checking which_jump==1 etc with indicator values for desired behavior so that more jump types can be added in the future
recompute_rn = False
recompute_gwb = False
recompute_int = False
recompute_dist = False
all_eigs = False
fail_point = False
merged_point = False
if which_jump==0: # update psr distances
recompute_dist = True
n_dist_loc = min(Npsr,mcc.chain_params.n_dist_main)#max(1,np.int64(cm.n_dist_main*mcc.chain_params.Ts[j])))
idx_choose_psr_dist = np.random.choice(Npsr,n_dist_loc,replace=False)
n_jump_loc = n_dist_loc
#idx_choose_psr[0] = pta.pulsars.index(par_names_cw_int[jump_select][:-11])
idx_choose = mcc.x0s[j].idx_dists[idx_choose_psr_dist]
scaling = 2.38*np.sqrt(Ts[j])/np.sqrt(n_jump_loc)
#scaling = 1.0
#scaling = 0.5
elif which_jump==1: # update per psr RN
recompute_rn = True
all_eigs = True
n_jump_loc = 2*Npsr
idx_choose_psr = list(range(Npsr))
idx_choose_psr_dist = idx_choose_psr
idx_choose = np.concatenate((mcc.x0s[j].idx_rn_gammas,mcc.x0s[j].idx_rn_log10_As))
scaling = 2.38*np.sqrt(Ts[j])/np.sqrt(n_jump_loc)
#scaling = 1/np.sqrt(n_jump_loc)
elif which_jump==2: # update common RN
recompute_gwb = True
n_jump_loc = 2
idx_choose = np.array([mcc.x0s[j].idx_gwb_gamma, mcc.x0s[j].idx_gwb_log10_A])
idx_choose_psr_dist = list(range(Npsr)) #all pulsars need to be updated in everything here
scaling = 2.38*np.sqrt(Ts[j])/np.sqrt(n_jump_loc/2)
#scaling = 1/np.sqrt(n_jump_loc)
elif which_jump==3: # update common intrinsic parameters (chirp mass, frequency, sky location[2])
recompute_int = True
n_jump_loc = 4 # 2+mcc.chain_params.ndist_extra
idx_choose = mcc.x0s[j].idx_cw_int[:4] # np.array([par_names.index(par_names_cw_int[itrk]) for itrk in range(4)])
#don't count parameters where jump sizes are probably saturated for the purposes of determining the appropriate jump sizing
saturated_idxs = np.sum((2.38*np.sqrt(Ts[j])*mcc.fisher_diag[j][idx_choose])>0.5)
if saturated_idxs==n_jump_loc:
saturated_idxs = n_jump_loc-1
scaling = 2.38*np.sqrt(Ts[j])/np.sqrt(n_jump_loc-saturated_idxs)
idx_choose_psr_dist = list(range(Npsr)) #all pulsars need to be updated in everything here
all_eigs = True
#scaling = 1.0
#scaling = 0.5
elif which_jump==4: # do every possible jump
#including this ensures any point in parameter space has some finite probability density to be reached in a single jump
recompute_rn = True
recompute_gwb = True
recompute_int = True
recompute_dist = True
all_eigs = True
n_dist_loc = min(Npsr,mcc.chain_params.n_dist_main)#max(1,np.int64(cm.n_dist_main*mcc.chain_params.Ts[j])))
idx_choose_psr_dist = np.random.choice(Npsr,n_dist_loc,replace=False)
idx_choose_psr = list(range(Npsr))
n_jump_loc = 2*Npsr+4+2 #distance+RN+common_pars+crn
idx_choose = np.concatenate((mcc.x0s[j].idx_cw_int[:4],
mcc.x0s[j].idx_rn_gammas, mcc.x0s[j].idx_rn_log10_As,
[mcc.x0s[j].idx_gwb_gamma, mcc.x0s[j].idx_gwb_log10_A]))
scaling = 2.38*np.sqrt(Ts[j])/np.sqrt(n_jump_loc)
else:
raise ValueError('jump index unrecognized',which_jump)
#decide what kind of jump we do
if recompute_rn and not recompute_gwb:
if mcc.rn_emp_dist is None: # RN jump w/o emp dist --> only do fisher
prior_draw_prob = 0
de_prob = 0
fisher_prob = mcc.chain_params.fisher_prob
else: # RN jump w/ emp dist --> do fisher and emp dist (called prior here)
prior_draw_prob = mcc.chain_params.prior_draw_prob
de_prob = 0
fisher_prob = mcc.chain_params.fisher_prob
elif recompute_gwb and not recompute_rn and not recompute_int: # GWB --> do fisher and DE
prior_draw_prob = 0
#if j==(mcc.n_chain-1): #never do DE on hottest chain
# de_prob = 0.
#else:
de_prob = mcc.chain_params.de_prob
fisher_prob = mcc.chain_params.fisher_prob
elif recompute_gwb and recompute_rn: #all --> do fisher and de only
prior_draw_prob = 0#mcc.chain_params.prior_draw_prob
#if j==(mcc.n_chain-1): #never do DE on hottest chain
# de_prob = 0.
#else:
de_prob = mcc.chain_params.de_prob
fisher_prob = mcc.chain_params.fisher_prob
elif j==(mcc.n_chain-1) and which_jump==3: #distance of common parameters and hottest chain --> only do prior draws
prior_draw_prob = mcc.chain_params.prior_draw_prob
de_prob = 0
fisher_prob = 0
elif which_jump!=3: #distance jump --> do prior draws and fisher
prior_draw_prob = mcc.chain_params.prior_draw_prob
de_prob = 0
fisher_prob = mcc.chain_params.fisher_prob
else: #common jump --> do everything
prior_draw_prob = mcc.chain_params.prior_draw_prob
de_prob = mcc.chain_params.de_prob
fisher_prob = mcc.chain_params.fisher_prob
total_type_weight = prior_draw_prob + de_prob + fisher_prob
which_jump_type = np.random.choice(3, p=[prior_draw_prob/total_type_weight,
de_prob/total_type_weight,
fisher_prob/total_type_weight])
if which_jump_type==1 and which_jump==4:
#force 'all' differential evolution jumps to be in both gwb and common parameters only
idx_choose = np.concatenate((mcc.x0s[j].idx_cw_int[:4],[mcc.x0s[j].idx_gwb_gamma, mcc.x0s[j].idx_gwb_log10_A]))
n_jump_loc = 6
scaling = 2.38*np.sqrt(Ts[j])/np.sqrt(n_jump_loc)
idx_choose_psr = []
idx_choose_psr_dist = []
recompute_int = True
recompute_gwb = True
recompute_rn = False
recompute_dist = False
if which_jump_type==0: # do prior draw (or empirical distribution in case of RN)
if which_jump==1: # updateing RN --> do empirical distribution step
new_point = samples_current.copy()
log_proposal_ratio = 0.
#overwrite the list of pulsars to update,
#because we might want to update fewer pulsars when using empirical distributions
#to help acceptence despite the penalty factors
#scale number of dimensions by a factor related to the temperature if it goes to T>~50 to avoid under-aggressive jumps
n_noise_emp_dist_loc = max(min(Npsr,np.int64(mcc.chain_params.n_noise_emp_dist*(Ts[j]/400.+1))),1)
idx_choose_psr = np.random.choice(Npsr,n_noise_emp_dist_loc,replace=False)
#log_proposal_ratio = 0.0
for psr_idx in idx_choose_psr:
#use temperature adapted empirical distributions if possible
#if mcc.rn_emp_dist_adapt is None:
rn_emp_dist_loc = mcc.rn_emp_dist
#else:
# rn_emp_dist_loc = mcc.rn_emp_dist_adapt[j]
#rn_draw = mcc.rn_emp_dist[psr_idx].draw()
rn_draw = rn_emp_dist_loc[psr_idx].draw()
new_point[mcc.x0s[j].idx_rn_log10_As[psr_idx]] = rn_draw[0]
new_point[mcc.x0s[j].idx_rn_gammas[psr_idx]] = rn_draw[1]
log_proposal_ratio += rn_emp_dist_loc[psr_idx].logprob(np.array([samples_current[mcc.x0s[j].idx_rn_log10_As[psr_idx]],
samples_current[mcc.x0s[j].idx_rn_gammas[psr_idx]]]))
log_proposal_ratio +=-rn_emp_dist_loc[psr_idx].logprob(rn_draw)
#if j==0: print("RNEmpDist--psr="+mcc.psrs[psr_idx].name)
#if j==0: print("RNEmpDist--log_prop_ratio="+str(log_proposal_ratio))
else: # other parameter --> do actual prior draw
new_point = CWFastPrior.get_sample_idxs(samples_current.copy(),idx_choose,mcc.FPI)
log_prior_old = CWFastPrior.get_lnprior(samples_current, mcc.FPI)
log_prior_new = CWFastPrior.get_lnprior(new_point, mcc.FPI)
#backwards/forwards proposal ratio not necessarily 1 (e.g. for distances with non-flat priors)
log_proposal_ratio = log_prior_old - log_prior_new
elif which_jump_type==1: # do differential evolution step
de_indices = np.random.choice(mcc.de_history.shape[1], size=2, replace=False)
ndim = idx_choose.size
#alpha0 = 2.38/np.sqrt(2*ndim)
alpha0 = 1.68/np.sqrt(ndim)*np.sqrt(Ts[j])
alpha = alpha0*np.random.normal(0.,1.)
x1 = np.copy(mcc.de_history[j,de_indices[0],idx_choose])
x2 = np.copy(mcc.de_history[j,de_indices[1],idx_choose])
new_point = np.copy(samples_current)
#new_point[idx_choose] += alpha0*(x1-x2)
#new_point[idx_choose] += alpha0*(1+alpha)*(x1-x2)
#backwards/forwards proposal ratio is always one for Gaussian jumps
log_proposal_ratio = 0.0
big_jump_decide = np.random.uniform(0.0, 1.0)
if big_jump_decide<mcc.chain_params.big_de_jump_prob: #do big jump
#new_point[idx_choose] += (1+alpha)*(x1-x2)
#TODO does this actually need to be scaled by a random amount?
new_point[idx_choose] += (x1-x2)
else: #do smaller jump scaled by alpha0
#new_point[idx_choose] += alpha0*(1+alpha)*(x1-x2)
new_point[idx_choose] += alpha*(x1-x2)
elif which_jump_type==2: # do regular fisher jump
#jumps don't necessarily need to be mutually exclusive so use the indicator variables
new_point = samples_current.copy()
jump = np.zeros(mcc.n_par_tot)
#backwards/forwards proposal ratio is always one for Gaussian jumps
log_proposal_ratio = 0.0
if recompute_rn: # use RN eigenvectors
scale_eig0 = scaling*mcc.eig_rn[j,:,0,:]
scale_eig1 = scaling*mcc.eig_rn[j,:,1,:]
new_point = add_rn_eig_jump(scale_eig0,scale_eig1,new_point,new_point[mcc.x0s[j].idx_rn],mcc.x0s[j].idx_rn,Npsr,all_eigs=all_eigs)
if recompute_gwb: # use diagonal fishers
idx_loc = np.array([mcc.x0s[j].idx_gwb_gamma, mcc.x0s[j].idx_gwb_log10_A])
fisher_diag_loc = scaling * mcc.fisher_diag[j][idx_loc]
jump[idx_loc] += fisher_diag_loc*np.random.normal(0.,1.,idx_loc.size)
if recompute_int: # use common parameter eigenvectors
if all_eigs:
#allows attempting all of the eigenvalue jumps simultaneously
for itrp in range(0,4):
jump[mcc.x0s[j].idx_cw_int[:4]] += scaling*mcc.eig_common[j,itrp,:].flatten()*np.random.normal(0., 1.)
else:
which_eig = np.random.choice(4, size=1)
jump[mcc.x0s[j].idx_cw_int[:4]] += scaling*mcc.eig_common[j,which_eig,:].flatten()*np.random.normal(0., 1.)
if recompute_dist: # use diagonal fishers
idx_loc = mcc.x0s[j].idx_dists[idx_choose_psr_dist]
fisher_diag_loc = scaling * mcc.fisher_diag[j][idx_loc]
#smoothly saturate the jump sizes by adding the prior - takes into account the approximate width of the priors
fisher_diag_loc = np.sqrt(1./(1./fisher_diag_loc**2+n_jump_loc/(2.38*mcc.dist_prior_sigmas[idx_choose_psr_dist])**2))
jump[idx_loc] += fisher_diag_loc*np.random.normal(0.,1.,idx_loc.size)
new_point = new_point + jump
else:
raise ValueError('jump type unrecognized',which_jump_type)
#TODO check wrapping is working right
new_point = correct_intrinsic(new_point,mcc.x0s[j],mcc.chain_params.freq_bounds,mcc.FPI.cut_par_ids, mcc.FPI.cut_lows, mcc.FPI.cut_highs)
#more thorough jump types take precedence
mask = None
if check_merged(new_point[mcc.x0s[j].idx_log10_fgw],new_point[mcc.x0s[j].idx_log10_mc],mcc.FLIs[j].max_toa):
#do not do anything if already merged
mcc.x0s[j].validate_consistent(samples_current)
mcc.FLIs[j].validate_consistent(mcc.x0s[j])
new_point = samples_current.copy()
merged_point = True
elif recompute_rn or recompute_gwb: # update per psr RN or GWB
#TODO check_merged should be done before this
if recompute_rn: #if rn update, set up mask to only update pulsars we need to update
mask = np.ones(Npsr,dtype=np.bool_)
mask[idx_choose_psr] = False
#don't necessarily need to update choleskys for distances
#if not in red noise mask but do need to update MM and NN so remove them from the mask
mask[idx_choose_psr_dist] = False
#print(mask)
#make sure FLI_swap corresponds to current chain and sample so that we can partially modify it
safe_reset_swap(mcc.FLI_swap,mcc.x0s[j],samples_current,FLI_mem_save)
for ii in range(Npsr):
mcc.FLI_swap.chol_Sigmas[ii][:] = mcc.FLIs[j].chol_Sigmas[ii]
assert mcc.FLI_swap.logdet == mcc.FLIs[j].logdet
mcc.x0s[j].update_params(new_point)
try:
mcc.flm.recompute_FastLike(mcc.FLI_swap,mcc.x0s[j],dict(zip(mcc.par_names, new_point)), mask=mask)
except np.linalg.LinAlgError:
print("failed to update parameters to requested point, rejecting proposal")
print("jump selections: ",which_jump,which_jump_type)
print("idx choose",idx_choose)
print("log proposal ratio",log_proposal_ratio)
if which_jump_type==1:
print("de jump selections: ",de_indices,big_jump_decide,alpha0,alpha)
print("de point 1",x1)
print("de point 2",x2)
elif which_jump_type==2:
print("fisher jump selections",jump)
t_err = perf_counter()
old_file = "err_state_old_"+str(t_err)+".npy"
new_file = "err_state_new_"+str(t_err)+".npy"
print("failure point:",new_point)
print("failure point output to:",new_file)
print("old point:",samples_current)
print("old point output to:",old_file)
np.save(new_file,new_point)
np.save(old_file,samples_current)
print("attempting recovery to old point")
mcc.x0s[j].update_params(samples_current)
safe_reset_swap(mcc.FLI_swap,mcc.x0s[j],samples_current,FLI_mem_save)
for ii in range(Npsr):
mcc.FLI_swap.chol_Sigmas[ii][:] = mcc.FLIs[j].chol_Sigmas[ii]
fail_point = True
mcc.FLI_swap.validate_consistent(mcc.x0s[j])
elif recompute_int: # update common intrinsic parameters (chirp mass, frequency, sky location[2])
mcc.x0s[j].update_params(new_point)
mcc.FLIs[j].update_intrinsic_params(mcc.x0s[j])
mcc.FLIs[j].validate_consistent(mcc.x0s[j])
elif recompute_dist: # update psr distances
mcc.x0s[j].update_params(new_point)
mcc.FLIs[j].update_pulsar_distances(mcc.x0s[j], idx_choose_psr_dist)
mcc.FLIs[j].validate_consistent(mcc.x0s[j])
else:
raise ValueError('no recompute type selected')
#save current MM and NN
FLI_mem_new = get_FLI_mem(mcc.FLIs[j])
#check_not_merged(mcc.x0s[j].log10_fgw,mcc.x0s[j].log10_mc,FLIs[j].max_toa)
#w0 = np.pi * 10.0**mcc.x0s[j].log10_fgw
#mc = 10.0**mcc.x0s[j].log10_mc# * const.Tsun
#check the maximum toa is not such that the source has already merged, and if so automatically reject the proposal
if fail_point:
log_acc_ratio = -np.inf
log_acc_decide = 1.
log_L_choose = -np.inf
chosen_trial = -1
print("Rejected due to error in point")
mcc.x0s[j].update_params(samples_current)
safe_reset_swap(mcc.FLIs[j],mcc.x0s[j],samples_current,FLI_mem_save)
mcc.x0s[j].validate_consistent(samples_current)
mcc.FLIs[j].validate_consistent(mcc.x0s[j])
elif merged_point:#check_merged(mcc.x0s[j].log10_fgw,mcc.x0s[j].log10_mc,mcc.FLIs[j].max_toa):
#TODO should do this check before updating choleskys
#set these so that the step is rejected
#acc_ratio = -1
#acc_decide = 0.
log_acc_ratio = -np.inf
log_acc_decide = 1.
log_L_choose = -np.inf
chosen_trial = -1
print("Rejected due to too fast evolution.")
#mcc.x0s[j].update_params(samples_current)
#safe_reset_swap(mcc.FLIs[j],mcc.x0s[j],samples_current,FLI_mem_save)
#for ii in range(Npsr):
# mcc.FLI_swap.chol_Sigmas[ii][:] = mcc.FLIs[j].chol_Sigmas[ii]
mcc.x0s[j].validate_consistent(samples_current)
mcc.FLIs[j].validate_consistent(mcc.x0s[j])
else:
log_acc_ratio,chosen_trial,sample_choose,log_L_choose = do_mt_step(mcc,j,itrb,new_point,samples_current,FLI_mem_save,recompute_rn or recompute_gwb,log_proposal_ratio)
if np.isfinite(log_acc_ratio):
log_acc_decide = np.log(uniform(1.e-304, 1.0))
else:
log_acc_decide = 1.
#if j==0 and which_jump_type==0 and which_jump==1:
# print("RNEmpDist--log_acc_ratio="+str(log_acc_ratio))
# print("RNEmpDist--log_L_current="+str(mcc.log_likelihood[j,itrb]))
# print("RNEmpDist--log_L_choose="+str(log_L_choose))
# print("RNEmpDist--log_L_choose_from_pta="+str(mcc.pta.get_lnlikelihood(sample_choose)))
# print(samples_current)
# print(sample_choose)
# print(sample_choose-samples_current)
if log_acc_decide<=log_acc_ratio:
#if j==0 and which_jump_type==0 and which_jump==1: print("Accepted")
#accepted
mcc.x0s[j].update_params(sample_choose)
mcc.samples[j,itrb+1,:] = sample_choose
if recompute_rn or recompute_gwb:
#swap the temporary FLI for the old one
FLI_temp = mcc.FLIs[j]
mcc.FLIs[j] = mcc.FLI_swap
mcc.FLI_swap = FLI_temp
mcc.FLIs[j].validate_consistent(mcc.x0s[j])
mcc.x0s[j].validate_consistent(sample_choose)
#print('3',mcc.FLIs[j].get_lnlikelihood(mcc.x0s[j]))
else:
#since we reverted to old ones for calculating the reference point likelihoods, revert that
safe_reset_swap(mcc.FLIs[j],mcc.x0s[j],sample_choose,FLI_mem_new)
mcc.FLIs[j].validate_consistent(mcc.x0s[j])
mcc.x0s[j].validate_consistent(sample_choose)
#print('4',mcc.FLIs[j].get_lnlikelihood(mcc.x0s[j]))
mcc.log_likelihood[j,itrb+1] = log_L_choose
if chosen_trial==0:
mcc.a_yes[6*which_jump+2*which_jump_type,j] += 1
else:
mcc.a_no[6*which_jump+2*which_jump_type,j] += 1
mcc.a_yes[6*which_jump+2*which_jump_type+1,j] += 1
#print('1',mcc.FLIs[j].get_lnlikelihood(mcc.x0s[j]))
else:
#if j==0 and which_jump_type==0 and which_jump==1: print("Rejected")
#rejected
mcc.samples[j,itrb+1,:] = samples_current
mcc.log_likelihood[j,itrb+1] = mcc.log_likelihood[j,itrb]
#Add to both elements of a_no, so we can get acceptance over total jumps w/ and w/o projection perturbation
if chosen_trial==0 and np.isfinite(log_acc_ratio):
mcc.a_yes[6*which_jump+2*which_jump_type,j] += 1
else:
mcc.a_no[6*which_jump+2*which_jump_type,j] += 1
mcc.a_no[6*which_jump+2*which_jump_type+1,j] += 1
mcc.x0s[j].update_params(samples_current)
#print('2',mcc.FLIs[j].get_lnlikelihood(mcc.x0s[j]))
if not recompute_rn and not recompute_gwb:
#don't needs to do anything if which_jump==1 because we didn't update FLIs[j] at all,
#and FLI_swap will just be completely overwritten next time it is used
#revert the changes to FastLs
safe_reset_swap(mcc.FLIs[j],mcc.x0s[j],samples_current,FLI_mem_save)
else:
#revert swap to guaranteed self consistent state
safe_reset_swap(mcc.FLI_swap,mcc.x0s[j],samples_current,FLI_mem_save)
for ii in range(Npsr):
mcc.FLI_swap.chol_Sigmas[ii][:] = mcc.FLIs[j].chol_Sigmas[ii]
#print('2',mcc.FLIs[j].get_lnlikelihood(mcc.x0s[j]))
#print(which_jump)
#print(mcc.FLIs[j].get_lnlikelihood(mcc.x0s[j]),mcc.log_likelihood[j,itrb+1])
mcc.FLIs[j].validate_consistent(mcc.x0s[j])
mcc.x0s[j].validate_consistent(mcc.samples[j,itrb+1,:])
assert mcc.FLIs[j].get_lnlikelihood(mcc.x0s[j]) == mcc.log_likelihood[j,itrb+1]
if not recompute_dist:
assert np.all(mcc.samples[j,itrb,mcc.x0_swap.idx_dists]==mcc.samples[j,itrb+1,mcc.x0_swap.idx_dists])
if not recompute_gwb:
assert np.all(mcc.samples[j,itrb,mcc.x0_swap.idx_gwb]==mcc.samples[j,itrb+1,mcc.x0_swap.idx_gwb])
if not recompute_int:
assert np.all(mcc.samples[j,itrb,mcc.x0_swap.idx_cw_int[:4]]==mcc.samples[j,itrb+1,mcc.x0_swap.idx_cw_int[:4]])
if not recompute_rn:
assert np.all(mcc.samples[j,itrb,mcc.x0_swap.idx_rn]==mcc.samples[j,itrb+1,mcc.x0_swap.idx_rn])
if fail_point or merged_point or log_acc_decide>log_acc_ratio:
#check nothing changed if the point failed
assert np.all(mcc.samples[j,itrb,:]==mcc.samples[j,itrb+1,:])
assert mcc.log_likelihood[j,itrb+1]==mcc.log_likelihood[j,itrb]
if mask is not None:
if np.any(mask):
#no updating gwb or common unless everything was updated
assert np.all(mcc.samples[j,itrb,mcc.x0_swap.idx_cw_int[:4]]==mcc.samples[j,itrb+1,mcc.x0_swap.idx_cw_int[:4]])
assert np.all(mcc.samples[j,itrb,mcc.x0_swap.idx_gwb]==mcc.samples[j,itrb+1,mcc.x0_swap.idx_gwb])
#no distance updates for parameters that were masked
assert np.all(mcc.samples[j,itrb,mcc.x0_swap.idx_dists[mask]]==mcc.samples[j,itrb+1,mcc.x0_swap.idx_dists[mask]])
#no red noise updates for parameters that were masked
assert np.all(mcc.samples[j,itrb,mcc.x0_swap.idx_rn_log10_As[mask]]==mcc.samples[j,itrb+1,mcc.x0_swap.idx_rn_log10_As[mask]])
assert np.all(mcc.samples[j,itrb,mcc.x0_swap.idx_rn_gammas[mask]]==mcc.samples[j,itrb+1,mcc.x0_swap.idx_rn_gammas[mask]])
#no red noise updates for parameters that were not masked
#print(j,Ts[j],which_jump,which_jump_type,recompute_dist,recompute_gwb,recompute_int,recompute_rn,log_acc_decide,log_acc_ratio,log_acc_decide<log_acc_ratio,log_L_choose,chosen_trial)
#print(idx_choose)
#print(mask)
#if log_acc_decide<=log_acc_ratio:
# print(mcc.FLI_swap.logdet)
# mcc.x0_swap.update_params(new_point)
# mcc.flm.recompute_FastLike(mcc.FLI_swap,mcc.x0_swap,dict(zip(mcc.par_names, new_point)))
# print(mcc.FLI_swap.logdet,mcc.FLIs[j].logdet)
# assert mcc.FLI_swap.logdet==mcc.FLIs[j].logdet
#else:
# print(mcc.FLI_swap.logdet)
# mcc.x0_swap.update_params(samples_current)
# mcc.flm.recompute_FastLike(mcc.FLI_swap,mcc.x0_swap,dict(zip(mcc.par_names, samples_current)))
# print(mcc.FLI_swap.logdet,mcc.FLIs[j].logdet)
# assert mcc.FLI_swap.logdet==mcc.FLIs[j].logdet
if fail_point:
#something went wrong so do extra test of self consistency
mcc.validate_consistent(itrb+1)
return mcc.FLI_swap
[docs]def do_mt_step(mcc,j,itrb,new_point,samples_current,FLI_mem_save,recompute_rn,log_proposal_ratio):
"""compute the multiple tries and chose a sample
:param mcc: MCMCChain onject
:param j: Index of PT chain
:param itrb: Index within saved values (as opposed to block index itri or overall index itrn)
:param new_point: Proposed new point (with new shape parameters)
:param samples_current: Current point in parameter space
:param FLI_mem_save: Parts of FLI object saved to memory
:param recompute_rn: If True, recompute everything needed to go to new RN parameters
:param log_proposal_ratio: Log of the proposal ratio needed to calculate acceptance probability
:return log_acc_ratio: Log of acceptance probability
:return chosen_trial: Index of chosen trial
:return sample_choose: Parameters of the chosen trial
:return log_Ls[chosen_trial]: Log likelihood of the chosen trial
"""
Ts = mcc.chain_params.Ts
log_prior_old = CWFastPrior.get_lnprior(samples_current, mcc.FPI)
log_posterior_old = mcc.log_likelihood[j,itrb]/Ts[j] + log_prior_old
assert np.isfinite(log_posterior_old)
#do multiple try MCMC step with random draws of projection parameters
#more parameters will be uniform at higher temperatures
fisher_mask = np.sqrt(Ts[j])*mcc.fisher_diag[j][mcc.x0s[0].idx_cw_ext]<0.5
#don't propose fisher jumps at all above some specified temperature
fisher_norm = 1.
if Ts[j]>cm.proj_prior_all_temp:
fisher_mask[:] = False
elif fisher_mask.sum()>0:
fisher_norm = 2.38/np.sqrt(fisher_mask.sum())
random_normals = np.random.normal(0.,fisher_norm,(cm.n_multi_try,fisher_mask.sum()))
jumps = random_normals*np.sqrt(Ts[j])*mcc.fisher_diag[j][mcc.x0s[0].idx_cw_ext][fisher_mask]
random_draws_from_prior = np.random.uniform(mcc.FPI.cw_ext_lows[~fisher_mask],mcc.FPI.cw_ext_highs[~fisher_mask],(cm.n_multi_try,(~fisher_mask).sum()))
#itrd = 0
#for ii,ll in enumerate(mcc.x0s[j].idx_cw_ext):
# fisher_loc = mcc.fisher_diag[j][ll]
# if fisher_loc<0.5: #not maxed out fisher --> do fisher update
# jumps_old[:,ii] = jumps[:,itrj]#np.random.normal(0.,fisher_loc,cm.n_multi_try)
# random_normals_old[:,ii] = random_normals[:,itrj]#jumps[:,itrj]/mcc.fisher_diag[j][mcc.x0s[0].idx_cw_ext][fisher_mask][itrj]
# itrj += 1
# else:
# random_draws_from_prior_old[:,ii] = random_draws_from_prior[:,itrd]#np.random.uniform(FPI.cw_ext_lows[ii],FPI.cw_ext_highs[ii],cm.n_multi_try)
# itrd += 1
#make sure the jumps are null for the initial sample
jumps[0,:] = 0.
random_draws_from_prior[0,:] = new_point[mcc.x0_swap.idx_cw_ext][~fisher_mask]
tries = set_params(new_point,jumps,fisher_mask,random_draws_from_prior,mcc.x0_swap)
tries[0] = new_point # just to make sure it didn't get reset
log_prior_news = CWFastPrior.get_lnprior_array(tries, mcc.FPI)
if recompute_rn:
FLI_use = mcc.FLI_swap
else:
FLI_use = mcc.FLIs[j]
mt_weights, log_Ls, log_mt_norm_shift = get_mt_weights(mcc.x0_extras, FLI_use, Ts[j],log_posterior_old,tries,log_prior_news)
#if j==0: print(mt_weights)
#not sure why but still can get nans here...
assert np.all(np.isfinite(mt_weights))
if np.sum(mt_weights)==0.0:
log_acc_ratio = -np.inf
chosen_trial = -1
sample_choose = new_point.copy()
else:
chosen_trial = np.random.choice(cm.n_multi_try, p=mt_weights/np.sum(mt_weights))
if not recompute_rn: # need to set back FLIs to old state to calculate likelihoods at reference points
safe_reset_swap(mcc.FLIs[j],mcc.x0s[j],samples_current,FLI_mem_save)
else:
mcc.x0s[j].update_params(samples_current)
mcc.FLIs[j].validate_consistent(mcc.x0s[j])
mcc.x0s[j].validate_consistent(samples_current)
sample_ref = samples_current.copy()
sample_ref[mcc.x0s[j].idx_cw_ext] = tries[chosen_trial,mcc.x0s[j].idx_cw_ext]
ref_tries = set_params(sample_ref,jumps,fisher_mask,random_draws_from_prior,mcc.x0_swap)
ref_tries[0] = sample_ref # fix if it got reset
log_prior_refs = CWFastPrior.get_lnprior_array(ref_tries, mcc.FPI)
ref_mt_weights,log_ref_mt_norm_shift = get_ref_mt_weights(mcc.x0_extras, mcc.FLIs[j], Ts[j],log_posterior_old,chosen_trial,ref_tries,log_prior_refs)
#must undo the normalization shifts; they aren't needed in log space anyway
log_acc_ratio = np.log(np.sum(mt_weights))-np.log(np.sum(ref_mt_weights))+log_mt_norm_shift-log_ref_mt_norm_shift+log_proposal_ratio
sample_choose = tries[chosen_trial].copy()
# if chosen_trial==0:
# print("selected trial was null?")
# print(sample_ref)
# print(sample_choose)
# print(samples_current)
# print(tries[1])
# print(new_point)
# print(log_ref_mt_norm_shift,log_mt_norm_shift)
# print(ref_mt_weights)
# print(mt_weights)
# print(log_Ls)
# print(log_prior_refs)
return log_acc_ratio,chosen_trial,sample_choose,log_Ls[chosen_trial]
[docs]@njit(parallel=True)
def get_mt_weights(x0_extras, FLI_use, Ts, log_posterior_old,tries,log_prior_news):
"""Helper function to quickly return multiple tries and their likelihoods fo MTMCMC
:param x0_extras: List of extra CWInfo objects for parallelizing multiple try
:param FLI_use: FastLikeInfo object
:param Ts: List of PT temperatures
:param log_posterior_old: Log posterior at old parameters
:param tries: Parameters at a set of multiple tries for which we want to calculate the weights
:param log_prior_news: Log prior values at propose new points
:return mt_weights: Multiple try weights
:return log_Ls: Log likelihoods
:return log_mt_norm_shift: Amount to shift the multiple try weights (helps with using floating point precision efficiently)
"""
#NOTE isfinite does not work with fastmath enabled
#set up needed arrays
log_mt_weights = np.zeros(cm.n_multi_try)
log_Ls = np.zeros(cm.n_multi_try)
#get mt_weights --------------------------------------------------------------------------------------------------------
for KK in prange(cm.n_x0_extra):
for kk in range(cm.n_block_try):
itrkk = KK*cm.n_block_try+kk
x0_extras[KK].update_params(tries[itrkk,:])
#print(x0_extras[KK].cos_gwtheta)
log_L = FLI_use.get_lnlikelihood(x0_extras[KK])
log_posterior_new = log_L/Ts + log_prior_news[itrkk]
if np.isfinite(log_posterior_new):
log_mt_weights[itrkk] = log_posterior_new - log_posterior_old
else:
log_mt_weights[itrkk] = -np.inf
log_Ls[itrkk] = log_L
#can apply the same multiplier to shift all the weights, prevents over/underflows in the exponential from breaking the code
log_mt_norm_shift = np.max(log_mt_weights)
log_mt_weights -= log_mt_norm_shift
mt_weights = np.zeros(log_mt_weights.shape)
#get weights while preventing underflow (values which are <1.e-304 times as likely to be chosen as the most likely value are totally irrelevant)
mt_weights[log_mt_weights>-700] = np.exp(log_mt_weights[log_mt_weights>-700])
return mt_weights, log_Ls, log_mt_norm_shift
[docs]@njit()
def add_rn_eig_jump(scale_eig0,scale_eig1,new_point,rn_base,idx_rn,Npsr,all_eigs=False):
"""add a fisher eigenvalue jump to the red noise parameters in place
:param scale_eig0: Amount to scale jump in gamma values by
:param scale_eig1: Amount to scale in log10_A values by
:param new_point: Parameter values to add RN jump to
:param rn_base: RN values to jump from (usually justa slice of new_point)
:param idx_rn: Indices of new_point containing RN parameters
:param Npsr: Number of pulsars
:param all_eigs: If True, perturb all pulsars' RN, if False, pick randomly [False]
:return new_point: Perturbed parameter values
"""
which_eig = np.random.choice(2, size=Npsr)
jump_sizes = np.random.normal(0., 1.,Npsr)
jump = np.zeros(2*Npsr)
for ll in range(Npsr):
if all_eigs or which_eig[ll] == 0:
jump[ll] += scale_eig0[ll,0]*jump_sizes[ll]
jump[ll+Npsr] += scale_eig0[ll,1]*jump_sizes[ll]
if all_eigs or which_eig[ll] == 1:
jump[ll] += scale_eig1[ll,0]*jump_sizes[ll]
jump[ll+Npsr] += scale_eig1[ll,1]*jump_sizes[ll]
new_point[idx_rn] = rn_base + jump
return new_point
[docs]@njit()
def set_params(sample_set,jumps,fisher_mask,random_draws_from_prior,x0):
"""assign parameters to tries for multiple try mcmc
:param sample_set: Samples to start from
:param jumps: Precaluclated fisher jumps to use
:param fisher_mask: Mask determining which projection parameters to do fisher jump vs prior draw in
:param random_draws_from_prior: Precalculated prior draws to use
:param x0: CWInfo object
:return ref_tries: 2D array holding samples at multiple trials
"""
ref_tries = np.zeros((cm.n_multi_try, sample_set.size))
#jumps and random_draws_from_prior should give a null jump for the 0th value
#copy in intrinsic parameters
#ref_tries[:] = sample_set
#ref_tries[:,x0.idx_cw_int] = sample_set[x0.idx_cw_int]
#ref_tries[:,x0.idx_rn] = sample_set[x0.idx_rn]
#ref_tries[:,x0.idx_gwb] = sample_set[x0.idx_gwb]
ref_tries[:,x0.idx_int] = sample_set[x0.idx_int]
ref_tries[:,x0.idx_cw_ext[fisher_mask]] = sample_set[x0.idx_cw_ext[fisher_mask]]+jumps
ref_tries[:,x0.idx_cw_ext[~fisher_mask]] = random_draws_from_prior
ref_tries = correct_extrinsic_array(ref_tries,x0)
return ref_tries
[docs]@njit(parallel=True)
def get_ref_mt_weights(x0_extras, FLI_use, Ts, log_posterior_old, chosen_trial,ref_tries,log_prior_refs):
"""Helper function to quickly return multiple tries and their likelihoods fo MTMCMC
:param x0_extras: List of extra CWInfo objects for parallelizing multiple try
:param FLI_use: FastLikeInfo object
:param Ts: List of PT temperatures
:param log_posterior_old: Log posterior at old parameters
:param chosen_trial: Index of chosen trial
:param ref_tries: Parameters at a set of reference multiple tries for which we want to calculate the weights
:param log_prior_refs: Log prior values at reference points
:return ref_mt_weights: Reference point multiply try weights
:return log_ref_mt_norm_shift: Amount to shift the reference point multiple try weights (helps with using floating point precision efficiently)
"""
#NOTE isfinite does not work with fastmath enabled
#set up needed arrays
log_ref_mt_weights = np.zeros(cm.n_multi_try)
##get ref_mt_weights ----------------------------------------------------------------------------------------------------
for KK in prange(cm.n_x0_extra):
for kk in range(cm.n_block_try):
itrkk = KK*cm.n_block_try+kk
x0_extras[KK].update_params(ref_tries[itrkk,:])
log_L = FLI_use.get_lnlikelihood(x0_extras[KK])
log_posterior_ref = log_L/Ts + log_prior_refs[itrkk]
if np.isfinite(log_posterior_ref):
log_ref_mt_weights[itrkk] = log_posterior_ref - log_posterior_old
else:
log_ref_mt_weights[itrkk] = -np.inf
#can apply the same multiplier to shift all the weights, prevents over/underflows in the exponential from breaking the code
log_ref_mt_weights[chosen_trial] = 0. # np.log(1)=0. is the value it should be at the chosen trial pre-shift
log_ref_mt_norm_shift = np.max(log_ref_mt_weights)
log_ref_mt_weights -= log_ref_mt_norm_shift
ref_mt_weights = np.zeros(log_ref_mt_weights.shape)
#get weights while preventing underflow (values which are <1.e-304 times as likely to be chosen as the most likely value are totally irrelevant)
ref_mt_weights[log_ref_mt_weights>-700] = np.exp(log_ref_mt_weights[log_ref_mt_weights>-700])
return ref_mt_weights,log_ref_mt_norm_shift