package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.InstructionFlag;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.Comparison.ComparisonKind;
import net.covers1624.coffeegrinder.bytecode.matching.*;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformer;
import net.covers1624.coffeegrinder.bytecode.transform.transformers.statement.ExpressionTransforms;
import net.covers1624.coffeegrinder.bytecode.transform.transformers.statement.Inlining;
import net.covers1624.coffeegrinder.type.ClassType;
import net.covers1624.coffeegrinder.type.Method;
import net.covers1624.coffeegrinder.type.TypeResolver;
import net.covers1624.coffeegrinder.type.TypeSystem;
import net.covers1624.coffeegrinder.util.None;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;

import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

import static java.util.Objects.requireNonNull;
import static net.covers1624.coffeegrinder.bytecode.insns.Continue.matchContinue;
import static net.covers1624.coffeegrinder.bytecode.matching.AssignmentMatching.matchCompoundAssignment;
import static net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching.matchLeave;
import static net.covers1624.coffeegrinder.bytecode.matching.ComparisonMatching.matchComparison;
import static net.covers1624.coffeegrinder.bytecode.matching.LdcMatching.matchLdcInt;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.*;

/**
 * Created by covers1624 on 26/4/21.
 */
public class HighLevelLoops extends SimpleInsnVisitor<MethodTransformContext> implements MethodTransformer {

    private static final Type ITERATOR = Type.getType(Iterator.class);
    private static final Type ITERABLE = Type.getType(Iterable.class);

    private final ClassType iteratorClass;
    private final ClassType iterableClass;

    public HighLevelLoops(TypeResolver typeResolver) {
        iteratorClass = typeResolver.resolveClass(ITERATOR);
        iterableClass = typeResolver.resolveClass(ITERABLE);
    }

    @Override
    public void transform(MethodDecl function, MethodTransformContext ctx) {
        function.accept(this, ctx);
    }

    @Override
    public None visitWhileLoop(WhileLoop basicLoop, MethodTransformContext ctx) {
        super.visitWhileLoop(basicLoop, ctx);

        ctx.pushStep("Replace branches with continues");
        basicLoop.getBody().getEntryPoint().getBranches().toList().forEach(e -> e.replaceWith(new Continue(basicLoop).withOffsets(e)));
        ctx.popStep();

        if (transformWhileLoop(basicLoop, ctx)) {
            if (!transformForLoop(basicLoop, ctx)) {
                transformIteratorForEachLoop(basicLoop, ctx);
            }
        } else if (!transformDoWhileLoop(basicLoop, ctx)) {
            // Try to make a for loop with a true condition, this could reduce labelled breaks
            transformForLoop(basicLoop, ctx);
        }

        return NONE;
    }

    private static boolean transformWhileLoop(WhileLoop loop, MethodTransformContext ctx) {
        BlockContainer body = loop.getBody();
        Block entryPoint = body.getEntryPoint();

        if (!(entryPoint.getFirstChild() instanceof IfInstruction ifInsn)) return false;
        if (!(ifInsn.getFalseInsn() instanceof Nop)) return false;

        // The loop-body is nested within the if
        // if (loop-condition) { loop-body }
        // leave loop-container
        if (entryPoint.instructions.size() != 2 || matchLeave(entryPoint.instructions.last(), body) == null) return false;

        if (!ifInsn.getTrueInsn().hasFlag(InstructionFlag.END_POINT_UNREACHABLE)) {
            ((Block) ifInsn.getTrueInsn()).instructions.add(new Leave(body));
        }

        ConditionDetection.invertIf(ifInsn, ctx);
        ctx.pushStep("Capture While condition");
        loop.setCondition(new LogicNot(ifInsn.getCondition()));
        ifInsn.remove();
        ExpressionTransforms.runOnExpression(loop.getCondition(), ctx);
        ctx.popStep();
        return true;
    }

