package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.InsnOpcode;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.SemanticMatcher;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.flow.ControlFlowGraph;
import net.covers1624.coffeegrinder.bytecode.flow.ControlFlowNode;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.TryCatch.TryCatchHandler;
import net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching;
import net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching;
import net.covers1624.coffeegrinder.bytecode.matching.TryCatchMatching;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformer;
import net.covers1624.coffeegrinder.util.None;
import net.covers1624.quack.collection.ColUtils;
import org.jetbrains.annotations.Nullable;

import java.util.*;

import static net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching.matchBranch;
import static net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching.matchReturn;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchLoadLocal;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchStore;

/**
 * Created by covers1624 on 31/5/21.
 */
public class TryCatches implements MethodTransformer {

    @Override
    public void transform(MethodDecl function, MethodTransformContext ctx) {
        ctx.pushStep("Inline handler variable stores");
        inlinedHandlerVariables(function, ctx);
        ctx.popStep();

        ctx.pushStep("Convert Finally Handlers");
        transformFinallyBlocks(function, ctx);
        ctx.popStep();

        // combine handlers here
        ctx.pushStep("Combine handlers");
        combineHandlers(function, ctx);
        ctx.popStep();

        ctx.pushStep("Capture try exit blocks");
        captureTryBodyExitBlocks(function, ctx);
        ctx.popStep();

        ctx.pushStep("Capture handler blocks");
        captureBlocks(function, ctx);
        ctx.popStep();
    }

    public static void inlinedHandlerVariables(MethodDecl function, MethodTransformContext ctx) {
        // all catch blocks should branch to STORE(handlerLocalVar, handlerStackVar)
        // take this store and set handlerStackVar -> handlerLocalVAR
        // if the block is now empty, inline it.

        function.<TryCatchHandler>descendantsToList(InsnOpcode.TRY_CATCH_HANDLER)
                .forEach(handler -> {
                    LocalReference handlerVarDecl = handler.getVariable();
                    LocalVariable handlerVar = handlerVarDecl.variable;
                    assert handlerVar.getKind() == LocalVariable.VariableKind.STACK_SLOT;
                    assert handlerVar.getLoadCount() == 1;

                    BlockContainer body = handler.getBody();
                    Block bodyBlock = body.blocks.only();
                    Branch firstBranch = matchBranch(bodyBlock.instructions.only());
                    assert firstBranch != null;

                    Block target = firstBranch.getTargetBlock();
                    assert target.getIncomingEdgeCount() == 1;

                    Store varStore = LoadStoreMatching.matchStoreLocalLoadLocal(target.getFirstChild(), handlerVar);
                    assert varStore != null;

                    LocalVariable handlerLocalVar = varStore.getVariable();

                    ctx.pushStep("Inline handler variable " + handlerLocalVar.getUniqueName());
                    handlerLocalVar.setType(handlerVar.getType());
                    handlerVarDecl.replaceWith(new LocalReference(handlerLocalVar));
                    varStore.remove();

                    if (target.instructions.size() == 1) { // remove empty block
                        firstBranch.replaceWith(target.getFirstChild());
                        target.remove();
                    }

                    ctx.popStep();
                });
    }

    /**
     * Attempts to merge nested try-catches with their outer.
     *
     * @param function The function to merge handlers in.
     * @param ctx      The context.
     */
    private static void combineHandlers(MethodDecl function, MethodTransformContext ctx) {
        function.accept(new SimpleInsnVisitor<MethodTransformContext>() {
            @Override
            public None visitTryCatch(TryCatch tryCatch, MethodTransformContext ctx) {
                // Continually process this try until we don't find any mergeable nested trys.
                while (true) {
                    BlockContainer container = tryCatch.getTryBody();
                    if (container.blocks.size() != 1) break;

                    Block block = container.blocks.first();
                    if (block.instructions.size() != 1) break;

                    TryCatch inner = TryCatchMatching.matchTryCatch(block.instructions.first());
                    if (inner == null) break;

                    // We have found a nested TryCatch, with no extra blocks inside the outer TryCatch.
                    // We can safely inline this.

                    ctx.pushStep("Combine " + block.getName() + " into " + ((Block) tryCatch.getParent()).getName());

                    // Add the handler to the beginning to preserve handler order.
                    // As we process in pre-order, there should only ever be a single handler in the inner.
                    tryCatch.handlers.addFirst(inner.handlers.first());

                    // Replace the body with that of the inners.
                    tryCatch.setTryBody(inner.getTryBody());
                    ctx.popStep();
                }

                return super.visitTryCatch(tryCatch, ctx);
            }
        }, ctx);
    }

