# services/audit_pdf.py
"""
Generates a PDF audit log for an automated sampling run.
Uses reportlab Platypus for clean multi-page layout.
"""
import io
import math
from datetime import datetime

from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.lib.units import mm
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_RIGHT
from reportlab.platypus import (
    SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
    HRFlowable, PageBreak, KeepTogether
)


# ── Colours ──────────────────────────────────────────────────────────────────
PRIMARY   = colors.HexColor("#111827")
ACCENT    = colors.HexColor("#2563eb")
LIGHT_BG  = colors.HexColor("#f6f7f9")
BORDER    = colors.HexColor("#e7e9ef")
MUTED     = colors.HexColor("#6b7280")
SUCCESS   = colors.HexColor("#16a34a")
WARNING   = colors.HexColor("#b45309")
WHITE     = colors.white


def _styles():
    base = getSampleStyleSheet()
    def S(name, **kw):
        return ParagraphStyle(name, **kw)

    return {
        "title": S("title",
            fontName="Helvetica-Bold", fontSize=22,
            textColor=PRIMARY, spaceAfter=4, leading=28),
        "subtitle": S("subtitle",
            fontName="Helvetica", fontSize=11,
            textColor=MUTED, spaceAfter=2),
        "section": S("section",
            fontName="Helvetica-Bold", fontSize=13,
            textColor=PRIMARY, spaceBefore=14, spaceAfter=6, leading=18),
        "body": S("body",
            fontName="Helvetica", fontSize=9,
            textColor=PRIMARY, leading=14, spaceAfter=2),
        "small": S("small",
            fontName="Helvetica", fontSize=8,
            textColor=MUTED, leading=12),
        "label": S("label",
            fontName="Helvetica-Bold", fontSize=8,
            textColor=MUTED, leading=11),
        "value": S("value",
            fontName="Helvetica-Bold", fontSize=9,
            textColor=PRIMARY, leading=13),
        "code": S("code",
            fontName="Courier", fontSize=8,
            textColor=ACCENT, leading=12),
        "center": S("center",
            fontName="Helvetica", fontSize=9,
            textColor=PRIMARY, alignment=TA_CENTER),
        "right": S("right",
            fontName="Helvetica", fontSize=9,
            textColor=PRIMARY, alignment=TA_RIGHT),
    }


def _kv_table(rows, styles):
    """Two-column key-value table."""
    data = []
    for k, v in rows:
        data.append([
            Paragraph(str(k), styles["label"]),
            Paragraph(str(v), styles["value"]),
        ])
    t = Table(data, colWidths=[55*mm, 110*mm])
    t.setStyle(TableStyle([
        ("BACKGROUND", (0,0), (-1,-1), LIGHT_BG),
        ("ROWBACKGROUNDS", (0,0), (-1,-1), [WHITE, LIGHT_BG]),
        ("GRID", (0,0), (-1,-1), 0.3, BORDER),
        ("LEFTPADDING", (0,0), (-1,-1), 8),
        ("RIGHTPADDING", (0,0), (-1,-1), 8),
        ("TOPPADDING", (0,0), (-1,-1), 5),
        ("BOTTOMPADDING", (0,0), (-1,-1), 5),
        ("VALIGN", (0,0), (-1,-1), "TOP"),
    ]))
    return t


