package net.covers1624.coffeegrinder.type.asm;

import net.covers1624.coffeegrinder.bytecode.AccessFlag;
import net.covers1624.coffeegrinder.type.*;
import net.covers1624.coffeegrinder.util.EnumBitSet;
import net.covers1624.coffeegrinder.util.Util;
import net.covers1624.coffeegrinder.util.asm.TypeParameterParser;
import net.covers1624.coffeegrinder.util.resolver.CachedClassNode;
import net.covers1624.quack.collection.FastStream;
import net.covers1624.quack.util.JavaVersion;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.InnerClassNode;

import java.lang.annotation.ElementType;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;

import static java.util.Objects.requireNonNull;
import static net.covers1624.quack.util.SneakyUtils.unsafeCast;

/**
 * A Class as loaded from Bytecode.
 * <p>
 * Created by covers1624 on 7/4/21.
 */
public class AsmClass extends ClassType {

    // Special internal annotation used by Javac to determine which functions have polymorphic signatures.
    private static final Type POLYMORPHIC_SIGNATURE = Type.getObjectType("java/lang/invoke/MethodHandle$PolymorphicSignature");

    private static final Type ANNOTATION_TARGET = Type.getObjectType("java/lang/annotation/Target");

    // These classes are known to have methods annotated with PolymorphicSignature.
    private static final Set<Type> KNOWN_POLYMORPHIC = Set.of(
            Type.getObjectType("java/lang/invoke/MethodHandle"),
            Type.getObjectType("java/lang/invoke/VarHandle")
    );

    private final TypeResolver typeResolver;
    private final CachedClassNode cNode;
    private final Type descriptor;
    private final String name;
    private final String pkg;
    private final String fullName;
    private final Optional<AsmClass> declaringClass;

    private final DeclType declType;

    private final EnumBitSet<AccessFlag> accessFlags;
    private final boolean hasTypeParameters;
    @Nullable
    private List<AsmTypeParameter> typeParameters;
    private final Supplier<SignatureInfo> signatureInfo;
    private final Supplier<ClassType> superClass;
    private final Supplier<List<ClassType>> nestedClasses;
    private final Supplier<List<ClassType>> interfaces;
    private final Supplier<List<Field>> fields;
    private final Supplier<List<Method>> methods;
    private final Supplier<Optional<Method>> enclosingMethod;

    private final Supplier<Map<Object, Field>> instanceConstantLookup;
    private final Supplier<Map<Object, Field>> staticConstantLookup;

    private final AnnotationSupplier annotationSupplier;
    private final Supplier<List<ElementType>> annotationTargets;

    private final Supplier<List<ClassType>> permittedSubclasses;

    private final RawClass rawClass;

    private final JavaVersion classVersion;

    private final boolean containsPolymorphicMethods;

