package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import net.covers1624.coffeegrinder.bytecode.AccessFlag;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.TopLevelClassTransformer;
import net.covers1624.coffeegrinder.type.ClassType;
import net.covers1624.coffeegrinder.type.Field;
import net.covers1624.quack.util.JavaVersion;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;

import java.util.HashMap;
import java.util.Map;
import java.util.function.IntFunction;

import static java.util.Objects.requireNonNull;
import static net.covers1624.coffeegrinder.bytecode.insns.Invoke.InvokeKind;
import static net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching.matchInvoke;
import static net.covers1624.coffeegrinder.bytecode.matching.LdcMatching.matchLdcInt;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchLoadFieldRef;

/**
 * Created by covers1624 on 5/12/21.
 */
public class SwitchOnEnum implements TopLevelClassTransformer {

    @Override
    public void transform(ClassDecl cInsn, ClassTransformContext ctx) {
        SwitchState switchState = new SwitchState(cInsn);

        for (Switch aSwitch : cInsn.descendantsOfType(Switch.class)) {
            if (transformSwitchMapEnumSwitch(cInsn.getClazz(), aSwitch, switchState, ctx)) continue;

            transformRawEnumSwitch(cInsn.getClazz(), aSwitch, ctx);
        }

        if (switchState.switchMapClass == null) return;

        ctx.pushStep("Remove switchmap inner class");
        switchState.switchMapClass.remove();
        ctx.popStep();
    }

    private boolean transformSwitchMapEnumSwitch(ClassType rootClass, Switch aSwitch, SwitchState switchState, ClassTransformContext ctx) {
        // Match the following pattern:
        // SWITCH(ARRAY_LOAD(LOAD_FIELD the.current.RootClass$1.$SwitchMap$the$current$RootClass$EnumName(NOP), INVOKE.VIRTUAL(LOAD dir) java.lang.Enum.ordinal())) {
        //     CASE LDC_INT ...: {
        //     }
        //     ... more cases
        // }
        // ->
        // SWITCH(LOAD dir) {
        //     CASE LOAD_FIELD EnumName.ENUM_CONSTANT(NOP): {
        //     }
        //     ... more cases
        // }

        if (!(aSwitch.getValue() instanceof Load load)) return false;
        if (!(load.getReference() instanceof ArrayElementReference arrayElemRef)) return false;

        // Match invoke first.
        Invoke invoke = matchInvoke(arrayElemRef.getIndex(), InvokeKind.VIRTUAL, "ordinal", Type.getMethodType("()I"));
        if (invoke == null) return false;

        FieldReference fieldRef = matchLoadFieldRef(arrayElemRef.getArray());
        if (fieldRef == null) return false;
        Field field = fieldRef.getField();
        ClassType fieldDecl = field.getDeclaringClass();

        // Field must be synthetic, and its class must be declared in the rootClass
        if (!fieldDecl.isSynthetic()) return false;
        if (fieldDecl.getEnclosingClass().map(e -> e == rootClass).isEmpty()) return false;

        Int2ObjectMap<Field> switchMap = switchState.getSwitchMap(field);

        transformSwitchOnEnum(aSwitch, invoke, ctx, switchMap::get);

        return true;
    }

    private static void transformSwitchOnEnum(Switch aSwitch, Invoke invoke, ClassTransformContext ctx, IntFunction<Field> fieldLookup) {
        ctx.pushStep("Transform switch on enum");
        for (SwitchTable.SwitchSection section : aSwitch.getSwitchTable().sections) {
            for (Instruction value : section.values) {
                if (value instanceof Nop) continue;

                LdcNumber ldc = matchLdcInt(value);
                assert ldc != null;
                Field enumField = fieldLookup.apply(ldc.intValue());
                assert enumField != null;
                ldc.replaceWith(new FieldReference(enumField));
            }
        }
        aSwitch.getValue().replaceWith(invoke.getTarget());
        ctx.popStep();
    }

    private boolean transformRawEnumSwitch(ClassType rootClass, Switch swtch, ClassTransformContext ctx) {
        if (!ctx.classVersion.isAtLeast(JavaVersion.JAVA_21)) return false;

        // On Java 21+ switch-on-enum, for enums defined in the same compilation unit as a switch usage will
        // not have a switch map generated, instead it will use the raw enum ordinal output.
        Invoke invoke = matchInvoke(swtch.getValue(), InvokeKind.VIRTUAL, "ordinal", Type.getMethodType("()I"));
        if (invoke == null) return false;

        var enumClass = invoke.getTargetClassType();
        if (!enumClass.isEnum()) return false;

        // Ensure the enum class is contained inside the same compilation unit as the switch.
        if (!rootClass.getTopLevelClass().equals(enumClass.getTopLevelClass())) return false;

        Int2ObjectMap<Field> switchMap = new Int2ObjectArrayMap<>();
        for (Field field : enumClass.getFields()) {
            if (!field.isStatic()) continue;
            if (!field.getAccessFlags().get(AccessFlag.ENUM)) continue;

            switchMap.put(switchMap.size(), field);
        }

        transformSwitchOnEnum(swtch, invoke, ctx, switchMap::get);

        return true;
    }

    // State holder for the current classes switch map information.
    private static class SwitchState {

        private final ClassDecl rootClass;

        @Nullable
        private ClassDecl switchMapClass;
        @Nullable
        private Map<Field, Int2ObjectMap<Field>> switchMap;

        private SwitchState(ClassDecl rootClass) {
            this.rootClass = rootClass;
        }

        private Int2ObjectMap<Field> getSwitchMap(Field switchMapField) {
            if (switchMap == null) {
                switchMapClass = rootClass.getClassMembers()
                        .filter(e -> e.getClazz() == switchMapField.getDeclaringClass())
                        .only();
                switchMap = parseSwitchMap(switchMapClass);
            }
            return requireNonNull(switchMap.get(switchMapField));
        }

        private Map<Field, Int2ObjectMap<Field>> parseSwitchMap(ClassDecl switchMapClass) {
            MethodDecl staticInit = switchMapClass.getMethod("<clinit>", Type.getMethodType("()V"));

            Map<Field, Int2ObjectMap<Field>> switchMap = new HashMap<>();

            staticInit.descendantsMatching(SwitchOnEnum::matchStoreElem).forEach(arrStore -> {
                // ARRAY_STORE (
                //     LOAD_FIELD RootClass$1.$SwitchMap$RootClass$EnumClass(NOP),               // Array
                //     INVOKE.VIRTUAL (LOAD_FIELD EnumClass.NAME(NOP)) java.lang.Enum.ordinal(), // Index
                //     LDC_INT <constant>                                                        // Value
                // )
                ArrayElementReference elemRef = (ArrayElementReference) arrStore.getReference();
                FieldReference arrayRef = requireNonNull(matchLoadFieldRef(elemRef.getArray()));
                Invoke invoke = requireNonNull(matchInvoke(elemRef.getIndex(), InvokeKind.VIRTUAL, "ordinal", Type.getMethodType("()I")));
                FieldReference enumFieldRef = requireNonNull(matchLoadFieldRef(invoke.getTarget()));
                LdcNumber constantLdc = requireNonNull(matchLdcInt(arrStore.getValue()));

                switchMap.computeIfAbsent(arrayRef.getField(), e -> new Int2ObjectArrayMap<>())
                        .put(constantLdc.intValue(), enumFieldRef.getField());
            });
            return switchMap;
        }
    }

    @Nullable
    private static Store matchStoreElem(Instruction instruction) {
        if (!(instruction instanceof Store store) || !(store.getReference() instanceof ArrayElementReference)) return null;
        return store;
    }
}
