package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.AccessFlag;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.LocalVariable.VariableKind;
import net.covers1624.coffeegrinder.bytecode.insns.tags.RecordPatternComponentTag;
import net.covers1624.coffeegrinder.bytecode.insns.tags.SwitchRecordPatternTag;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformer;
import net.covers1624.coffeegrinder.type.*;
import net.covers1624.quack.collection.ColUtils;
import net.covers1624.quack.collection.FastStream;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;

import java.util.Objects;

import static java.util.Objects.requireNonNull;
import static net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching.matchInvokeDynamic;
import static net.covers1624.coffeegrinder.bytecode.matching.LdcMatching.matchLdcBoolean;
import static net.covers1624.coffeegrinder.bytecode.matching.LdcMatching.matchLdcInt;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.*;

/**
 * Created by covers1624 on 1/20/26.
 */
public class PrepareSwitchOnType implements MethodTransformer {

    private final ClassType string;
    private final @Nullable ClassType c_switchBootstraps;

    public PrepareSwitchOnType(TypeResolver resolver) {
        string = resolver.resolveClass(String.class);
        c_switchBootstraps = resolver.tryResolveClassDecl("java/lang/runtime/SwitchBootstraps");
    }

    @Override
    public void transform(MethodDecl function, MethodTransformContext ctx) {
        if (c_switchBootstraps == null) return;

        function.descendantsToList(Switch.class).forEach(e -> processSwitch(function, e, ctx));
    }

    private void processSwitch(MethodDecl function, Switch swtch, MethodTransformContext ctx) {
        assert c_switchBootstraps != null;
        // Match:
        // Object selPush = selVar
        // int idxPush = indexVar
        // int s_3 = INVOKE_DYNAMIC SwitchBootstraps.typeSwitch(.., selPush, idxPush)
        // [... s_swExpression = ] switch (s_3) {...}
        var indy = matchTypeSwitchIndy(matchPushForPop(swtch.getValue()), c_switchBootstraps);
        if (indy == null) return;

        ctx.pushStep("Process switch " + swtch.getBody().getEntryPoint().getName());
        boolean isExpression = swtch.getParent() instanceof Store;
        assert isExpression && indy.getParent() == swtch.getParent().getPrevSibling()    // An expression, indy must be direct prior sibling of store.
               || indy.getParent() == swtch.getPrevSiblingOrNull();                      // Not an expression, indy must be direct prior sibling
        var indyPush = (Store) indy.getParent();

        var idxLoad = requireNonNull(matchLoadLocal(matchPushForPop(indyPush.getPrevSibling(), indy.arguments.get(1))));
        var idxPush = (Store) idxLoad.getParent();
        var idxVar = idxLoad.getVariable();

        var selLoad = requireNonNull(matchLoadLocal(matchPushForPop(idxPush.getPrevSibling(), indy.arguments.get(0))));
        var selPush = (Store) selLoad.getParent();
        var selVar = selLoad.getVariable();

        // The switch may be inside a loop for retries, if so, unwrap it and produce SwitchGuard's.
        var outerLoop = matchOuterLoop(swtch, selPush);
        if (outerLoop != null) {
            tryUnwrapGuardLoop(outerLoop, swtch, isExpression, idxVar, ctx);
        }

        processSwitchCases(function, swtch, indy, selVar, ctx);
        ctx.popStep();
    }

    private @Nullable WhileLoop matchOuterLoop(Switch swtch, Store selPush) {
        // We expect the declarations for the selector and index variables to
        // directly precede the switch, if they don't, we have a guard loop.
        // Match:
        // L0: { selVar declared here, BR L1}
        // L1: {
        //     while (true) {
        //         L1_loop: {
        //             Object selPush = selVar;
        //             ....
        //         }
        //     }
        // }
        //
        if (selPush.getPrevSiblingOrNull() != null) return null; // Nothing to do.
        if (swtch.getNextSiblingOrNull() != null) return null;

        WhileLoop loop = (WhileLoop) selPush.getParent().getParent().getParent();
        assert loop.getBody().blocks.size() == 1;
        assert loop.getPrevSiblingOrNull() == null;
        assert loop.getNextSiblingOrNull() instanceof Branch;
        assert matchLdcBoolean(loop.getCondition(), true) != null;
        return loop;
    }

