package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.tags.RecordPatternComponentTag;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformer;
import net.covers1624.coffeegrinder.type.ClassType;
import net.covers1624.coffeegrinder.type.TypeResolver;
import org.jetbrains.annotations.Nullable;

import java.util.LinkedList;
import java.util.List;

import static net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching.matchBranch;
import static net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching.matchConstructorInvokeSpecial;
import static net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching.matchInvoke;
import static net.covers1624.coffeegrinder.bytecode.matching.LdcMatching.matchLdcBoolean;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.*;

/**
 * Created by covers1624 on 11/25/25.
 */
public class PrepareRecordPatterns implements MethodTransformer {

    private final @Nullable ClassType c_MatchException;

    public PrepareRecordPatterns(TypeResolver resolver) {
        c_MatchException = resolver.tryResolveClassDecl("java/lang/MatchException");
    }

    @Override
    public void transform(MethodDecl function, MethodTransformContext ctx) {
        if (c_MatchException == null) return;

        var matchExceptions = function.descendantsOfType(NewObject.class)
                .map(this::matchRecordPatternMatchExceptionThrow)
                .filterNonNull()
                .toList();
        matchExceptions.forEach(mE -> transformFromMatchException(mE, ctx));
    }

    private void transformFromMatchException(Block handlerBlock, MethodTransformContext ctx) {
        List<Runnable> actions = new LinkedList<>();
        for (Branch branch : handlerBlock.getBranches()) {
            if (!matchRecordPropertyRetrieval(branch, actions, ctx)) {
                return; // Not a valid pattern.
            }
        }

        ctx.pushStep("Process Record Pattern MatchException.");
        actions.forEach(Runnable::run);
        if (handlerBlock.isConnected()) {
            // MatchException throw block was shared between multiple deconstruction calls, clean it up.
            ctx.pushStep("Remove MatchException throw block");
            handlerBlock.remove();
            ctx.popStep();
        }
        ctx.popStep();
    }

    private boolean matchRecordPropertyRetrieval(Branch branch, List<Runnable> actions, MethodTransformContext ctx) {
        // Match the following pattern:
        // BLOCK L1 (incoming: 1) {
        //     TRY_CATCH BLOCK_CONTAINER { <-- Unwrap this try completely, replacing it with the store.
        //         BLOCK L1_try (incoming: 1) {
        //             STORE (LOCAL s_5, INVOKE.VIRTUAL (LOAD (LOCAL s_4)) value()) <-- The invoke we tag
        //             LEAVE L1_try
        //         }
        //     }
        //     CATCH (Throwable var_3) BLOCK_CONTAINER {
        //         BLOCK SYN_L0 (incoming: 1) {
        //             BRANCH L5 <-- our input branch
        //         }
        //     }
        //     BRANCH L2
        // }
        // BLOCK L2 (incoming: 1) {

        // Find our outer try.
        if (!(branch.getParentOrNull() instanceof Block inBlock)) return false;
        if (!(inBlock.getParentOrNull() instanceof BlockContainer inContainer)) return false;
        if (!(inContainer.getParentOrNull() instanceof TryCatch.TryCatchHandler handler)) return false;
        var tryCatch = handler.getTry();
        if (tryCatch.getNextSiblingOrNull() != null) return false;

        // Try body should have a single block
        var body = tryCatch.getTryBody();
        if (body.blocks.size() != 1) return false;

        // Ensure we immediately branch to the next block.
        if (!(body.getEntryPoint().getLastChild() instanceof Branch nextBr)) return false;
        Block tryInside = (Block) tryCatch.getParent();
        if (tryInside != nextBr.getTargetBlock().getPrevSiblingOrNull()) return false;

        // Our var store should be the first thing.
        var entry = body.getEntryPoint();
        if (!(matchStoreLocal(entry.getFirstChild()) instanceof Store valueStore)) return false;
        if (valueStore.getVariable().getKind() != LocalVariable.VariableKind.STACK_SLOT) return false;
        // Must be an invoke to a record class, with no arguments.
        if (!(matchInvoke(valueStore.getValue(), Invoke.InvokeKind.VIRTUAL) instanceof Invoke valueRetriever)) return false;
        if (!valueRetriever.getTargetClassType().isRecord()) return false;
        if (!valueRetriever.getArguments().isEmpty()) return false;
        if (valueStore.getNextSiblingOrNull() != nextBr) return false;

        actions.add(() -> {
            ctx.pushStep("Unwrap record pattern try-catch");
            tryInside.instructions.addAll(body.blocks.only().instructions);
            tryCatch.remove();
            valueRetriever.setTag(new RecordPatternComponentTag());
            tryReplaceExactConversionWithInstanceOf(nextBr.getTargetBlock(), ctx);
            ctx.popStep();
        });
        return true;
    }

