package net.covers1624.coffeegrinder.type.accessors;

import com.google.common.collect.ImmutableList;
import net.covers1624.coffeegrinder.type.Field;
import net.covers1624.coffeegrinder.type.TypeResolver;
import net.covers1624.coffeegrinder.type.accessors.SyntheticAccessor.CtorAccessor;
import net.covers1624.coffeegrinder.type.accessors.SyntheticAccessor.FieldAccessor;
import net.covers1624.coffeegrinder.type.accessors.SyntheticAccessor.FieldIncrementAccessor;
import net.covers1624.coffeegrinder.type.accessors.SyntheticAccessor.MethodAccessor;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.*;

import java.util.List;
import java.util.Objects;

import static net.covers1624.coffeegrinder.type.accessors.SyntheticAccessor.AccessorType.*;
import static org.objectweb.asm.Opcodes.*;

/**
 * Created by covers1624 on 4/7/22.
 */
public class AccessorParser {

    private static final int[] PRE_POST_CONST = {
            ICONST_1,
            ICONST_1,
            ICONST_1,
            ICONST_1,
            FCONST_1,
            LCONST_1,
            DCONST_1
    };
    private static final int[] PRE_POST_CAST = {
            I2C,
            I2B,
            I2S,
            -1,
            -1,
            -1,
            -1
    };

    private static final List<AccessorMatcher> FUNCS = ImmutableList.of(
            AccessorParser::matchFieldLoadAccessor,
            AccessorParser::matchFieldSetAccessor,
            AccessorParser::matchMethodAccessor,
            AccessorParser::matchCtorAccessor,
            AccessorParser::matchFieldPrePostIncrementAccessor
    );

    @Nullable
    public static SyntheticAccessor parseAccessor(TypeResolver resolver, MethodNode mNode) {
        for (AccessorMatcher matcher : FUNCS) {
            try {
                return matcher.match(resolver, mNode, new InsnCursor(mNode.instructions));
            } catch (MatchFailed ignored) {
            }
        }
        return null;
    }

    private static FieldAccessor matchFieldLoadAccessor(TypeResolver resolver, MethodNode mNode, InsnCursor c) {
        Type mDesc = Type.getMethodType(mNode.desc);
        Type retType = mDesc.getReturnType();

        FieldInsnNode field;
        if (c.peekVarInsn(ALOAD, 0)) {
            c.requireVarInsn(ALOAD, 0);
            field = c.requireFieldInsn(GETFIELD);
        } else {
            field = c.requireFieldInsn(GETSTATIC);
        }
        c.requireInsn(retType.getOpcode(IRETURN));

        return new FieldAccessor(FIELD_LOAD, resolver.resolveClassDecl(field.owner).resolveField(field.name, Type.getType(field.desc)));
    }

    private static FieldAccessor matchFieldSetAccessor(TypeResolver resolver, MethodNode mNode, InsnCursor c) {
        if (mNode.name.equals("<init>")) throw MatchFailed.INSTANCE;

        Type mDesc = Type.getMethodType(mNode.desc);
        Type[] args = mDesc.getArgumentTypes();
        Type retType = mDesc.getReturnType();

        FieldInsnNode field;
        if (args.length == 2) {
            if (!args[1].equals(retType)) throw MatchFailed.INSTANCE; // What?
            c.requireVarInsn(ALOAD, 0);
            c.requireVarInsn(retType.getOpcode(ILOAD), 1);
            c.requireInsn(retType.getSize() == 1 ? DUP_X1 : DUP2_X1);
            field = c.requireFieldInsn(PUTFIELD);
        } else {
            if (args.length != 1 || !args[0].equals(retType)) throw MatchFailed.INSTANCE; // What?
            c.requireVarInsn(retType.getOpcode(ILOAD), 0);
            c.requireInsn(retType.getSize() == 1 ? DUP : DUP2);
            field = c.requireFieldInsn(PUTSTATIC);
        }
        c.requireInsn(retType.getOpcode(IRETURN));

        return new FieldAccessor(FIELD_STORE, resolver.resolveClassDecl(field.owner).resolveField(field.name, Type.getType(field.desc)));
    }