    private static boolean transformForLoop(WhileLoop whileLoop, MethodTransformContext ctx) {
        // if the while loop has continues, then we can't change their target by inserting an increment block
        if (whileLoop.getContinues().count() > 1) return false;

        BlockContainer body = whileLoop.getBody();
        Block lastBlock = (Block) body.getLastChild();
        if (matchContinue(lastBlock.getLastChild(), whileLoop) == null) return false;

        // Try and match 'int i = 0' before the while loop.
        // todo, multi-store?
        Store store = LoadStoreMatching.matchStoreLocal(whileLoop.getPrevSiblingOrNull());

        Block incBlock;
        BlockContainer loopBodyEnd = findLabelledContainerFromEndOfBlock(lastBlock);
        if (loopBodyEnd != null) {
            if (loopBodyEnd.getNextSibling() == lastBlock.getLastChild()) {
                return false;
            }
            incBlock = lastBlock.extractRange("increment", loopBodyEnd.getNextSibling(), requireNonNull(lastBlock.instructions.secondToLastOrDefault()));
        } else {
            // don't necessarily need to make a for loop, but we will if there's one that looks nice
            if (store == null || body.blocks.size() > 1) return false;

            Instruction increment = lastBlock.instructions.secondToLastOrDefault();
            if (increment == null || !isSimpleStatement(increment)) return false;

            // TODO, this heuristic might be better explicitly only matching STORE/COMPOUND_ASSIGNMENT/POST_INCREMENT
            if (increment.descendantsMatching(e -> matchLocalRef(e, store.getVariable())).isEmpty()) return false;

            incBlock = lastBlock.extractRange("increment", increment, increment);
        }

        ctx.pushStep("Produce for loop");
        Instruction initializer = store != null ? store : new Nop();
        ForLoop forLoop = replaceLoop(whileLoop, new ForLoop(initializer, whileLoop.getCondition(), body, incBlock));

        transformArrayForEachLoop(forLoop, ctx);
        ctx.popStep();
        return true;
    }

    private boolean transformIteratorForEachLoop(WhileLoop whileLoop, MethodTransformContext ctx) {
        // Match the following case:
        // STORE iteratorVar (INVOKE.[INTERFACE/VIRTUAL] (LOAD iterable) java.lang.Iterable.iterator())
        // WHILE_LOOP (INVOKE.INTERFACE(LOAD iteratorVar) java.util.Iterator.hasNext()) {
        //     STORE iteratorStore (INVOKE.INTERFACE (LOAD iteratorVar) java.util.Iterator.next())
        //     ...
        // }
        // ->
        // FOR_EACH_LOOP (STORE iteratorStore (NOP) : LOAD iterable) {
        //     ...
        // }

        // First match the 'hasNext' condition of the Loop.
        Invoke hasNextInvoke = InvokeMatching.matchInvoke(whileLoop.getCondition(), Invoke.InvokeKind.INTERFACE, "hasNext", Type.getMethodType("()Z"));
        if (hasNextInvoke == null) return false;

        Load condLoad = matchLoadLocal(hasNextInvoke.getTarget());
        if (condLoad == null) return false;
        LocalVariable iteratorVar = condLoad.getVariable();
        // Variable must be synthetic, loaded twice, and stored to once.
        if (!iteratorVar.isSynthetic() || iteratorVar.getLoadCount() != 2 || iteratorVar.getStoreCount() != 1) return false;
        // Variable must be an instance of Iterator.
        if (!TypeSystem.isAssignableTo(iteratorVar.getType(), iteratorClass)) return false;

        // Next match the store directly before the Loop.
        Store iteratorStore = LoadStoreMatching.matchStoreLocal(whileLoop.getPrevSiblingOrNull(), iteratorVar);
        if (iteratorStore == null) return false;

        // The call to 'Iterable.iterator()' may be Virtual if the type returned by Iterable.iterator() is an implementation of Iterable.
        if (!(iteratorStore.getValue() instanceof Invoke iteratorInvoke)) return false;
        if (iteratorInvoke.getKind() != Invoke.InvokeKind.VIRTUAL && iteratorInvoke.getKind() != Invoke.InvokeKind.INTERFACE) return false;
        Method iteratorInvokeMethod = iteratorInvoke.getMethod();
        if (!iteratorInvokeMethod.getName().equals("iterator")) return false;
        if (!iteratorInvokeMethod.getParameters().isEmpty()) return false;
        // Return type must be an instance of Iterator.
        if (!TypeSystem.isAssignableTo(iteratorInvokeMethod.getReturnType(), iteratorClass)) return false;
        // The Invoke target must be an instance of Iterable.
        if (!TypeSystem.isAssignableTo(iteratorInvoke.getTarget().getResultType(), iterableClass)) return false;

        // Next match the Variable store inside the Loop.
        Store nextStore = LoadStoreMatching.matchStoreLocal(whileLoop.getBody().getEntryPoint().getFirstChildOrNull());
        if (nextStore == null) return false;
        Instruction nextStoreValue = nextStore.getValue();
        // We may have an unboxing to perform.
        if (BoxingMatching.matchUnboxing(nextStoreValue) instanceof Invoke unbox) {
            nextStoreValue = unbox.getTarget();
        }
        // The store value may be:
        // STORE iteratorStore (CAST java/lang/String) (INVOKE.INTERFACE (LOAD iteratorVar) java.util.Iterator.next())
        if (nextStoreValue instanceof Cast cast) {
            // Unwrap cast if present. TODO, this should go away with generics. Maybe.
            nextStoreValue = cast.getArgument();
        }
        Invoke nextInvoke = InvokeMatching.matchInvoke(nextStoreValue, Invoke.InvokeKind.INTERFACE, "next");
        if (nextInvoke == null || LoadStoreMatching.matchLoadLocal(nextInvoke.getTarget(), iteratorVar) == null) return false;

        ctx.pushStep("Produce foreach loop from Iterator");

        replaceLoop(whileLoop, new ForEachLoop((LocalReference) nextStore.getReference(), iteratorInvoke.getTarget(), whileLoop.getBody()));

        iteratorStore.remove();
        nextStore.remove();
        ctx.popStep();
        return false;
    }

