package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.InstructionFlag;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.matching.IfMatching;
import net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching;
import net.covers1624.coffeegrinder.bytecode.transform.BlockTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.BlockTransformer;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.transformers.statement.ExpressionTransforms;
import net.covers1624.coffeegrinder.bytecode.transform.transformers.statement.Inlining;
import net.covers1624.coffeegrinder.type.Field;
import net.covers1624.quack.collection.ColUtils;
import net.covers1624.quack.collection.FastStream;
import net.covers1624.quack.util.JavaVersion;
import org.jetbrains.annotations.Nullable;

import java.util.LinkedList;
import java.util.List;

import static java.util.Objects.requireNonNull;
import static net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching.compatibleExitInstruction;
import static net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching.matchBranch;
import static net.covers1624.coffeegrinder.bytecode.matching.IfMatching.matchNopFalseIf;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.*;

/**
 * Created by covers1624 on 25/5/21.
 */
public class ConditionDetection implements BlockTransformer {

    @SuppressWarnings ("NotNullFieldNotInitialized")
    private BlockTransformContext ctx;

    @Override
    public void transform(Block block, BlockTransformContext ctx) {
        this.ctx = ctx;

        // Because this transform runs at the beginning of the block transforms,
        // we know that `block` is still a (non-extended) basic block.

        // Previous-to-last instruction might have conditional control flow,
        // usually an IfInstruction with a branch:
        if (block.instructions.secondToLastOrDefault() instanceof IfInstruction ifInsn) {
            handleIfInstruction(block, ifInsn);
        } else {
            inlineExitBranch(block, ctx);
        }
    }

    /**
     * Repeatedly inlines and simplifies, maintaining a good block exit and then attempting to match bytecode order.
     *
     * @param block  The block containing the IfInstruction.
     * @param ifInsn The IfInstruction.
     */
    private void handleIfInstruction(Block block, IfInstruction ifInsn) {
        invertIf(ifInsn); // javac emits all if instructions as if (!cond) goto else;

        while (inlineTrueBranch(ifInsn) || inlineExitBranch(block, ctx)) {
            boolean needsInvert = prepareShortCircuitOr(ifInsn) || prepareInvertedTernary(ifInsn);
            tryMergeTrueExitWithFallthrough(block, ifInsn);
            introduceShortCircuit(ifInsn);

            produceTernary(ifInsn);

            if (needsInvert) {
                invertIf(ifInsn);
            }
        }
        mergeExitsForInline(block, ifInsn);
    }

    /**
     * <pre>
     *    if (a) {
     *      if (b) br false;
     *      br true;
     *    }
     *    if (c) br true;
     *    br false;
     * ->
     *    if (a) {
     *      if (b) br false;
     *      br true;
     *    }
     *    if (!c) br false;
     *    br true;
     * </pre>
     */
    private boolean prepareInvertedTernary(IfInstruction ifInsn) {
        if (!(ifInsn.getTrueInsn() instanceof Block ternaryThenCond)) return false;

        // match with potential inline to ensure that introduceShortCircuit below will run. No need to actually do the inlining here, just invert the if.
        IfInstruction thenIf = Inlining.matchWithPotentialInline(ternaryThenCond.getFirstChildOrNull(), new LinkedList<>(), ctx, IfMatching::matchNopFalseIf);
        if (thenIf == null) return false;

        IfInstruction elseIf = Inlining.matchWithPotentialInline(ifInsn.getNextSiblingOrNull(), new LinkedList<>(), ctx, IfMatching::matchNopFalseIf);
        if (elseIf == null) return false;

        if (!compatibleExitInstruction(thenIf.getTrueInsn(), elseIf.getNextSiblingOrNull())) return false;
        if (!compatibleExitInstruction(thenIf.getNextSiblingOrNull(), elseIf.getTrueInsn())) return false;

        ctx.pushStep("Prepare inverted ternary");
        invertIf(elseIf);
        ctx.popStep();
        return true;
    }

