package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.InstructionFlag;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.Block;
import net.covers1624.coffeegrinder.bytecode.insns.Branch;
import net.covers1624.coffeegrinder.bytecode.insns.MethodDecl;
import net.covers1624.coffeegrinder.bytecode.insns.TryCatch;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformer;
import net.covers1624.coffeegrinder.util.None;
import org.jetbrains.annotations.Nullable;

import java.util.function.BiConsumer;

/**
 * Created by covers1624 on 11/25/25.
 */
public class TryCatchMerging implements MethodTransformer {

    @Override
    public void transform(MethodDecl function, MethodTransformContext ctx) {
        fixJavacFinallyExHandlerEntryInTryRangeBug(function, ctx);
        mergeSplitTries(function, ctx);
    }

    // Moves the entry points of try finally exception handlers out of the try body. Potentially unwrapping the entire try-finally
    // These are likely a javac bug
    //
    // BLOCK L0 (incoming: 3) {
    //     TRY_CATCH BLOCK_CONTAINER {
    //         BLOCK L_handler (incoming: 1) {
    //             STORE (LOCAL var, LOAD (LOCAL ex))
    //             ...maybe some monitor exit stuff
    //             BRANCH L_next
    //         }
    //     }
    //     CATCH (Throwable ex, Unprocessed Finally) BLOCK_CONTAINER {
    //         BLOCK SYN_L1 (incoming: 1) {
    //             BRANCH L_handler
    //         }
    //     }
    // }
    private static void fixJavacFinallyExHandlerEntryInTryRangeBug(MethodDecl function, MethodTransformContext ctx) {
        function.descendantsToList(TryCatch.TryCatchHandler.class)
                .forEach(handler -> {
                    var tryCatch = (TryCatch) handler.getParent();
                    var inBlock = (Block) tryCatch.getParent();
                    var target = getHandlerTarget(handler);

                    if (!target.isDescendantOf(tryCatch)) return;
                    assert tryCatch.handlers.only().isUnprocessedFinally;
                    assert !target.hasFlag(InstructionFlag.MAY_THROW);

                    if (tryCatch.getTryBody().blocks.size() == 1) {
                        ctx.pushStep("Unwrap redundant try-finally");
                        inBlock.instructions.addAllFirst(tryCatch.getTryBody().blocks.only().instructions);
                        tryCatch.remove();
                        ctx.popStep();
                    } else {
                        ctx.pushStep("Move finally handler entry block out of try range");
                        inBlock.insertAfter(target);
                        ctx.popStep();
                    }
                });
    }

    private static void mergeSplitTries(MethodDecl function, MethodTransformContext ctx) {
        function.accept(new SimpleInsnVisitor<>() {

            @Override
            public None visitDefault(Instruction insn, MethodTransformContext ctx) {
                Instruction child = insn.getFirstChildOrNull();
                while (child != null) {
                    assert child.getParent() == insn;
                    child.accept(this, ctx);
                    child = child.getNextSiblingOrNull();
                }
                return NONE;
            }

            @Override
            public None visitTryCatch(TryCatch tryCatch, MethodTransformContext ctx) {
                mergeSplitTries(tryCatch, ctx);
                return super.visitTryCatch(tryCatch, ctx);
            }
        }, ctx);
    }

    private static void mergeSplitTries(TryCatch first, MethodTransformContext ctx) {
        var target = getHandlerTarget(first.handlers.only());
        var tryCatches = target.getBranches()
                .map(TryCatchMerging::matchTryCatchHandlerBranch)
                .filterNonNull()
                .filter(tc -> tc != first)
                .toList();

        if (tryCatches.isEmpty()) return;

        var firstBody = first.getTryBody();
        for (TryCatch toProcess : tryCatches) {
            var inBlock = (Block) toProcess.getParent();

            // Can't CFG across containers, make sure we are dominated by the first body.
            if (!inBlock.getBranches().allMatch(e -> e.isDescendantOf(firstBody))) return;

            ctx.pushStep("Merge split try-catch " + ((Block) toProcess.getParent()).getName() + " dominated by first " + ((Block) first.getParent()).getName());

            var branchesToCheck = toProcess.getTryBody().descendantsOfType(Branch.class).toList();

            // Repoint branches from toProcess entry point to inBlock
            var toProcessEntry = toProcess.getTryBody().getEntryPoint();
            toProcessEntry.getBranches().toList().forEach(br -> br.setTargetBlock(inBlock));

            // Unwrap the entry block to preserve the original block label name
            inBlock.instructions.addAll(toProcessEntry.instructions);
            toProcessEntry.remove();

            // add all blocks to the first try body
            firstBody.blocks.add(inBlock);
            firstBody.blocks.addAll(toProcess.getTryBody().blocks);
            toProcess.remove();

            // Repoint branches to containing block of first, to first entry point
            branchesToCheck.forEach(b -> {
                while (b.isDescendantOf(b.getTargetBlock()) && b.getTargetBlock() != inBlock) {
                    TryCatch tryCatch = (TryCatch) b.getTargetBlock().getFirstChild();
                    b.setTargetBlock(tryCatch.getTryBody().getEntryPoint());
                }
            });
            ctx.popStep();
        }
    }

    private static @Nullable TryCatch matchTryCatchHandlerBranch(Branch branch) {
        if (!(branch.getParent().getParent().getParent() instanceof TryCatch.TryCatchHandler h)) return null;
        return (TryCatch) h.getParent();
    }

    private static Block getHandlerTarget(TryCatch.TryCatchHandler handler) {
        return ((Branch) handler.getBody().getEntryPoint().getFirstChild()).getTargetBlock();
    }
}
