package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.SemanticMatcher;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.flow.ControlFlowGraph;
import net.covers1624.coffeegrinder.bytecode.flow.ControlFlowNode;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.TryCatch.TryCatchHandler;
import net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformer;
import net.covers1624.coffeegrinder.util.None;
import org.jetbrains.annotations.Nullable;

import java.util.*;

import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchLoadLocal;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchStoreLocal;

/**
 * Created by covers1624 on 31/5/21.
 */
public class TryCatches implements MethodTransformer {

    @Override
    public void transform(MethodDecl function, MethodTransformContext ctx) {
        ctx.pushStep("Inline handler variable stores");
        inlinedHandlerVariables(function, ctx);
        ctx.popStep();

        ctx.pushStep("Convert Finally Handlers");
        transformFinallyBlocks(function, ctx);
        ctx.popStep();

        ctx.pushStep("Capture try exit blocks");
        captureTryBodyExitBlocks(function, ctx);
        ctx.popStep();

        ctx.pushStep("Combine catch handlers");
        combineHandlers(function, ctx);
        ctx.popStep();

        ctx.pushStep("Capture handler blocks");
        captureBlocks(function, ctx);
        ctx.popStep();

        ctx.pushStep("TryCatchFinally");
        createTryCatchFinally(function, ctx);
        ctx.popStep();
    }

    public static void inlinedHandlerVariables(MethodDecl function, MethodTransformContext ctx) {
        // all catch blocks should branch to STORE(handlerLocalVar, handlerStackVar)
        // take this store and set handlerStackVar -> handlerLocalVAR
        // if the block is now empty, inline it.

        function.descendantsToList(TryCatchHandler.class)
                .forEach(handler -> {
                    LocalReference handlerVarDecl = handler.getVariable();
                    LocalVariable handlerVar = handlerVarDecl.variable;
                    assert handlerVar.getKind() == LocalVariable.VariableKind.STACK_SLOT;
                    assert handlerVar.getLoadCount() == 1;

                    BlockContainer body = handler.getBody();
                    Block bodyBlock = body.blocks.only();
                    Branch firstBranch = (Branch) bodyBlock.instructions.only();

                    Block target = firstBranch.getTargetBlock();
                    if (target.getIncomingEdgeCount() != 1) return;

                    Store varStore = LoadStoreMatching.matchStoreLocalLoadLocal(target.getFirstChild(), handlerVar);
                    assert varStore != null;

                    LocalVariable handlerLocalVar = varStore.getVariable();

                    ctx.pushStep("Inline handler variable " + handlerLocalVar.getUniqueName());
                    handlerLocalVar.setType(handlerVar.getType());
                    handlerVarDecl.replaceWith(new LocalReference(handlerLocalVar));
                    varStore.remove();

                    if (target.instructions.size() == 1) { // remove empty block
                        firstBranch.replaceWith(target.getFirstChild());
                        target.remove();
                    }

                    ctx.popStep();
                });
    }

    /**
     * Attempts to merge nested try-catches with their outer.
     *
     * @param function The function to merge handlers in.
     * @param ctx      The context.
     */
    private static void combineHandlers(MethodDecl function, MethodTransformContext ctx) {
        function.accept(new SimpleInsnVisitor<MethodTransformContext>() {
            @Override
            public None visitTryCatch(TryCatch tryCatch, MethodTransformContext ctx) {
                if (tryCatch.getFinallyBody() != null) return super.visitTryCatch(tryCatch, ctx);

                // Continually process this try until we don't find any mergeable nested trys.
                while (true) {
                    BlockContainer container = tryCatch.getTryBody();
                    if (container.blocks.size() != 1) break;

                    Block block = container.blocks.first();
                    if (block.instructions.size() != 1) break;

                    if (!(block.instructions.first() instanceof TryCatch inner) || inner.handlers.isEmpty()) break;

                    // We have found a nested TryCatch, with no extra blocks inside the outer TryCatch.
                    // We can safely inline this.

                    ctx.pushStep("Combine " + block.getName() + " into " + ((Block) tryCatch.getParent()).getName());

                    // Add the handler to the beginning to preserve handler order.
                    // As we process in pre-order, there should only ever be a single handler in the inner.
                    tryCatch.handlers.addFirst(inner.handlers.first());

                    // Replace the body with that of the inners.
                    tryCatch.setTryBody(inner.getTryBody());
                    ctx.popStep();
                }

                return super.visitTryCatch(tryCatch, ctx);
            }
        }, ctx);
    }

