package com.poizon.security;

import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlRunnable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;


public class TtlPerfHighConcurrencyTest {

    static final int THREADS            = Integer.getInteger("threads", Math.max(4, Runtime.getRuntime().availableProcessors()));
    static final int PRODUCERS          = Integer.getInteger("producers", THREADS * 30);
    static final int TASKS_PER_PRODUCER = Integer.getInteger("tasks.per.producer", 200_000);
    static final int TTL_COUNT          = Integer.getInteger("ttl.count", 0);

    static final int STAGES             = Integer.getInteger("stages", 3);
    static final String FANOUTS_STR     = System.getProperty("fanouts", "");          // e.g. "5,5,5,5"
    static final String THREADS_STR     = System.getProperty("threads.stages", "");   // e.g. "8,8,8,8"
    static final String QUEUES_STR      = System.getProperty("queues.stages", "");    // e.g. "sync,sync,small,sync"

    static final int QCAP_SMALL         = Integer.getInteger("qcap.small", 1024);
    static final int QCAP_LARGE         = Integer.getInteger("qcap.large", 65536);

    static final int LEAF_EXECUTORS     = Integer.getInteger("leaf.executors", 3);    // 乒乓节点数
    static final int LEAF_SWITCH_ROUNDS = Integer.getInteger("leaf.switch.rounds", 32);// 切换轮数（每轮一次 execute）

    static final long MAX_TOTAL_TASKS   = Long.getLong("max.total.tasks", 5_000_000L);

    static final List<TransmittableThreadLocal<String>> TTLS = new ArrayList<>(TTL_COUNT);

    static final AtomicInteger BLACKHOLE = new AtomicInteger();

    public static void main(String[] args) throws Exception {
        final int[] fanouts = parseIntListOrDefault(FANOUTS_STR, STAGES, 5);
        final int[] stageThreads = parseIntListOrDefault(THREADS_STR, STAGES, THREADS);
        final String[] stageQueues = parseStrListOrDefault(QUEUES_STR, STAGES, "sync");

        System.out.printf(
                "JDK=%s, stages=%d, fanouts=%s, threads.stages=%s, queues.stages=%s%n" +
                "threads=%d, producers=%d, tasks.per.producer=%d, ttl.count=%d%n" +
                "leaf.executors=%d, leaf.switch.rounds=%d, qcap.small=%d, qcap.large=%d, max.total.tasks=%d%n",
                System.getProperty("java.version"),
                STAGES, Arrays.toString(fanouts), Arrays.toString(stageThreads), Arrays.toString(stageQueues),
                THREADS, PRODUCERS, TASKS_PER_PRODUCER, TTL_COUNT,
                LEAF_EXECUTORS, LEAF_SWITCH_ROUNDS, QCAP_SMALL, QCAP_LARGE, MAX_TOTAL_TASKS
        );

        for (int i = 0; i < TTL_COUNT; i++) {
            TTLS.add(new TransmittableThreadLocal<>());
        }

        // 预热
        runCase("warmup-baseline", fanouts, stageThreads, stageQueues, false);
        runCase("warmup-withTTL",  fanouts, stageThreads, stageQueues, true);

        // 正式对比
        Result r1 = runCase("baseline", fanouts, stageThreads, stageQueues, false);
        gcPause();
        Result r2 = runCase("withTTL",  fanouts, stageThreads, stageQueues, true);

        // 汇总
        System.out.println();
        System.out.println("==== Summary ====");
        printRow("Case", "Tasks", "Time(s)", "Throughput(t/s)", "Avg(ns/task)");
        printRow(r1.name, r1.totalTasks, fmtSec(r1.seconds), fmtNum(r1.throughput), fmtNum(r1.avgNs));
        printRow(r2.name, r2.totalTasks, fmtSec(r2.seconds), fmtNum(r2.throughput), fmtNum(r2.avgNs));
        System.out.println("-----------------");
        double slowdown = r2.seconds / r1.seconds;
        System.out.printf("Slowdown (withTTL vs baseline): %.2fx  (ttl.count=%d, stages=%d, fanouts=%s, leaf.switch.rounds=%d)%n",
                slowdown, TTL_COUNT, STAGES, Arrays.toString(fanouts), LEAF_SWITCH_ROUNDS);
    }

