package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.AccessFlag;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching;
import net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching;
import net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching;
import net.covers1624.coffeegrinder.bytecode.matching.LogicMatching;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformer;
import net.covers1624.coffeegrinder.bytecode.transform.transformers.statement.ExpressionTransforms;
import net.covers1624.coffeegrinder.type.*;
import net.covers1624.coffeegrinder.util.None;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;

import java.util.Objects;

import static java.util.Objects.requireNonNull;
import static net.covers1624.coffeegrinder.bytecode.matching.IfMatching.matchNopFalseIf;
import static net.covers1624.coffeegrinder.type.TypeResolver.CLASS_TYPE;
import static net.covers1624.quack.util.SneakyUtils.notPossible;
import static org.objectweb.asm.Type.getMethodType;

/**
 * Created by covers1624 on 30/8/22.
 */
public class AssertTransform implements ClassTransformer {

    public static final Type ASSERTION_ERROR_TYPE = Type.getType(AssertionError.class);

    private final Method m_desiredAssertionStatus;
    private final ClassType assertionErrorType;

    public AssertTransform(TypeResolver resolver) {
        ClassType clazzType = resolver.resolveClassDecl(CLASS_TYPE);
        m_desiredAssertionStatus = requireNonNull(clazzType.resolveMethod("desiredAssertionStatus", getMethodType("()Z")));

        assertionErrorType = resolver.resolveClassDecl(ASSERTION_ERROR_TYPE);
    }

    @Override
    public void transform(ClassDecl cInsn, ClassTransformContext ctx) {
        FieldDecl assertionField = cInsn.getFieldMembers()
                .map(this::matchAssertionsDisabled)
                .filter(Objects::nonNull)
                .firstOrDefault();
        if (assertionField == null) return;

        cInsn.accept(new SimpleInsnVisitor<>() {
            @Override
            public None visitIfInstruction(IfInstruction ifInsn, None none) {
                Assert assertInsn = processIf(ifInsn, assertionField, ctx);
                if (assertInsn != null) {
                    return assertInsn.accept(this);
                }

                return super.visitIfInstruction(ifInsn, none);
            }
        });

        assertionField.remove();
    }

    @Nullable
    private Assert processIf(IfInstruction ifInsn, FieldDecl assertionField, ClassTransformContext ctx) {
        // Match the following case:
        // IF (TERNARY(COMPARISON.EQUAL (FIELD ${assertionField}, LDC_BOOLEAN false), ${condition}, LDC_BOOLEAN false)) THROW (NEW AssertionError(${message}))
        // Logically the same as:
        // if (!${assertionField} && ${condition}) throw new AssertionError(${message})
        // ->
        // ASSERT (!${condition}, ${message})

        if (matchNopFalseIf(ifInsn) == null) return null;

        LogicAnd and = LogicMatching.matchLogicAnd(ifInsn.getCondition());
        if (and == null) return null;

        LogicNot not = LogicMatching.matchLogicNot(and.getLeft());
        if (not == null) return null;

        if (LoadStoreMatching.matchLoadField(not.getArgument(), assertionField.getField()) == null) return null;

        Throw throwInsn = BranchLeaveMatching.matchThrow(ifInsn.getTrueInsn());
        if (throwInsn == null) return null;

        Instruction message = null;
        New newInsn = InvokeMatching.matchNew(throwInsn.getArgument(), assertionErrorType);
        if (newInsn == null) return null;
        if (!newInsn.getArguments().isEmpty()) {
            message = newInsn.getArguments().get(0);
        }

        ctx.pushStep("Emit assertion");
        Assert assertInsn = new Assert(new LogicNot(and.getRight()), message);
        ifInsn.replaceWith(assertInsn);
        ExpressionTransforms.runOnExpression(assertInsn.getCondition(), ctx);
        ctx.popStep();

        return assertInsn;
    }

    @Nullable
    private FieldDecl matchAssertionsDisabled(FieldDecl fieldDecl) {
        // Find the following field:
        // FIELD_DECL static final synthetic boolean ${name} (COMPARISON.EQUAL (INVOKE.VIRTUAL (LDC_CLASS ${currentClass}) Class.desiredAssertionStatus(), LDC_BOOLEAN false))

        Field field = fieldDecl.getField();
        ClassType parent = field.getDeclaringClass();
        while (parent.getDeclType() != ClassType.DeclType.TOP_LEVEL) {
            // Inner/Local classes will LDC the top level class.
            parent = parent.getEnclosingClass().orElseThrow(notPossible());
        }

        // Must be static final synthetic
        if (!field.isStatic()) return null;
        if (!field.getAccessFlags().get(AccessFlag.FINAL)) return null;
        if (!field.isSynthetic()) return null;

        // Must be boolean
        if (field.getType() != PrimitiveType.BOOLEAN) return null;

        // Value must be logic not.
        Instruction value = fieldDecl.getValue();
        LogicNot not = LogicMatching.matchLogicNot(value);
        if (not == null) return null;

        // Must invoke ThisClass.class.desiredAssertionStatus
        Invoke invoke = InvokeMatching.matchInvoke(not.getArgument(), Invoke.InvokeKind.VIRTUAL, m_desiredAssertionStatus);
        if (invoke == null) return null;

        // Target must load the parent class.
        if (!(invoke.getTarget() instanceof LdcClass ldcClass)) return null;
        // check the declarations (might be raw due to parent class having type parameters)
        if (((ClassType) ldcClass.getType()).getDeclaration() != parent) return null;

        return fieldDecl;
    }
}