    /**
     * Only inlines branches that are strictly dominated by this block (incoming edge count == 1)
     * <pre>
     *    if (...) br trueBlock;
     * ->
     *    if (...) { trueBlock... }
     * </pre>
     *
     * @param ifInsn The IfInstruction.
     * @return If the true branch was inlined.
     */
    private boolean inlineTrueBranch(IfInstruction ifInsn) {
        if (!canInline(ifInsn.getTrueInsn())) return false;
        assert matchNopFalseIf(ifInsn) != null;

        // Don't reorder instructions
        Block targetBlock = ((Branch) ifInsn.getTrueInsn()).getTargetBlock();
        Instruction fallthrough = ifInsn.getNextSibling();
        int nextInsnBytecodeOffset = !(fallthrough instanceof Branch) && !(fallthrough instanceof Leave) ? fallthrough.getBytecodeOffset() : -1;
        if (nextInsnBytecodeOffset >= 0 && nextInsnBytecodeOffset < targetBlock.getBytecodeOffset()) return false;

        ctx.pushStep("Inline true-branch");
        // The targetBlock was already processed, and is ready to embed
        ifInsn.setTrueInsn(targetBlock);
        tryProducePatternMatch(ifInsn);

        if (targetBlock.instructions.size() == 1) { // makes matching easier for other transforms/ternaries
            targetBlock.replaceWith(targetBlock.getFirstChild());
        }

        ctx.popStep();
        return true;
    }

    private void tryProducePatternMatch(IfInstruction ifInsn) {
        // Only do this when we are on J17+
        if (ctx.classVersion.ordinal() < JavaVersion.JAVA_17.ordinal()) return;

        // Match the following pattern:
        // STORE (LOCAL s_36, INSTANCE_OF (LOAD (LOCAL s_35)) String)
        // IF (LOAD (LOCAL s_36)) BLOCK L16_1 {
        //     STORE (LOCAL str$4, CHECK_CAST String(LOAD (LOCAL var_5)))
        //     ..
        // }

        if (!(ifInsn.getTrueInsn() instanceof Block trueBlock)) return;

        // We must find a store + cast as the first thing in the true block.
        if (!(matchStoreLocal(trueBlock.getFirstChild()) instanceof Store store)) return;
        var variable = store.getVariable();
        if (variable.getKind() != LocalVariable.VariableKind.LOCAL) return;
        if (!(store.getValue() instanceof Cast cast)) return;

        // Prior to the if, we should have an InstanceOf (which its result will still be a stack var..)
        var condPushFrom = matchPushForPop(ifInsn.getCondition());
        if (condPushFrom == null || condPushFrom.getParent() != ifInsn.getPrevSiblingOrNull()) return;
        if (!(condPushFrom instanceof InstanceOf instanceOf)) return;

        // Instanceof and Cast must match types (still raw at this point)
        if (!cast.getType().equals(instanceOf.getType())) return;
        // Must load the same variable.
        if (!(matchLoadLocal(cast.getArgument()) instanceof Load castLoad)) return;
        if (matchLoadLocal(matchPushForPop(instanceOf.getArgument()), castLoad.getVariable()) == null) return;

        var cfg = ctx.getControlFlowGraph();
        var head = cfg.getNode(trueBlock);
        if (!ColUtils.allMatch(variable.getReferences(), i -> !i.isWrittenTo() || cfg.dominates(head, i))) return;

        ctx.pushStep("Inline Instanceof pattern");
        // Move the ref, and remove the store.
        instanceOf.setPattern((LocalReference) store.getReference());
        store.remove();

        // The cast for the pattern match is always placed in the true block, but if the condition is inverted in source,
        // the actual true block of the if will still be waiting to be inlined at the current else exit.
        // If the remaining true block after the pattern cast is just a branch to a to-be-inlined block,
        // and there are other blocks that needs to be inlined before it,
        // then those other blocks must belong in the true branch of the if, and we should invert.
        //
        // After making a pattern instanceof:
        // BLOCK L0 {
        //     STORE (LOCAL s_36, INSTANCE_OF (LOAD (LOCAL s_35), LOCAL var_5) String)
        //     IF (LOAD (LOCAL s_36)) BLOCK L16_1 {
        //         BRANCH LATER
        //     }
        //     BRANCH NEXT
        // }
        // BLOCK NEXT { ... }
        // BLOCK LATER { ... }
        // ->
        // BLOCK L0 {
        //     STORE (LOCAL s_36, INSTANCE_OF (LOAD (LOCAL s_35), LOCAL var_5) String)
        //     IF (LOAD (LOCAL s_36)) BLOCK L16_1 {
        //         BRANCH NEXT
        //     }
        //     BRANCH LATER
        // }
        // BLOCK NEXT { ... }
        // BLOCK LATER { ... }
        if ((trueBlock.getFirstChild() instanceof Branch br)) {
            if (br.getTargetBlock() != ifInsn.getParent().getNextSiblingOrNull()) {
                invertIf(ifInsn);
            }
        } else if (trueBlock.getFirstChild() instanceof Leave) {
            invertIf(ifInsn);
        }

        ctx.popStep();
    }