    private static SyntheticAccessor matchFieldPrePostIncrementAccessor(TypeResolver resolver, MethodNode mNode, InsnCursor c) {
        // if (!static) ALOAD 0
        // if (!static) DUP
        // static ? GETSTATIC : GETFIELD <field>
        // if (postInc) static ? DUP : DUP_X1
        // ICONST_1
        // if (positive) IADD : ISUB
        // if (!postInc) static ? DUP : DUP_X1
        // static ? PUTSTATIC : PUTFIELD <field>
        // IRETURN

        Type mDesc = Type.getMethodType(mNode.desc);
        Type retType = mDesc.getReturnType();
        FieldInsnNode field;
        boolean postInc;
        boolean positive;
        boolean isStatic = !c.peekVarInsn(ALOAD, 0);
        int dupOpcode;
        int putOpcode;
        if (!isStatic) {
            c.requireVarInsn(ALOAD, 0);
            c.requireInsn(DUP);
            field = c.requireFieldInsn(GETFIELD);
            if (!field.desc.equals(retType.toString())) throw MatchFailed.INSTANCE;

            dupOpcode = retType.getSize() == 1 ? DUP_X1 : DUP2_X1;
            putOpcode = PUTFIELD;
        } else {
            field = c.requireFieldInsn(GETSTATIC);
            if (!field.desc.equals(retType.toString())) throw MatchFailed.INSTANCE;

            dupOpcode = retType.getSize() == 1 ? DUP : DUP2;
            putOpcode = PUTSTATIC;
        }
        postInc = c.peekInsn(dupOpcode);
        if (postInc) {
            c.requireInsn(dupOpcode);
        }
        c.requireInsn(PRE_POST_CONST[retType.getSort() - Type.CHAR]);
        positive = c.peekInsn(retType.getOpcode(IADD));
        c.requireInsn(retType.getOpcode(positive ? IADD : ISUB));
        int castOpcode = PRE_POST_CAST[retType.getSort() - Type.CHAR];
        if (castOpcode != -1) {
            c.requireInsn(castOpcode);
        }

        if (!postInc) {
            c.requireInsn(dupOpcode);
        }
        c.requireFieldInsn(putOpcode, field);
        c.requireInsn(retType.getOpcode(IRETURN));

        Field fieldType = resolver.resolveClassDecl(field.owner).resolveField(field.name, retType);
        if (postInc) {
            return new FieldIncrementAccessor(FIELD_POST_INC, fieldType, positive);
        }
        return new FieldIncrementAccessor(FIELD_PRE_INC, fieldType, positive);
    }

    private static MethodAccessor matchMethodAccessor(TypeResolver resolver, MethodNode mNode, InsnCursor c) {
        if (mNode.name.equals("<init>")) throw MatchFailed.INSTANCE;

        Type mDesc = Type.getMethodType(mNode.desc);
        Type[] args = mDesc.getArgumentTypes();
        Type retType = mDesc.getReturnType();
        int index = 0;
        for (Type type : args) {
            c.requireVarInsn(type.getOpcode(ILOAD), index);
            index += type.getSize();
        }
        MethodInsnNode method;
        if (c.peekMethodInsn(INVOKESTATIC)) {
            method = c.requireMethodInsn(INVOKESTATIC);
        } else if (c.peekMethodInsn(INVOKEVIRTUAL)) {
            method = c.requireMethodInsn(INVOKEVIRTUAL);
        } else {
            method = c.requireMethodInsn(INVOKESPECIAL);
            if (method.name.equals("<init>")) throw MatchFailed.INSTANCE;
        }
        c.requireInsn(retType.getOpcode(IRETURN));

        return new MethodAccessor(INVOKE, resolver.resolveClassDecl(method.owner).resolveMethod(method.name, Type.getMethodType(method.desc)));
    }

