package net.covers1624.coffeegrinder.bytecode.transform.transformers.statement;

import net.covers1624.coffeegrinder.bytecode.AccessFlag;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformer;
import net.covers1624.coffeegrinder.type.ClassType;
import net.covers1624.coffeegrinder.type.Field;
import net.covers1624.coffeegrinder.type.TypeResolver;
import net.covers1624.quack.collection.FastStream;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;

import java.util.Objects;

import static java.util.Objects.requireNonNull;
import static net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching.matchLeave;
import static net.covers1624.coffeegrinder.bytecode.matching.IfMatching.matchNopFalseIf;
import static net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching.matchInvokeDynamic;
import static net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching.matchNew;
import static net.covers1624.coffeegrinder.bytecode.matching.LdcMatching.matchLdcBoolean;
import static net.covers1624.coffeegrinder.bytecode.matching.LdcMatching.matchLdcInt;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchLoadLocal;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchStoreLocal;

/**
 * Created by covers1624 on 8/27/25.
 */
public class SwitchOnType implements StatementTransformer {

    private final ClassType string;
    private final @Nullable ClassType switchBootstraps;
    private final @Nullable ClassType matchException;

    public SwitchOnType(TypeResolver resolver) {
        string = resolver.resolveClass(String.class);
        switchBootstraps = resolver.tryResolveClassDecl("java/lang/runtime/SwitchBootstraps");
        matchException = resolver.tryResolveClassDecl("java/lang/MatchException");
    }

    @Override
    public void transform(Instruction statement, StatementTransformContext ctx) {
        // Only process if we have SwitchBootstraps
        if (switchBootstraps == null) return;

        Store valueStore = matchStoreLocal(statement);
        if (valueStore == null) return;

        // Match:
        // STORE (value, ...)
        // STORE (startFrom, LDC_NUMBER 0)
        LocalVariable valueVar = valueStore.getVariable();

        Store startFromStore = matchStoreLocal(valueStore.getNextSiblingOrNull());
        if (startFromStore == null) return;
        LocalVariable startFromVar = startFromStore.getVariable();
        if (matchLdcInt(startFromStore.getValue(), 0) == null) return;

        // Pattern switches may be contained within a very specific loop construct
        // when guards are present.
        // TODO make sure these work inside a manual loop, they should because the variables will be in the way otherwise.
        Switch basicSwitch;
        WhileLoop loop = null;
        if (startFromStore.getNextSiblingOrNull() instanceof WhileLoop llp) {
            loop = llp;
            if (matchLdcBoolean(loop.getCondition(), true) == null) return;
            basicSwitch = matchPatternSwitchDecl(loop.getBody().getEntryPoint().getFirstChild(), valueVar, startFromVar);
        } else {
            Instruction next = startFromStore.getNextSiblingOrNull();
            basicSwitch = matchPatternSwitchDecl(next, valueVar, startFromVar);
        }
        if (basicSwitch == null) return;
        InvokeDynamic indy = (InvokeDynamic) basicSwitch.getValue();
        ClassType indyEnumType = indy.bootstrapHandle.getName().equals("enumSwitch") ? (ClassType) indy.descriptorArgs[0] : null;

        ctx.pushStep("Produce pattern matching switch");
        for (SwitchTable.SwitchSection section : basicSwitch.getSwitchTable().sections) {
            for (Instruction value : section.values) {
                if (value instanceof Nop) {
                    ctx.pushStep("Try capture unconditional default pattern");
                    // Default cases may be an unconditional pattern (same type as the switch input type)
                    tryCapturePatternVariable(valueVar, value, section, null, ctx);
                    ctx.popStep();
                } else {
                    int idx = ((LdcNumber) value).getValue().intValue();
                    ctx.pushStep("Rewrite case " + idx);
                    rewriteSwitchCase(loop, indyEnumType, valueVar, startFromVar, section, indy, idx, value, ctx);
                    ctx.popStep();
                }
            }
        }

        assert valueVar.getLoadCount() == 1 && valueVar.getStoreCount() == 1;
        assert startFromVar.getLoadCount() == 1 && startFromVar.getStoreCount() == 1;
        if (loop != null) {
            BlockContainer body = loop.getBody();
            assert loop.getContinues().isEmpty(); // Shouldn't exist yet anyway, but sure.
            assert body.getEntryPoint().getIncomingEdgeCount() == 1;
            assert body.getFirstChild() == body.getLastChild();
            assert body.getLeaveCount() == 1;
            assert matchLeave(body.getEntryPoint().getLastChild(), body) != null;
            ctx.pushStep("Remove redundant loop");
            // Use first child here, because we may already be an expression.
            loop.replaceWith(body.getEntryPoint().getFirstChild());
            ctx.popStep();
        }

        basicSwitch.getValue().replaceWith(valueStore.getValue());
        startFromStore.remove();
        valueStore.remove();

        cleanupRedundantMatchException(basicSwitch, ctx);

        ctx.popStep();
    }

