package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.InsnOpcode;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.flow.ControlFlowNode;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.transform.BlockTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.BlockTransformer;
import net.covers1624.quack.collection.ColUtils;
import net.covers1624.quack.collection.FastStream;
import org.jetbrains.annotations.Nullable;

import java.util.LinkedList;
import java.util.List;
import java.util.function.Predicate;

import static net.covers1624.quack.collection.FastStream.of;

/**
 * Detects loops in Bytecode.
 * <p>
 * Created by covers1624 on 19/4/21.
 */
public class LoopDetection implements BlockTransformer {

    @Override
    public void transform(Block block, BlockTransformContext ctx) {
        // LoopDetection runs early enough so that block should still be in the original container at this point.
        assert block.getParentOrNull() == ctx.getControlFlowGraph().container;

        ControlFlowNode head = ctx.getControlFlowNode(); // CFG node for our potential loop head.
        assert head.block == block;

        LinkedList<Block> backEdges = null;
        for (ControlFlowNode n : head.getPredecessors()) {
            if (head.dominates(n)) {
                if (backEdges == null) {
                    backEdges = new LinkedList<>();
                }
                backEdges.add(n.getBlock());
            }
        }
        if (backEdges == null) return;

        LinkedList<Block> loop = new LinkedList<>();
        Block exitPoint = extendLoop(head, loop, backEdges);

        ctx.pushTiming("Construct loop");
        constructLoop(loop, exitPoint);
        ctx.popTiming();
    }

    private static FastStream<Block> getDominanceFrontier(ControlFlowNode root, Predicate<Block> p) {
        return of(root.getDominatorTreeChildren())
                .flatMap(c -> p.test(c.getBlock()) ? of(c.getBlock()) : getDominanceFrontier(c, p));
    }

    @Nullable
    private static Block extendLoop(ControlFlowNode entryPoint, List<Block> loop, List<Block> backEdges) {
        BlockContainer container = (BlockContainer) entryPoint.getBlock().getParent();
        Block lastBackEdge = ColUtils.requireMaxBy(backEdges, Instruction::getBytecodeOffset);

        Block goodExit = getDominanceFrontier(entryPoint,
                    b -> b.getParent() == container && b.getBytecodeOffset() > lastBackEdge.getBytecodeOffset())
                .maxByOrDefault(Instruction::getBytecodeOffset);

        Block lastBlock = goodExit != null ? (Block) goodExit.getPrevSibling() : lastBackEdge;

        // lastBlock may have been moved into an inner loop already
        while (lastBlock.getParent() != container) {
            lastBlock = lastBlock.getParent().firstAncestorOfType(InsnOpcode.BLOCK);
        }

        Block b = entryPoint.getBlock();
        while (true) {
            loop.add(b);
            if (b == lastBlock) {
                break;
            }

            b = (Block) b.getNextSibling();
        }

        return goodExit;
    }

    /**
     * Move the blocks associated with the loop into a new block container.
     */
    private void constructLoop(List<Block> blocks, @Nullable Block exitTargetBlock) {
        Block oldEntryPoint = blocks.remove(0);
        BlockContainer fromContainer = (BlockContainer) oldEntryPoint.getParent();
        assert oldEntryPoint.getBytecodeOffset() >= 0;

        WhileLoop whileLoop = new WhileLoop(new BlockContainer(), new LdcBoolean(true)).withOffsets(oldEntryPoint);
        BlockContainer loopContainer = whileLoop.getBody();

        // Move contents of oldEntryPoint to newEntryPoint
        // (we can't move the block itself because it might be the target of branch instructions outside the loop)
        Block newEntryPoint = new Block(oldEntryPoint.getSubName("loop")).withOffsets(oldEntryPoint);
        newEntryPoint.instructions.addAll(oldEntryPoint.instructions);
        blocks.add(0, newEntryPoint);

        oldEntryPoint.instructions.add(whileLoop);
        if (exitTargetBlock != null) {
            oldEntryPoint.instructions.add(new Branch(exitTargetBlock));
        }
        TransformerUtils.moveBlocksIntoContainer(blocks, fromContainer, loopContainer, exitTargetBlock);

        // Rewrite branches within the loop from oldEntryPoint to newEntryPoint
        for (Branch branch : oldEntryPoint.getBranches().toList()) {
            if (branch.isDescendantOf(loopContainer)) { branch.replaceWith(new Branch(newEntryPoint).withOffsets(branch)); }
        }
    }
}