    static Result runCase(String name, int[] fanouts, int[] stageThreads, String[] stageQueues, boolean withTTL)
            throws InterruptedException {

        final long fanoutProduct = productOf(fanouts);
        final long wouldTotal = 1L * PRODUCERS * TASKS_PER_PRODUCER * fanoutProduct;

        final int tasksPerProducerUsed;
        if (wouldTotal > MAX_TOTAL_TASKS) {
            long denom = 1L * PRODUCERS * fanoutProduct;
            tasksPerProducerUsed = (int) Math.max(1L, MAX_TOTAL_TASKS / Math.max(1L, denom));
            System.out.printf("[WARN] totalTasks would be %,d > %,d. Auto-reducing tasks.per.producer to %d.%n",
                    wouldTotal, MAX_TOTAL_TASKS, tasksPerProducerUsed);
        } else {
            tasksPerProducerUsed = TASKS_PER_PRODUCER;
        }
        final int totalTasks = safeToInt(1L * PRODUCERS * tasksPerProducerUsed * fanoutProduct);

        final ThreadPoolExecutor[] stageExecs = new ThreadPoolExecutor[STAGES];
        for (int i = 0; i < STAGES; i++) {
            stageExecs[i] = newExecutor("stage-" + (i + 1) + "-",
                    stageThreads[i], stageQueues[i], QCAP_SMALL, QCAP_LARGE);
        }

        final ThreadPoolExecutor[] leafExecs = new ThreadPoolExecutor[Math.max(1, LEAF_EXECUTORS)];
        for (int i = 0; i < leafExecs.length; i++) {
            leafExecs[i] = new ThreadPoolExecutor(
                    1, 1, 0L, TimeUnit.MILLISECONDS,
                    new SynchronousQueue<>(),
                    new NamedFactory("leaf-" + (i + 1) + "-"),
                    new ThreadPoolExecutor.CallerRunsPolicy());
            leafExecs[i].allowCoreThreadTimeOut(false);
        }

        CountDownLatch start = new CountDownLatch(1);
        CountDownLatch done  = new CountDownLatch(totalTasks);

        List<Thread> producers = new ArrayList<>(PRODUCERS);
        for (int p = 0; p < PRODUCERS; p++) {
            final int pid = p;
            Thread t = new Thread(() -> {
                for (int i = 0; i < TTL_COUNT; i++) {
                    TTLS.get(i).set("P" + pid + "-V" + i);
                }

                await(start);

                for (int i = 0; i < tasksPerProducerUsed; i++) {
                    Runnable stageTask = new StageTask(1, withTTL, stageExecs, fanouts, leafExecs, done);

                    for (int k = 0; k < fanouts[0]; k++) {
                        Runnable r = stageTask;
                        if (withTTL) r = TtlRunnable.get(r, false, false);
                        stageExecs[0].execute(r);
                    }

                }
            }, "producer-" + pid);
            t.setDaemon(true);
            t.start();
            producers.add(t);
        }

        long t0 = System.nanoTime();
        start.countDown();

        done.await();
        long t1 = System.nanoTime();

        // 收尾
        for (Thread t : producers) t.join();
        for (ThreadPoolExecutor ex : stageExecs) shutdownAndAwait(ex);
        for (ThreadPoolExecutor ex : leafExecs) shutdownAndAwait(ex);

        double seconds    = (t1 - t0) / 1_000_000_000.0;
        double throughput = totalTasks / Math.max(1e-9, seconds);
        double avgNs      = (t1 - t0) * 1.0 / Math.max(1, totalTasks);

        System.out.printf("%-12s => tasks=%d, time=%.3fs, throughput=%.0f tasks/s, avg=%.1f ns/task%n",
                name, totalTasks, seconds, throughput, avgNs);

        return new Result(name, totalTasks, seconds, throughput, avgNs);
    }

    static final class LeafSwitcher {
        final ThreadPoolExecutor[] execs;
        final boolean withTTL;
        final int rounds;
        final CountDownLatch done;

        LeafSwitcher(ThreadPoolExecutor[] execs, boolean withTTL, int rounds, CountDownLatch done) {
            this.execs = execs; this.withTTL = withTTL; this.rounds = rounds; this.done = done;
        }

        void start() {
            Hop h = new Hop(execs, withTTL, rounds, done);
            Runnable r = withTTL ? TtlRunnable.get(h, false, false) : h;
            execs[0].execute(r); // 从第一个 leaf 执行器起跳
        }

        static final class Hop implements Runnable {
            final ThreadPoolExecutor[] execs;
            final boolean withTTL;
            final int rounds;
            final CountDownLatch done;
            int round = 0;
            int idx = 0;

            Hop(ThreadPoolExecutor[] execs, boolean withTTL, int rounds, CountDownLatch done) {
                this.execs = execs; this.withTTL = withTTL; this.rounds = rounds; this.done = done;
            }

            @Override public void run() {
                if (round < rounds) {
                    round++;
                    idx = (idx + 1) % execs.length;
                    Runnable next = this; // 复用对象，串行执行，无并发竞态
                    execs[idx].execute(withTTL ? TtlRunnable.get(next, false, false) : next);
                } else {
                    heavyWork();
                    BLACKHOLE.incrementAndGet();
                    done.countDown();
                }
            }
        }
    }

    static final class StageTask implements Runnable {
        final int nextStageIdx;
        final boolean withTTL;
        final ThreadPoolExecutor[] stageExecs;
        final int[] fanouts;
        final ThreadPoolExecutor[] leafExecs;
        final CountDownLatch done;

