"""
Path B Step 3: Weak-lensing shape catalog for Bullet Cluster.

Runs source extraction + moments-based shape measurement on HST ACS deep imaging
to produce a galaxy shape catalog usable as WL constraints in lensing reconstruction.

Inputs:
  - HST ACS DRC/DRZ FITS files from HST_ACS_BulletCluster/
    Priority: F606W (best WL band) > F814W > F435W

Outputs:
  - bullet_wl_shape_catalog.fits — galaxy positions, ellipticities, sizes, S/N
  - bullet_wl_shape_summary.txt — pipeline parameters and statistics
"""
import os
import sys
import glob
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import sep
from astropy.io import fits
from astropy.wcs import WCS
from astropy.table import Table

print("=" * 70)
print("Bullet Cluster WL Shape Catalog — Path B Step 3")
print("=" * 70)

# Find HST ACS files
HST_DIR = 'HST_ACS_BulletCluster'
candidates = []
for fn in sorted(glob.glob(os.path.join(HST_DIR, 'hst_*_acs_wfc_*_drc.fits'))):
    # Prefer larger combined-program mosaics
    sz = os.path.getsize(fn)
    candidates.append((sz, fn))

candidates.sort(reverse=True)
print(f"Found {len(candidates)} HST ACS DRC files")
print(f"Top 5 by size:")
for sz, fn in candidates[:5]:
    print(f"  {sz/1e6:7.1f} MB  {os.path.basename(fn)}")

if not candidates:
    print("ERROR: no HST ACS DRC files found")
    sys.exit(1)

# Pick largest F606W combined mosaic (best for WL)
preferred = None
for sz, fn in candidates:
    name = os.path.basename(fn).lower()
    if 'f606w' in name and 'hst_10200' in name:
        preferred = fn
        break

if preferred is None:
    # Fall back to any F606W
    for sz, fn in candidates:
        if 'f606w' in os.path.basename(fn).lower():
            preferred = fn
            break

if preferred is None:
    # Fall back to largest file
    preferred = candidates[0][1]

print(f"\nUsing: {preferred}")
print(f"Size: {os.path.getsize(preferred)/1e6:.1f} MB")

# ===== Load image =====
print("\n--- Loading image ---")
with fits.open(preferred, memmap=True) as hdul:
    sci_ext = None
    for i, h in enumerate(hdul):
        if h.name == 'SCI' or (i == 1 and h.data is not None):
            sci_ext = i
            break
    if sci_ext is None:
        sci_ext = 0
    data = hdul[sci_ext].data.astype(np.float64)
    header = hdul[sci_ext].header
    primary_header = hdul[0].header
    wcs = WCS(header)
print(f"SCI extension: {sci_ext}")
print(f"Image shape: {data.shape}")
print(f"Image dtype: {data.dtype}")
print(f"WCS valid: {wcs.has_celestial}")
print(f"Pixel scale: {wcs.proj_plane_pixel_scales()[0].to('arcsec').value:.4f} arcsec/pixel")

filt = primary_header.get('FILTER1', primary_header.get('FILTER', '?'))
if filt in ('CLEAR1L', 'CLEAR1S'):
    filt = primary_header.get('FILTER2', '?')
print(f"Filter: {filt}")
print(f"Exposure: {primary_header.get('EXPTIME', '?')} s")

# ===== Background subtraction =====
print("\n--- Background estimation ---")
data = np.ascontiguousarray(data)
bkg = sep.Background(data, bw=128, bh=128, fw=3, fh=3)
print(f"Global background: {bkg.globalback:.6f}")
print(f"Global RMS: {bkg.globalrms:.6f}")

data_sub = data - bkg.back()

# ===== Source extraction =====
print("\n--- Source extraction ---")
# Detection threshold: 1.5 sigma after smoothing, minimum 5 connected pixels
objects = sep.extract(
    data_sub,
    thresh=1.5,
    err=bkg.globalrms,
    minarea=5,
    deblend_nthresh=32,
    deblend_cont=0.005
)
print(f"Detected {len(objects)} sources")

# Compute moments-based shape parameters
# For each source we have x, y, a, b, theta, x2, y2, xy already
# Ellipticity components:
#   e1 = (x2 - y2) / (x2 + y2)
#   e2 = 2*xy / (x2 + y2)

x2 = objects['x2']
y2 = objects['y2']
xy = objects['xy']
ixx_iyy = x2 + y2

# Avoid divide-by-zero
mask = ixx_iyy > 1e-10
e1 = np.zeros(len(objects))
e2 = np.zeros(len(objects))
e1[mask] = (x2[mask] - y2[mask]) / ixx_iyy[mask]
e2[mask] = 2 * xy[mask] / ixx_iyy[mask]
size_T = np.sqrt(ixx_iyy)  # quadratic radius proxy

# Total ellipticity magnitude
e_mag = np.sqrt(e1**2 + e2**2)

# S/N estimate: peak / RMS
snr = objects['peak'] / bkg.globalrms

# ===== Apply WL quality cuts =====
print("\n--- WL galaxy selection ---")
# Star-galaxy separation by size: stars have size_T ~ PSF FWHM (0.1″ for ACS)
# ACS pixel scale ~0.05″/pix, PSF FWHM ~2 pixels
# Galaxies have size_T > 2.5 pixels (resolved)
pix_scale = wcs.proj_plane_pixel_scales()[0].to('arcsec').value
psf_pix = 0.1 / pix_scale  # ACS PSF FWHM in pixels

