package net.covers1624.coffeegrinder.bytecode;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.TryCatch.TryCatchHandler;
import net.covers1624.coffeegrinder.type.AType;
import net.covers1624.coffeegrinder.type.Field;
import net.covers1624.coffeegrinder.type.Method;
import org.jetbrains.annotations.Nullable;

import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

/**
 * Matches a tree of instructions to another tree of instructions.
 * <p>
 * Created by covers1624 on 2/6/21.
 */
public class SemanticMatcher {

    private final BiMap<LocalVariable, LocalVariable> varMap = HashBiMap.create();
    @Nullable
    private final Instruction endpoint;

    /**
     * The block match map. Left hand block to Right hand block.
     */
    public final Map<Block, Block> blockMap = new LinkedHashMap<>();

    @Nullable
    public Instruction leftFail;

    @Nullable
    public Instruction rightFail;

    /**
     * The right hand side endpoint.
     */
    @Nullable
    public Instruction matchedEndpoint;

    /**
     * Construct a new Semantic matcher.
     *
     * @param endpoint The Left hand side endpoint.
     */
    public SemanticMatcher(@Nullable Instruction endpoint) {
        this.endpoint = endpoint;
    }

    @SuppressWarnings ("ConstantConditions") //goway
    public boolean equivalent(Instruction insn1, Instruction insn2) {
        if (insn1 == endpoint) {
            matchedEndpoint = insn2;
            return true;
        }

        if (insn1 == insn2) return true; // If the insn is the same object, true match.
        if (insn1.opcode != insn2.opcode) return failMatch(insn1, insn2); // If the opcodes are different, no match.

        //Block comes first as we do some slightly custom matching.
        if (insn1 instanceof Block) {
            Block block1 = (Block) insn1;
            Block block2 = (Block) insn2;
            if (blockMap.get(block1) == block2) return true;
            blockMap.put(block1, block2);
        }

        //Check all child instructions.
        Iterator<Instruction> children1 = insn1.getChildren().iterator();
        Iterator<Instruction> children2 = insn2.getChildren().iterator();
        while (children1.hasNext() || children2.hasNext()) {
            if (children1.hasNext() != children2.hasNext()) return failMatch(insn1, insn2);
            Instruction child1 = children1.next();
            Instruction child2 = children2.next();
            if (!equivalent(child1, child2)) return failMatch(child1, child2);
            if (matchedEndpoint == child2) return true;
        }

        if (insn1.opcode == InsnOpcode.ARRAY_LEN) return true;
        if (insn1.opcode == InsnOpcode.BLOCK) return true;
        if (insn1.opcode == InsnOpcode.BLOCK_CONTAINER) return true;
        if (insn1.opcode == InsnOpcode.BRANCH) return matches((Branch) insn1, (Branch) insn2);
        if (insn1.opcode == InsnOpcode.CHECK_CAST) return matches((Cast) insn1, (Cast) insn2);
        if (insn1 instanceof Comparison) return matches((Comparison) insn1, (Comparison) insn2);
        if (insn1.opcode == InsnOpcode.FIELD_DECL) return matches((FieldDecl) insn1, (FieldDecl) insn2);
        if (insn1.opcode == InsnOpcode.IF) return true;
        if (insn1.opcode == InsnOpcode.POST_INCREMENT) return true;
        if (insn1.opcode == InsnOpcode.INSTANCE_OF) return matches((InstanceOf) insn1, (InstanceOf) insn2);
        if (insn1.opcode == InsnOpcode.INVOKE) return matches((Invoke) insn1, (Invoke) insn2);
        if (insn1.opcode == InsnOpcode.INVOKE_DYNAMIC) return true;
        if (insn1 instanceof LdcInsn) return matches((LdcInsn) insn1, (LdcInsn) insn2);
        if (insn1.opcode == InsnOpcode.LEAVE) return matches((Leave) insn1, (Leave) insn2);
        if (insn1.opcode == InsnOpcode.LOAD) return true;
        if (insn1.opcode == InsnOpcode.LOCAL_REFERENCE) return matches((LocalReference) insn1, (LocalReference) insn2);
        if (insn1.opcode == InsnOpcode.FIELD_REFERENCE) return matches((FieldReference) insn1, (FieldReference) insn2);
        if (insn1.opcode == InsnOpcode.ARRAY_ELEMENT_REFERENCE) return true;
        if (insn1.opcode == InsnOpcode.COMPOUND_ASSIGNMENT) return true;
        if (insn1.opcode == InsnOpcode.LOAD_THIS) return matches((LoadThis) insn1, (LoadThis) insn2);
        if (insn1.opcode == InsnOpcode.METHOD_REFERENCE) return matches((MethodReference) insn1, (MethodReference) insn2);
        if (insn1 instanceof Monitor) return true;
        if (insn1.opcode == InsnOpcode.NEW) return matches((New) insn1, (New) insn2);
        if (insn1.opcode == InsnOpcode.NEW_ARRAY) return matches((NewArray) insn1, (NewArray) insn2);
        if (insn1.opcode == InsnOpcode.NEW_OBJECT) return matches((NewObject) insn1, (NewObject) insn2);
        if (insn1.opcode == InsnOpcode.NOP) return true;
        if (insn1.opcode == InsnOpcode.BINARY) return matches((Binary) insn1, (Binary) insn2);
        if (insn1.opcode == InsnOpcode.STORE) return true;
        if (insn1.opcode == InsnOpcode.RETURN) return true;
        if (insn1.opcode == InsnOpcode.SWITCH_TABLE) return true;
        if (insn1.opcode == InsnOpcode.SWITCH_SECTION) return true;
        if (insn1.opcode == InsnOpcode.TERNARY) return true;
        if (insn1.opcode == InsnOpcode.THROW) return true;
        if (insn1.opcode == InsnOpcode.TRY_CATCH) return true;
        if (insn1.opcode == InsnOpcode.TRY_CATCH_HANDLER) return matches((TryCatchHandler) insn1, (TryCatchHandler) insn2);

        throw new IllegalArgumentException("Unhandled instruction type for semantic match: " + insn1.getClass().getName());
    }