    private static void captureTryBodyExitBlocks(MethodDecl function, MethodTransformContext ctx) {
        function.<TryCatch>descendantsOfType(InsnOpcode.TRY_CATCH).reversed().forEach(tryCatch -> {
            BlockContainer tryBody = tryCatch.getTryBody();
            Block tryBlock = (Block) tryCatch.getParent();

            List<Block> blocksToMove = new LinkedList<>();
            for (Block b = (Block) tryBlock.getNextSiblingOrNull(); b != null; b = (Block) b.getNextSiblingOrNull()) {
                Instruction firstInsn = b.getFirstChild();

                if (firstInsn.opcode != InsnOpcode.BRANCH && firstInsn.opcode != InsnOpcode.RETURN) break;
                if (!b.getBranches().allMatch(e -> e.isDescendantOf(tryBody))) break;

                blocksToMove.add(b);
            }

            if (blocksToMove.isEmpty()) return;

            ctx.pushStep(tryBody.getEntryPoint().getName());
            BlockContainer parentContainer = (BlockContainer) tryBlock.getParent();
            TransformerUtils.moveBlocksIntoContainer(blocksToMove, parentContainer, tryBody, null);
            ctx.popStep();
        });
    }

    private static void captureBlocks(MethodDecl function, MethodTransformContext ctx) {
        for (TryCatchHandler handler : function.<TryCatchHandler>descendantsToList(InsnOpcode.TRY_CATCH_HANDLER)) {
            ctx.pushStep("Capture blocks " + handler.getBody().getEntryPoint().getName());
            captureCatchBlocks(handler.getBody());
            ctx.popStep();
        }
    }

    /**
     * Capture the catch dominator tree.
     *
     * @param catchContainer The BlockContainer to traverse.
     */
    private static void captureCatchBlocks(BlockContainer catchContainer) {
        Block catchEntry = catchContainer.getEntryPoint();
        Branch entryBranch = matchBranch(catchEntry.instructions.only());
        if (entryBranch == null) return;

        Block realCatchEntry = entryBranch.getTargetBlock();
        // Block is not dominated by the catch handler. Cannot be inlined
        if (realCatchEntry.getIncomingEdgeCount() > 1) return;

        ControlFlowGraph graph = new ControlFlowGraph((BlockContainer) realCatchEntry.getParent());
        ControlFlowNode node = graph.getNode(realCatchEntry);

        LinkedList<Block> blocks = new LinkedList<>();
        Block block = realCatchEntry;
        while (block != null && node.dominates(graph.getNode(block))) {
            blocks.add(block);
            block = (Block) block.getNextSiblingOrNull();
        }

        TransformerUtils.moveBlocksIntoContainer(blocks, graph.container, catchContainer, null);
    }