    private void rewriteSwitchCase(
            @Nullable WhileLoop loop,
            @Nullable ClassType indyEnumType,
            LocalVariable valueVar, LocalVariable startFromVar,
            SwitchTable.SwitchSection section, InvokeDynamic indy,
            int idx, Instruction value,
            StatementTransformContext ctx) {
        // -1 is always 'null'
        if (idx == -1) {
            value.replaceWith(new LdcNull());
            return;
        }

        Object indyArg = indy.bootstrapArguments[idx];
        // If we have an indyEnumType (we are a SwitchBootstraps.enumSwitch), then every case is an enum switch.
        if (indyEnumType != null) {
            assert indyArg instanceof String;
            Field enumField = FastStream.of(indyEnumType.getFields())
                    .filter(e -> e.getAccessFlags().get(AccessFlag.ENUM) && e.getName().equals(indyArg))
                    .only();
            value.replaceWith(new FieldReference(enumField));
            return;
        }

        // Otherwise, SwitchBootstraps.typeSwitch only supports String, Integer, a Class, or an EnumDesc
        switch (indyArg) {
            case String s -> value.replaceWith(new LdcString(string, s));
            case Integer i -> value.replaceWith(new LdcNumber(i));
            case Type type -> {
                assert section.values.size() == 1; // Type switches can only have a single value per label.

                // Type's switches must always include a cast & captured var, and may include a guard condition
                transformTypeSwitchCase(loop, valueVar, startFromVar, value, section, idx, type, ctx);
            }
            case Field enumConstant -> value.replaceWith(new FieldReference(enumConstant));
            case null, default -> throw new UnsupportedOperationException("Unhandled type switch type: " + indyArg);
        }
    }

    private @Nullable SwitchTable.SwitchPattern tryCapturePatternVariable(
            LocalVariable valueVar,
            Instruction value, SwitchTable.SwitchSection section,
            @Nullable Type indyArg,
            StatementTransformContext ctx) {

        // Match the following pattern with an optional cast.
        // STORE (LOCAL ..., CAST indyArg(LOAD valueVar))
        if (!(section.getBody() instanceof Branch switchBranch)) return null;
        Block switchBody = switchBranch.getTargetBlock();
        if (!(switchBody.getFirstChild() instanceof Store patternVarStore)) return null;

        // If we have an unconditional pattern (exact switch type) we won't have a cast.
        Instruction unwrappedStoreValue = patternVarStore.getValue();
        if (unwrappedStoreValue instanceof Cast cast) {
            assert indyArg != null : "Expected indyArg to exist for pattern variables with a cast.";
            assert cast.getType().equals(ctx.getTypeResolver().resolveClass(indyArg));
            unwrappedStoreValue = cast.getArgument();
        }
        if (matchLoadLocal(unwrappedStoreValue, valueVar) == null) return null;

        ctx.pushStep("Capture variable into switch section.");
        SwitchTable.SwitchPattern pattern = value.replaceWith(new SwitchTable.SwitchPattern((LocalReference) patternVarStore.getReference()));
        patternVarStore.remove();
        ctx.popStep();
        return pattern;
    }