    private @Nullable Block matchRecordPatternMatchExceptionThrow(NewObject newObj) {
        if (!newObj.getType().equals(c_MatchException)) return null;

        // Match the following pattern:
        // STORE (LOCAL s_9, NEW_OBJECT MatchException)
        //
        // STORE (LOCAL s_10, LOAD (LOCAL var_3))
        // STORE (LOCAL s_11, INVOKE.VIRTUAL (LOAD (LOCAL s_10)) toString())
        //
        // STORE (LOCAL s_12, LOAD (LOCAL var_3))
        // INVOKE.SPECIAL (LOAD (LOCAL s_9)) MatchException.<init>(LOAD (LOCAL s_11), LOAD (LOCAL s_12))
        //
        // STORE (LOCAL s_13, LOAD (LOCAL s_9))
        // THROW (LOAD (LOCAL s_13))
        if (!(matchStoreLocal(newObj.getParent()) instanceof Store objStore)) return null;
        if (!(objStore.getParent() instanceof Block handlerBlock)) return null;

        if (!(matchStoreLocal(objStore.getNextSiblingOrNull()) instanceof Store toStringTarget)) return null;
        if (!(matchLoadLocal(toStringTarget.getValue()) instanceof Load exLoad)) return null;
        if (!(matchStoreLocal(toStringTarget.getNextSiblingOrNull()) instanceof Store toStringResult)) return null;
        if (!(matchInvoke(toStringResult.getValue(), Invoke.InvokeKind.VIRTUAL, "toString") instanceof Invoke toStringCall) || !toStringCall.getArguments().isEmpty()) return null;

        if (!(matchStoreLocal(toStringResult.getNextSiblingOrNull()) instanceof Store ctorP2)) return null;
        if (matchLoadLocal(ctorP2.getValue(), exLoad.getVariable()) == null) return null;
        if (!(matchConstructorInvokeSpecial(ctorP2.getNextSiblingOrNull(), c_MatchException) instanceof Invoke ctorCall) || ctorCall.getArguments().size() != 2) return null;
        // Match ctor target and args
        if (matchLoadLocal(ctorCall.getTarget(), objStore.getVariable()) == null) return null;
        if (matchLoadLocal(ctorCall.getArguments().get(0), toStringResult.getVariable()) == null) return null;
        if (matchLoadLocal(ctorCall.getArguments().get(1), ctorP2.getVariable()) == null) return null;

        if (!(matchStoreLocal(ctorCall.getNextSiblingOrNull()) instanceof Store throwStore)) return null;
        if (matchLoadLocal(throwStore.getValue(), objStore.getVariable()) == null) return null;
        if (!(throwStore.getNextSiblingOrNull() instanceof Throw thrw)) return null;
        if (matchLoadLocal(thrw.getArgument(), throwStore.getVariable()) == null) return null;

        return handlerBlock;
    }

    private void tryReplaceExactConversionWithInstanceOf(Block block, MethodTransformContext ctx) {
        // BLOCK L2 (incoming: 1) {
        //     ...
        //     STORE (LOCAL var_2, LOAD (LOCAL s_6))
        //     STORE (LOCAL s_7, LDC_BOOLEAN true)
        //     IF (LOGIC_NOT (LOAD (LOCAL s_7))) BRANCH L4
        //     BRANCH L2_1
        // }
        // BLOCK L2_1 (incoming: 1) {
        //     STORE (LOCAL s_8, LOAD (LOCAL var_1))
        // ->
        // BLOCK L2 (incoming: 1) {
        //     ...
        //     STORE (LOCAL s_7, INSTANCE_OF int(LOAD (LOCAL s_6)))
        //     IF (LOGIC_NOT (LOAD (LOCAL s_7))) BRANCH L4
        //     BRANCH L2_1
        // }
        // BLOCK L2_1 (incoming: 1) {
        //     STORE (LOCAL s_8, CHECK_CAST int(LOAD (LOCAL var_1)))
        if (!(block.instructions.secondToLastOrDefault() instanceof IfInstruction ifInsn)) return;

        // TODO Not always if true, may be `ExactConversionSupport` call when those aren't preview
        if (!(ifInsn.getCondition() instanceof LogicNot lN)) return;
        if (!(matchLdcBoolean(matchPushForPop(lN.getArgument()), true) instanceof LdcBoolean ldcTrue)) return;
        if (!(ifInsn.getPrevSibling().getPrevSiblingOrNull() instanceof Store deadVar) || deadVar.getVariable().getLoadCount() != 0) return;
        if (!(matchLoadLocal(matchPushForPop(deadVar.getValue())) instanceof Load sharedLoad)) return;

        if (!(block.getNextSiblingOrNull() instanceof Block next)) return;
        if (matchBranch(block.getLastChild(), next) == null) return;

        if (!(next.getFirstChild() instanceof Store store) || matchLoadLocal(store.getValue(), sharedLoad.getVariable()) == null) return;

        ctx.pushStep("Replace exact conversion with instanceof");
        var type = deadVar.getResultType();
        ldcTrue.replaceWith(new InstanceOf(deadVar.getValue(), type));
        store.getValue().replaceWith(new Cast(store.getValue(), type));
        deadVar.remove();
        ctx.popStep();
    }
}
