/*
 * Decompiled with CFR 0.152.
 */
package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import java.util.Iterator;
import java.util.LinkedList;
import java.util.Objects;
import java.util.function.Function;
import net.covers1624.coffeegrinder.bytecode.InsnOpcode;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.InstructionFlag;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.Block;
import net.covers1624.coffeegrinder.bytecode.insns.BlockContainer;
import net.covers1624.coffeegrinder.bytecode.insns.Branch;
import net.covers1624.coffeegrinder.bytecode.insns.Leave;
import net.covers1624.coffeegrinder.bytecode.insns.MethodDecl;
import net.covers1624.coffeegrinder.bytecode.insns.Switch;
import net.covers1624.coffeegrinder.bytecode.insns.TryCatch;
import net.covers1624.coffeegrinder.bytecode.insns.TryFinally;
import net.covers1624.coffeegrinder.bytecode.insns.TryWithResources;
import net.covers1624.coffeegrinder.bytecode.insns.WhileLoop;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformer;
import net.covers1624.coffeegrinder.util.None;
import net.covers1624.quack.collection.FastStream;

public class DetectExitPoints
extends SimpleInsnVisitor<MethodTransformContext>
implements MethodTransformer {
    @Override
    public void transform(MethodDecl function, MethodTransformContext ctx) {
        function.accept(this, ctx);
    }

    @Override
    public None visitTryCatch(TryCatch tryCatch, MethodTransformContext ctx) {
        super.visitTryCatch(tryCatch, ctx);
        this.detectExitPoints(tryCatch, ctx);
        Iterator<TryCatch.TryCatchHandler> iterator = tryCatch.handlers.iterator();
        while (iterator.hasNext()) {
            TryCatch.TryCatchHandler handler = iterator.next();
            this.combineExitsAtEndOfCatchContainer(ctx, handler);
        }
        return NONE;
    }

    private void combineExitsAtEndOfCatchContainer(MethodTransformContext ctx, TryCatch.TryCatchHandler handler) {
        BlockContainer body = handler.getBody();
        if (body.getLeaveCount() == 0) {
            return;
        }
        ctx.pushStep("Introduce leave block for catch");
        Block leaveBlock = new Block();
        body.blocks.add(leaveBlock);
        for (Leave leave : body.getLeaves().toList()) {
            leave.replaceWith(new Branch(leaveBlock).withOffsets(leave));
        }
        leaveBlock.instructions.add(new Leave(body));
        ctx.popStep();
    }

    @Override
    public None visitTryFinally(TryFinally tryFinally, MethodTransformContext ctx) {
        super.visitTryFinally(tryFinally, ctx);
        this.detectExitPoints(tryFinally, ctx);
        return NONE;
    }

    @Override
    public None visitTryWithResources(TryWithResources tryWithResources, MethodTransformContext ctx) {
        super.visitTryWithResources(tryWithResources, ctx);
        this.detectExitPoints(tryWithResources, ctx);
        return NONE;
    }

    @Override
    public None visitWhileLoop(WhileLoop whileLoop, MethodTransformContext ctx) {
        super.visitWhileLoop(whileLoop, ctx);
        this.detectExitPoints(whileLoop, ctx);
        return NONE;
    }

    @Override
    public None visitSwitch(Switch switchInsn, MethodTransformContext ctx) {
        super.visitSwitch(switchInsn, ctx);
        this.detectExitPoints(switchInsn, ctx);
        return NONE;
    }

    private void detectExitPoints(Instruction insn, MethodTransformContext ctx) {
        if (!insn.hasFlag(InstructionFlag.END_POINT_UNREACHABLE)) {
            return;
        }
        LinkedList exits = insn.descendantsOfType(InsnOpcode.BRANCH).map(Branch::getTargetBlock).filter(b -> !b.isDescendantOf(insn)).toLinkedList();
        if (exits.isEmpty()) {
            return;
        }
        assert (insn.getBytecodeOffset() >= 0);
        Block bestExit = (Block)FastStream.of((Iterable)exits).filter(b -> b.getBytecodeOffset() > insn.getBytecodeOffset()).maxByOrDefault(b -> -b.getBytecodeOffset());
        if (bestExit == null) {
            bestExit = (Block)((FastStream.Group)FastStream.of((Iterable)exits).groupBy(Function.identity()).maxBy(FastStream::count)).getKey();
        }
        this.replaceBranchesWithLeaves(bestExit, insn, ctx);
    }

    private void replaceBranchesWithLeaves(Block target, Instruction insn, MethodTransformContext ctx) {
        LinkedList branches = target.getBranches().filter(b -> b.isDescendantOf(insn)).toLinkedList();
        if (branches.isEmpty()) {
            return;
        }
        ctx.pushStep("Rewrite branches to leaves");
        for (Branch b2 : branches) {
            b2.replaceWith(new Leave(this.getClosestContainer(insn, b2)));
        }
        insn.insertAfter(new Branch(target));
        ctx.popStep();
    }

    private BlockContainer getClosestContainer(Instruction parent, Instruction insn) {
        assert (insn.isDescendantOf(parent));
        BlockContainer c = null;
        while (insn.getParent() != parent) {
            if (!((insn = insn.getParent()) instanceof BlockContainer)) continue;
            c = (BlockContainer)insn;
        }
        return Objects.requireNonNull(c);
    }
}