    private boolean failMatch(Instruction left, Instruction right) {
        if (leftFail == null) {
            leftFail = left;
            rightFail = right;
        }
        return false;
    }

    //region Per instruction matching.
    private boolean matches(Branch br1, Branch br2) {
        return equivalent(br1.getTargetBlock(), br2.getTargetBlock());
    }

    private boolean matches(Cast cast1, Cast cast2) {
        return matchType(cast1.getType(), cast2.getType());
    }

    private boolean matches(Comparison comp1, Comparison comp2) {
        return comp1.getKind() == comp2.getKind();
    }

    private boolean matches(FieldDecl f1, FieldDecl f2) {
        return f1.getField().equals(f2.getField());
    }

    private boolean matches(InstanceOf iof1, InstanceOf iof2) {
        return matchType(iof1.getType(), iof2.getType());
    }

    private boolean matches(Invoke invoke1, Invoke invoke2) {
        if (invoke1.getKind() != invoke2.getKind()) return false;
        return matches(invoke1.getMethod(), invoke2.getMethod());
    }

    private boolean matches(LdcInsn ldc1, LdcInsn ldc2) {
        return Objects.equals(ldc1.getRawValue(), ldc2.getRawValue());
    }

    private boolean matches(Leave leave1, Leave leave2) {
        return leave1.getTargetContainer() == leave2.getTargetContainer();
    }

    private boolean matches(LocalReference ref1, LocalReference ref2) {
        LocalVariable var1 = ref1.variable;
        LocalVariable var2 = ref2.variable;
        if (ref1.isWrittenTo() != ref2.isWrittenTo() || ref1.isReadFrom() != ref2.isReadFrom()) return false;
        if (ref1.isWrittenTo()) {
            LocalVariable rhMapping = varMap.get(var1);
            if (rhMapping != null) return rhMapping == var2; // If we got a result, it must match.
            if (varMap.containsValue(var2)) return false; // The second variable must not be mapped to anything else.
            varMap.put(var1, var2);
            return true;
        }

        return varMap.get(var1) == var2 || var1 == var2;
    }

    private boolean matches(FieldReference ref1, FieldReference ref2) {
        return matches(ref1.getField(), ref2.getField());
    }

    private boolean matches(LoadThis load1, LoadThis load2) {
        return matchType(load1.getType(), load2.getType());
    }

    private boolean matches(MethodReference insn1, MethodReference insn2) {
        return matches(insn1.getMethod(), insn2.getMethod());
    }

    private boolean matches(New n1, New n2) {
        return matches(n1.getMethod(), n2.getMethod());
    }

    private boolean matches(NewArray nArr1, NewArray nArr2) {
        return matchType(nArr1.getType(), nArr2.getType());
    }

    private boolean matches(NewObject nArr1, NewObject nArr2) {
        return matchType(nArr1.getType(), nArr2.getType());
    }

    private boolean matches(Binary n1, Binary n2) {
        return n1.getOp() == n2.getOp();
    }

    private boolean matches(TryCatchHandler handler1, TryCatchHandler handler2) {
        return handler1.isUnprocessedFinally == handler2.isUnprocessedFinally;
    }
    //endregion

    //region Common matching
    private boolean matchType(AType t1, AType t2) {
        return t1.equals(t2);
    }

    private boolean matches(Method m1, Method m2) {
        return matchType(m1.getDeclaringClass(), m2.getDeclaringClass())
                && m1.getName().equals(m2.getName())
                && m1.getDescriptor().equals(m2.getDescriptor());
    }

    private boolean matches(Field f1, Field f2) {
        return matchType(f1.getDeclaringClass(), f2.getDeclaringClass())
                && f1.getName().equals(f2.getName())
                && f1.getDescriptor().equals(f2.getDescriptor());
    }
    //endregion
}