    /**
     * Attempts to merge nested try-catches with their outer.
     *
     * @param function The function to merge handlers in.
     * @param ctx      The context.
     */
    private static void createTryCatchFinally(MethodDecl function, MethodTransformContext ctx) {
        function.accept(new SimpleInsnVisitor<MethodTransformContext>() {
            @Override
            public None visitTryCatch(TryCatch tryCatch, MethodTransformContext ctx) {
                if (!isUnprocessedSynchronized(tryCatch) && tryCatch.getFinallyBody() != null) {
                    tryMerge(tryCatch, ctx);
                }

                return super.visitTryCatch(tryCatch, ctx);
            }

            private static void tryMerge(TryCatch tryFinally, MethodTransformContext ctx) {
                BlockContainer container = tryFinally.getTryBody();
                if (container.blocks.size() != 1) return;

                Block block = container.blocks.first();
                if (block.instructions.size() != 1 || block.getIncomingEdgeCount() != 1) return;

                if (!(block.instructions.first() instanceof TryCatch inner)) return;
                if (inner.getFinallyBody() != null) return;

                // We have found a nested TryCatch, with no extra blocks inside the outer TryFinally.
                // We can safely inline this.

                ctx.pushStep("Combine " + block.getName() + " into " + ((Block) tryFinally.getParent()).getName());
                tryFinally.handlers.addAll(inner.handlers);
                tryFinally.setTryBody(inner.getTryBody());
                ctx.popStep();
            }
        }, ctx);
    }

    private static boolean isUnprocessedSynchronized(TryCatch tryCatch) {
        if (!(tryCatch.getFinallyBody() instanceof BlockContainer body)) return false;
        if (!(matchStoreLocal(body.getEntryPoint().getFirstChildOrNull()) instanceof Store store)) return false;
        return store.getNextSiblingOrNull() instanceof MonitorExit;
    }

    private static void captureTryBodyExitBlocks(MethodDecl function, MethodTransformContext ctx) {
        function.descendantsOfType(TryCatch.class).filter(e -> e.getFinallyBody() == null).reversed().forEach(tryCatch -> {
            BlockContainer tryBody = tryCatch.getTryBody();
            Block tryBlock = (Block) tryCatch.getParent();

            List<Block> blocksToMove = new LinkedList<>();
            for (Block b = (Block) tryBlock.getNextSiblingOrNull(); b != null; b = (Block) b.getNextSiblingOrNull()) {
                Instruction firstInsn = b.getFirstChild();

                if (!(firstInsn instanceof Branch) && !(firstInsn instanceof Return)) break;
                if (!b.getBranches().allMatch(e -> e.isDescendantOf(tryBody))) break;

                blocksToMove.add(b);
            }

            if (blocksToMove.isEmpty()) return;

            ctx.pushStep(tryBody.getEntryPoint().getName());
            BlockContainer parentContainer = (BlockContainer) tryBlock.getParent();
            TransformerUtils.moveBlocksIntoContainer(blocksToMove, parentContainer, tryBody, null);
            ctx.popStep();
        });
    }

    private static void captureBlocks(MethodDecl function, MethodTransformContext ctx) {
        for (TryCatchHandler handler : function.descendantsToList(TryCatchHandler.class)) {
            ctx.pushStep("Capture blocks " + handler.getBody().getEntryPoint().getName());
            captureCatchBlocks(handler.getBody());
            ctx.popStep();
        }
    }

