"""
HST ACS Bullet Cluster: query + download.
Use Proposal 10863 (Bradac/Clowe HST ACS imaging) which contains F775W and F850LP.
"""
import os, sys, urllib3, requests, time
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
old_request = requests.Session.request
def _patched(self, *args, **kwargs):
    kwargs['verify'] = False
    return old_request(self, *args, **kwargs)
requests.Session.request = _patched

from astroquery.mast import Observations
try:
    Observations._portal_api_connection.session.verify = False
except Exception:
    pass

OUT_DIR = 'HST_ACS_BulletCluster'
os.makedirs(OUT_DIR, exist_ok=True)

# Known HST programs targeting Bullet Cluster
PROGRAMS = ['10863', '10200', '10666', '9722', '9290']  # Bradac, Clowe, others
print("--- HST ACS programs at Bullet Cluster ---")
all_obs = []
for prop in PROGRAMS:
    try:
        obs = Observations.query_criteria(
            obs_collection='HST',
            proposal_id=prop,
            target_name=['1E0657-558', '1E0657-56', 'BULLETCLUSTER', 'BULLET-CLUSTER']
        )
        print(f"  Program {prop}: {len(obs)} obs")
        if len(obs) > 0:
            print(f"    filters: {sorted(set(str(f) for f in obs['filters']))}")
            all_obs.append(obs)
    except Exception as e:
        print(f"  Program {prop}: query failed: {e}")

if not all_obs:
    print("\nFalling back to position search.")
    from astropy.coordinates import SkyCoord
    import astropy.units as u
    coord = SkyCoord(ra=104.66*u.deg, dec=-55.95*u.deg)
    obs_region = Observations.query_region(coord, radius=0.07*u.deg)
    mask = (obs_region['obs_collection'] == 'HST')
    all_obs = [obs_region[mask]]
    print(f"  Position-based HST obs: {len(all_obs[0])}")

from astropy.table import vstack
obs_combined = vstack(all_obs) if len(all_obs) > 1 else all_obs[0]
print(f"\nTotal HST obs: {len(obs_combined)}")

# Get products
print("\nGetting product list...")
products = Observations.get_product_list(obs_combined)
print(f"Total products: {len(products)}")

# Filter to DRZ (drizzled mosaic level-3 products)
drz = Observations.filter_products(
    products,
    productType=['SCIENCE'],
    productSubGroupDescription=['DRZ'],
    extension='fits'
)
print(f"DRZ science products: {len(drz)}")

if 'size' in drz.colnames:
    sizes = [s for s in drz['size'] if s is not None and s > 0]
    total = sum(sizes) / 1e9
    print(f"Total DRZ volume: {total:.2f} GB")

drz.write('hst_drz_products.csv', format='csv', overwrite=True)

# Download all DRZ
total_dl = 0
total_sk = 0
print(f"\nDownloading {len(drz)} DRZ files to {OUT_DIR}/ ...")
for row in drz:
    fn = row['productFilename']
    size = row['size'] if row['size'] else 0
    target = os.path.join(OUT_DIR, fn)
    if os.path.exists(target) and os.path.getsize(target) > 0.9 * size:
        total_sk += 1
        continue
    uri = row['dataURI']
    print(f"  -> {fn} ({size/1e6:.1f} MB)", end=" ", flush=True)
    t0 = time.time()
    try:
        result = Observations.download_file(uri, local_path=target)
        elapsed = time.time() - t0
        actual = os.path.getsize(target) if os.path.exists(target) else 0
        total_dl += 1
        print(f"[{elapsed:.1f}s, {actual/1e6:.1f} MB]")
    except Exception as e:
        print(f"FAILED: {type(e).__name__}")

print(f"\nDownloaded: {total_dl}, Skipped: {total_sk}")