    private void tryUnwrapGuardLoop(WhileLoop outerLoop, Switch swtch, boolean isExpression, LocalVariable idxVar, MethodTransformContext ctx) {
        var entry = outerLoop.getBody().getEntryPoint();

        ctx.pushStep("Unwrap switch on type guard loop");
        // All loop entry points are matched and asserted they are retry guards.
        for (Branch branch : outerLoop.getBody().getEntryPoint().getBranches().toList()) {
            // Match:
            // L3_1: {
            //     int idxTmp = 1
            //     idxVar = idxTmp
            //     BRANCH loopEntry
            // }
            var block = ((Block) branch.getParent());
            var idxStore = requireNonNull(matchStoreLocal(branch.getPrevSibling(), idxVar));
            var idxTmp = requireNonNull(matchLdcInt(matchPushForPop(idxStore.getPrevSibling(), idxStore.getValue())));
            assert idxTmp.getParent().getPrevSiblingOrNull() == null;

            // We need to remove the whole block, otherwise we block exit point inlining from
            // putting the whole switch body inside all the guards.

            block.getBranches().toList().forEach(e -> e.replaceWith(new Switch.SwitchGuard(swtch)));
            block.remove();
        }
        if (isExpression) {
            // If we are an expression, we expect there to be a single leave from the loop, directly after the switch.
            var leave = outerLoop.getBody().getLeaves().only();
            assert swtch.getParent().getNextSibling() == leave;
            leave.remove();
        } else {
            // If we aren't an expression, all loop exits become switch exits.
            outerLoop.getBody().getLeaves().toList().forEach(leave -> leave.replaceWith(new Leave(swtch.getBody()).withOffsets(leave)));
        }

        // Loop has a single block, just directly unwrap it into the parent.
        var loopBlock = ((Block) outerLoop.getParent());
        loopBlock.instructions.addAllFirst(entry.instructions);
        outerLoop.remove();
        ctx.popStep();
    }

    private void processSwitchCases(MethodDecl function, Switch swtch, InvokeDynamic indy, LocalVariable selVar, MethodTransformContext ctx) {
        var indyArgs = indy.bootstrapArguments;
        if (indy.bootstrapHandle.getName().equals("enumSwitch")) {
            indyArgs = resolveEnumSwitchArgs((ClassType) indy.descriptorArgs[0], indyArgs);
        }

        ctx.pushStep("Process switch cases");
        for (SwitchTable.SwitchSection section : swtch.getSwitchTable().sections) {
            for (Instruction value : section.values) {
                if (value instanceof Nop) {
                    ctx.pushStep("Try capture unconditional default pattern");
                    // Default cases may be an unconditional pattern (same type as the switch input type)
                    tryCapturePatternVariable(function, swtch, selVar, value, section, null, ctx);
                    ctx.popStep();
                } else {
                    int idx = ((LdcNumber) value).getValue().intValue();
                    ctx.pushStep("Rewrite case " + idx);
                    rewriteSwitchCase(function, swtch, selVar, section, indyArgs, idx, value, ctx);
                    ctx.popStep();
                }
            }
        }
        ctx.popStep();
    }