    /**
     * Capture the catch dominator tree.
     *
     * @param catchContainer The BlockContainer to traverse.
     */
    private static void captureCatchBlocks(BlockContainer catchContainer) {
        Block catchEntry = catchContainer.getEntryPoint();
        if (!(catchEntry.instructions.only() instanceof Branch entryBranch)) return;

        Block realCatchEntry = entryBranch.getTargetBlock();
        // Block is not dominated by the catch handler. Cannot be inlined
        if (realCatchEntry.getIncomingEdgeCount() > 1) return;

        var realCatchContainer = (BlockContainer) realCatchEntry.getParent();
        if (realCatchContainer != catchContainer.getParent().firstAncestorOfType(BlockContainer.class)) return;

        ControlFlowGraph graph = new ControlFlowGraph(realCatchContainer);
        ControlFlowNode node = graph.getNode(realCatchEntry);

        LinkedList<Block> blocks = new LinkedList<>();
        Block block = realCatchEntry;
        while (block != null && node.dominates(graph.getNode(block))) {
            blocks.add(block);
            block = (Block) block.getNextSiblingOrNull();
        }

        TransformerUtils.moveBlocksIntoContainer(blocks, graph.container, catchContainer, null);
    }

    private static void transformFinallyBlocks(MethodDecl function, MethodTransformContext ctx) {
        // While true, since we only want to handle a single try at a time, then re-scan.
        while (true) {
            // Find any finally blocks.
            TryCatchHandler handler = findUnprocessedFinally(function);
            if (handler == null) break;

            TryCatch tryCatch = (TryCatch) handler.getParent();
            assert tryCatch.handlers.size() == 1;
            BlockContainer tryBody = tryCatch.getTryBody();
            BlockContainer parentContainer = (BlockContainer) tryCatch.getParent().getParent();

            BlockContainer handlerBodyContainer = handler.getBody();
            Block entryPoint = handlerBodyContainer.getEntryPoint();

            ctx.pushStep("Process finally " + entryPoint.getName());

            Branch finallyBranch = (Branch) entryPoint.instructions.first();
            Block finallyBody = finallyBranch.getTargetBlock();

            LocalVariable exVariable = handler.getVariable().variable;

            // We use this Set for 2 things.
            // - Remembering the blocks we have already visited when finding the throw block for the finally.
            // - Tracking which blocks are contained within the finally block.
            // LinkedHashSet is used to preserve the order of these blocks.
            LinkedHashSet<Block> finallyBlocks = new LinkedHashSet<>();

            // First instruction of the 'throw'.
            Store finallyEndpoint = findFinallyEndpoint(finallyBlocks, finallyBody, exVariable);

            // The finally throw block should only be missing if there are no loads of the exception variable.
            assert finallyEndpoint != null || exVariable.getLoadCount() == 0;

            // Find all the exits from the try body.
            Iterable<Block> tryBodyExits = tryBody.descendantsOfType(Branch.class)
                    .filter(e -> !e.getTargetContainer().isDescendantOf(tryBody))
                    .map(e -> extractMonitorExitBlock(tryCatch, e, ctx))
                    .map(Branch::getTargetBlock)
                    .distinct()
                    .toLinkedList();

            // Match all the exits of the try body
            tryBodyExits.forEach(exitTarget -> {
                ctx.pushStep("Match try exit " + exitTarget.getName());
                SemanticMatcher matcher = new SemanticMatcher(finallyEndpoint);
                if (!matcher.equivalent(finallyBody, exitTarget)) {
                    // (╯°□°）╯︵ ┻━┻ We tried, someone broke their bytecode.
                    throw new IllegalStateException("Exit branch of Try container " + exitTarget.getName() + " does not pass through a semantically matching finally block: " + finallyBody.getName());
                }

                // Remember the endpoint and all matched blocks for the branch target.
                List<Block> matchedBlocks = new ArrayList<>(matcher.blockMap.values());

                // If we have a finally endpoint we must have a matched endpoint.
                assert finallyEndpoint == null || matcher.matchedEndpoint != null;

                ctx.pushStep("Delete duplicate finally code");

                Block exitBlock = null;
                if (matcher.matchedEndpoint != null) {
                    exitBlock = (Block)matcher.matchedEndpoint.getParent();
                    while (exitBlock.getFirstChild() != matcher.matchedEndpoint) {
                        exitBlock.getFirstChild().remove();
                    }
                }

                for (Block block : matchedBlocks) {
                    if (block != exitBlock && block.isConnected()) {
                        block.remove();
                    }
                }

                ctx.popStep();

                ctx.pushStep("Rewrite exits");
                rewriteExits(tryBody, exitTarget, exitBlock);
                ctx.popStep();

                if (exitBlock != null && exitBlock.getIncomingEdgeCount() == 1) {
                    tryInlineTryFinallyBodyExitBlock(tryCatch, exitBlock, finallyBody, ctx);
                }
                ctx.popStep();
            });

            // Generate the finally

            BlockContainer tryFinallyHandler = new BlockContainer();
            tryCatch.handlers.clear();
            tryCatch.setFinallyBody(tryFinallyHandler);

            if (finallyEndpoint != null) {
                // Replace endpoint block with leave of the finally
                // STORE s_tr = LOAD ex
                // THROW LOAD s_tr
                var next = finallyEndpoint.getNextSibling();
                assert next instanceof Throw;
                next.remove();
                finallyEndpoint.replaceWith(new Leave(tryFinallyHandler));
            }

            TransformerUtils.moveBlocksIntoContainer(finallyBlocks, parentContainer, tryFinallyHandler, null);

            ctx.popStep();
        }
    }

