package net.covers1624.coffeegrinder.bytecode.transform;

import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.InstructionFlag;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.util.None;
import net.covers1624.quack.collection.ColUtils;
import net.covers1624.quack.collection.FastStream;
import org.jetbrains.annotations.Nullable;

import java.util.ArrayList;
import java.util.function.Function;

/**
 * Created by covers1624 on 30/11/22.
 */
public class ExitPointCleanup extends SimpleInsnVisitor<MethodTransformContext> implements MethodTransformer {

    @Override
    public void transform(MethodDecl function, MethodTransformContext ctx) {
        function.accept(this, ctx);
    }

    @Override
    public None visitBlock(Block block, MethodTransformContext ctx) {
        unwrapLeaveToContinueIn(block, ctx);
        return super.visitBlock(block, ctx);
    }

    @Override
    public None visitTryCatch(TryCatch tryCatch, MethodTransformContext ctx) {
        super.visitTryCatch(tryCatch, ctx);
        cleanupOrphanedExitAfterCatchFixups(tryCatch, ctx);
        return NONE;
    }

    @Override
    public None visitSwitch(Switch switchInsn, MethodTransformContext ctx) {
        pickBetterSwitchExit(switchInsn, ctx);
        super.visitSwitch(switchInsn, ctx);
        return NONE;
    }

    private void pickBetterSwitchExit(Switch switchInsn, MethodTransformContext ctx) {
        var indirectLeaves = switchInsn.descendantsOfType(Leave.class)
                .filter(l -> l.getTargetContainer() != switchInsn.getBody())
                .filter(l -> l.firstAncestorOfType(BlockContainer.class) == switchInsn.getBody())
                .toList();
        if (indirectLeaves.isEmpty()) return;

        while (!switchInsn.hasFlag(InstructionFlag.END_POINT_UNREACHABLE)) {
            if (unwrapLeaveToContinue(switchInsn, ctx) || inlineLeaveToDefault(switchInsn, ctx, indirectLeaves))
                continue;

            return;
        }

        // find the most frequent break target
        var container = FastStream.of(indirectLeaves)
                .map(Leave::getTargetContainer)
                .groupBy(Function.identity())
                .maxBy(FastStream::count)
                .getKey();

        ctx.pushStep("Pick exit for switch");
        FastStream.of(container.getLeaves().toList())
                .filter(l -> l.isDescendantOf(switchInsn.getBody()))
                .forEach(l -> l.replaceWith(new Leave(switchInsn.getBody())));

        if (switchInsn.getParent().getParent() == container && container.blocks.size() == 1 && container.getParent() instanceof Block) {
            ctx.pushStep("Unwrap labelled break container");
            container.getEntryPoint().instructions.forEach(container::insertAfter);
            container.remove();
            ctx.popStep();
        } else {
            switchInsn.insertAfter(new Leave(container));
        }
        ctx.popStep();
    }

    private boolean inlineLeaveToDefault(Switch switchInsn, MethodTransformContext ctx, ArrayList<Leave> indirectLeaves) {
        if (!(switchInsn.getParent().getParent() instanceof BlockContainer cont)) return false;
        if (!(cont.getParent() instanceof Block)) return false;
        if (!ColUtils.anyMatch(indirectLeaves, l -> l.getTargetContainer() == cont)) return false;
        if (cont.blocks.count() != 1) return false;
        if (!(switchInsn.getBody().getLeaves().onlyOrDefault() instanceof Leave leave)) return false;
        assert switchInsn.getSwitchTable().isExhaustive();
        if (switchInsn.getSwitchTable().sections.last().getBody() != leave) return false;

        var block = (Block)switchInsn.getParent();
        ctx.pushStep("Inline switch exit case");
        leave.replaceWith(block.extractRange(switchInsn.getNextSibling(), block.getLastChild()));
        cont.getLeaves().toList().forEach(l -> l.replaceWith(new Leave(switchInsn.getBody())));
        cont.replaceWith(switchInsn);
        ctx.popStep();
        return true;
    }

    private boolean unwrapLeaveToContinue(Switch switchInsn, MethodTransformContext ctx) {
        Instruction exit = switchInsn.getNextSiblingOrNull();
        if (exit == null) return false;
        Continue cont = matchContinueOrLeaveToContinue(exit);
        if (cont == null) return false;

        ctx.pushStep("Inline break to continue");
        for (Leave e : switchInsn.getBody().getLeaves().toList()) {
            e.replaceWith(new Continue(cont.getLoop()).withOffsets(e));
        }

        exit.remove();
        ctx.popStep();
        return true;
    }

    private void unwrapLeaveToContinueIn(Block block, MethodTransformContext ctx) {
        Instruction insn = block.instructions.secondToLastOrDefault();
        if (!(insn instanceof BlockContainer container)) return;

        unwrapLeaveToContinue(container, ctx);
    }

    private void unwrapLeaveToContinue(BlockContainer container, MethodTransformContext ctx) {
        if (container.blocks.size() > 1) return;

        Block parent = (Block) container.getParent();
        Instruction exit = container.getNextSibling();
        Continue cont = matchContinueOrLeaveToContinue(exit);
        if (cont == null) return;

        ctx.pushStep("Unwrap container with break to continue");
        for (Leave e : container.getLeaves().toList()) {
            e.replaceWith(new Continue(cont.getLoop()).withOffsets(e));
        }

        exit.remove();

        Block b = (Block) container.getFirstChild();
        parent.instructions.addAll(b.instructions);
        container.remove();

        ctx.popStep();
    }

    @Nullable
    private Continue matchContinueOrLeaveToContinue(Instruction insn) {
        if (insn instanceof Continue cont) return cont;
        if (insn instanceof Leave leave) return followFlowToContinue(leave);
        return null;
    }

    private void cleanupOrphanedExitAfterCatchFixups(TryCatch tryCatch, MethodTransformContext ctx) {
        if (!(tryCatch.getParent() instanceof Block)) return;

        Instruction next = tryCatch.getNextSiblingOrNull();
        if (!tryCatch.hasFlag(InstructionFlag.END_POINT_UNREACHABLE) || next == null) return;
        assert next.hasFlag(InstructionFlag.END_POINT_UNREACHABLE);

        ctx.pushStep("Remove unreachable fallthrough");
        next.remove();
        ctx.popStep();
    }

    @Nullable
    private Continue followFlowToContinue(Leave leave) {
        if (!(leave.getTargetContainer().getParent() instanceof TryCatch.TryCatchHandler)) return null;

        Instruction next = leave.getTargetContainer().getParent().getParent().getNextSibling();
        if (next instanceof Continue cont) return cont;
        if (next instanceof Leave leave1) return followFlowToContinue(leave1);
        return null;
    }
}
