package net.covers1624.coffeegrinder.util.resolver;

import com.google.common.collect.ImmutableMap;
import net.covers1624.coffeegrinder.asm.ASMClassTransformer;
import net.covers1624.coffeegrinder.util.asm.CustomInsnList;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.FieldNode;
import org.objectweb.asm.tree.MethodNode;

import java.util.List;
import java.util.Map;

/**
 * Represents a {@link ClassNode} that may only be a {@link ClassReader},
 * a {@link ClassNode} without any instructions parsed, or a full {@link ClassNode}
 * with frames expanded.
 * <p>
 * Created by covers1624 on 8/4/21.
 */
//TODO, Reader-only should probably be removed.
public class CachedClassNode {

    private ClassReader reader;
    private final List<ASMClassTransformer> transformers;
    @Nullable
    private ClassNode partialNode;
    @Nullable
    private ClassNode node;

    private final int access;
    private final String name;
    private final String superName;
    private final List<String> interfaces;

    @Nullable
    private Map<String, FieldNode> fieldNodes;
    @Nullable
    private Map<String, MethodNode> methodNodes;

    CachedClassNode(byte[] bytes, List<ASMClassTransformer> transformers) {
        reader = new ClassReader(bytes);
        this.transformers = transformers;
        access = reader.getAccess();
        name = reader.getClassName();
        superName = reader.getSuperName();
        interfaces = List.of(reader.getInterfaces());

        assert superName != null || name.equals("java/lang/Object");
    }

    /**
     * The class's access flags as returned by {@link ClassReader#getAccess()}.
     *
     * @return The access flags.
     */
    public int getAccess() {
        return access;
    }

    /**
     * Gets the {@link Type#getInternalName()} for this class.
     *
     * @return The name.
     */
    public String getClassName() {
        return name;
    }

    /**
     * Gets the {@link Type#getInternalName()} for the declared superclass.
     *
     * @return The name.
     */
    @Nullable
    public String getSuperName() {
        return superName;
    }

    /**
     * Gets the list of {@link Type#getInternalName()}s for the declared interfaces.
     *
     * @return The interface names.
     */
    public List<String> getInterfaces() {
        return interfaces;
    }

    /**
     * Gets a partial {@link ClassNode} without any method instructions being parsed.
     * This is preferable when the Class is only needed as reference.
     * <p>
     * This method will return the full expanded {@link ClassNode} if this class has had
     * {@link #getNode()} called. Callers should not be sensitive to this.
     * <p>
     * The returned {@link ClassNode} is modifiable. Please do not modify it.
     *
     * @return The partial {@link ClassNode}.
     */
    public ClassNode getPartialNode() {
        requirePartial();
        return node != null ? node : partialNode;
    }

    /**
     * Gets the full {@link ClassNode}.
     * <p>
     * Calling this method will cause the {@link #getPartialNode()} to be replaced
     * with the full class node.
     * <p>
     * The returned {@link ClassNode} is modifiable. Please do not modify it.
     *
     * @return The full {@link ClassNode}.
     */
    public ClassNode getNode() {
        requireFull();
        return node;
    }

    /**
     * Tries to find a field in this class for the given name and descriptor.
     *
     * @param name The name of the field.
     * @param desc The descriptor of the field.
     * @return The {@link FieldNode} if it exists. Literal <code>null</code> otherwise.
     */
    public FieldNode findField(String name, String desc) {
        requirePartial();
        assert fieldNodes != null;
        return fieldNodes.get(name + desc);
    }

    /**
     * Tries to find a partial method node in this class for the given name and descriptor.
     * In the case of this ClassNode being resolved into a Full {@link ClassNode} this method
     * will return a Full {@link MethodNode} with expanded frames.
     *
     * @param name The name of the method.
     * @param desc The descriptor of the method.
     * @return The {@link MethodNode} if it exists. Literal <code>null</code> otherwise.
     */
    public MethodNode findMethod(String name, String desc) {
        requirePartial();
        assert methodNodes != null;
        return methodNodes.get(name + desc);
    }

    /**
     * Tries to find a full method node in this class for the given name and descriptor.
     * This method is guaranteed to return a Full {@link MethodNode} with expanded frames.
     *
     * @param name The name of the method.
     * @param desc The descriptor of the method.
     * @return The {@link MethodNode} if it exists. Literal <code>null</code> otherwise.
     */
    public MethodNode findFullMethod(String name, String desc) {
        requireFull();
        assert methodNodes != null;
        return methodNodes.get(name + desc);
    }

    private void requirePartial() {
        if (partialNode != null || node != null) {
            return;
        }
        synchronized (this) {
            if (partialNode != null || node != null) return;

            ClassNode partialNode = new ClassNode();
            reader.accept(partialNode, ClassReader.SKIP_CODE);

            for (ASMClassTransformer transformer : transformers) {
                transformer.transform(partialNode);
            }

            processFields(partialNode);
            processMethods(partialNode);
            this.partialNode = partialNode;
        }
    }

    private void requireFull() {
        if (node != null) {
            return;
        }
        synchronized (this) {
            if (node != null) return;
            ClassNode node = new ClassNode();
            reader.accept(node, ClassReader.EXPAND_FRAMES);

            //Replace InsnLists with our extended InsnList.
            for (MethodNode method : node.methods) {
                method.instructions = new CustomInsnList(method.instructions);
            }

            for (ASMClassTransformer transformer : transformers) {
                transformer.transform(node);
            }

            processFields(node);
            processMethods(node);
            this.node = node;

            reader = null;//Reader not needed, let this GC
            partialNode = null;//Same with partial node.
        }
    }

    private void processFields(ClassNode node) {
        ImmutableMap.Builder<String, FieldNode> builder = ImmutableMap.builder();

        for (FieldNode field : node.fields) {
            builder.put(field.name + field.desc, field);
        }
        fieldNodes = builder.build();
    }

    private void processMethods(ClassNode node) {
        ImmutableMap.Builder<String, MethodNode> builder = ImmutableMap.builder();

        for (MethodNode method : node.methods) {
            builder.put(method.name + method.desc, method);
        }
        methodNodes = builder.build();
    }
}