    private static boolean transformArrayForEachLoop(ForLoop forLoop, MethodTransformContext ctx) {
        // Match the following case:
        // STORE synArray(LOAD array)
        // STORE synLen(ARRAY_LEN.int(LOAD synArray))
        // FOR_LOOP (STORE synIndex(LDC_INT 0); COMPARISON.LESS_THAN(LOAD synIndex, LOAD synLen); COMPOUND_ASSIGNMENT.ADD(synIndex, LDC_INT 1)) {
        //     STORE arrayLoad(ARRAY_LOAD (LOAD synArray, LOAD synIndex))
        //     ...
        // }
        // ->
        // FOR_EACH_LOOP (STORE arrayLoad(NOP) : LOAD array) {
        //     ...
        // }

        // Match loop initializer.
        Store synIndexStore = matchStoreLocal(forLoop.getInitializer());
        if (synIndexStore == null || matchLdcInt(synIndexStore.getValue(), 0) == null) return false;
        LocalVariable synIndex = synIndexStore.getVariable();
        if (!synIndex.isSynthetic() || synIndex.getLoadCount() != 3 || synIndex.getStoreCount() != 2) return false;

        // Match synthetic length store above loop.
        Store synLenStore = matchStoreLocal(forLoop.getPrevSiblingOrNull());
        if (synLenStore == null) return false;
        LocalVariable synLen = synLenStore.getVariable();
        if (!synLen.isSynthetic() || synLen.getLoadCount() != 1 || synLen.getStoreCount() != 1) return false;

        // Match synthetic array variable.
        Store synArrayStore = matchStoreLocal(synLenStore.getPrevSiblingOrNull());
        if (synArrayStore == null) return false;
        LocalVariable synArray = synArrayStore.getVariable();
        if (!synArray.isSynthetic() || synArray.getLoadCount() != 2 || synArray.getStoreCount() != 1) return false;
        if (matchArrayLenLoad(synLenStore.getValue(), synArrayStore.getVariable()) == null) return false;

        // Match the comparison.
        if (matchComparison(forLoop.getCondition(), ComparisonKind.LESS_THAN, synIndex, synLen) == null) return false;

        // Match the increment.
        if (matchCompoundAssignment(forLoop.getIncrement().getFirstChild(), BinaryOp.ADD, 1) == null) return false;

        // Match the array load inside the loop.
        Store arrayStore = matchStoreLocal(forLoop.getBody().getEntryPoint().getFirstChildOrNull());
        if (arrayStore == null) return false;
        Instruction storeValue = arrayStore.getValue();
        if (BoxingMatching.matchUnboxing(storeValue) instanceof Invoke unbox) {
            storeValue = unbox.getTarget();
        }
        if (!isArrayElemLoad(storeValue, synArray, synIndex)) return false;

        ctx.pushStep("Produce foreach loop from array for-i");

        replaceLoop(forLoop, new ForEachLoop((LocalReference) arrayStore.getReference(), synArrayStore.getValue(), forLoop.getBody()));

        synLenStore.remove();
        synArrayStore.remove();
        arrayStore.remove();

        ctx.popStep();
        return true;
    }