    private static CtorAccessor matchCtorAccessor(TypeResolver resolver, MethodNode mNode, InsnCursor c) {
        if (!mNode.name.equals("<init>")) throw MatchFailed.INSTANCE;

        Type mDesc = Type.getMethodType(mNode.desc);
        Type[] args = mDesc.getArgumentTypes();
        if (args.length < 1) throw MatchFailed.INSTANCE;

        c.requireVarInsn(ALOAD, 0);

        int index = 1;
        for (int i = 0; i < args.length - 1; i++) { // Chop last, should be unused
            Type type = args[i];
            c.requireVarInsn(type.getOpcode(ILOAD), index);
            index += type.getSize();
        }
        MethodInsnNode method = c.requireMethodInsn(INVOKESPECIAL, "<init>");
        c.requireInsn(RETURN);

        return new CtorAccessor(
                resolver.resolveClassDecl(method.owner).resolveMethod(method.name, Type.getMethodType(method.desc)),
                resolver.resolveClassDecl(args[args.length - 1])
        );
    }

    public interface AccessorMatcher {

        SyntheticAccessor match(TypeResolver resolver, MethodNode mNode, InsnCursor c) throws MatchFailed;
    }

    public static class InsnCursor {

        @Nullable
        private AbstractInsnNode pointer;

        public InsnCursor(InsnList list) {
            pointer = skipUnimportant(list.getFirst());
        }

        public void requireInsn(int opcode) {
            if (!peekInsn(opcode)) {
                throw MatchFailed.INSTANCE;
            }
            move();
        }

        public void requireVarInsn(int opcode, int operand) {
            if (!peekVarInsn(opcode, operand)) {
                throw MatchFailed.INSTANCE;
            }
            move();
        }

        public FieldInsnNode requireFieldInsn(int opcode) {
            return requireFieldInsn(opcode, null);
        }

        public FieldInsnNode requireFieldInsn(int opcode, @Nullable FieldInsnNode other) {
            if (!peekFieldInsn(opcode, other)) {
                throw MatchFailed.INSTANCE;
            }
            return (FieldInsnNode) move();
        }

        public MethodInsnNode requireMethodInsn(int opcode) {
            return requireMethodInsn(opcode, "*");
        }

        public MethodInsnNode requireMethodInsn(int opcode, String name) {
            if (!peekMethodInsn(opcode, name)) {
                throw MatchFailed.INSTANCE;
            }
            return (MethodInsnNode) move();
        }

        public boolean peekInsn(int opcode) {
            AbstractInsnNode insn = peek();

            return insn.getOpcode() == opcode;
        }

        public boolean peekVarInsn(int opcode, int operand) {
            AbstractInsnNode insn = peek();
            if (insn.getOpcode() != opcode) return false;

            return ((VarInsnNode) insn).var == operand;
        }

        public boolean peekFieldInsn(int opcode) {
            return peekFieldInsn(opcode, null);
        }

        public boolean peekFieldInsn(int opcode, @Nullable FieldInsnNode other) {
            AbstractInsnNode insn = peek();
            if (insn.getOpcode() != opcode) return false;

            FieldInsnNode fInsn = (FieldInsnNode) insn;

            if (other == null) return true;

            return other.owner.equals(fInsn.owner)
                    && other.name.equals(fInsn.name)
                    && other.desc.equals(fInsn.desc);
        }

        public boolean peekMethodInsn(int opcode) {
            return peekMethodInsn(opcode, "*");
        }

        public boolean peekMethodInsn(int opcode, String name) {
            AbstractInsnNode insn = peek();
            if (insn.getOpcode() != opcode) return false;

            MethodInsnNode mInsn = (MethodInsnNode) insn;

            return name.equals("*") || name.equals(mInsn.name);
        }

        public AbstractInsnNode peek() {
            return Objects.requireNonNull(pointer, "Already read entire method.");
        }

        public AbstractInsnNode move() {
            AbstractInsnNode curr = peek();
            pointer = skipUnimportant(peek().getNext());
            return curr;
        }

        @Nullable
        private static AbstractInsnNode skipUnimportant(AbstractInsnNode pointer) {
            while (pointer != null
                    && (pointer.getType() == AbstractInsnNode.LABEL
                    || pointer.getType() == AbstractInsnNode.LINE
                    || pointer.getType() == AbstractInsnNode.FRAME)) {
                pointer = pointer.getNext();
            }
            return pointer;
        }
    }

    private static class MatchFailed extends RuntimeException {

        private static final MatchFailed INSTANCE = new MatchFailed();

        static {
            INSTANCE.setStackTrace(new StackTraceElement[0]);
        }
    }
}
