"""
Robust NIRCam download with hard socket timeouts and per-file retry.
Resumable: skips files already on disk.
"""
import os, sys, urllib3, requests, time, socket, ssl
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# Hard socket timeout — prevents indefinite hangs
socket.setdefaulttimeout(120)  # 2 min max per network operation

# Patch requests sessions: disable SSL verify + set timeout
old_request = requests.Session.request
def _patched(self, *args, **kwargs):
    kwargs['verify'] = False
    if 'timeout' not in kwargs:
        kwargs['timeout'] = 120
    return old_request(self, *args, **kwargs)
requests.Session.request = _patched

from astroquery.mast import Observations
try:
    Observations._portal_api_connection.session.verify = False
except Exception:
    pass

OUT_DIR = 'JWST_GO4598_NIRCAM'
os.makedirs(OUT_DIR, exist_ok=True)

filters = ['F090W', 'F115W', 'F150W', 'F200W', 'F277W', 'F356W', 'F410M', 'F444W']

total_downloaded = 0
total_skipped = 0
total_failed = 0
total_bytes = 0

for filt in filters:
    print(f"\n=== Filter {filt} ===", flush=True)
    try:
        obs = Observations.query_criteria(
            proposal_id='4598',
            instrument_name='NIRCAM/IMAGE',
            filters=filt
        )
    except Exception as e:
        print(f"  Query failed: {e}", flush=True)
        continue

    if len(obs) == 0:
        print(f"  No observations for {filt}", flush=True)
        continue
    print(f"  {len(obs)} observations", flush=True)

    try:
        products = Observations.get_product_list(obs)
    except Exception as e:
        print(f"  Product list failed: {e}", flush=True)
        continue

    i2d = Observations.filter_products(
        products,
        productType=['SCIENCE'],
        productSubGroupDescription=['I2D'],
        extension='fits'
    )
    print(f"  {len(i2d)} I2D products", flush=True)

    for row in i2d:
        fn = row['productFilename']
        size = row['size'] if row['size'] else 0
        target = os.path.join(OUT_DIR, fn)
        if os.path.exists(target) and os.path.getsize(target) > 0.9 * size:
            total_skipped += 1
            continue
        uri = row['dataURI']

        # Up to 3 retries per file with timeout
        success = False
        for attempt in range(3):
            try:
                t0 = time.time()
                print(f"  -> {fn} ({size/1e6:.1f} MB) attempt {attempt+1}", flush=True)
                result = Observations.download_file(uri, local_path=target)
                elapsed = time.time() - t0
                actual = os.path.getsize(target) if os.path.exists(target) else 0
                if actual > 0.9 * size:
                    total_downloaded += 1
                    total_bytes += actual
                    print(f"     OK [{elapsed:.1f}s, {actual/1e6:.1f} MB]", flush=True)
                    success = True
                    break
                else:
                    print(f"     PARTIAL ({actual/1e6:.1f} of {size/1e6:.1f} MB) — retrying", flush=True)
                    if os.path.exists(target):
                        os.remove(target)
            except (socket.timeout, requests.exceptions.Timeout, requests.exceptions.ReadTimeout, ssl.SSLError) as e:
                print(f"     TIMEOUT after {time.time()-t0:.0f}s: {type(e).__name__}", flush=True)
                if os.path.exists(target):
                    os.remove(target)
            except Exception as e:
                print(f"     FAIL: {type(e).__name__}: {e}", flush=True)
                if os.path.exists(target):
                    os.remove(target)

        if not success:
            total_failed += 1
            print(f"     GAVE UP on {fn}", flush=True)

print(f"\n=== Summary ===", flush=True)
print(f"Downloaded: {total_downloaded}", flush=True)
print(f"Skipped:    {total_skipped}", flush=True)
print(f"Failed:     {total_failed}", flush=True)
print(f"Bytes new:  {total_bytes/1e9:.2f} GB", flush=True)