    private static boolean isArrayElemLoad(Instruction insn, LocalVariable arrayVar, LocalVariable indexVar) {
        ArrayElementReference elemRef = matchLoadElemRef(insn);
        if (elemRef == null) return false;

        Load array = matchLoadLocal(elemRef.getArray(), arrayVar);
        if (array == null) return false;

        Load index = matchLoadLocal(elemRef.getIndex(), indexVar);
        return index != null;
    }

    private static boolean transformDoWhileLoop(WhileLoop basicLoop, MethodTransformContext ctx) {
        // if the while loop has continues, then we can't change their target by inserting an increment block
        if (basicLoop.getContinues().count() > 1) return false;

        BlockContainer body = basicLoop.getBody();
        Block lastBlock = (Block) body.getLastChild();

        // Match:
        // IF (cond) LEAVE body
        // CONTINUE basicLoop
        if (matchContinue(lastBlock.getLastChild(), basicLoop) == null) return false;
        IfInstruction ifInsn = IfMatching.matchNopFalseIf(lastBlock.instructions.secondToLastOrDefault());
        if (ifInsn == null) return false;
        if (BranchLeaveMatching.matchLeave(ifInsn.getTrueInsn(), body) == null) return false;

        ctx.pushStep("Produce do-while loop");

        BlockContainer loopBodyEnd = findLabelledContainerFromEndOfBlock(lastBlock);
        if (loopBodyEnd != null) {
            // Use a block container as extra pressure to inline into the condition. If we can inline, the container will be able to be unwrapped
            List<Runnable> extraTasks = new LinkedList<>();
            if (Inlining.matchWithPotentialInline(loopBodyEnd.getNextSibling(), extraTasks, ctx, e -> e == ifInsn ? ifInsn : null) != null) {
                extraTasks.forEach(Runnable::run);
            }
        }

        DoWhileLoop loop = replaceLoop(basicLoop, new DoWhileLoop(body, new LogicNot(ifInsn.getCondition())));
        ifInsn.remove();
        ExpressionTransforms.runOnExpression(loop.getCondition(), ctx);

        ctx.popStep();
        return true;
    }

    @Nullable
    private static BlockContainer findLabelledContainerFromEndOfBlock(Block inBlock) {
        Instruction insn = inBlock.getLastChild();
        while (true) {
            if (insn == null) return null;
            if (insn instanceof BlockContainer) break;
            insn = insn.getPrevSiblingOrNull();
        }

        return (BlockContainer) insn;
    }

    private static <T extends AbstractLoop> T replaceLoop(AbstractLoop oldLoop, T newLoop) {
        oldLoop.replaceWith(newLoop);

        for (Continue aContinue : oldLoop.getContinues().toList()) {
            aContinue.setLoop(newLoop);
        }

        assert oldLoop.getContinues().isEmpty();
        return newLoop;
    }

    private static boolean isSimpleStatement(Instruction insn) {
        return switch (insn) {
            case Store ignored -> true;
            case Load ignored -> true;
            case PostIncrement ignored -> true;
            case CompoundAssignment ignored -> true;
            default -> false;
        };
    }
}