    private boolean tryCapturePatternVariable(
            MethodDecl function, Switch swtch, LocalVariable selVar,
            Instruction value, SwitchTable.SwitchSection section,
            @Nullable ClassType indyArg,
            MethodTransformContext ctx
    ) {
        if (!(section.getBody() instanceof Branch switchBranch)) return false;
        Block sectionBody = switchBranch.getTargetBlock();

        // Match the following pattern with an optional cast.
        // STORE s_push1 = LOAD selVar
        // STORE s_push2 = CAST indyArg (LOAD s_push1) // Optional, excluded for unconditional patterns.
        // STORE var = LOAD s_push2

        if (!(matchPush(sectionBody.getFirstChild()) instanceof Store push1)) return false;
        if (!(matchLoadLocal(push1.getValue(), selVar) instanceof Load selVarLoad)) return false;
        var varStoreSource = push1;
        Cast cast = matchStoreCast(varStoreSource.getNextSiblingOrNull(), indyArg);
        if (cast != null) {
            varStoreSource = (Store) cast.getParent();
        }
        if (!(matchStoreLocalLoadLocal(varStoreSource.getNextSiblingOrNull(), varStoreSource.getVariable()) instanceof Store varStore)) return false;

        if (cast != null) {
            var patternVar = insertSyntheticRecordPatternInstanceofRoot(function, swtch, sectionBody, switchBranch, selVarLoad, cast, varStore, ctx);
            if (patternVar != null) {
                value.replaceWith(new SwitchTable.SwitchPattern(new LocalReference(patternVar)));
                return true;
            }
        }

        ctx.pushStep("Capture variable into switch section.");
        value.replaceWith(new SwitchTable.SwitchPattern((LocalReference) varStore.getReference()));
        push1.remove();
        if (varStoreSource != push1) varStoreSource.remove();
        varStore.remove();
        ctx.popStep();
        return true;
    }

    private @Nullable LocalVariable insertSyntheticRecordPatternInstanceofRoot(
            MethodDecl function, Switch swtch,
            Block sectionBody, Branch switchBranch,
            Load selVarLoad,
            Cast cast, Store varStore,
            MethodTransformContext ctx
    ) {

        // TODO, properly document:
        // This works around some issues in ConditionDetection, where having an instanceof makes the record pattern
        // infinitely easier to produce. We generate and insert the same pattern ConditionDetection would see for a regular
        // instanceof record pattern, we tag the variable shared between the instanceof and the switch pattern and pick it in
        // after conditions.

        // Only do this for records.
        var castedType = cast.getType();
        var synVar = varStore.getVariable();
        if (!(castedType instanceof ClassType recordType) || !recordType.isRecord()) return null;

        // Even if it's a record, it might not be a record pattern, to reduce redundant work, only do this
        // if the variable eventually gets used by a tagged record pattern deconstruction invoke.
        // Technically _all_ usages of the variable should end in a deconstruction, but it doesn't really matter.
        if (synVar.getStoreCount() != 1) return null;
        if (!ColUtils.anyMatch(synVar.getReferences(), this::isRecordPatternUsage)) return null;

        // Transform:
        // case value: BRANCH L2;
        // ...
        // L2 {
        //     Object s_push1 = LOAD selVar
        //     SomeType s_push2 = CAST indyArg (LOAD s_push1)
        // }
        // ->
        // case SomeType case$0: BRANCH L2_syn_record_pattern_instanceof;
        // ...
        // L2_syn_record_pattern_instanceof: {
        //     SomeType s_0 = case$0;
        //     boolean s_1 = s_0 instanceof SomeType
        //     if (!s_1) __SwitchGuard;
        //     BRANCH L2;
        // }
        // L2 {
        //     Object s_push1 = LOAD case$0
        //     SomeType s_push2 = CAST indyArg (LOAD s_push1)
        // }
        ctx.pushStep("Produce synthetic instanceof for record pattern case");
        Block synBlock = new Block(sectionBody.getSubName("syn_record_pattern_instanceof"));
        sectionBody.insertBefore(synBlock);
        var varTemp = new LocalVariable(VariableKind.STACK_SLOT, selVarLoad.getVariable().getType(), "s_" + function.variables.size());
        function.variables.add(varTemp);
        var boolStack = new LocalVariable(VariableKind.STACK_SLOT, PrimitiveType.BOOLEAN, "s_" + function.variables.size());
        function.variables.add(boolStack);

        var tempVar = new LocalVariable(VariableKind.TEMP_LOCAL, castedType, "case$" + function.variables.size());
        function.variables.add(tempVar);
        tempVar.setTag(new SwitchRecordPatternTag());

        var instanceOf = new InstanceOf(new Load(new LocalReference(varTemp)), castedType);
        synBlock.instructions.add(new Store(new LocalReference(varTemp), new Load(new LocalReference(tempVar))));
        synBlock.instructions.add(new Store(new LocalReference(boolStack), instanceOf));
        synBlock.instructions.add(new IfInstruction(new LogicNot(new Load(new LocalReference(boolStack))), new Switch.SwitchGuard(swtch)));
        synBlock.instructions.add(new Branch(sectionBody));
        switchBranch.replaceWith(new Branch(synBlock));
        selVarLoad.getReference().replaceWith(new LocalReference(tempVar));
        ctx.popStep();
        return tempVar;
    }

