#!/usr/bin/env python3
"""
Temporal DAG Generator — Reproducibility Artifact
===================================================
Generates a synthetic sparse DAG with power-law edge degree distribution
and temporal-bias attachment. Computes cache-line utilization metrics
identical to those reported in:

  "Organic Cache Locality: Implicit Memory Coalescing in Dynamically
   Grown Sparse Neural Networks" — MecanoAI, 2026.

The generator models the memory layout of a real neural network training
engine: each sub-DAG is grown independently, then flattened into a single
contiguous activation array for GPU dispatch. New nodes are appended
chronologically within each sub-DAG. Edges preferentially connect to
recently created nodes ("temporal-bias attachment"), causing source IDs
to cluster numerically in the flattened array — producing near-perfect
cache-line coalescing without any explicit graph partitioning.

Usage:
    python temporal_dag_generator.py [--nodes N] [--seed S] [--no-plot]

Requirements: numpy (plotting requires matplotlib)
License: MIT
"""

import argparse
import sys
from collections import defaultdict

import numpy as np


# ---------------------------------------------------------------------------
# Hardware constants
# ---------------------------------------------------------------------------

CACHE_LINE_BYTES = 128                              # Apple Silicon / NVIDIA
FLOAT_BYTES = 4                                     # sizeof(float)
FLOATS_PER_CACHE_LINE = CACHE_LINE_BYTES // FLOAT_BYTES  # 32
SIMD_WIDTH = 32                                     # Apple Silicon SIMD width


# ---------------------------------------------------------------------------
# 1. Sub-DAG Builder
# ---------------------------------------------------------------------------

def build_subdags(
    n_hidden_total: int = 4000,
    n_subdags: int = 98,
    ctx_positions: int = 2,
    temporal_window: int = 80,
    temporal_prob: float = 0.90,
    power_law_alpha: float = 1.8,
    min_edges: int = 2,
    max_edges: int = 200,
    seed: int = 42,
) -> list[dict]:
    """
    Grow sub-DAGs independently, each with its own local node IDs.

    Each sub-DAG starts with:
      - ctx_positions input nodes (IDs 0..ctx-1)
      - 1 output node (ID ctx)
      - Dense input->output edges

    Hidden nodes are appended with sequential IDs (ctx+1, ctx+2, ...).
    Each new node wires to existing nodes using temporal-bias attachment.

    Returns a list of sub-DAG dicts, each with:
      - edges: list of (src_local, tgt_local) tuples
      - n_nodes: total nodes in this sub-DAG
    """
    rng = np.random.default_rng(seed)
    subdags = []
    hidden_per_sd = [0] * n_subdags

    # Distribute hidden nodes across sub-DAGs
    for i in range(n_hidden_total):
        hidden_per_sd[rng.integers(n_subdags)] += 1

    for sd in range(n_subdags):
        n_input = ctx_positions
        output_id = n_input  # output node is right after inputs
        edges = []

        # Initial input->output edges
        for inp in range(n_input):
            edges.append((inp, output_id))

        next_id = output_id + 1
        all_ids = list(range(next_id))  # [inp0, inp1, ..., output]

        for h in range(hidden_per_sd[sd]):
            new_id = next_id
            next_id += 1

            fan_in = _sample_power_law(rng, power_law_alpha, min_edges, max_edges)
            fan_in = min(fan_in, len(all_ids))

            sources = _select_sources_temporal(
                rng, all_ids, fan_in, temporal_window, temporal_prob
            )
            for s in sources:
                edges.append((s, new_id))

            all_ids.append(new_id)

        subdags.append({
            "edges": edges,
            "n_nodes": next_id,
            "n_input": n_input,
            "n_hidden": hidden_per_sd[sd],
        })

    return subdags


def _sample_power_law(rng, alpha, lo, hi):
    k = np.arange(lo, hi + 1, dtype=np.float64)
    pmf = k ** (-alpha)
    pmf /= pmf.sum()
    return int(rng.choice(k, p=pmf))


def _select_sources_temporal(rng, existing, fan_in, window, prob):
    n = len(existing)
    recent_start = max(0, n - window)
    recent = existing[recent_start:]
    older = existing[:recent_start] if recent_start > 0 else []

    selected = set()
    attempts = 0
    while len(selected) < fan_in and attempts < fan_in * 4:
        attempts += 1
        if rng.random() < prob and recent:
            selected.add(recent[rng.integers(len(recent))])
        elif older:
            selected.add(older[rng.integers(len(older))])
        elif recent:
            selected.add(recent[rng.integers(len(recent))])
        else:
            break
    return list(selected)


# ---------------------------------------------------------------------------
# 2. Flatten into Global Activation Array (mimics SubDAGFlat)
# ---------------------------------------------------------------------------

