package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformer;
import net.covers1624.coffeegrinder.type.Method;
import net.covers1624.coffeegrinder.type.TypeSystem;
import net.covers1624.coffeegrinder.util.None;
import net.covers1624.quack.collection.FastStream;
import org.jetbrains.annotations.Nullable;

import java.util.Objects;

import static net.covers1624.coffeegrinder.bytecode.matching.LdcMatching.matchLdcInt;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchLoadLocal;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchStoreLocal;

/**
 * Created by covers1624 on 19/8/21.
 */
public class SwitchOnString extends SimpleInsnVisitor<StatementTransformContext> implements StatementTransformer {

    @Override
    public void transform(Instruction statement, StatementTransformContext ctx) {
        statement.accept(this, ctx);
    }

    @Override
    public None visitStore(Store store, StatementTransformContext ctx) {
        tryProcessSwitchOnString(store, ctx);
        return NONE;
    }

    private void tryProcessSwitchOnString(Store strVarStore, StatementTransformContext ctx) {
        // Match the switch pattern:
        // STORE (strVar, ...)
        // STORE (synVar, LDC_INT -1)
        // SWITCH(INVOKE.VIRTUAL (LOAD strVar) hashCode()) {
        //     ...
        // }
        // SWITCH(LOAD synVar) {
        //     ...
        // }

        if (!(strVarStore.getReference() instanceof LocalReference)) return;
        LocalVariable strVar = strVarStore.getVariable();
        if (!strVar.isSynthetic()) return;

        // Match
        // STORE (synVar, LDC_INT -1)
        Store synVarStore = matchStoreLocal(strVarStore.getNextSiblingOrNull());
        if (synVarStore == null) return;
        LocalVariable synVar = synVarStore.getVariable();
        if (!synVar.isSynthetic()) return;
        if (matchLdcInt(synVarStore.getValue(), -1) == null) return;

        if (!(synVarStore.getNextSiblingOrNull() instanceof Switch synSwitch)) return;

        if (matchInvokeStringHashcode(synSwitch.getValue()) == null) return;

        // During inlining a switch expression may have been created and the secondary switch may have been inlined somewhere.
        // We need to find it via the only load of the synVar (bucket index used on the second switch).
        Load realSwitchLoad = FastStream.of(synVar.getReferences())
                .map(e -> matchLoadLocal(e.getParent()))
                .filter(Objects::nonNull)
                .only();
        Switch realSwitch = (Switch) realSwitchLoad.getParent();
        // The switch _should_ be somewhere in the next instruction.
        assert realSwitch.isDescendantOf(synSwitch.getNextSibling());

        // Build synthetic id -> String lookup map.
        Int2ObjectMap<LdcString> stringLookup = new Int2ObjectOpenHashMap<>();

        for (SwitchTable.SwitchSection section : synSwitch.getSwitchTable().sections) {
            // The synthetic switch should have a single value per section.
            assert section.values.size() == 1;
            if (section.values.first() instanceof Nop) continue;

            collectSyntheticIds(stringLookup, synVar, (Block) section.getBody());
        }

        ctx.pushStep("Transform switch-on-string");
        // Set the switch value to the strVarStore value.
        realSwitch.getValue().replaceWith(strVarStore.getValue());
        for (SwitchTable.SwitchSection section : realSwitch.getSwitchTable().sections) {
            for (Instruction value : section.values) {
                if (value instanceof Nop) continue;

                // Replace ldc with string.
                LdcNumber ldc = (LdcNumber) value;
                LdcString stringVal = stringLookup.remove(ldc.intValue());
                assert stringVal != null : "Duplicate synthetic switch id.";
                ldc.replaceWith(stringVal);
            }
        }

        // All cases should have been handled.
        assert stringLookup.isEmpty();

        // Remove the synthetic switch.
        synSwitch.remove();

        synVarStore.remove();
        strVarStore.remove();
        ctx.popStep();
    }

    private void collectSyntheticIds(Int2ObjectMap<LdcString> stringLookup, LocalVariable synVar, Block block) {
        // This is all synthetically generated by javac, we are safe to just raw cast and assert that way.
        // Match the following case:
        // IF (strVar.equals(LDC_String "Some string")) Block {
        //      synVar = LDC_INT 1212;
        // }
        IfInstruction ifInsn = (IfInstruction) block.instructions.first();
        Invoke invoke = (Invoke) ifInsn.getCondition();
        LdcString ldcString = (LdcString) invoke.getArguments().first();

        Store sectionStore = (Store) ((Block) ifInsn.getTrueInsn()).instructions.first();
        assert sectionStore.getVariable() == synVar;
        LdcString prev = stringLookup.put(((LdcNumber) sectionStore.getValue()).intValue(), ldcString);
        assert prev == null : "Duplicate synthetic switch id.";

        // Recurse into else blocks, javac will generate these in the case of a hash collision.
        if (!(ifInsn.getFalseInsn() instanceof Nop)) {
            collectSyntheticIds(stringLookup, synVar, (Block) ifInsn.getFalseInsn());
        }
    }

    @Nullable
    private static Invoke matchInvokeStringHashcode(Instruction value) {
        Invoke invoke = InvokeMatching.matchInvoke(value, Invoke.InvokeKind.VIRTUAL);
        if (invoke == null) return null;

        Method method = invoke.getMethod();
        if (!TypeSystem.isString(method.getDeclaringClass())) return null;
        if (!method.getName().equals("hashCode")) return null;
        return invoke;
    }
}