def _data_table(headers, rows, styles, col_widths=None):
    """Generic data table with header row."""
    data = [[Paragraph(str(h), styles["label"]) for h in headers]]
    for row in rows:
        data.append([Paragraph(str(c), styles["body"]) for c in row])

    w = col_widths or [165*mm // len(headers)] * len(headers)
    t = Table(data, colWidths=w, repeatRows=1)
    t.setStyle(TableStyle([
        ("BACKGROUND", (0,0), (-1,0), PRIMARY),
        ("TEXTCOLOR", (0,0), (-1,0), WHITE),
        ("FONTNAME", (0,0), (-1,0), "Helvetica-Bold"),
        ("FONTSIZE", (0,0), (-1,0), 8),
        ("ROWBACKGROUNDS", (0,1), (-1,-1), [WHITE, LIGHT_BG]),
        ("GRID", (0,0), (-1,-1), 0.3, BORDER),
        ("LEFTPADDING", (0,0), (-1,-1), 6),
        ("RIGHTPADDING", (0,0), (-1,-1), 6),
        ("TOPPADDING", (0,0), (-1,-1), 4),
        ("BOTTOMPADDING", (0,0), (-1,-1), 4),
        ("VALIGN", (0,0), (-1,-1), "TOP"),
    ]))
    return t


def _header_footer(canvas, doc):
    canvas.saveState()
    w, h = A4

    # Header bar
    canvas.setFillColor(PRIMARY)
    canvas.rect(0, h - 18*mm, w, 18*mm, fill=1, stroke=0)
    canvas.setFillColor(WHITE)
    canvas.setFont("Helvetica-Bold", 10)
    canvas.drawString(15*mm, h - 11*mm, "Sampling Audit Log")
    canvas.setFont("Helvetica", 9)
    canvas.drawRightString(w - 15*mm, h - 11*mm,
        datetime.now().strftime("%Y-%m-%d %H:%M"))

    # Footer
    canvas.setFillColor(MUTED)
    canvas.setFont("Helvetica", 8)
    canvas.drawString(15*mm, 10*mm, "Generated by Sampling Tool — Confidential")
    canvas.drawRightString(w - 15*mm, 10*mm, f"Page {doc.page}")
    canvas.setStrokeColor(BORDER)
    canvas.line(15*mm, 14*mm, w - 15*mm, 14*mm)

    canvas.restoreState()


def generate_audit_pdf(
    run_id: str,
    timestamp: str,
    filename: str,
    cochran_params: dict,
    cochran_results: dict,
    method: str,
    method_params: dict,
    method_info: dict,
    sampled_df,
    uuid_col: str = None,
    run_by: str = "Unknown",
) -> io.BytesIO:
    """
    Generate a complete audit PDF and return as BytesIO.
    """
    buf = io.BytesIO()
    doc = SimpleDocTemplate(
        buf, pagesize=A4,
        leftMargin=15*mm, rightMargin=15*mm,
        topMargin=22*mm, bottomMargin=20*mm,
    )

    st = _styles()
    story = []

    # ── Cover / Header ───────────────────────────────────────────────────────
    story.append(Spacer(1, 6*mm))
    story.append(Paragraph("Sampling Audit Log", st["title"]))
    story.append(Paragraph(f"Run ID: {run_id}  ·  Generated: {timestamp}", st["subtitle"]))
    story.append(HRFlowable(width="100%", thickness=1.5, color=ACCENT, spaceAfter=10))

    # ── Run Summary ──────────────────────────────────────────────────────────
    story.append(Paragraph("1. Run Summary", st["section"]))
    story.append(_kv_table([
        ("Run ID",          run_id),
        ("Timestamp",       timestamp),
        ("Dataset File",    filename),
        ("Run By",          run_by),
        ("Sampling Method", _method_label(method)),
        ("Population (N)",  f"{cochran_results.get('N', 'N/A'):,}" if cochran_results.get('N') else "Not provided"),
        ("Final Sample (n)", f"{cochran_results.get('n_final', '—'):,}"),
    ], st))
    story.append(Spacer(1, 4*mm))

    # ── Cochran Parameters ───────────────────────────────────────────────────
    story.append(Paragraph("2. Cochran Formula Parameters & Results", st["section"]))

    z = cochran_params.get("z", 1.96)
    p = cochran_params.get("p", 0.5)
    e = cochran_params.get("e", 0.05)
    cl_map = {1.645: "90%", 1.96: "95%", 2.576: "99%"}
    cl = cl_map.get(z, "Custom")

    story.append(_kv_table([
        ("Z (Confidence Level)",   f"{z}  ({cl})"),
        ("p (Proportion)",         str(p)),
        ("e (Margin of Error)",    f"{e}  ({round(e*100, 1)}%)"),
        ("n0 (Base, raw)",         str(cochran_results.get("n0_raw", "—"))),
        ("n0 (Base, ceiling)",     str(cochran_results.get("n0_ceil", "—"))),
        ("N (Population)",         f"{cochran_results['N']:,}" if cochran_results.get("N") else "Not provided"),
        ("n (Corrected, raw)",     str(cochran_results.get("n_corrected_raw", "N/A"))),
        ("n (Final, ceiling)",     str(cochran_results.get("n_final", "—"))),
    ], st))

    # Formula note
    story.append(Spacer(1, 3*mm))
    story.append(Paragraph(
        "Formula: n0 = Z^2 * p * (1-p) / e^2"
        + ("  |  Finite correction: n = n0 / (1 + (n0-1)/N)" if cochran_results.get("N") else ""),
        st["small"]
    ))
    story.append(Spacer(1, 4*mm))

    # ── Sampling Parameters ──────────────────────────────────────────────────
    story.append(Paragraph("3. Sampling Parameters", st["section"]))
    story.append(_kv_table([
        ("Method", _method_label(method)),
        ("Population Rows", f"{method_info.get('rows_in_population', '—'):,}"),
        ("Rows Sampled", f"{method_info.get('rows_sampled', '—'):,}"),
        ("Random Seed", str(method_params.get("random_state", 42))),
    ], st))
    story.append(Spacer(1, 3*mm))

    # Method-specific params
    if method == "stratified":
        story.append(Paragraph("Stratification Groups:", st["body"]))
        strata = method_params.get("strata", [])
        for s in strata:
            col = s.get("column", "—")
            filters = s.get("filters", [])
            for f in filters:
                story.append(Paragraph(
                    f"  • {col} = {f.get('value')} → {f.get('pct')}% of sample",
                    st["body"]
                ))

    elif method == "cluster":
        story.append(_kv_table([
            ("Cluster Column",    method_params.get("cluster_column", "—")),
            ("Clusters Selected", str(len(method_info.get("clusters_selected", [])))),
            ("Cluster Names",     ", ".join(str(c) for c in method_info.get("clusters_selected", [])[:20])),
        ], st))

    elif method == "systematic":
        story.append(_kv_table([
            ("Interval k",    str(method_info.get("k", "—"))),
            ("Starting Point", str(method_info.get("start", "—"))),
        ], st))

    story.append(Spacer(1, 4*mm))

    # ── Breakdown Table ──────────────────────────────────────────────────────
    breakdown = method_info.get("breakdown")
    if breakdown:
        story.append(Paragraph("4. Sample Breakdown", st["section"]))

        if method == "stratified":
            headers = ["Group", "Group Population", "Allocated %", "Sampled"]
            rows = [
                [b["label"], f"{b['group_population']:,}", f"{b['allocated_pct']}%", str(b["sample_count"])]
                for b in breakdown
            ]
            story.append(_data_table(headers, rows, st,
                col_widths=[70*mm, 35*mm, 30*mm, 30*mm]))

        elif method == "cluster":
            headers = ["Cluster", "Cluster Population", "Sampled"]
            rows = [
                [b["cluster"], f"{b['cluster_population']:,}", str(b["sampled_from_cluster"])]
                for b in breakdown
            ]
            story.append(_data_table(headers, rows, st,
                col_widths=[80*mm, 50*mm, 35*mm]))

        story.append(Spacer(1, 4*mm))

    # ── Sampled Row IDs ──────────────────────────────────────────────────────
    story.append(PageBreak())
    story.append(Paragraph("5. Sampled Record IDs / UUIDs", st["section"]))

    if uuid_col and uuid_col in sampled_df.columns:
        uuids = sampled_df[uuid_col].astype(str).tolist()
        story.append(Paragraph(
            f"UUID Column: <b>{uuid_col}</b>  ·  Total records: <b>{len(uuids):,}</b>",
            st["body"]
        ))
        story.append(Spacer(1, 3*mm))

        # Render in 3 columns
        chunk = 3
        rows = []
        for i in range(0, len(uuids), chunk):
            row_uuids = uuids[i:i+chunk]
            while len(row_uuids) < chunk:
                row_uuids.append("")
            rows.append([Paragraph(u, st["code"]) for u in row_uuids])

        if rows:
            col_w = [55*mm, 55*mm, 55*mm]
            t = Table(rows, colWidths=col_w)
            t.setStyle(TableStyle([
                ("ROWBACKGROUNDS", (0,0), (-1,-1), [WHITE, LIGHT_BG]),
                ("GRID", (0,0), (-1,-1), 0.2, BORDER),
                ("LEFTPADDING", (0,0), (-1,-1), 5),
                ("RIGHTPADDING", (0,0), (-1,-1), 5),
                ("TOPPADDING", (0,0), (-1,-1), 3),
                ("BOTTOMPADDING", (0,0), (-1,-1), 3),
            ]))
            story.append(t)
    else:
        # Show row numbers as fallback
        story.append(Paragraph(
            f"No UUID column specified. Showing row indices of sampled records.",
            st["body"]
        ))
        story.append(Spacer(1, 3*mm))
        indices = [str(i) for i in sampled_df.index.tolist()]
        chunk = 6
        rows = []
        for i in range(0, len(indices), chunk):
            row_idx = indices[i:i+chunk]
            while len(row_idx) < chunk:
                row_idx.append("")
            rows.append([Paragraph(x, st["code"]) for x in row_idx])
        if rows:
            col_w = [27*mm] * 6
            t = Table(rows, colWidths=col_w)
            t.setStyle(TableStyle([
                ("ROWBACKGROUNDS", (0,0), (-1,-1), [WHITE, LIGHT_BG]),
                ("GRID", (0,0), (-1,-1), 0.2, BORDER),
                ("LEFTPADDING", (0,0), (-1,-1), 4),
                ("TOPPADDING", (0,0), (-1,-1), 3),
                ("BOTTOMPADDING", (0,0), (-1,-1), 3),
            ]))
            story.append(t)

    story.append(Spacer(1, 6*mm))
    story.append(HRFlowable(width="100%", thickness=0.5, color=BORDER))
    story.append(Spacer(1, 2*mm))
    story.append(Paragraph(
        f"End of audit log  ·  Run ID: {run_id}  ·  {timestamp}",
        st["small"]
    ))

    doc.build(story, onFirstPage=_header_footer, onLaterPages=_header_footer)
    buf.seek(0)
    return buf


def _method_label(method):
    return {
        "simple_random": "Simple Random Sampling",
        "stratified":    "Stratified Sampling",
        "cluster":       "Cluster Sampling",
        "systematic":    "Systematic (Interval) Sampling",
    }.get(method, method)