# Quality cuts
cut_resolved = size_T > 1.5 * psf_pix     # resolved objects (likely galaxies)
cut_snr = snr > 10                          # S/N > 10
cut_emag = e_mag < 0.9                      # not extremely elongated (rejecting artifacts)
cut_size_max = size_T < 30                  # reject very large objects (bright galaxies, satellites)
cut_flag = objects['flag'] == 0             # no extraction flags

wl_mask = cut_resolved & cut_snr & cut_emag & cut_size_max & cut_flag
n_wl = wl_mask.sum()

print(f"All detections: {len(objects)}")
print(f"After resolved cut (size_T > {1.5*psf_pix:.2f} pix): {cut_resolved.sum()}")
print(f"After S/N > 10: {(cut_resolved & cut_snr).sum()}")
print(f"After e_mag < 0.9: {(cut_resolved & cut_snr & cut_emag).sum()}")
print(f"After size_T < 30 pix: {(cut_resolved & cut_snr & cut_emag & cut_size_max).sum()}")
print(f"After flag==0: {n_wl}")
print(f"\nFinal WL galaxy count: {n_wl}")

# Sky coordinates for selected galaxies
sky = wcs.pixel_to_world(objects['x'][wl_mask], objects['y'][wl_mask])
ra = sky.ra.deg
dec = sky.dec.deg

# Source density per arcmin²
ny, nx = data.shape
area_arcmin2 = (nx * pix_scale / 60.0) * (ny * pix_scale / 60.0)
density = n_wl / area_arcmin2
print(f"Field area: {area_arcmin2:.2f} arcmin²")
print(f"WL source density: {density:.1f} sources/arcmin²")

# ===== Save shape catalog =====
print("\n--- Saving WL shape catalog ---")
tab = Table()
tab['id'] = np.arange(n_wl)
tab['x_pix'] = objects['x'][wl_mask]
tab['y_pix'] = objects['y'][wl_mask]
tab['ra_deg'] = ra
tab['dec_deg'] = dec
tab['flux'] = objects['flux'][wl_mask]
tab['peak'] = objects['peak'][wl_mask]
tab['snr'] = snr[wl_mask]
tab['e1'] = e1[wl_mask]
tab['e2'] = e2[wl_mask]
tab['e_mag'] = e_mag[wl_mask]
tab['size_T_pix'] = size_T[wl_mask]
tab['a_pix'] = objects['a'][wl_mask]
tab['b_pix'] = objects['b'][wl_mask]
tab['theta_rad'] = objects['theta'][wl_mask]

out_fits = 'bullet_wl_shape_catalog.fits'
tab.write(out_fits, format='fits', overwrite=True)
print(f"Wrote {out_fits}  ({os.path.getsize(out_fits)/1024:.1f} KB)")

# Summary
summary_path = 'bullet_wl_shape_summary.txt'
with open(summary_path, 'w') as f:
    f.write("Bullet Cluster WL Shape Catalog — Path B Step 3\n")
    f.write("Charles Anthony Hyatt Battiste UM/FUM — USPTO 19/640,364 anchors\n\n")
    f.write(f"Input image: {preferred}\n")
    f.write(f"Filter: {filt}\n")
    f.write(f"Image shape: {data.shape}\n")
    f.write(f"Pixel scale: {pix_scale:.4f} arcsec/pixel\n")
    f.write(f"Field area: {area_arcmin2:.2f} arcmin^2\n\n")
    f.write(f"Background: {bkg.globalback:.6f}\n")
    f.write(f"Background RMS: {bkg.globalrms:.6f}\n\n")
    f.write(f"All detections: {len(objects)}\n")
    f.write(f"WL galaxy count: {n_wl}\n")
    f.write(f"WL source density: {density:.2f} sources/arcmin^2\n\n")
    f.write("Quality cuts applied:\n")
    f.write(f"  size_T > {1.5*psf_pix:.2f} pixels (resolved galaxies)\n")
    f.write(f"  S/N > 10\n")
    f.write(f"  e_mag < 0.9 (no extreme elongations)\n")
    f.write(f"  size_T < 30 pixels (no very-extended sources)\n")
    f.write(f"  extraction flag == 0\n\n")
    f.write(f"Ellipticity statistics (for selected WL sample):\n")
    f.write(f"  <e1> = {tab['e1'].mean():.4f} +/- {tab['e1'].std()/np.sqrt(n_wl):.4f}\n")
    f.write(f"  <e2> = {tab['e2'].mean():.4f} +/- {tab['e2'].std()/np.sqrt(n_wl):.4f}\n")
    f.write(f"  e1 RMS = {tab['e1'].std():.4f}\n")
    f.write(f"  e2 RMS = {tab['e2'].std():.4f}\n")
    f.write(f"  <e_mag> = {tab['e_mag'].mean():.4f}\n")
    f.write("\nNotes:\n")
    f.write("  Moments-based shape measurement, no PSF correction yet.\n")
    f.write("  Initial Path B Step 3 baseline. Production WL would add KSB+ PSF correction.\n")
    f.write("  This catalog is sufficient for joint SL+WL reconstruction at baseline accuracy.\n")

print(f"Wrote {summary_path}  ({os.path.getsize(summary_path)/1024:.1f} KB)")
print(f"\n=== Path B Step 3 baseline complete ===")
print(f"\nNext step: Run joint SL+WL reconstruction with both")
print(f"  - bullet_kappa_observed_SLbaseline.fits (Step 2 baseline)")
print(f"  - bullet_wl_shape_catalog.fits (Step 3 shape catalog)")