    /**
     * Only inlines branches that are strictly dominated by this block (incoming edge count == 1)
     * <pre>
     *    ...; br nextBlock;
     * ->
     *    ...; { nextBlock... }
     * </pre>
     *
     * @param block The block.
     * @return If the exit was inlined.
     */
    public static boolean inlineExitBranch(Block block, BlockTransformContext ctx) {

        Instruction exitInsn = getExit(block);
        if (!canInline(exitInsn)) return false;

        ctx.pushStep("Inline exit-branch");
        Block targetBlock = ((Branch) exitInsn).getTargetBlock();
        // The targetBlock was already processed, and is ready to embed
        block.instructions.last().remove();
        block.instructions.addAll(targetBlock.instructions);
        targetBlock.remove();
        ctx.popStep();

        return true;
    }

    /**
     * If this function returns true, we can replace the branch instruction with the block itself.
     *
     * @param exitInsn The insn to check.
     * @return If it should be inlined.
     */
    private static boolean canInline(Instruction exitInsn) {
        if (!(exitInsn instanceof Branch branch)) return false;

        Block targetBlock = branch.getTargetBlock();
        if (targetBlock.getIncomingEdgeCount() != 1) return false;

        Block parentBlock = exitInsn.ancestorsOfType(Block.class).filter(b -> b.getParent() instanceof BlockContainer).first();
        return parentBlock.getNextSiblingOrNull() == targetBlock;
    }

    /**
     * Looks for common exits in the inlined 'then' and 'else' branches of an if instruction
     * and performs inversion and simplifications to merge them provided they don't
     * isolate a higher priority block exit.
     *
     * @param block  The block containing the IfInstruction.
     * @param ifInsn The IfInstruction.
     */
    private void mergeExitsForInline(Block block, IfInstruction ifInsn) {
        if (matchNopFalseIf(ifInsn) == null) return;

        Block nextBlock = (Block) block.getNextSiblingOrNull();
        if (nextBlock == null) return;

        if (!nextBlock.getBranches().allMatch(b -> b.isDescendantOf(block))) return;

        moveBranchExitsTowardsRoot(block, nextBlock);
        inlineExitBranch(block, ctx);
    }

    private void moveBranchExitsTowardsRoot(Block block, Block targetBlock) {
        if (!block.hasFlag(InstructionFlag.END_POINT_UNREACHABLE)) return;

        for (Instruction insn = block.getLastChild(); insn != null; insn = insn.getPrevSiblingOrNull()) {
            IfInstruction ifInsn = matchNopFalseIf(insn);
            if (ifInsn == null) continue;

            if (ifInsn.getTrueInsn() instanceof Block trueBlock) {
                moveBranchExitsTowardsRoot(trueBlock, targetBlock);
            }

            Branch trueExit = matchBranch(tryGetExit(ifInsn.getTrueInsn()), targetBlock);
            if (trueExit == null) continue;

            if (compatibleExitInstruction(tryGetExit(block), trueExit)) {
                mergeTrueExitWithFallthrough(block, ifInsn);
            } else {
                moveTrueExitToFallthrough(block, ifInsn);
            }
        }
    }

    private void moveTrueExitToFallthrough(Block block, IfInstruction ifInsn) {
        //   if (...) {
        //     ...
        //     exit;
        //   }
        //   ...
        //   [exit]
        // ->
        //   if (...) {
        //     ...
        //   }
        //   else {
        //     ...
        //   }
        //   exit;
        ctx.pushStep("Move true exit to fallthrough");

        ctx.pushStep("Introduce else");
        ifInsn.setFalseInsn(block.extractRange(ifInsn.getNextSibling(), block.getLastChild()));
        ctx.popStep();

        Instruction exit = getExit(ifInsn.getTrueInsn());
        ensureParentIsBlock(exit);
        block.instructions.add(exit);
        ctx.popStep();
    }

    private void tryMergeTrueExitWithFallthrough(Block block, IfInstruction ifInsn) {
        if (matchNopFalseIf(ifInsn) == null) return;
        Instruction trueExit = tryGetExit(ifInsn.getTrueInsn());
        Instruction exit = tryGetExit(block);
        if (compatibleExitInstruction(trueExit, exit)) {
            mergeTrueExitWithFallthrough(block, ifInsn);
        }
    }