    private boolean isRecordPatternUsage(LocalReference ref) {
        if (!(matchLoadLocal(ref.getParent()) instanceof Load load)) return false;
        if (!(matchPush(load.getParent()) instanceof Store push)) return false;
        var pop = FastStream.of(push.getVariable().getReferences())
                .map(e -> matchPop(e.getParent()))
                .filter(Objects::nonNull)
                .onlyOrDefault();
        if (pop == null) return false;

        return pop.getParent().getTag() instanceof RecordPatternComponentTag;
    }

    private void rewriteSwitchCase(
            MethodDecl function, Switch swtch, LocalVariable selVar,
            SwitchTable.SwitchSection section, Object[] indyArgs,
            int idx, Instruction value,
            MethodTransformContext ctx) {
        // -1 is always 'null'
        if (idx == -1) {
            value.replaceWith(new LdcNull());
            return;
        }

        switch (indyArgs[idx]) {
            case String s -> value.replaceWith(new LdcString(string, s));
            case Integer i -> value.replaceWith(new LdcNumber(i));
            case Type type -> {
                assert section.values.size() == 1; // Type switches can only have a single value per label.
                var cType = ctx.getTypeResolver().resolveClass(type);

                // Type is expected to always have a pattern variable that needs to be captured.
                if (!tryCapturePatternVariable(function, swtch, selVar, value, section, cType, ctx)) {
                    throw new IllegalStateException("Did not match a pattern variable for a type switch.");
                }
            }
            case Field enumConstant -> value.replaceWith(new FieldReference(enumConstant));
            case null, default -> throw new UnsupportedOperationException("Unhandled type switch type: " + indyArgs[idx]);
        }
    }

    public static @Nullable InvokeDynamic matchTypeSwitchIndy(@Nullable Instruction insn, ClassType switchBootstraps) {
        InvokeDynamic indy = matchInvokeDynamic(insn, switchBootstraps, "typeSwitch");
        if (indy == null) {
            indy = matchInvokeDynamic(insn, switchBootstraps, "enumSwitch");
        }
        if (indy == null) return null;

        if (indy.arguments.size() != 2) return null;
        return indy;
    }

    private static Object[] resolveEnumSwitchArgs(ClassType indyEnumType, Object[] args) {
        args = args.clone();
        for (int i = 0; i < args.length; i++) {
            var arg = args[i];
            switch (arg) {
                case String enumName -> arg = FastStream.of(indyEnumType.getFields())
                        .filter(e -> e.getAccessFlags().get(AccessFlag.ENUM) && e.getName().equals(enumName))
                        .only();
                case Type ignored -> { }
                default -> throw new IllegalArgumentException("Unexpected type in SwitchBootstraps.enumSwitch indy arguments. " + arg);
            }
            args[i] = arg;
        }
        return args;
    }
}