    private static void transformFinallyBlocks(MethodDecl function, MethodTransformContext ctx) {
        // While true, since we only want to handle a single try at a time, then re-scan.
        while (true) {
            // Find any finally blocks.
            TryCatchHandler handler = findUnprocessedFinally(function);
            if (handler == null) break;

            TryCatch tryCatch = (TryCatch) handler.getParent();
            assert tryCatch.handlers.size() == 1;
            BlockContainer tryBody = tryCatch.getTryBody();
            BlockContainer parentContainer = (BlockContainer) tryCatch.getParent().getParent();

            BlockContainer handlerBodyContainer = handler.getBody();
            Block entryPoint = handlerBodyContainer.getEntryPoint();

            ctx.pushStep("Process finally " + entryPoint.getName());

            Branch finallyBranch = matchBranch(entryPoint.instructions.first());
            assert finallyBranch != null;
            Block finallyBody = finallyBranch.getTargetBlock();

            LocalVariable exVariable = handler.getVariable().variable;

            // We use this Set for 2 things.
            // - Remembering the blocks we have already visited when finding the throw block for the finally.
            // - Tracking which blocks are contained within the finally block.
            // LinkedHashSet is used to preserve the order of these blocks.
            LinkedHashSet<Block> finallyBlocks = new LinkedHashSet<>();

            // First instruction of the 'throw'.
            Store finallyEndpoint = findFinallyEndpoint(finallyBlocks, finallyBody, exVariable);

            // The finally throw block should only be missing if there are no loads of the exception variable.
            assert finallyEndpoint != null || exVariable.getLoadCount() == 0;

            // Find all the exits from the try body.
            Iterable<Block> tryBodyExits = tryBody.<Branch>descendantsOfType(InsnOpcode.BRANCH)
                    .filter(e -> !e.getTargetContainer().isDescendantOf(tryBody))
                    .map(e -> extractMonitorExitBlock(tryCatch, e, ctx))
                    .map(Branch::getTargetBlock)
                    .distinct()
                    .toLinkedList();

            // Match all the exits of the try body
            tryBodyExits.forEach(exitTarget -> {
                ctx.pushStep("Match try exit " + exitTarget.getName());
                SemanticMatcher matcher = new SemanticMatcher(finallyEndpoint);
                if (!matcher.equivalent(finallyBody, exitTarget)) {
                    // (╯°□°）╯︵ ┻━┻ We tried, someone broke their bytecode.
                    throw new IllegalStateException("Exit branch of Try container " + exitTarget.getName() + " does not pass through a semantically matching finally block: " + finallyBody.getName());
                }

                // Remember the endpoint and all matched blocks for the branch target.
                List<Block> blocksToRemove = new ArrayList<>(matcher.blockMap.values());

                // If we have a finally endpoint we must have a matched endpoint.
                assert finallyEndpoint == null || matcher.matchedEndpoint != null;

                ctx.pushStep("Rewrite exits");
                Block newExit = rewriteExits(tryBody, exitTarget, matcher.matchedEndpoint);
                ctx.popStep();

                ctx.pushStep("Delete duplicate blocks");
                for (Block block : blocksToRemove) {
                    if (block != newExit && block.isConnected()) {
                        block.remove();
                    }
                }
                ctx.popStep();

                if (newExit != null) {
                    tryInlineTryFinallyBodyExitBlock(tryCatch, newExit, finallyBody, ctx);
                }
                ctx.popStep();
            });

            // Generate the finally

            BlockContainer tryFinallyHandler = new BlockContainer();
            tryCatch.replaceWith(new TryFinally(tryBody, tryFinallyHandler).withOffsets(tryCatch));

            if (finallyEndpoint != null) {
                // Replace endpoint block with leave of the finally to preserve fallthrough exits.
                Block finallyEndpointBlock = getOrSplitBlockStartsAt(finallyEndpoint);
                finallyBlocks.add(finallyEndpointBlock);
                finallyEndpointBlock.instructions.clear();
                finallyEndpointBlock.instructions.add(new Leave(tryFinallyHandler));
            }

            TransformerUtils.moveBlocksIntoContainer(finallyBlocks, parentContainer, tryFinallyHandler, null);

            ctx.popStep();
        }
    }

    @Nullable
    private static TryCatchHandler findUnprocessedFinally(Instruction insn) {
        if (insn.opcode == InsnOpcode.TRY_CATCH) {
            TryCatch tc = (TryCatch) insn;
            for (TryCatchHandler handler : tc.handlers) {
                if (handler.isUnprocessedFinally) {
                    return handler;
                }
            }
        }

        for (Instruction c = insn.getFirstChildOrNull(); c != null; c = c.getNextSiblingOrNull()) {
            TryCatchHandler r = findUnprocessedFinally(c);
            if (r != null) {
                return r;
            }
        }

        return null;
    }

    // Splits and rewrites matched exits of a specific exit target in a try body.
    @Nullable
    private static Block rewriteExits(BlockContainer tryBody, Block tryBodyExit, @Nullable Instruction matchedEndpoint) {
        if (matchedEndpoint == null) {
            tryBodyExit.getBranches().toList().forEach(e -> e.replaceWith(new Leave(tryBody)));
            // TODO, if there's a dead synthetic local store right before this exit, merge it into a return insn right here
            return null;
        }

        Block matchedEndpointBlock = getOrSplitBlockStartsAt(matchedEndpoint);

        for (Branch exit : tryBodyExit.getBranches().toList()) {
            exit.replaceWith(new Branch(matchedEndpointBlock));
        }
        return matchedEndpointBlock;
    }

    private static Block getOrSplitBlockStartsAt(Instruction start) {
        // TODO, can we remove this entirely? Is this just making more work elsewhere?
        //  Does this branch to branch occur predictably in the scenarios where we need to remove it?
        if (start.opcode == InsnOpcode.BRANCH) {
            Block jmpBlock = ((Branch) start).getTargetBlock();
            if (jmpBlock.getBytecodeOffset() >= start.getBytecodeOffset()) {
                // simplify a branch to branch as the actual exit
                // except when the branch is a back-edge (to a loop head) because then it's relevant as an indicator of fall-through continue
                return jmpBlock;
            }
        }

        Block block = (Block) start.getParent();
        if (block.getFirstChild() == start) { return block; }

        Block splitBlock = block.extractRange(start, block.getLastChild());
        block.instructions.add(new Branch(splitBlock));
        block.insertAfter(splitBlock);
        return splitBlock;
    }

