package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.InstructionFlag;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformer;
import net.covers1624.coffeegrinder.bytecode.transform.transformers.statement.SwitchOnType;
import net.covers1624.coffeegrinder.type.ClassType;
import net.covers1624.coffeegrinder.type.TypeResolver;
import net.covers1624.coffeegrinder.util.None;
import net.covers1624.quack.collection.ColUtils;
import net.covers1624.quack.collection.FastStream;
import org.jetbrains.annotations.Nullable;

import java.util.List;
import java.util.Objects;
import java.util.function.Function;

import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchPushForPop;

public class DetectExitPoints extends SimpleInsnVisitor<MethodTransformContext> implements MethodTransformer {

    private final @Nullable ClassType switchBootstraps;

    public DetectExitPoints(TypeResolver resolver) {
        switchBootstraps = resolver.tryResolveClassDecl("java/lang/runtime/SwitchBootstraps");
    }

    @Override
    public void transform(MethodDecl function, MethodTransformContext ctx) {
        function.accept(this, ctx);
    }

    @Override
    public None visitTryCatch(TryCatch tryCatch, MethodTransformContext ctx) {
        super.visitTryCatch(tryCatch, ctx);
        detectExitPoints(tryCatch, ctx);
        for (TryCatch.TryCatchHandler handler : tryCatch.handlers) {
            combineExitsAtEndOfCatchContainer(ctx, handler);
        }
        return NONE;
    }

    private void combineExitsAtEndOfCatchContainer(MethodTransformContext ctx, TryCatch.TryCatchHandler handler) {
        BlockContainer body = handler.getBody();
        if (body.getLeaveCount() == 0) return;

        ctx.pushStep("Introduce leave block for catch");
        Block leaveBlock = new Block();
        body.blocks.add(leaveBlock);
        for (Leave leave : body.getLeaves().toList()) {
            leave.replaceWith(new Branch(leaveBlock).withOffsets(leave));
        }
        leaveBlock.instructions.add(new Leave(body));
        ctx.popStep();
    }

    @Override
    public None visitWhileLoop(WhileLoop whileLoop, MethodTransformContext ctx) {
        super.visitWhileLoop(whileLoop, ctx);
        detectExitPoints(whileLoop, ctx);
        return NONE;
    }

    @Override
    public None visitSwitch(Switch switchInsn, MethodTransformContext ctx) {
        super.visitSwitch(switchInsn, ctx);
        canonicalizeGuardedPatternMatch(switchInsn, ctx);
        detectExitPoints(switchInsn, ctx);
        return NONE;
    }

    private void canonicalizeGuardedPatternMatch(Switch switchInsn, MethodTransformContext ctx) {
        if (!switchInsn.hasFlag(InstructionFlag.END_POINT_UNREACHABLE)) return;

        var typeSwitchIndy = matchTypeSwitch(switchInsn);
        if (typeSwitchIndy == null) return;

        var guardContainer = switchInsn.firstAncestorOfType(BlockContainer.class);
        if (!(guardContainer.getParent() instanceof WhileLoop)) return;
        if (guardContainer.getLeaveCount() == 0) return;

        if (!guardContainer.getLeaves().allMatch(e-> e.isDescendantOf(switchInsn))) return;

        ctx.pushStep("Canonicalize guarded pattern switch");
        guardContainer.getLeaves().toList().forEach(e -> e.replaceWith(new Leave(switchInsn.getBody())));
        switchInsn.insertAfter(new Leave(guardContainer));
        ctx.popStep();
    }

    private void detectExitPoints(Instruction insn, MethodTransformContext ctx) {
        if (!insn.hasFlag(InstructionFlag.END_POINT_UNREACHABLE)) return;

        List<Block> exits = insn.descendantsOfType(Branch.class)
                .map(Branch::getTargetBlock)
                .filter(b -> !b.isDescendantOf(insn))
                .toLinkedList();

        if (exits.isEmpty()) return;

        // prioritise the nearest exit to this container, located _after_ it (fallthrough/break)
        assert insn.getBytecodeOffset() >= 0;
        Block bestExit = FastStream.of(exits)
                .filter(b -> b.getBytecodeOffset() > insn.getBytecodeOffset())
                .maxByOrDefault(b -> -b.getBytecodeOffset());

        if (bestExit == null) {
            // followed by whichever exit occurs most often (cosmetic)
            bestExit = FastStream.of(exits)
                    .groupBy(Function.identity())
                    .maxBy(FastStream::count)
                    .getKey();
        }

        replaceBranchesWithLeaves(bestExit, insn, ctx);
    }

    private void replaceBranchesWithLeaves(Block target, Instruction insn, MethodTransformContext ctx) {
        List<Branch> branches = target.getBranches().filter(b -> b.isDescendantOf(insn)).toLinkedList();
        if (branches.isEmpty()) return;

        ctx.pushStep("Rewrite branches to leaves");
        for (Branch b : branches) {
            b.replaceWith(new Leave(getClosestContainer(insn, b)));
        }

        insn.insertAfter(new Branch(target));
        ctx.popStep();
    }

    private BlockContainer getClosestContainer(Instruction parent, Instruction insn) {
        assert insn.isDescendantOf(parent);
        BlockContainer c = null;
        while (insn.getParent() != parent) {
            insn = insn.getParent();
            if (insn instanceof BlockContainer) {
                c = (BlockContainer) insn;
            }
        }

        return Objects.requireNonNull(c);
    }

    private @Nullable InvokeDynamic matchTypeSwitch(Switch swtch) {
        if (switchBootstraps == null) return null;

        return SwitchOnType.matchTypeSwitchIndy(matchPushForPop(swtch.getValue()), switchBootstraps);
    }
}
