package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.InstructionFlag;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.TopLevelClassTransformer;
import net.covers1624.coffeegrinder.debug.Step;
import net.covers1624.coffeegrinder.type.ClassType;
import net.covers1624.coffeegrinder.type.TypeResolver;
import net.covers1624.coffeegrinder.util.None;
import org.jetbrains.annotations.Nullable;

import static net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching.matchLeave;
import static net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching.matchNew;

/**
 * Created by covers1624 on 11/16/25.
 */
public class SwitchCleanup extends SimpleInsnVisitor<ClassTransformContext> implements TopLevelClassTransformer {

    private final @Nullable ClassType c_MatchException;

    public SwitchCleanup(TypeResolver resolver) {
        c_MatchException = resolver.tryResolveClassDecl("java/lang/MatchException");
    }

    @Override
    public void transform(ClassDecl cInsn, ClassTransformContext ctx) {
        cInsn.accept(this, ctx);
    }

    @Override
    public None visitSwitch(Switch swtch, ClassTransformContext ctx) {
        super.visitSwitch(swtch, ctx);

        stripRedundantMatchExceptions(swtch, ctx);
        stripRedundantDefaultCases(swtch, ctx);
        removeRedundantDefault(swtch, ctx);

        return NONE;
    }

    @Override
    public None visitMethodDecl(MethodDecl methodDecl, ClassTransformContext ctx) {
        var doPush = methodDecl.getParent() instanceof ClassDecl;

        if (doPush) ctx.pushStep("Cleanup method " + methodDecl.getMethod().getName(), Step.StepContextType.METHOD);
        super.visitMethodDecl(methodDecl, ctx);
        if (doPush) ctx.popStep();

        return NONE;
    }

    private void stripRedundantMatchExceptions(Switch swtch, ClassTransformContext ctx) {
        if (c_MatchException == null) return;

        // TODO there may be cases where this is too aggressive. We need counter cases to detect explict
        //      vs implicit.

        // Cleanup switches with a redundant match exception in their first and default case.
        // Match the following pattern:
        // THROW NEW MatchException(LDC_NULL, LDC_NULL)

        var firstSection = swtch.getSwitchTable().sections.firstOrDefault();
        if (firstSection == null || !(firstSection.values.onlyOrDefault() instanceof Nop)) return;

        if (!(firstSection.getBody() instanceof Throw thrw)) return;

        if (!(matchNew(thrw.getArgument(), c_MatchException) instanceof New nw)) return;

        if (nw.getArguments().size() != 2) return;
        if (!nw.getArguments().allMatch(e -> e instanceof LdcNull)) return;

        ctx.pushStep("Remove redundant MatchException");
        firstSection.remove();
        ctx.popStep();
    }

    // Javac will conditionally emit TABLE_SWITCH instead of LOOKUP_SWITCH depending on a cost
    // function, using the case count, min and max value. Our goal here is to strip any cases
    // which mirror the default case, but only if it doesn't change the min/max value in a table switch,
    // as this would both change the bytecode in the switch, and potentially make Javac choose a LOOKUP_SWITCH.
    private static void stripRedundantDefaultCases(Switch swtch, ClassTransformContext ctx) {
        var table = swtch.getSwitchTable();
        // We can only do this processing on switches which came from a TableSwitch.
        if (table.tableInfo == null) return;

        var defaultSection = table.sections
                .filter(e -> e.values.anyMatch(v -> v instanceof Nop))
                .onlyOrDefault();

        if (defaultSection == null) return;

        var tI = table.tableInfo;
        ctx.pushStep("Remove redundant default table switch case labels.");
        for (Instruction e : defaultSection.values.toList()) {
            switch (e) {
                // We can remove any entry which is between our found high/low range.
                case LdcNumber num when betweenE(tI.min(), num.intValue(), tI.max()) -> e.remove();
                case LdcChar chr when betweenE(tI.min(), chr.getValue(), tI.max()) -> e.remove();
                default -> { }
            }
        }
        ctx.popStep();
    }

    private static boolean betweenE(int min, int value, int max) {
        return min < value && value < max;
    }

    public void removeRedundantDefault(Switch swtch, ClassTransformContext ctx) {
        var body = swtch.getBody();
        var table = swtch.getSwitchTable();

        // No cleanup for pattern switches, none of the entries are redundant.
        if (table.sections.anyMatch(e -> e.values.anyMatch(v -> v instanceof SwitchTable.SwitchPattern))) return;

        var defaultSection = table.sections
                .filter(e -> e.values.size() == 1 && e.values.first() instanceof Nop)
                .onlyOrDefault();
        if (defaultSection == null) return;

        if (defaultSection.getNextSiblingOrNull() != null) return;
        if (!(matchLeave(defaultSection.getBody(), body) instanceof Leave leave)) return;

        ctx.pushStep("Remove redundant default case");
        if (defaultSection.getPrevSiblingOrNull() instanceof SwitchTable.SwitchSection prev
            && !prev.getBody().getFlags().get(InstructionFlag.END_POINT_UNREACHABLE)) {
            prev.getBody().getLastChild().insertAfter(leave);
        }
        defaultSection.remove();
        ctx.popStep();
    }
}