    private static void tryInlineTryFinallyBodyExitBlock(TryCatch processingFinally, Block block, Block finallyBody, MethodTransformContext ctx) {
        // Exit must have a higher offset.
        if (block.getBytecodeOffset() > finallyBody.getBytecodeOffset()) return;

        // All incoming branches must be inside the finally we are processing.
        if (!ColUtils.allMatch(block.getBranches(), br -> getContainingFinally(br) == processingFinally)) {
            return;
        }

        // Only simple return blocks are safe to inline.
        // J11TryWithResources relies on matching resource closure exits outside the try body,
        // as this is the only indicator that they are compiler-generated.

        // Match the following:
        // [STORE (LOCAL s1, ...)]
        // RETURN [s1]
        Return ret = matchReturn(block.getLastChild());
        if (ret == null) return;
        Instruction prev = ret.getPrevSiblingOrNull();
        if (prev != null) {
            Store store = matchStore(prev);
            if (store == null) return;
            if (matchLoadLocal(ret.getValue(), store.getVariable()) == null) return;
            if (block.getFirstChild() != store) return;
        }

        // Exit gets inlined after the block with the highest bytecode offset.
        Block highestBlock = block.getBranches()
                .map(br -> br.<Block>firstAncestorOfType(InsnOpcode.BLOCK))
                .maxBy(Instruction::getBytecodeOffset);

        ctx.pushStep("Move synthetic return block " + block.getName() + " into try body");
        highestBlock.insertAfter(block);
        ctx.popStep();
    }

    private static TryCatch getContainingFinally(Branch br) {
        return br.<TryCatch>ancestorsOfType(InsnOpcode.TRY_CATCH)
                .filter(tc -> tc.handlers.only().isUnprocessedFinally)
                .first();
    }

    private static Branch extractMonitorExitBlock(TryCatch tryCatch, Branch exit, MethodTransformContext ctx) {
        // Match the following case:
        // STORE s_0(LOAD var_0)
        // MONITOR_EXIT(LOAD s_0)
        // BRANCH exit
        Instruction prev = exit.getPrevSiblingOrNull();
        if (prev == null || prev.opcode != InsnOpcode.MONITOR_EXIT) return exit;
        Block block = (Block) exit.getParent();
        if (block.getParent().getParent() != tryCatch) return exit;

        // If the block containing the MONITOR_EXIT is by itself, we can just match this block directly.
        // This is required for cases where there is another try with the actual return nested inside the synchronized block.
        if (block.getFirstChild() == prev.getPrevSibling()) {
            // Just arbitrarily return the first branch, it doesn't matter, directly after this call we are just getting the target block.
            // TODO can we improve this and instead return the block directly?
            return block.getBranches().first();
        }

        ctx.pushStep("Split MONITOR_EXIT block out of try");
        // Split the STORE, MONITOR_EXIT and BRANCH instructions into a new block.
        Block split = block.extractRange(exit.getPrevSibling().getPrevSibling(), exit);
        // Double-check that we actually captured a STORE as the first instruction in the split block.
        assert matchStore(split.getFirstChild()) != null;

        // Insert branch at the end of the old block to the new split block.
        Branch newBranch = new Branch(split);
        block.instructions.add(newBranch);

        // Insert split block into parent block container.
        Block tryCatchBlock = (Block) tryCatch.getParent();
        tryCatchBlock.insertAfter(split);
        ctx.popStep();

        return newBranch;
    }

    /**
     * Traverse the entire reachability tree of the given block, ending in a throw of the exception variable.
     *
     * @param visited           The blocks that were visited whilst finding the throw block. This will contain the throw block itself.
     * @param block             The block to search from. This will be included in the visited set.
     * @param exceptionVariable The {@link LocalVariable} storing the finally exception.
     * @return The first instruction of the finally throw sequence.
     */
    @Nullable
    private static Store findFinallyEndpoint(Set<Block> visited, Block block, LocalVariable exceptionVariable) {
        // Match the following case:
        // STORE synVar LOAD exceptionVariable
        // THROW LOAD synVar
        if (!visited.add(block)) return null;

        Throw thr = BranchLeaveMatching.matchThrow(block.getLastChild());
        if (thr != null) {
            Store store = LoadStoreMatching.matchStoreLocalLoadLocal(thr.getPrevSiblingOrNull(), exceptionVariable);
            if (store != null) {
                assert LoadStoreMatching.matchLoadLocal(thr.getArgument(), store.getVariable()) != null;
                return store;
            }
        }

        return block.<Branch>descendantsOfType(InsnOpcode.BRANCH)
                .map(br -> findFinallyEndpoint(visited, br.getTargetBlock(), exceptionVariable))
                .filter(Objects::nonNull)
                .distinct()
                .onlyOrDefault();
    }
}