    private void transformTypeSwitchCase(
            @Nullable WhileLoop loop,
            LocalVariable valueVar, LocalVariable startFromVar,
            Instruction value, SwitchTable.SwitchSection section,
            int idx, Type indyArg,
            StatementTransformContext ctx) {

        Branch switchBranch = (Branch) section.getBody();
        Block switchBody = switchBranch.getTargetBlock();

        var pattern = tryCapturePatternVariable(valueVar, value, section, indyArg, ctx);
        if (pattern == null) throw new IllegalStateException("Did not match a pattern variable for a type switch.");

        // Match guard patterns:
        // IF (NOT(something)) BLOCK {
        //     STORE (startFromVar, LDC_NUMBER idx + 1)
        //     LEAVE basicSwitch
        // }
        IfInstruction guardIf = matchNopFalseIf(switchBody.getFirstChild());
        if (guardIf == null) return;
        if (!(guardIf.getTrueInsn() instanceof Block guardBlock)) return;
        if (guardBlock.getIncomingEdgeCount() > 1) return;

        Store setStartFromStore = matchStoreLocal(guardBlock.getFirstChild(), startFromVar);
        if (setStartFromStore == null) return;
        assert matchLdcInt(setStartFromStore.getValue(), idx + 1) != null;
        assert loop != null;
        assert ((Branch) setStartFromStore.getNextSibling()).getTargetBlock() == loop.getBody().getEntryPoint();

        ctx.pushStep("Capture guard condition");
        pattern = pattern.replaceWith(new SwitchTable.SwitchPattern(pattern.getReference(), new LogicNot(guardIf.getCondition())));
        guardIf.remove();
        ExpressionTransforms.runOnExpression(pattern.getCondition(), ctx); // Must run after guardIf.remove(), as ast is in an invalid state and this pushes steps.
        ctx.popStep();
    }

    private @Nullable Switch matchPatternSwitchDecl(@Nullable Instruction insn, LocalVariable valueVar, LocalVariable startFromVar) {
        if (insn instanceof Switch swtch) {
            if (matchTypeSwitchIndy(swtch.getValue(), valueVar, startFromVar) == null) return null;
            return swtch;
        }

        // The switch may have already been turned into a switch expression and inlined somewhere, lets go hunting!
        Instruction indyCandidate = FastStream.of(valueVar.getReferences())
                .filter(e -> e.isDescendantOf(insn))
                .map(e -> matchLoadLocal(e.getParent(), valueVar))
                .filter(Objects::nonNull)
                .map(e -> matchTypeSwitchIndy(e.getParent(), valueVar, startFromVar))
                .filter(Objects::nonNull)
                .onlyOrDefault();
        if (indyCandidate == null) return null;

        return indyCandidate.getParent() instanceof Switch swtch ? swtch : null;
    }

    private void cleanupRedundantMatchException(Switch basicSwitch, StatementTransformContext ctx) {
        SwitchTable.SwitchSection firstSection = basicSwitch.getSwitchTable().sections.firstOrDefault();

        if (firstSection == null) return; // What? Probably due for an assertion.
        if (!(firstSection.values.onlyOrDefault() instanceof Nop)) return;

        if (!(firstSection.getBody() instanceof Branch indirect)) return;

        Block block = indirect.getTargetBlock();
        if (block.getIncomingEdgeCount() != 1) return;

        if (!(block.getFirstChild() instanceof Throw thrw)) return;

        New nw = matchNew(thrw.getArgument(), requireNonNull(matchException));
        if (nw == null) return;
        if (nw.getArguments().size() != 2) return;
        for (Instruction argument : nw.getArguments()) {
            if (!(argument instanceof LdcNull)) return;
        }
        ctx.pushStep("Remove default MatchException");
        firstSection.remove();
        block.remove();
        ctx.popStep();
    }

    private @Nullable InvokeDynamic matchTypeSwitchIndy(Instruction insn, LocalVariable valueVar, LocalVariable startFromVar) {
        var indy = matchTypeSwitchIndy(insn, requireNonNull(switchBootstraps));
        if (indy == null) return null;

        if (matchLoadLocal(indy.arguments.get(0), valueVar) == null) return null;
        if (matchLoadLocal(indy.arguments.get(1), startFromVar) == null) return null;
        return indy;
    }

    public static @Nullable InvokeDynamic matchTypeSwitchIndy(@Nullable Instruction insn, ClassType switchBootstraps) {
        InvokeDynamic indy = matchInvokeDynamic(insn, switchBootstraps, "typeSwitch");
        if (indy == null) {
            indy = matchInvokeDynamic(insn, switchBootstraps, "enumSwitch");
        }
        if (indy == null) return null;

        if (indy.arguments.size() != 2) return null;
        return indy;
    }
}
