package net.covers1624.coffeegrinder.bytecode.transform.transformers.statement;

import net.covers1624.coffeegrinder.bytecode.InsnOpcode;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.tags.PotentialConstantLookupTag;
import net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformer;
import net.covers1624.coffeegrinder.type.ClassType;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;

import static net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching.matchInvoke;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchLoadLocal;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchStoreLocal;

/**
 * Created by covers1624 on 17/9/21.
 */
public class GeneratedNullChecks implements StatementTransformer {

    @SuppressWarnings ("NotNullFieldNotInitialized")
    private StatementTransformContext ctx;

    @Override
    public void transform(Instruction statement, StatementTransformContext ctx) {
        this.ctx = ctx;

        Store deadStore = matchDeadStore(statement);
        if (deadStore == null) return;

        Instruction value = deadStore.getValue();
        Instruction nullCheckedValue = matchNullCheck(value);
        if (nullCheckedValue != null && removeSyntheticNullCheck(statement, nullCheckedValue)) return;

        tagPotentialConstantLookup(statement, value, nullCheckedValue);
    }

    @Nullable
    private Store matchDeadStore(Instruction statement) {
        Instruction prev = statement.getPrevSiblingOrNull();
        if (prev == null) return null;

        Store deadStore = matchStoreLocal(prev);
        if (deadStore == null) return null;

        LocalVariable var = deadStore.getVariable();
        if (!var.isSynthetic() || var.getReferenceCount() > 1) return null;
        return deadStore;
    }

    private boolean removeSyntheticNullCheck(Instruction statement, Instruction nullCheckedValue) {
        Load nullChecked = LoadStoreMatching.matchLoadLocal(nullCheckedValue);
        if (nullChecked == null) return false;

        LocalVariable var = nullChecked.getVariable();
        if (var.getKind() != LocalVariable.VariableKind.STACK_SLOT) return false;
        if (var.getLoadCount() < 2) return false; // Ensure the stack variable is used after the null check.

        ctx.pushStep("Remove synthetic null check");
        ctx.moveNext();
        statement.getPrevSibling().remove();
        ctx.popStep();

        return true;
    }

    private void tagPotentialConstantLookup(Instruction statement, Instruction deadStoreValue, @Nullable Instruction nullCheckedValue) {
        Store store = matchStoreLocal(statement);
        if (store == null) return;

        if (!(store.getValue() instanceof LdcInsn)) return;
        LdcInsn ldc = (LdcInsn) store.getValue();
        if (ldc.opcode == InsnOpcode.LDC_NULL || ldc.opcode == InsnOpcode.LDC_CLASS) return;

        boolean isStatic = nullCheckedValue == null;
        Instruction target = isStatic ? deadStoreValue : nullCheckedValue;
        if (!(target.getResultType() instanceof ClassType)) return;

        ctx.pushStep("Tag ldc: " + ldc.getRawValue());
        deadStoreValue.setTag(new PotentialConstantLookupTag(isStatic, ldc));
        ctx.popStep();
    }

    @Nullable
    @Contract ("null->null")
    public static Instruction matchNullCheck(@Nullable Instruction statement) {
        if (statement == null) return null;

        // Used by <= J8, 'expr.getClass()'
        Invoke invoke = matchInvoke(statement, Invoke.InvokeKind.VIRTUAL, "getClass", Type.getMethodType("()Ljava/lang/Class;"));
        if (invoke != null) {
            return invoke.getTarget();
        }

        // Used by >= J9 'Objects.requireNonNull(expr)'
        invoke = matchInvoke(statement, Invoke.InvokeKind.STATIC, "requireNonNull", Type.getMethodType("(Ljava/lang/Object;)Ljava/lang/Object;"));
        if (invoke != null) {
            return invoke.getArguments().first();
        }
        return null;

    }
}
