package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.flow.ControlFlowGraph;
import net.covers1624.coffeegrinder.bytecode.flow.ControlFlowNode;
import net.covers1624.coffeegrinder.bytecode.flow.Dominance;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching;
import net.covers1624.coffeegrinder.bytecode.transform.BlockTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.BlockTransformer;
import net.covers1624.quack.collection.FastStream;
import org.jetbrains.annotations.Nullable;

import java.util.*;

import static net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching.matchLeave;

/**
 * Created by covers1624 on 9/8/21.
 */
public class SwitchDetection implements BlockTransformer {

    private static final ControlFlowNode NO_EXIT_POINT = new ControlFlowNode();

    @SuppressWarnings ("NotNullFieldNotInitialized")
    private BlockTransformContext ctx;

    @Override
    public void transform(Block block, BlockTransformContext ctx) {
        this.ctx = ctx;

        Instruction last = block.instructions.lastOrDefault();
        if (last instanceof SwitchTable) {
            transformSwitch(block, (SwitchTable) last);
        }
    }

    private void transformSwitch(Block block, SwitchTable switchTable) {
        ControlFlowNode entryPoint = ctx.getControlFlowNode();

        LinkedList<Block> switchNodes = new LinkedList<>();
        Block exitBlock = extendSwitch(entryPoint, switchNodes);
        assert switchNodes.getFirst() == block;
        switchNodes.removeFirst(); // Remove switch entrypoint

        BlockContainer switchContainer = new BlockContainer();
        Block newEntryPoint = new Block(block.getSubName("switch")).withOffsets(switchTable);
        newEntryPoint.instructions.add(switchTable);
        switchNodes.addFirst(newEntryPoint);

        block.instructions.add(new Switch(switchTable.getValue(), switchContainer).withOffsets(switchTable));
        switchTable.setValue(new Nop());

        if (exitBlock != null) {
            block.instructions.add(new Branch(exitBlock));
        }

        TransformerUtils.moveBlocksIntoContainer(switchNodes, (BlockContainer) block.getParent(), switchContainer, exitBlock);
    }

    @Nullable
    private Block extendSwitch(ControlFlowNode entryPoint, List<Block> switchNodes) {
        LoopContext loopCtx = new LoopContext(ctx.getControlFlowGraph(), entryPoint);

        Set<ControlFlowNode> inSwitchNodes = new HashSet<>(); // nodes which have to be in the switch because they're predecessors of a case block (fallthrough)
        loopCtx.getDominatorTreeChildren(entryPoint) // case nodes can only be dominated by switch head
                .forEach(c -> addPredecessorsInDominatorTree(entryPoint, c, inSwitchNodes));

        ControlFlowNode exitNode = pickExitViaBytecodeOffset(entryPoint, loopCtx);
        if (exitNode == NO_EXIT_POINT) {
            exitNode = computeImmediatePostDominator(entryPoint, inSwitchNodes, loopCtx);
        }

        ControlFlowNode finalExitNode = exitNode;
        entryPoint.streamPreOrder(e -> loopCtx.getDominatorTreeChildren(e).filter(c -> c != finalExitNode))
                .forEach(e -> switchNodes.add(e.block));

        assert FastStream.of(inSwitchNodes).map(ControlFlowNode::getBlock).allMatch(switchNodes::contains);
        return exitNode.block;
    }

    private void addPredecessorsInDominatorTree(ControlFlowNode head, ControlFlowNode n, Set<ControlFlowNode> preds) {
        if (n == head) return;

        for (ControlFlowNode pred : n.getPredecessors()) {
            if (head.dominates(pred) && preds.add(pred)) {
                addPredecessorsInDominatorTree(head, pred, preds);
            }
        }
    }

    private ControlFlowNode computeImmediatePostDominator(ControlFlowNode entryPoint, Set<ControlFlowNode> excludingNodes, LoopContext loopCtx) {
        BiMap<ControlFlowNode, ControlFlowNode> map = HashBiMap.create();
        ControlFlowNode revExitNode = map.computeIfAbsent(NO_EXIT_POINT, ControlFlowNode::new);

        entryPoint.streamPreOrder(loopCtx::getDominatorTreeChildren).forEach(node -> {
            ControlFlowNode revNode = map.computeIfAbsent(node, ControlFlowNode::new);

            for (ControlFlowNode succ : node.getSuccessors()) {
                if (loopCtx.matchContinue(succ, 1)) continue;
                ControlFlowNode succNode = map.computeIfAbsent(succ, ControlFlowNode::new);

                succNode.addEdgeTo(revNode);
                if (!entryPoint.dominates(succ) || ctx.getControlFlowGraph().hasDirectExit(node)) {
                    revExitNode.addEdgeTo(succNode);
                }
            }
        });

        if (revExitNode.getSuccessors().isEmpty()) return NO_EXIT_POINT;

        Dominance.computeDominance(revExitNode);

        ControlFlowNode revEntryPoint = Objects.requireNonNull(map.get(entryPoint));
        assert revEntryPoint.isReachable();

        ControlFlowNode revPostDominator = revEntryPoint.getImmediateDominator();
        while (excludingNodes.contains(map.inverse().get(revPostDominator))) { // this node must be in the switch, because it is a predecessor of another case
            revPostDominator = revPostDominator.getImmediateDominator();
        }

        return map.inverse().get(revPostDominator);
    }