    private void mergeTrueExitWithFallthrough(Block block, IfInstruction ifInsn) {
        ctx.pushStep("Merge true exit with fallthrough");

        // if (...) { ...; blockExit; } ...; blockExit;
        // -> if (...) { ...; blockExit; } else { ... } blockExit;
        if (ifInsn != block.instructions.secondToLastOrDefault()) {
            assert ifInsn.getFalseInsn() instanceof Nop;
            ctx.pushStep("Introduce else");
            ifInsn.setFalseInsn(block.extractRange(ifInsn.getNextSibling(), block.getLastChild().getPrevSibling()));
            ctx.popStep();
        }

        // if (...) { ...; goto blockExit; } [else { ... }] blockExit;
        // -> if (...) { ... } [else { ... }] blockExit;
        Instruction exit = getExit(ifInsn.getTrueInsn());
        ensureParentIsBlock(exit);
        exit.remove();
        ctx.popStep();
    }

    /**
     * <pre>
     *    if (a) {
     *      if (b) br else;
     *      ...
     *      trueExit
     *    }
     *    br else;
     * ->
     *    if (a) {
     *      if (!b) { ...; trueExit }
     *      br else;
     *    }
     *    br else;
     * </pre>
     */
    private boolean prepareShortCircuitOr(IfInstruction ifInsn) {
        if (!(ifInsn.getNextSibling() instanceof Branch elseBranch)) return false;

        if (!(ifInsn.getTrueInsn() instanceof Block trueBlock) || tryGetExit(trueBlock) == null) return false;

        // match with potential inline to ensure that introduceShortCircuit below will run. No need to actually do the inlining here, just invert the if.
        IfInstruction trueIf = Inlining.matchWithPotentialInline(trueBlock.getFirstChildOrNull(), new LinkedList<>(), ctx, IfMatching::matchNopFalseIf);
        if (trueIf == null || matchBranch(trueIf.getTrueInsn(), elseBranch.getTargetBlock()) == null) return false;

        ctx.pushStep("Prepare short-circuit or");
        invertIf(trueIf);
        ctx.popStep();
        return true;
    }

    /**
     * Merges nested ifs.
     * <pre>
     *    if (cond) { if (nestedCond) { nestedThen... } }
     * ->
     *    if (cond && nestedCond) { nestedThen... }
     * </pre>
     *
     * @param ifInsn The IfInstruction.
     */
    private void introduceShortCircuit(IfInstruction ifInsn) {
        if (matchNopFalseIf(ifInsn) == null) return;
        if (!(ifInsn.getTrueInsn() instanceof Block trueBlock)) return;

        List<Runnable> extraTransforms = new LinkedList<>();
        IfInstruction nestedIf = Inlining.matchWithPotentialInline(trueBlock.getFirstChildOrNull(), extraTransforms, ctx, IfMatching::matchNopFalseIf);
        if (nestedIf == null) return;

        // short-circuit if must be the only instruction in the block, (aside from a potential extra required inline)
        if (nestedIf.getNextSiblingOrNull() != null) return;

        // Block short circuit when an if contains a single assertion.
        if (isAssertion(nestedIf)) return;

        ctx.pushStep("Introduce short-circuit");
        extraTransforms.forEach(Runnable::run);
        ifInsn.setCondition(new LogicAnd(ifInsn.getCondition(), nestedIf.getCondition()));
        ifInsn.setTrueInsn(nestedIf.getTrueInsn());
        ExpressionTransforms.runOnExpression(ifInsn.getCondition(), ctx);
        ctx.popStep();
    }

    // TODO unhardcode this. Perhaps use some type system analysis to identify the assertion field on the top level class
    private boolean isAssertion(IfInstruction ifInsn) {
        if (!(ifInsn.getCondition() instanceof LogicAnd and)) return false;
        if (!(and.getLeft() instanceof LogicNot not)) return false;

        Load load = LoadStoreMatching.matchLoadField(not.getArgument());
        if (load == null) return false;

        Field field = ((FieldReference) load.getReference()).getField();
        return field.isSynthetic() && field.getName().equals("$assertionsDisabled");
    }

