package net.covers1624.coffeegrinder.util.asm;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Label;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.InsnList;
import org.objectweb.asm.tree.LabelNode;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

/**
 * Represents a logical list of {@link Label}s as described in Bytecode.
 * <p>
 * Created by covers1624 on 6/4/21.
 */
public class LabelRange {

    public final Label start;
    public final Label end;

    public final List<Label> range;
    private final Set<Label> _range;

    public LabelRange(List<Label> range) {
        if (range.isEmpty()) throw new IllegalArgumentException("Expected non empty range.");

        start = range.get(0);
        end = range.get(range.size() - 1);
        this.range = ImmutableList.copyOf(range);
        _range = ImmutableSet.copyOf(range);
    }

    /**
     * Gets the first label included in this range.
     *
     * @return The {@link Label}.
     */
    public Label getStart() {
        return start;
    }

    /**
     * Gets the last label included in this range.
     * This may be the same as {@link #getStart()}
     *
     * @return The {@link Label}.
     */
    public Label getEnd() {
        return end;
    }

    /**
     * Gets the ordered list of labels included in this range.
     * This is exactly how they appear in Bytecode, and are not sorted.
     *
     * @return The range.
     */
    public List<Label> getRange() {
        return range;
    }

    /**
     * Checks if a given label is contained within the range.
     *
     * @param label The label to check.
     * @return If the label is contained within the range.
     */
    public boolean containsLabel(Label label) {
        return _range.contains(label);
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;

        LabelRange that = (LabelRange) o;

        return _range.equals(that._range);
    }

    @Override
    public int hashCode() {
        return _range.hashCode();
    }

    /**
     * Computes a {@link LabelRange} from the given {@link InsnList}.
     *
     * @param list         The {@link InsnList} to extract the range from.
     * @param start        The first {@link Label} to be included in the range.
     * @param end          The last {@link Label} to be included in the range.
     * @param exclusiveEnd If the provided end {@link Label} should be exclusive or inclusive.
     * @return The built range.
     * @throws IllegalStateException If either the start or end label are
     *                               not included in the provided {@link InsnList}.
     */
    public static LabelRange compute(InsnList list, Label start, Label end, boolean exclusiveEnd) {
        List<Label> labels = new ArrayList<>();
        for (AbstractInsnNode insnNode : list) {
            if (insnNode instanceof LabelNode) {
                LabelNode labelNode = (LabelNode) insnNode;
                Label label = labelNode.getLabel();
                if (labels.isEmpty() && label != start) {
                    continue;
                }
                labels.add(label);
                if (label == end) {
                    if (exclusiveEnd) labels.remove(label);
                    break;
                }
            }
        }
        if (labels.isEmpty()) {
            throw new IllegalStateException("Unable to find start label.");
        }
        if (!exclusiveEnd && !labels.contains(end)) {
            throw new IllegalStateException("Unable to find end label.");
        }
        return new LabelRange(labels);
    }

    /**
     * Get the label declared immediately after the target label.
     *
     * @param list   The {@link InsnList} to search.
     * @param target The target label to search for.
     * @return The label after the target.
     */
    public static Label getLabelAfter(InsnList list, Label target) {
        boolean foundLabel = false;
        for (AbstractInsnNode insnNode : list) {
            if (insnNode instanceof LabelNode) {
                LabelNode labelNode = (LabelNode) insnNode;
                Label label = labelNode.getLabel();
                if (label == target) {
                    foundLabel = true;
                    continue;
                }
                if (foundLabel) {
                    return label;
                }
            }
        }
        throw new IllegalStateException("Unable to find next label.");
    }

    public static @Nullable LabelRange extractRange(InsnList list) {
        List<Label> labels = new ArrayList<>();
        for (AbstractInsnNode insnNode : list) {
            if (insnNode instanceof LabelNode) {
                labels.add(((LabelNode) insnNode).getLabel());
            }
        }
        if (labels.isEmpty()) return null;
        return new LabelRange(labels);
    }
}