    private static ControlFlowNode pickExitViaBytecodeOffset(ControlFlowNode entryPoint, LoopContext loopContext) {
        ControlFlowNode node = loopContext.getDominatorTreeChildren(entryPoint)
                .maxByOrDefault(e -> e.getBlock().getBytecodeOffset());
        if (node == null) return NO_EXIT_POINT;
        // only insert a break if it's going to be used more than once
        if (node.getPredecessors().size() <= 1) return NO_EXIT_POINT;
        return node;
    }

    @Nullable
    public static Branch matchIncrementBlock(Block block) {
        Branch branch = BranchLeaveMatching.matchBranch(block.instructions.last());
        if (branch == null) return null;

        // Ensure there are no stack variables assigned in other blocks referenced in this block.
        // If there are, we can't keep the block outside the switch as an increment block.
        // TODO, we should detect this boundary and split the block.
        if (!block.descendantsOfType(LocalReference.class)
                .filter(e -> e.variable.getKind() == LocalVariable.VariableKind.STACK_SLOT)
                .flatMap(e -> e.variable.getReferences())
                .allMatch(e -> e.isDescendantOf(block))) {
            return null;
        }
        return branch;
    }

    private static boolean matchDoWhileConditionBlock(Block block, Block loopHead) {
        if (block.instructions.size() < 2) return false;

        Instruction last = block.instructions.last();
        Instruction secondToLast = block.instructions.secondToLastOrDefault();
        if (!(secondToLast instanceof IfInstruction ifInstruction) || !(ifInstruction.getFalseInsn() instanceof Nop)) return false;

        Branch target = BranchLeaveMatching.matchBranch(ifInstruction.getTrueInsn());
        if (target != null && loopHead == target.getTargetBlock()) return true;

        if (BranchLeaveMatching.matchReturn(ifInstruction.getTrueInsn()) == null) return false;

        target = BranchLeaveMatching.matchBranch(last);
        Leave leave = matchLeave(last);

        return leave != null && target != null && loopHead == target.getTargetBlock();
    }

    public static class LoopContext {

        private final Object2IntMap<ControlFlowNode> continueDepth = new Object2IntOpenHashMap<>();

        public LoopContext(ControlFlowGraph cfg, ControlFlowNode contextNode) {
            List<ControlFlowNode> loopHeads = new ArrayList<>();

            contextNode.getSuccessors().forEach(e -> analyze(contextNode, e, loopHeads));
            cfg.resetVisited();

            int l = 1;
            for (ControlFlowNode loopHead : FastStream.of(loopHeads).sorted(Comparator.comparingInt(e -> e.postOrderNumber))) {
                continueDepth.put(findContinue(loopHead), l++);
            }
        }

        private void analyze(ControlFlowNode contextNode, ControlFlowNode n, List<ControlFlowNode> loopHeads) {
            if (n.visited) return;

            n.visited = true;
            if (n.dominates(contextNode)) {
                loopHeads.add(n);
            } else {
                n.getSuccessors().forEach(e -> analyze(contextNode, e, loopHeads));
            }
        }

        private static ControlFlowNode findContinue(ControlFlowNode loopHead) {
            //potential continue target
            ControlFlowNode pred = FastStream.of(loopHead.getPredecessors()).filter(e -> e != loopHead && loopHead.dominates(e)).onlyOrDefault();
            if (pred == null) return loopHead;
            assert pred.block != null;

            // match for loop increment block
            if (pred.getSuccessors().size() == 1) {
                Branch target = matchIncrementBlock(pred.block);
                if (target != null && target.getTargetBlock() == loopHead.block) {
                    return pred;
                }
            }

            if (pred.getSuccessors().size() <= 2) {
                if (loopHead.block != null && matchDoWhileConditionBlock(pred.block, loopHead.block)) {
                    return pred;
                }
            }

            return loopHead;
        }

        public boolean matchContinue(ControlFlowNode node, int depth) {
            return continueDepth.getOrDefault(node, -1) == depth;
        }

        /**
         * Similar to how the control flow graph ends with nodes which return/throw unconditionally
         * This method filters out children which are continue blocks (at any depth)
         */
        public FastStream<ControlFlowNode> getDominatorTreeChildren(ControlFlowNode node) {
            return FastStream.of(node.getDominatorTreeChildren()).filter(e -> !continueDepth.containsKey(e));
        }
    }
}