    @Nullable
    private static TryCatchHandler findUnprocessedFinally(Instruction insn) {
        if (insn instanceof TryCatch tc) {
            for (TryCatchHandler handler : tc.handlers) {
                if (handler.isUnprocessedFinally) {
                    return handler;
                }
            }
        }

        for (Instruction c = insn.getFirstChildOrNull(); c != null; c = c.getNextSiblingOrNull()) {
            TryCatchHandler r = findUnprocessedFinally(c);
            if (r != null) {
                return r;
            }
        }

        return null;
    }

    private static void rewriteExits(BlockContainer tryBody, Block tryBodyExit, @Nullable Block matchedEndpoint) {
        Instruction edge;
        if (tryBodyExit.getIncomingEdgeCount() == 1) {
            edge = tryBodyExit.getBranches().only();
        } else {
            // since the presence of a finally copy indicates a unique exit from the tryBody
            // we should make a single point for all the branches to this copy to merge within the try
            var b = new Block(tryBodyExit.getSubName("finally_entry"));
            b.setOffsets(tryBodyExit);

            TransformerUtils.moveBlocksIntoContainer(List.of(b), null, tryBody, null);

            tryBodyExit.getBranches().toList().forEach(e -> e.replaceWith(new Branch(b)));

            b.instructions.add(edge = new Branch(tryBodyExit));
        }

        if (matchedEndpoint == null) {
            edge.replaceWith(new Leave(tryBody));
            // TODO, if there's a dead synthetic local store right before this exit, merge it into a return insn right here
        }
        else {
            edge.replaceWith(new Branch(matchedEndpoint));
        }
    }

    private static void tryInlineTryFinallyBodyExitBlock(TryCatch processingFinally, Block block, Block finallyBody, MethodTransformContext ctx) {
        if (block.getFirstChild() instanceof Branch br) {
            block.getBranches().only().replaceWith(br);
            block.remove();
            block = br.getTargetBlock();
        }

        // Exit must have a higher offset.
        if (block.getBytecodeOffset() > finallyBody.getBytecodeOffset()) return;

        // Only simple return blocks are safe to inline.
        // J11TryWithResources relies on matching resource closure exits outside the try body,
        // as this is the only indicator that they are compiler-generated.

        // Match the following:
        // [STORE (LOCAL s1, ...)]
        // RETURN [s1]
        if (!(block.getLastChild() instanceof Return ret)) return;
        Instruction prev = ret.getPrevSiblingOrNull();
        if (prev != null) {
            if (!(prev instanceof Store store)) return;
            if (matchLoadLocal(ret.getValue(), store.getVariable()) == null) return;
            if (block.getFirstChild() != store) return;
        }

        Branch branch = block.getBranches().only();
        Block fromBlock = branch.firstAncestorOfType(Block.class);
        TryCatch containingFinally = fromBlock.ancestorsOfType(TryCatch.class)
                .filter(tc -> !tc.handlers.isEmpty() && tc.handlers.only().isUnprocessedFinally)
                .first();
        if (containingFinally != processingFinally) return;

        ctx.pushStep("Move synthetic return block " + block.getName() + " into try body");
        fromBlock.insertAfter(block);
        ctx.popStep();
    }

