package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.AccessFlag;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.ClassDecl.RecordComponentDecl;
import net.covers1624.coffeegrinder.bytecode.insns.tags.CompactConstructorTag;
import net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformer;
import net.covers1624.coffeegrinder.type.*;
import net.covers1624.coffeegrinder.type.AnnotationSupplier.TypeAnnotationLocation;
import net.covers1624.quack.collection.ColUtils;
import net.covers1624.quack.collection.FastStream;
import net.covers1624.quack.util.JavaVersion;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;

import java.lang.annotation.ElementType;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.*;

/**
 * Created by covers1624 on 20/4/23.
 */
public class RecordTransformer implements ClassTransformer {

    private final @Nullable ClassType objectMethods;

    public RecordTransformer(TypeResolver typeResolver) {
        objectMethods = typeResolver.tryResolveClassDecl("java/lang/runtime/ObjectMethods");
    }

    @Override
    public void transform(ClassDecl cInsn, ClassTransformContext ctx) {
        ClassType type = cInsn.getClazz();
        if (!type.isRecord() || objectMethods == null) return;

        // Records use all instance fields as their record components.
        List<FieldDecl> componentFields = cInsn.getFieldMembers()
                .filterNot(e -> e.getField().isStatic())
                .toList();
        MethodDecl mainCtor = findMainConstructor(cInsn, componentFields);

        buildRecordComponents(cInsn, componentFields, mainCtor);

        // Find the implicit implementations of toString, hashCode and equals.
        List<MethodDecl> implicitMethods = FastStream.of(
                        findDefaultMethod(cInsn, "toString", "()Ljava/lang/String;"),
                        findDefaultMethod(cInsn, "hashCode", "()I"),
                        findDefaultMethod(cInsn, "equals", "(Ljava/lang/Object;)Z")
                )
                .filter(Objects::nonNull)
                .toList();

        // Determine which accessors are implicit and strip them.
        detectAndStripImplicitAccessors(cInsn, FastStream.of(implicitMethods).firstOrDefault(), ctx);

        implicitMethods.forEach(e -> removeDefaultMethod(e, ctx));

        transformCanonicalConstructor(cInsn, mainCtor, ctx);
    }

    private void buildRecordComponents(ClassDecl cInsn, List<FieldDecl> componentFields, MethodDecl mainCtor) {
        for (int index = 0; index < componentFields.size(); index++) {
            FieldDecl fieldDecl = componentFields.get(index);
            Field field = fieldDecl.getField();
            AType type = field.getType();
            String name = field.getName();

            MethodDecl accessor = cInsn.getMethodMembers()
                    .filter(e -> e.getMethod().getName().equals(name)
                                 && e.getMethod().getParameters().isEmpty()
                                 && e.getMethod().getReturnType().equals(type)
                    )
                    .only();

            TypeAnnotationData typeAnnotations = new TypeAnnotationData();
            List<AnnotationData> regularAnnotations = new ArrayList<>();

            // Parse Type annotations.
            // The field annotations are always declared on the component.
            field.getAnnotationSupplier()
                    .parseTypeAnnotations(typeAnnotations, TypeAnnotationLocation.FIELD, type, 0);
            // The constructor parameter annotations are always declared on the component.
            mainCtor.getMethod()
                    .getAnnotationSupplier()
                    .parseTypeAnnotations(typeAnnotations, TypeAnnotationLocation.PARAMETER, type, index);

            // Parse 'regular' annotations.
            // As above, both are always declared on the component.
            field.getAnnotationSupplier().parseAnnotations(regularAnnotations, typeAnnotations);
            Parameter parameter = mainCtor.getMethod().getParameters().get(index);
            parameter.getAnnotationSupplier().parseAnnotations(regularAnnotations, typeAnnotations);

            cInsn.recordComponents.add(new RecordComponentDecl(
                    fieldDecl,
                    index == componentFields.size() - 1 && mainCtor.getMethod().getAccessFlags().get(AccessFlag.VARARGS),
                    accessor,
                    typeAnnotations,
                    regularAnnotations
            ));
        }
    }