    public AsmClass(TypeResolver typeResolver, CachedClassNode cNode) {
        this.typeResolver = typeResolver;
        this.cNode = cNode;
        descriptor = Type.getObjectType(cNode.getClassName());
        ClassNode pNode = cNode.getPartialNode();

        int lastSlash = pNode.name.lastIndexOf('/');
        String pkg;
        String name;
        if (lastSlash == -1) {
            pkg = "";
            name = pNode.name;
        } else {
            name = pNode.name.substring(lastSlash + 1);
            pkg = pNode.name.substring(0, lastSlash).replace('/', '.');
        }
        String fullName = pNode.name.replace("/", ".");

        // ASM documentation is unclear here, outerClass is only present when outerMethod is present.
        String outerClass = pNode.outerClass;
        if (pNode.outerClass == null) { // fallback, try and find outer via name mangling
            int lastDollar = pNode.name.lastIndexOf("$");
            if (lastDollar != -1) {
                outerClass = pNode.name.substring(0, lastDollar);
            }
        }

        // an outer class may not exist if the name demangling fails (eg, if the library author added $ characters for other reasons)
        declaringClass = Optional.ofNullable(outerClass != null ? (AsmClass) typeResolver.tryResolveClassDecl(outerClass) : null);

        accessFlags = AccessFlag.unpackClass(pNode.access);

        InnerClassNode ownInnerClassNode = FastStream.of(pNode.innerClasses)
                .filter(e -> e.name.equals(pNode.name))
                .onlyOrDefault();
        if (ownInnerClassNode != null) {
            // Records don't persist the RECORD flag onto both the inner flags and the real class flags.
            // So we must do this annoying merge nonsense, instead of just choosing which to use, we only really
            // care about copying over access related flags here, as javac declares inners public when they are protected,
            // and package private when they are private, but persists the real flags on the inner attribute flags.
            accessFlags.clear(AccessFlag.PUBLIC);
            accessFlags.clear(AccessFlag.PRIVATE);
            accessFlags.clear(AccessFlag.PROTECTED);
            accessFlags.or(AccessFlag.unpackClass(ownInnerClassNode.access));
        }

        if (declaringClass.isPresent()) {
            InnerClassNode innerNode = FastStream.of(declaringClass.get().cNode.getPartialNode().innerClasses)
                    .filter(e -> e.name.equals(pNode.name))
                    .findFirst()
                    .orElseThrow(() -> new IllegalStateException("Class has outerClass attribute, but outer does not have innerClass reference to us. " + pNode.name));
            if (innerNode.innerName == null) {
                declType = DeclType.ANONYMOUS;
            } else if (innerNode.outerName == null) {
                declType = DeclType.LOCAL;
                name = innerNode.innerName;
            } else {
                declType = DeclType.INNER;
                name = innerNode.innerName;
                fullName = declaringClass.get().getFullName() + "$" + name;
            }
            accessFlags.or(AccessFlag.unpackInnerClass(innerNode.access));
        } else {
            declType = DeclType.TOP_LEVEL;
        }
        this.name = name;
        this.pkg = pkg;
        this.fullName = fullName;

        hasTypeParameters = pNode.signature != null && pNode.signature.charAt(0) == '<';

        signatureInfo = Util.singleMemoize(() -> {
            typeParameters = pNode.signature != null ? TypeParameterParser.parse(pNode.signature, this) : List.of();
            if (pNode.signature == null) return SignatureInfo.NONE;

            ClassSignatureParser visitor = ClassSignatureParser.parse(typeResolver, this, pNode.signature);
            return new SignatureInfo(visitor.getSuperClass(), visitor.getInterfaces());
        });

        superClass = Util.singleMemoize(() -> {
            if (pNode.signature != null) return requireNonNull(signatureInfo.get().superClass);

            if (pNode.superName == null) {
                throw new UnsupportedOperationException("Object does not have a super type.");
            }

            return typeResolver.resolveClass(pNode.superName);
        });
        nestedClasses = Util.singleMemoize(() -> {
                    FastStream<String> innerClassStream;
                    // Noticed on Javac17, the innerClasses list is no longer ordered correctly.
                    // Use nestMembers if it exists to keep inner class order.
                    if (pNode.nestMembers != null && !pNode.nestMembers.isEmpty()) {
                        innerClassStream = FastStream.of(pNode.nestMembers);
                    } else {
                        innerClassStream = FastStream.of(pNode.innerClasses)
                                .map(e -> e.name);
                    }

                    return innerClassStream.filter(e -> e.startsWith(pNode.name))
                            .map(typeResolver::resolveClassDecl)
                            .filter(e -> e.getEnclosingClass().filter(d -> d == this).isPresent())
                            .toImmutableList();
                }
        );
        interfaces = Util.singleMemoize(() -> {
            if (pNode.signature != null) return signatureInfo.get().interfaces;

            return FastStream.of(cNode.getInterfaces())
                    .map(typeResolver::resolveClass)
                    .toImmutableList();
        });
        fields = Util.singleMemoize(() -> FastStream.of(pNode.fields)
                .<Field>map(e -> new AsmField(typeResolver, this, e))
                .toImmutableList()
        );

        methods = Util.singleMemoize(() -> FastStream.of(pNode.methods)
                .<Method>map(e -> new AsmMethod(typeResolver, this, e))
                .toImmutableList()
        );

        instanceConstantLookup = Util.singleMemoize(() -> FastStream.of(fields.get())
                .filter(e -> e.getConstantValue() != null)
                .filter(e -> !e.isStatic())
                .toImmutableMap(Field::getConstantValue, Function.identity())
        );

        staticConstantLookup = Util.singleMemoize(() -> FastStream.of(fields.get())
                .filter(e -> e.getConstantValue() != null)
                .filter(Field::isStatic)
                .toImmutableMap(Field::getConstantValue, Function.identity())
        );
        enclosingMethod = Util.singleMemoize(() -> {
            if (pNode.outerMethod == null) {
                return Optional.empty();
            }
            return declaringClass.map(e -> e.methodsLookup.get().get(pNode.outerMethod + pNode.outerMethodDesc));
        });

        annotationSupplier = new AnnotationSupplier(typeResolver,
                Util.safeConcat(pNode.visibleAnnotations, pNode.invisibleAnnotations),
                Util.safeConcat(pNode.visibleTypeAnnotations, pNode.invisibleTypeAnnotations)
        );
        annotationTargets = Util.singleMemoize(() -> {
            if (!getAccessFlags().get(AccessFlag.ANNOTATION)) throw new UnsupportedOperationException(getFullName() + " is not an annotation.");

            AnnotationData targetAnnotation = FastStream.of(getAnnotationSupplier().getAnnotations())
                    .filter(e -> e.type().getDescriptor().equals(ANNOTATION_TARGET))
                    .onlyOrDefault();
            if (targetAnnotation == null) return List.of();

            //noinspection unchecked
            return FastStream.of((List<Field>) targetAnnotation.values().get("value"))
                    .map(e -> ElementType.valueOf(e.getName()))
                    .toImmutableList();
        });

        permittedSubclasses = Util.singleMemoize(() -> {
            if (pNode.permittedSubclasses == null || pNode.permittedSubclasses.isEmpty()) return List.of();

            return FastStream.of(pNode.permittedSubclasses)
                    .map(typeResolver::resolveClassDecl)
                    .toImmutableList();
        });

        rawClass = new RawClass(this);
        classVersion = JavaVersion.parseFromClass(pNode.version);

        containsPolymorphicMethods = KNOWN_POLYMORPHIC.contains(descriptor);
    }