        StageTask(int nextStageIdx,
                  boolean withTTL,
                  ThreadPoolExecutor[] stageExecs,
                  int[] fanouts,
                  ThreadPoolExecutor[] leafExecs,
                  CountDownLatch done) {
            this.nextStageIdx = nextStageIdx;
            this.withTTL = withTTL;
            this.stageExecs = stageExecs;
            this.fanouts = fanouts;
            this.leafExecs = leafExecs;
            this.done = done;
        }

        @Override public void run() {
            if (nextStageIdx < stageExecs.length) {
                Runnable child = new StageTask(nextStageIdx + 1, withTTL, stageExecs, fanouts, leafExecs, done);
                int f = fanouts[nextStageIdx]; // 本层的放大倍数
                for (int i = 0; i < f; i++) {
                    Runnable r = child;
                    if (withTTL) r = TtlRunnable.get(r, false, false);
                    stageExecs[nextStageIdx].execute(r);
                }
            } else {
                new LeafSwitcher(leafExecs, withTTL, LEAF_SWITCH_ROUNDS, done).start();
            }
        }
    }


    static ThreadPoolExecutor newExecutor(String namePrefix, int threads, String queueKind, int qcapSmall, int qcapLarge) {
        final BlockingQueue<Runnable> q;
        switch (queueKind) {
            case "small":
                q = new ArrayBlockingQueue<>(qcapSmall);
                break;
            case "large":
                q = new LinkedBlockingQueue<>(qcapLarge);
                break;
            case "sync":
            default:
                q = new SynchronousQueue<>();
                break;
        }
        ThreadPoolExecutor ex = new ThreadPoolExecutor(
                threads, threads,
                0L, TimeUnit.MILLISECONDS,
                q,
                new NamedFactory(namePrefix),
                new ThreadPoolExecutor.CallerRunsPolicy()
        );
        ex.allowCoreThreadTimeOut(false);
        return ex;
    }

    static final class NamedFactory implements ThreadFactory {
        final String prefix;
        final AtomicInteger n = new AtomicInteger(1);
        NamedFactory(String prefix) { this.prefix = prefix; }
        @Override public Thread newThread(Runnable r) {
            Thread t = new Thread(r, prefix + n.getAndIncrement());
            t.setDaemon(true);
            return t;
        }
    }

    static void shutdownAndAwait(ThreadPoolExecutor ex) throws InterruptedException {
        ex.shutdown();
        ex.awaitTermination(2, TimeUnit.MINUTES);
    }

    static void heavyWork() {
        for (int ii = 1; ii < 3; ii++) {
            double re = Math.sin(ii) * Math.cos(ii) * Math.tan(ii)
                    + Math.log(ii) * Math.exp(Math.sqrt(ii))
                    + Math.pow(ii, 0.5) * Math.pow(ii, 1.5);
            if (re == 42.0) BLACKHOLE.addAndGet(ii);
        }
    }

    static void await(CountDownLatch latch) {
        try { latch.await(); } catch (InterruptedException ignored) { Thread.currentThread().interrupt(); }
    }

    static void gcPause() throws InterruptedException {
        System.gc();
        Thread.sleep(300);
    }

    static void printRow(String a, Object b, Object c, Object d, Object e) {
        System.out.printf("%-12s | %-10s | %-8s | %-16s | %-12s%n", a, b, c, d, e);
    }

    static String fmtNum(double v) { return String.format("%.0f", v); }
    static String fmtSec(double s) { return String.format("%.3f", s); }

    static final class Result {
        final String name;
        final int totalTasks;
        final double seconds;
        final double throughput;
        final double avgNs;
        Result(String n, int t, double s, double th, double ns) {
            name = n; totalTasks = t; seconds = s; throughput = th; avgNs = ns;
        }
    }

    // ---------------- parse & math ----------------

    static int[] parseIntListOrDefault(String csv, int len, int fill) {
        int[] r = new int[len];
        Arrays.fill(r, fill);
        if (csv == null || csv.trim().isEmpty()) return r;
        String[] toks = csv.split(",");
        for (int i = 0; i < Math.min(len, toks.length); i++) {
            try { r[i] = Integer.parseInt(toks[i].trim()); } catch (Exception ignored) {}
        }
        return r;
    }

    static String[] parseStrListOrDefault(String csv, int len, String fill) {
        String[] r = new String[len];
        Arrays.fill(r, fill);
        if (csv == null || csv.trim().isEmpty()) return r;
        String[] toks = csv.split(",");
        for (int i = 0; i < Math.min(len, toks.length); i++) {
            String v = toks[i].trim();
            if (!v.isEmpty()) r[i] = v;
        }
        return r;
    }

    static long productOf(int[] arr) {
        long p = 1;
        for (int v : arr) p = Math.max(0, p) * Math.max(1, v);
        return p;
    }

    static int safeToInt(long v) {
        return v > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) v;
    }
}