    private void produceTernary(IfInstruction ifInsn) {
        // Match and transform the following case:
        //  IF (cond) {
        //      IF (thenCond) BRANCH then
        //  } else {
        //      IF (elseCond) BRANCH then
        //  }
        //  ->
        //  IF (IF(cond) thenCond else elseCond) {
        //      BRANCH then
        //  }

        // Match true branch
        if (!(ifInsn.getTrueInsn() instanceof Block trueBlock)) return;

        List<Runnable> extraTransforms = new LinkedList<>();
        IfInstruction thenIf = Inlining.matchWithPotentialInline(trueBlock.getFirstChildOrNull(), extraTransforms, ctx, IfMatching::matchNopFalseIf);
        if (thenIf == null || thenIf.getNextSiblingOrNull() != null) return;

        // Match false branch
        if (!(ifInsn.getFalseInsn() instanceof Block falseBlock)) return;

        IfInstruction elseIf = Inlining.matchWithPotentialInline(falseBlock.getFirstChildOrNull(), extraTransforms, ctx, IfMatching::matchNopFalseIf);
        if (elseIf == null || elseIf.getNextSiblingOrNull() != null) return;

        if (!compatibleExitInstruction(thenIf.getTrueInsn(), elseIf.getTrueInsn())) return;

        // Make ternary
        ctx.pushStep("Produce ternary");
        extraTransforms.forEach(Runnable::run);
        ifInsn.setCondition(new Ternary(ifInsn.getCondition(), thenIf.getCondition(), elseIf.getCondition()));
        ifInsn.setTrueInsn(FastStream.of(thenIf.getTrueInsn(), elseIf.getTrueInsn()).maxBy(Instruction::getBytecodeOffset));
        ifInsn.setFalseInsn(new Nop());
        ctx.popStep();
    }

    private void invertIf(IfInstruction ifInsn) {
        invertIf(ifInsn, ctx);
    }

    /**
     * Invert an IfInstruction.
     * Assumes the IfInstruction does not have an else block.
     * <pre>
     *    if (cond) { then... }
     *    else...;
     *    exit;
     * ->
     *    if (!cond) { else...; exit }
     *    then...;
     * </pre>
     *
     * @param ifInsn The IfInstruction.
     * @param ctx    The transform context.
     */
    public static void invertIf(IfInstruction ifInsn, MethodTransformContext ctx) {
        Block block = (Block) ifInsn.getParent();
        ctx.pushStep("Invert if");

        assert ifInsn.getParentOrNull() == block;
        assert ifInsn.getFalseInsn() instanceof Nop;

        //assert then block terminates
        getExit(ifInsn.getTrueInsn());
        Instruction elseInsn = getExit(block);
        if (ifInsn.getNextSibling() != elseInsn) {
            elseInsn = block.extractRange(ifInsn.getNextSibling(), elseInsn);
        }

        Instruction thenInsn = ifInsn.getTrueInsn();
        if (thenInsn instanceof Block thenBlock) {
            block.instructions.addAll(thenBlock.instructions);
        } else {
            block.instructions.add(thenInsn);
        }

        ifInsn.setTrueInsn(elseInsn);
        ifInsn.setCondition(new LogicNot(ifInsn.getCondition()));
        ExpressionTransforms.runOnExpression(ifInsn.getCondition(), ctx);
        ctx.popStep();
    }

    private static void ensureParentIsBlock(Instruction insn) {
        if (insn.getParent() instanceof Block) return;

        Block b = new Block();
        b.instructions.add(insn);
        insn.replaceWith(b);
    }

    /**
     * Determine if the specified instruction necessarily exits (END_POINT_UNREACHABLE)
     * and if so, return last (or single) exit instruction.
     *
     * @param insn The instruction to check.
     * @return The exit instruction, or <code>null</code>.
     */
    @Nullable
    private static Instruction tryGetExit(@Nullable Instruction insn) {
        Instruction exitInsn = insn;
        if (insn instanceof Block) {
            exitInsn = insn.getLastChildOrNull();
        }
        if (exitInsn != null && exitInsn.hasFlag(InstructionFlag.END_POINT_UNREACHABLE)) {
            return exitInsn;
        }
        return null;
    }

    /**
     * Gets the final instruction from a block (or a single instruction) assuming that all blocks
     * or instructions in this position have unreachable endpoints.
     *
     * @param insn The insn or block to check.
     * @return The exit insn.
     */
    private static Instruction getExit(Instruction insn) {
        return requireNonNull(tryGetExit(insn));
    }
}