def flatten_subdags(subdags: list[dict]) -> dict:
    """
    Flatten all sub-DAGs into a single contiguous activation array.

    Layout: [subdag_0 nodes | subdag_1 nodes | ... | subdag_N nodes]

    Within each sub-DAG, nodes keep their chronological order:
      [input_0, input_1, output, hidden_0, hidden_1, ...]

    This is the actual memory layout used during GPU dispatch.
    The key insight: because each sub-DAG's nodes are contiguous in the
    flat array, and edges connect to nearby local IDs (due to temporal
    bias), the resulting global source IDs are numerically clustered.
    """
    global_src = []
    global_tgt = []
    offset = 0

    for sd in subdags:
        for (s, t) in sd["edges"]:
            global_src.append(s + offset)
            global_tgt.append(t + offset)
        offset += sd["n_nodes"]

    return {
        "src": np.array(global_src, dtype=np.int32),
        "tgt": np.array(global_tgt, dtype=np.int32),
        "n_edges": len(global_src),
        "n_nodes": offset,
    }


# ---------------------------------------------------------------------------
# 3. Cache-Line Analysis
# ---------------------------------------------------------------------------

def analyze_cache(flat: dict) -> dict:
    """
    Simulate Vector CSR forward dispatch and measure cache-line
    utilization. Each target node's fan-in is processed by a
    32-thread SIMD group reading activations[src_id].

    Metrics:
      - cache_lines_per_32: avg cache lines touched per 32-edge warp
      - utilization_pct: useful bytes / fetched bytes
    """
    src = flat["src"]
    tgt = flat["tgt"]

    tgt_to_srcs = defaultdict(list)
    for i in range(len(src)):
        tgt_to_srcs[tgt[i]].append(src[i])

    total_useful = 0
    total_fetched = 0
    total_warps = 0
    total_lines = 0

    for target, sources in tgt_to_srcs.items():
        if not sources:
            continue

        srcs_sorted = sorted(sources)
        for ws in range(0, len(srcs_sorted), SIMD_WIDTH):
            warp = srcs_sorted[ws:ws + SIMD_WIDTH]
            lines = len({s // FLOATS_PER_CACHE_LINE for s in warp})

            total_lines += lines
            total_useful += len(warp) * FLOAT_BYTES
            total_fetched += lines * CACHE_LINE_BYTES
            total_warps += 1

    lines_per_32 = total_lines / total_warps if total_warps else 0
    util = total_useful / total_fetched * 100 if total_fetched else 0
    ampl = total_fetched / total_useful if total_useful else 0

    return {
        "edges": flat["n_edges"],
        "nodes": flat["n_nodes"],
        "cache_lines_per_32": lines_per_32,
        "utilization_pct": util,
        "useful_kb": total_useful / 1024,
        "fetched_kb": total_fetched / 1024,
        "amplification": ampl,
    }


# ---------------------------------------------------------------------------
# 4. Growth Trajectory (multiple snapshots)
# ---------------------------------------------------------------------------

def growth_trajectory(
    max_hidden: int = 4000,
    n_snapshots: int = 12,
    temporal_prob: float = 0.90,
    seed: int = 42,
    **kw,
) -> list[dict]:
    """Run multiple growth sizes and collect cache metrics at each."""
    results = []
    sizes = np.linspace(100, max_hidden, n_snapshots, dtype=int)
    for h in sizes:
        sds = build_subdags(
            n_hidden_total=int(h), temporal_prob=temporal_prob, seed=seed, **kw
        )
        flat = flatten_subdags(sds)
        r = analyze_cache(flat)
        r["hidden"] = int(h)
        results.append(r)
    return results


# ---------------------------------------------------------------------------
# 5. Main
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(
        description="Temporal DAG Generator — Cache Locality Analysis"
    )
    parser.add_argument("--nodes", type=int, default=4000,
                        help="Max hidden nodes to grow (default: 4000)")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--no-plot", action="store_true")
    args = parser.parse_args()

    print("=" * 74)
    print("  Temporal DAG Generator — Cache Locality Analysis")
    print("  MecanoAI Inc. — Reproducibility Artifact")
    print("=" * 74)

    # Temporal growth trajectory
    print("\n[1/3] Temporal-bias DAG trajectory (90% recent / 10% old) ...")
    traj_t = growth_trajectory(
        max_hidden=args.nodes, temporal_prob=0.90, seed=args.seed
    )

    # Random baseline trajectory
    print("[2/3] Random-attachment DAG trajectory (baseline) ...")
    traj_r = growth_trajectory(
        max_hidden=args.nodes, temporal_prob=0.00, seed=args.seed + 1
    )

    # Shuffled baseline (worst case: random permutation of source IDs)
    print("[3/3] Shuffled-index DAG (worst case) ...")
    sds_shuf = build_subdags(n_hidden_total=args.nodes, seed=args.seed + 2)
    flat_shuf = flatten_subdags(sds_shuf)
    rng = np.random.default_rng(args.seed + 2)
    perm = rng.permutation(flat_shuf["n_nodes"]).astype(np.int32)
    flat_shuf["src"] = perm[flat_shuf["src"]]
    flat_shuf["tgt"] = perm[flat_shuf["tgt"]]
    r_shuf = analyze_cache(flat_shuf)

    # --- Results ---
    print("\n" + "=" * 74)
    print("  TEMPORAL-BIAS DAG (90% recent / 10% old)")
    print("=" * 74)
    _print_table(traj_t)

    print("\n" + "=" * 74)
    print("  RANDOM-ATTACHMENT DAG (0% temporal bias)")
    print("=" * 74)
    _print_table(traj_r)

    print("\n" + "=" * 74)
    print(f"  SHUFFLED-INDEX DAG (random permutation, worst case)")
    print("=" * 74)
    print(f"  Edges: {r_shuf['edges']:,}  Nodes: {r_shuf['nodes']:,}")
    print(f"  cache_lines/32 = {r_shuf['cache_lines_per_32']:.2f}  "
          f"utilization = {r_shuf['utilization_pct']:.1f}%  "
          f"amplification = {r_shuf['amplification']:.1f}x")

    # --- Summary ---
    t_f = traj_t[-1]
    r_f = traj_r[-1]
    print("\n" + "=" * 74)
    print("  SUMMARY")
    print("=" * 74)
    print(f"  Temporal (organic):   {t_f['cache_lines_per_32']:.2f} lines/warp  "
          f"({t_f['utilization_pct']:.1f}% util)")
    print(f"  Random attachment:    {r_f['cache_lines_per_32']:.2f} lines/warp  "
          f"({r_f['utilization_pct']:.1f}% util)")
    print(f"  Shuffled (worst):     {r_shuf['cache_lines_per_32']:.2f} lines/warp  "
          f"({r_shuf['utilization_pct']:.1f}% util)")
    if r_shuf['utilization_pct'] > 0:
        print(f"  Temporal vs Shuffled: "
              f"{t_f['utilization_pct'] / r_shuf['utilization_pct']:.1f}x improvement")
    print()

    if not args.no_plot:
        try:
            _plot(traj_t, traj_r, r_shuf)
        except ImportError:
            print("  (matplotlib not available — skipping plot)")


def _print_table(results):
    print(f"  {'Hidden':>7}  {'Edges':>8}  {'Nodes':>6}  "
          f"{'Lines/32':>8}  {'Util%':>6}  {'Useful':>8}  {'Fetched':>8}  {'Ampl':>5}")
    print("  " + "-" * 68)
    for r in results:
        print(f"  {r['hidden']:7d}  {r['edges']:8d}  {r['nodes']:6d}  "
              f"{r['cache_lines_per_32']:8.2f}  {r['utilization_pct']:5.1f}%  "
              f"{r['useful_kb']:7.0f}KB  {r['fetched_kb']:7.0f}KB  "
              f"{r['amplification']:4.1f}x")


def _plot(temporal, random, shuffled):
    import matplotlib.pyplot as plt

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    for data, label, color, ls, marker in [
        (temporal, "Temporal (90%)", "#0d59f2", "-", "o"),
        (random, "Random (0%)", "#888", "--", "s"),
    ]:
        edges = [r["edges"] for r in data]
        ax1.plot(edges, [r["utilization_pct"] for r in data],
                 f"{marker}{ls}", color=color, label=label, lw=2, ms=4)
        ax2.plot(edges, [r["cache_lines_per_32"] for r in data],
                 f"{marker}{ls}", color=color, label=label, lw=2, ms=4)

    ax1.axhline(shuffled["utilization_pct"], color="red", ls=":", lw=1.5,
                label=f"Shuffled ({shuffled['utilization_pct']:.1f}%)")
    ax2.axhline(shuffled["cache_lines_per_32"], color="red", ls=":", lw=1.5,
                label=f"Shuffled ({shuffled['cache_lines_per_32']:.1f})")
    ax2.axhline(1.0, color="green", ls=":", alpha=0.4, label="Perfect (1.0)")

    ax1.set_xlabel("Edge Count")
    ax1.set_ylabel("Cache Utilization (%)")
    ax1.set_title("Cache-Line Utilization vs. Graph Size")
    ax1.legend(fontsize=9)
    ax1.grid(True, alpha=0.3)

    ax2.set_xlabel("Edge Count")
    ax2.set_ylabel("Cache Lines per 32-Edge Warp")
    ax2.set_title("Memory Fetch Overhead vs. Graph Size")
    ax2.legend(fontsize=9)
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig("cache_locality_results.png", dpi=150, bbox_inches="tight")
    print(f"\n  Plot saved: cache_locality_results.png")
    plt.show()


if __name__ == "__main__":
    sys.exit(main())