    private void detectAndStripImplicitAccessors(ClassDecl cInsn, @Nullable MethodDecl firstImplicitMethod, ClassTransformContext ctx) {
        for (RecordComponentDecl component : cInsn.recordComponents.reversed()) {
            Field field = component.field.getField();
            AType type = field.getType();
            assert component.accessor != null;
            // If the body of the accessor is not the default pattern, it must remain.
            if (!isAccessorBodySynthetic(component.accessor, component.field)) continue;

            AnnotationSupplier supplier = component.accessor
                    .getMethod()
                    .getAnnotationSupplier();

            // Parse the type/regular annotations for merge query.
            TypeAnnotationData tAnnotations = new TypeAnnotationData();
            List<AnnotationData> rAnnotations = new ArrayList<>();
            supplier.parseTypeAnnotations(tAnnotations, TypeAnnotationLocation.METHOD_RETURN, type, 0);
            supplier.parseAnnotations(rAnnotations, tAnnotations);

            // Check if the annotation set is equal or does not contain annotations with are not applicable to
            // the field/parameter and vice versa.
            if (!canMergeAnnotations(component, rAnnotations, tAnnotations)) continue;

            if (firstImplicitMethod != null && !isDeclaredLater(component.accessor, firstImplicitMethod)) continue;
            if (firstImplicitMethod == null && (component.accessor.getNextSiblingOrNull() instanceof MethodDecl)) continue;

            // The annotation set is applicable. Reparse it into the component.
            supplier.parseTypeAnnotations(component.typeAnnotations, TypeAnnotationLocation.METHOD_RETURN, type, 0);
            supplier.parseAnnotations(component.regularAnnotations, component.typeAnnotations);

            // Strip it! \o/
            ctx.pushStep("Remove default accessor for: " + component.field.getField().getName());
            component.accessor.remove();
            component.accessor = null;
            ctx.popStep();
        }
    }

    private boolean isAccessorBodySynthetic(MethodDecl accessor, FieldDecl fieldDecl) {
        // First instruction should be a return.
        if (!(accessor.getBody().getEntryPoint().getFirstChildOrNull() instanceof Return ret)) return false;

        // The return must reference the record component field.
        return matchLoadField(ret.getValue(), fieldDecl.getField()) != null;
    }

    private boolean canMergeAnnotations(RecordComponentDecl component, List<AnnotationData> rAnnotations, TypeAnnotationData tAnnotations) {
        List<AnnotationData> componentTypeAnnotations = component.typeAnnotations.getLeftMost().annotations;
        List<AnnotationData> methodTypeAnnotations = tAnnotations.getLeftMost().annotations;

        if (areAnnotationsIncompatible(rAnnotations, component.regularAnnotations, ElementType.FIELD, ElementType.PARAMETER)) return false;
        if (areAnnotationsIncompatible(methodTypeAnnotations, componentTypeAnnotations, ElementType.FIELD, ElementType.PARAMETER)) return false;

        if (areAnnotationsIncompatible(component.regularAnnotations, rAnnotations, ElementType.METHOD)) return false;
        if (areAnnotationsIncompatible(componentTypeAnnotations, methodTypeAnnotations, ElementType.METHOD)) return false;

        return true;
    }

    private boolean areAnnotationsIncompatible(List<AnnotationData> toCheck, List<AnnotationData> target, ElementType... targetLocations) {
        for (AnnotationData annotation : toCheck) {
            // Annotation already exists in target location, it is allowed.
            if (target.contains(annotation)) continue;

            // Annotation does not exist, and it applies to one of the target locations.
            if (doesApplyTo(annotation, targetLocations)) {
                // We are not able to merge the annotations into the target.
                return true;
            }
        }
        return false;
    }

    private boolean doesApplyTo(AnnotationData annotation, ElementType... required) {
        List<ElementType> types = annotation.type().getAnnotationTargets();
        // If the @Target annotation is missing, annotation applies everywhere.
        if (types.isEmpty()) return true;

        for (ElementType type : required) {
            if (types.contains(type)) {
                return true;
            }
        }
        return false;
    }

    private @Nullable MethodDecl findDefaultMethod(ClassDecl cInsn, String mName, String desc) {
        assert objectMethods != null;
        MethodDecl method = cInsn.getMethod(mName, Type.getMethodType(desc));
        if (!(method.getBody().getEntryPoint().getFirstChildOrNull() instanceof Return ret)) return null;
        // These implicit methods all use INVOKE_DYNAMIC to call their respective functions in ObjectMethods.
        InvokeDynamic indy = InvokeMatching.matchInvokeDynamic(ret.getValue(), objectMethods, "bootstrap");
        if (indy == null || !indy.name.equals(mName)) return null;

        return method;
    }

