"""
Exhaustive HST search at Bullet Cluster field — all 5 Cha 2025 filters.
Position-based query to find programs beyond 10863 covering F435W, F606W, F814W.
"""
import os, sys, urllib3, requests, time
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
old_request = requests.Session.request
def _patched(self, *args, **kwargs):
    kwargs['verify'] = False
    return old_request(self, *args, **kwargs)
requests.Session.request = _patched

from astroquery.mast import Observations
try:
    Observations._portal_api_connection.session.verify = False
except Exception:
    pass

from astropy.coordinates import SkyCoord
import astropy.units as u

OUT_DIR = 'HST_ACS_BulletCluster'
os.makedirs(OUT_DIR, exist_ok=True)

# Position-based query
coord = SkyCoord(ra=104.66*u.deg, dec=-55.95*u.deg)
print("--- Position-based HST search at Bullet Cluster (0.15 deg radius) ---")
obs = Observations.query_region(coord, radius=0.15*u.deg)
print(f"Total obs at field: {len(obs)}")

hst_mask = (obs['obs_collection'] == 'HST')
hst = obs[hst_mask]
print(f"HST: {len(hst)}")

# Filter to ACS instrument
acs_mask = (hst['instrument_name'] == 'ACS/WFC') | (hst['instrument_name'] == 'ACS/HRC') | (hst['instrument_name'] == 'ACS/SBC')
acs = hst[acs_mask]
print(f"HST ACS: {len(acs)}")

# Group by filter
filters_to_find = ['F435W', 'F606W', 'F775W', 'F814W', 'F850LP']
matched_obs = []
for tf in filters_to_find:
    mask = [tf in str(f) for f in acs['filters']]
    sub = acs[mask]
    print(f"  Filter {tf}: {len(sub)} obs, programs: {sorted(set(str(p) for p in sub['proposal_id']))}")
    matched_obs.append(sub)

# Combine all matched obs
from astropy.table import vstack, unique
all_matched = vstack(matched_obs)
all_matched = unique(all_matched, keys='obs_id')
print(f"\nTotal unique HST/ACS matched obs: {len(all_matched)}")

# Get products for all
print("\nFetching products...")
all_matched.write('hst_all_filters_matched_obs.csv', format='csv', overwrite=True)
products = Observations.get_product_list(all_matched)
print(f"Total products: {len(products)}")

# Filter to DRZ science FITS
drz = Observations.filter_products(
    products,
    productType=['SCIENCE'],
    productSubGroupDescription=['DRZ', 'DRC'],
    extension='fits'
)
print(f"DRZ/DRC science products: {len(drz)}")

# De-duplicate by filename
seen = set()
to_download = []
for row in drz:
    fn = row['productFilename']
    if fn in seen:
        continue
    seen.add(fn)
    to_download.append(row)
print(f"Unique DRZ filenames: {len(to_download)}")

if 'size' in drz.colnames:
    sizes = [r['size'] for r in to_download if r['size'] is not None and r['size'] > 0]
    total = sum(sizes) / 1e9
    print(f"Total unique-file volume: {total:.2f} GB")

# Download
total_dl = 0
total_sk = 0
total_bytes = 0
print(f"\nDownloading to {OUT_DIR}/ ...")
for row in to_download:
    fn = row['productFilename']
    size = row['size'] if row['size'] else 0
    target = os.path.join(OUT_DIR, fn)
    if os.path.exists(target) and os.path.getsize(target) > 0.9 * size:
        total_sk += 1
        continue
    uri = row['dataURI']
    print(f"  -> {fn} ({size/1e6:.1f} MB)", end=" ", flush=True)
    t0 = time.time()
    try:
        result = Observations.download_file(uri, local_path=target)
        elapsed = time.time() - t0
        actual = os.path.getsize(target) if os.path.exists(target) else 0
        total_dl += 1
        total_bytes += actual
        print(f"[{elapsed:.1f}s, {actual/1e6:.1f} MB]")
    except Exception as e:
        print(f"FAILED: {type(e).__name__}: {e}")

print(f"\nDownloaded: {total_dl} ({total_bytes/1e9:.2f} GB), Skipped: {total_sk}")