    private static Branch extractMonitorExitBlock(TryCatch tryCatch, Branch exit, MethodTransformContext ctx) {
        // Match the following case:
        // STORE s_0(LOAD var_0)
        // MONITOR_EXIT(LOAD s_0)
        // BRANCH exit
        Instruction prev = exit.getPrevSiblingOrNull();
        if (!(prev instanceof MonitorExit)) return exit;

        Block block = (Block) exit.getParent();
        if (block.getParent().getParent() != tryCatch) return exit;

        ctx.pushStep("Split MONITOR_EXIT block out of try");
        // Split the STORE, MONITOR_EXIT and BRANCH instructions into a new block.
        Block split = block.extractRange(exit.getPrevSibling().getPrevSibling(), exit);
        // Double-check that we actually captured a STORE as the first instruction in the split block.
        assert split.getFirstChild() instanceof Store;

        // Insert branch at the end of the old block to the new split block.
        Branch newBranch = new Branch(split);
        block.instructions.add(newBranch);

        // Insert split block into parent block container.
        Block tryCatchBlock = (Block) tryCatch.getParent();
        tryCatchBlock.insertAfter(split);
        // synchronized statements put the monitor exit block at the end of the try-finally
        //   in cases where all we're left with is a single branch, we should move it to the same container as its incoming branches
        //   as its current location (outside any nested try bodies) doesn't represent a logical control flow point
        //
        //   Note, we can't just skip splitting the block (or equivalently, inline the branch to branch), because it may have multiple incoming branches
        if (block.instructions.size() == 1) {
            Block containingLastExit = block.getBranches().map(e -> e.firstAncestorOfType(Block.class)).maxBy(Instruction::getBytecodeOffset);
            if (containingLastExit.getParent() != block.getParent()) {
                ctx.pushStep("Move block");
                containingLastExit.insertAfter(block);
                ctx.popStep();
            }
        }
        ctx.popStep();

        return newBranch;
    }

    /**
     * Traverse the entire reachability tree of the given block, ending in a throw of the exception variable.
     *
     * @param visited           The blocks that were visited whilst finding the throw block. This will contain the throw block itself.
     * @param block             The block to search from. This will be included in the visited set.
     * @param exceptionVariable The {@link LocalVariable} storing the finally exception.
     * @return The first instruction of the finally throw sequence.
     */
    @Nullable
    private static Store findFinallyEndpoint(Set<Block> visited, Block block, LocalVariable exceptionVariable) {
        // Match the following case:
        // STORE synVar LOAD exceptionVariable
        // THROW LOAD synVar
        if (!visited.add(block)) return null;

        if (block.getLastChild() instanceof Throw thr) {
            Store store = LoadStoreMatching.matchStoreLocalLoadLocal(thr.getPrevSiblingOrNull(), exceptionVariable);
            if (store != null) {
                assert LoadStoreMatching.matchLoadLocal(thr.getArgument(), store.getVariable()) != null;
                return store;
            }
        }

        return block.descendantsOfType(Branch.class)
                .map(br -> findFinallyEndpoint(visited, br.getTargetBlock(), exceptionVariable))
                .filter(Objects::nonNull)
                .distinct()
                .onlyOrDefault();
    }
}