    private static void removeDefaultMethod(@Nullable MethodDecl method, ClassTransformContext ctx) {
        if (method == null) return;

        ctx.pushStep("Remove default " + method.getMethod().getName() + " method.");
        method.remove();
        ctx.popStep();
    }

    private void transformCanonicalConstructor(ClassDecl cInsn, MethodDecl ctor, ClassTransformContext ctx) {
        ctx.pushStep("Remove redundant super call.");
        ImplicitConstructorCleanup.removeRedundantSuperCall(ctor);
        ctx.popStep();

        List<FieldDecl> recordMembers = cInsn.getFieldMembers()
                .filter(e -> !e.getField().isStatic())
                .toList();

        List<ParameterVariable> parameters = ctor.parameters.toList();

        // Ctor body will have return at the bottom of last block.
        Instruction insn = ctor.getBody().blocks.last().getLastChildOrNull();
        if (!(insn instanceof Return)) return; // Might be a throw at the bottom, with return and assignments in another path.
        insn = insn.getPrevSiblingOrNull();

        // Find all the to the record members, they should be in order. (we search up from the bottom.)
        List<Store> stores = new ArrayList<>(recordMembers.size());
        for (int i = recordMembers.size() - 1; i >= 0; i--, insn = insn.getPrevSiblingOrNull()) {
            // The assignments should store to the field, loading from the parameter.
            Store store = matchStoreField(insn, recordMembers.get(i).getField());
            if (store == null) break;
            if (matchLoadLocal(store.getValue(), parameters.get(i)) == null) break;

            stores.add(store);
        }
        boolean hasUserCode = insn != null;

        // We did not find them all.
        if (stores.size() != recordMembers.size()) return;

        // If any parameter has FINAL, we can't do anything.
        // This flag either came from the --parameters javac argument, or Java21.
        if (ColUtils.anyMatch(parameters, e -> e.parameter.isFinal())) return;

        Optional<Boolean> forceCompact = Optional.empty();
        if (ctx.classVersion.isAtLeast(JavaVersion.JAVA_21) && !parameters.isEmpty()) {
            forceCompact = Optional.of(parameters.getFirst().parameter.isMandated());
        }

        if (!forceCompact.orElse(false) && !hasUserCode && !areFlagsDifferent(cInsn, ctor)) {
            ctx.pushStep("Delete canonical constructor.");
            ctor.remove();
            ctx.popStep();
        } else if (forceCompact.orElse(true)) {
            ctx.pushStep("Create compact record constructor.");
            stores.forEach(Instruction::remove);
            ctor.setTag(new CompactConstructorTag());
            ctx.popStep();
        }
    }

    private static boolean areFlagsDifferent(ClassDecl cDecl, MethodDecl mDecl) {
        return AccessFlag.getAccess(cDecl.getClazz().getAccessFlags()) != AccessFlag.getAccess(mDecl.getMethod().getAccessFlags());
    }

    private static MethodDecl findMainConstructor(ClassDecl cInsn, List<FieldDecl> componentFields) {
        // Find the canonical ctor, Record classes will have a single ctor with all instance fields as parameters, in the order they are declared.
        return cInsn.getMethodMembers()
                .filter(e -> e.getMethod().getName().equals("<init>"))
                .filter(e -> matchParameterTypes(e.getMethod().getParameters(), componentFields))
                .only();
    }

    private static boolean matchParameterTypes(List<Parameter> parameters, List<FieldDecl> componentFields) {
        if (parameters.size() != componentFields.size()) return false;

        for (int i = 0; i < componentFields.size(); i++) {
            if (!componentFields.get(i).getField().getType().equals(parameters.get(i).getType())) {
                return false;
            }
        }
        return true;
    }

    // If A is declared after B
    private static boolean isDeclaredLater(Instruction a, Instruction b) {
        while (b != null) {
            if (b == a) return true;
            b = b.getNextSiblingOrNull();
        }
        return false;
    }
}
