fig,ax = plt.subplots(figsize=(16,6), ncols=3, nrows=2, gridspec_kw={'height_ratios':(0.65,0.35)})
def commons(ax, i):
# rotation curve
ax[0,i].plot(x, jax_vhalo({'log_mh':lmh, 'log_c':lc}, x), '--', c='tab:pink', label='true mean')
ax[0,i].errorbar(x, vm, yerr=e_vm, fmt='.', c='C0', lw=0.5, label='data')
ax[0,i].legend(loc='lower right')
ax[0,i].set_ylabel('velocity');
# residuals
ax[1,i].axhline(y=0, ls='--', c='tab:pink')
ax[1,i].errorbar(x, vm-jax_vhalo({'log_mh':lmh, 'log_c':lc}, x), yerr=e_vm, fmt='.', c='C0', lw=0.5)
ax[1,i].set_xlabel('radius')
ax[1,i].set_ylabel('residuals')
ax[1,i].set_ylim(-60,60);
commons(ax,0)
commons(ax,1)
commons(ax,2)
ax[0,0].plot(r_grid, pred_wn[inds].T, 'k', alpha=0.1)
ax[0,1].plot(r_grid, pred[inds].T, 'C0', alpha=0.2)
ax[0,2].plot(r_grid, pred_cd[inds].T, 'C2', alpha=0.1)
ax[1,0].plot(r_grid, (pred_wn[inds]-jax_vhalo({'log_mh':lmh, 'log_c':lc}, r_grid)).T, 'k', alpha=0.1)
ax[1,1].plot(r_grid, (pred[inds]-jax_vhalo({'log_mh':lmh, 'log_c':lc}, r_grid)).T, 'C0', alpha=0.2)
ax[1,2].plot(r_grid, (pred_cd[inds]-jax_vhalo({'log_mh':lmh, 'log_c':lc}, r_grid)).T, 'C2', alpha=0.1);
ax[0,0].set_title("Model assuming independent data");
ax[0,1].set_title("Model with GP: correlated data");
ax[0,2].set_title("Model with GP "+r"$\rm\bf conditioned$"+" to data");