    //@formatter:off
    public CachedClassNode getNode() { return cNode; }
    @Override public String getName() { return name; }
    @Override public Optional<ClassType> getEnclosingClass() { return unsafeCast(declaringClass); }
    @Override public String getPackage() { return pkg; }
    @Override public String getFullName() { return fullName; }
    @Override public ClassType getSuperClass() { return superClass.get(); }
    @Override public List<ClassType> getNestedClasses() { return nestedClasses.get(); }
    @Override public List<ClassType> getInterfaces() { return interfaces.get(); }
    @Override public List<Field> getFields() { return fields.get(); }
    @Override public List<Method> getMethods() { return methods.get(); }
    @Override public Optional<Method> getEnclosingMethod() { return enclosingMethod.get(); }
    @Override public AnnotationSupplier getAnnotationSupplier() { return annotationSupplier; }
    @Override public List<ElementType> getAnnotationTargets() { return annotationTargets.get(); }
    @Override public Type getDescriptor() { return descriptor; }
    @Override public ClassType getDeclaration() { return this; }
    @Override public DeclType getDeclType() { return declType; }
    @Override public EnumBitSet<AccessFlag> getAccessFlags() { return accessFlags; }
    @Override public ClassType asRaw() { return rawClass; }
    @Override public JavaVersion getClassVersion() { return classVersion; }
    @Override public TypeResolver getTypeResolver() { return typeResolver; }
    @Override public List<ClassType> getPermittedSubclasses() { return permittedSubclasses.get(); }
    private Map<Object, Field> getInstanceConstantLookup() { return instanceConstantLookup.get(); }
    private Map<Object, Field> getStaticConstantLookup() { return staticConstantLookup.get(); }
    //@formatter:on

    @Nullable
    @Override
    public Method resolveMethod(String name, Type desc) {
        Method polyMethod = getPolymorphicMethod(name, desc);
        if (polyMethod != null) {
            return polyMethod;
        }

        return super.resolveMethod(name, desc);
    }

    @Nullable
    private Method getPolymorphicMethod(String name, Type desc) {
        if (!containsPolymorphicMethods) return null;
        for (Method method : getMethods()) {
            if (method.getName().equals(name) && isPolymorphicMethod(method)) {
                return new PolymorphicSignatureMethod(method, typeResolver, desc);
            }
        }
        return null;
    }

    private static boolean isPolymorphicMethod(Method method) {
        // JLS says that native, Object[] methods in MethodHandle and VarHandle are polymorphic.
        // Javadoc says any method with PolymorphicSignature annotation. Let's test for both!

        // Must be native
        if (!method.getAccessFlags().get(AccessFlag.NATIVE)) return false;
        // Must have a single Object[] parameter.
        List<Parameter> params = method.getParameters();
        if (params.size() != 1) return false;
        AType first = params.getFirst().getType();
        if (!(first instanceof ArrayType)) return false;
        if (!TypeSystem.isObject(((ArrayType) first).getElementType())) return false;

        // Must have annotation.
        for (AnnotationData ann : method.getAnnotationSupplier().getAnnotations()) {
            if (ann.type().getDescriptor().equals(POLYMORPHIC_SIGNATURE)) {
                return true;
            }
        }
        return false;
    }

    @Override
    public boolean hasTypeParameters() {
        return hasTypeParameters;
    }

    @Override
    public List<TypeParameter> getTypeParameters() {
        if (typeParameters == null) {
            signatureInfo.get();
        }
        return unsafeCast(typeParameters);
    }

    @Nullable
    @Override
    public Field findConstant(Object value, boolean isStatic) {
        Map<Object, Field> lookup = isStatic ? getStaticConstantLookup() : getInstanceConstantLookup();

        Field field = lookup.get(value);
        if (field != null) return field;

        // TODO, inheritance!
        return null;
    }

    @Override
    public String toString() {
        String paramString = "<" + FastStream.of(getTypeParameters()).map(e -> "_").join(", ") + ">";
        return cNode.getClassName() + (hasTypeParameters ? paramString : "");
    }

    private static class SignatureInfo {

        public static final SignatureInfo NONE = new SignatureInfo();

        @Nullable
        public final ClassType superClass;
        public final List<ClassType> interfaces;

        private SignatureInfo() {
            superClass = null;
            interfaces = List.of();
        }

        private SignatureInfo(ClassType superClass, List<ClassType> interfaces) {
            this.superClass = superClass;
            this.interfaces = interfaces;
        }
    }

}
