/*
 * Copyright (c) 2016, 2017, 2018, 2019 FabricMC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.fabricmc.fabric.mixin.registry.sync;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.mojang.serialization.Lifecycle;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.spongepowered.asm.mixin.Final;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Shadow;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;
import net.fabricmc.fabric.api.event.Event;
import net.fabricmc.fabric.api.event.EventFactory;
import net.fabricmc.fabric.api.event.registry.RegistryEntryAddedCallback;
import net.fabricmc.fabric.api.event.registry.RegistryEntryRemovedCallback;
import net.fabricmc.fabric.api.event.registry.RegistryIdRemapCallback;
import net.fabricmc.fabric.impl.registry.sync.ListenableRegistry;
import net.fabricmc.fabric.impl.registry.sync.RemapException;
import net.fabricmc.fabric.impl.registry.sync.RemapStateImpl;
import net.fabricmc.fabric.impl.registry.sync.RemappableRegistry;
import net.minecraft.class_2370;
import net.minecraft.class_2378;
import net.minecraft.class_2960;
import net.minecraft.class_5321;

@Mixin(class_2370.class)
public abstract class MixinIdRegistry<T> implements RemappableRegistry, ListenableRegistry {
	@Shadow
	@Final
	private ObjectList<T> field_26682;
	@Shadow
	@Final
	private Object2IntMap<T> field_26683;
	@Shadow
	@Final
	private BiMap<class_2960, T> entriesById;
	@Shadow
	@Final
	private BiMap<class_5321<T>, T> entriesByKey;
	@Shadow
	private int nextId;
	@Unique
	private static Logger FABRIC_LOGGER = LogManager.getLogger();

	@Unique
	private final Event<RegistryEntryAddedCallback> fabric_addObjectEvent = EventFactory.createArrayBacked(RegistryEntryAddedCallback.class,
			(callbacks) -> (rawId, id, object) -> {
				for (RegistryEntryAddedCallback callback : callbacks) {
					//noinspection unchecked
					callback.onEntryAdded(rawId, id, object);
				}
			}
	);

	@Unique
	private final Event<RegistryEntryRemovedCallback> fabric_removeObjectEvent = EventFactory.createArrayBacked(RegistryEntryRemovedCallback.class,
			(callbacks) -> (rawId, id, object) -> {
				for (RegistryEntryRemovedCallback callback : callbacks) {
					//noinspection unchecked
					callback.onEntryRemoved(rawId, id, object);
				}
			}
	);

	@Unique
	private final Event<RegistryIdRemapCallback> fabric_postRemapEvent = EventFactory.createArrayBacked(RegistryIdRemapCallback.class,
			(callbacks) -> (a) -> {
				for (RegistryIdRemapCallback callback : callbacks) {
					//noinspection unchecked
					callback.onRemap(a);
				}
			}
	);

	@Unique
	private Object2IntMap<class_2960> fabric_prevIndexedEntries;
	@Unique
	private BiMap<class_2960, T> fabric_prevEntries;

	@Override
	public Event<RegistryEntryAddedCallback<T>> fabric_getAddObjectEvent() {
		//noinspection unchecked
		return (Event) fabric_addObjectEvent;
	}

	@Override
	public Event<RegistryEntryRemovedCallback<T>> fabric_getRemoveObjectEvent() {
		//noinspection unchecked
		return (Event) fabric_removeObjectEvent;
	}

	@Override
	public Event<RegistryIdRemapCallback<T>> fabric_getRemapEvent() {
		//noinspection unchecked
		return (Event) fabric_postRemapEvent;
	}

	// The rest of the registry isn't thread-safe, so this one need not be either.
	@Unique
	private boolean fabric_isObjectNew = false;

	@SuppressWarnings({"unchecked", "ConstantConditions"})
	@Inject(method = "set", at = @At("HEAD"))
	public void setPre(int id, class_5321<T> registryId, Object object, Lifecycle lifecycle, CallbackInfoReturnable info) {
		int indexedEntriesId = field_26683.getInt((T) object);

		if (indexedEntriesId >= 0) {
			throw new RuntimeException("Attempted to register object " + object + " twice! (at raw IDs " + indexedEntriesId + " and " + id + " )");
		}

		if (!entriesById.containsKey(registryId.method_29177())) {
			fabric_isObjectNew = true;
		} else {
			T oldObject = entriesById.get(registryId.method_29177());

			if (oldObject != null && oldObject != object) {
				int oldId = field_26683.getInt(oldObject);

				if (oldId != id) {
					throw new RuntimeException("Attempted to register ID " + registryId + " at different raw IDs (" + oldId + ", " + id + ")! If you're trying to override an item, use .set(), not .register()!");
				}

				fabric_removeObjectEvent.invoker().onEntryRemoved(oldId, registryId.method_29177(), oldObject);
				fabric_isObjectNew = true;
			} else {
				fabric_isObjectNew = false;
			}
		}
	}

	@SuppressWarnings("unchecked")
	@Inject(method = "set", at = @At("RETURN"))
	public void setPost(int id, class_5321<T> registryId, Object object, Lifecycle lifecycle, CallbackInfoReturnable info) {
		if (fabric_isObjectNew) {
			fabric_addObjectEvent.invoker().onEntryAdded(id, registryId.method_29177(), object);
		}
	}

	@Override
	public void remap(String name, Object2IntMap<class_2960> remoteIndexedEntries, RemapMode mode) throws RemapException {
		//noinspection unchecked, ConstantConditions
		class_2370<Object> registry = (class_2370<Object>) (Object) this;

		// Throw on invalid conditions.
		switch (mode) {
		case AUTHORITATIVE:
			break;
		case REMOTE: {
			List<String> strings = null;

			for (class_2960 remoteId : remoteIndexedEntries.keySet()) {
				if (!entriesById.keySet().contains(remoteId)) {
					if (strings == null) {
						strings = new ArrayList<>();
					}

					strings.add(" - " + remoteId);
				}
			}

			if (strings != null) {
				StringBuilder builder = new StringBuilder("Received ID map for " + name + " contains IDs unknown to the receiver!");

				for (String s : strings) {
					builder.append('\n').append(s);
				}

				throw new RemapException(builder.toString());
			}

			break;
		}
		case EXACT: {
			if (!entriesById.keySet().equals(remoteIndexedEntries.keySet())) {
				List<String> strings = new ArrayList<>();

				for (class_2960 remoteId : remoteIndexedEntries.keySet()) {
					if (!entriesById.keySet().contains(remoteId)) {
						strings.add(" - " + remoteId + " (missing on local)");
					}
				}

				for (class_2960 localId : registry.method_10235()) {
					if (!remoteIndexedEntries.keySet().contains(localId)) {
						strings.add(" - " + localId + " (missing on remote)");
					}
				}

				StringBuilder builder = new StringBuilder("Local and remote ID sets for " + name + " do not match!");

				for (String s : strings) {
					builder.append('\n').append(s);
				}

				throw new RemapException(builder.toString());
			}

			break;
		}
		}

		// Make a copy of the previous maps.
		// For now, only one is necessary - on an integrated server scenario,
		// AUTHORITATIVE == CLIENT, which is fine.
		// The reason we preserve the first one is because it contains the
		// vanilla order of IDs before mods, which is crucial for vanilla server
		// compatibility.
		if (fabric_prevIndexedEntries == null) {
			fabric_prevIndexedEntries = new Object2IntOpenHashMap<>();
			fabric_prevEntries = HashBiMap.create(entriesById);

			for (Object o : registry) {
				fabric_prevIndexedEntries.put(registry.method_10221(o), registry.method_10206(o));
			}
		}

		Int2ObjectMap<class_2960> oldIdMap = new Int2ObjectOpenHashMap<>();

		for (Object o : registry) {
			oldIdMap.put(registry.method_10206(o), registry.method_10221(o));
		}

		// If we're AUTHORITATIVE, we append entries which only exist on the
		// local side to the new entry list. For REMOTE, we instead drop them.
		switch (mode) {
		case AUTHORITATIVE: {
			int maxValue = 0;

			Object2IntMap<class_2960> oldRemoteIndexedEntries = remoteIndexedEntries;
			remoteIndexedEntries = new Object2IntOpenHashMap<>();

			for (class_2960 id : oldRemoteIndexedEntries.keySet()) {
				int v = oldRemoteIndexedEntries.getInt(id);
				remoteIndexedEntries.put(id, v);
				if (v > maxValue) maxValue = v;
			}

			for (class_2960 id : registry.method_10235()) {
				if (!remoteIndexedEntries.containsKey(id)) {
					FABRIC_LOGGER.warn("Adding " + id + " to saved/remote registry.");
					remoteIndexedEntries.put(id, ++maxValue);
				}
			}

			break;
		}
		case REMOTE: {
			// TODO: Is this what mods really want?
			Set<class_2960> droppedIds = new HashSet<>();

			for (class_2960 id : registry.method_10235()) {
				if (!remoteIndexedEntries.containsKey(id)) {
					Object object = registry.method_10223(id);
					int rid = registry.method_10206(object);

					droppedIds.add(id);

					// Emit RemoveObject events for removed objects.
					//noinspection unchecked
					fabric_getRemoveObjectEvent().invoker().onEntryRemoved(rid, id, (T) object);
				}
			}

			// note: indexedEntries cannot be safely remove()d from
			entriesById.keySet().removeAll(droppedIds);
			entriesByKey.keySet().removeIf(registryKey -> droppedIds.contains(registryKey.method_29177()));

			break;
		}
		}

		Int2IntMap idMap = new Int2IntOpenHashMap();

		for (Object o : field_26682) {
			class_2960 id = registry.method_10221(o);
			int rid = registry.method_10206(o);

			// see above note
			if (remoteIndexedEntries.containsKey(id)) {
				idMap.put(rid, remoteIndexedEntries.getInt(id));
			}
		}

		// entries was handled above, if it was necessary.
		field_26682.clear();
		field_26683.clear();
		nextId = 0;

		List<class_2960> orderedRemoteEntries = new ArrayList<>(remoteIndexedEntries.keySet());
		orderedRemoteEntries.sort(Comparator.comparingInt(remoteIndexedEntries::getInt));

		for (class_2960 identifier : orderedRemoteEntries) {
			int id = remoteIndexedEntries.getInt(identifier);
			T object = entriesById.get(identifier);

			// Warn if an object is missing from the local registry.
			// This should only happen in AUTHORITATIVE mode, and as such we
			// throw an exception otherwise.
			if (object == null) {
				if (mode != RemapMode.AUTHORITATIVE) {
					throw new RemapException(identifier + " missing from registry, but requested!");
				} else {
					FABRIC_LOGGER.warn(identifier + " missing from registry, but requested!");
				}

				continue;
			}

			// Add the new object, increment nextId to match.
			field_26682.size(Math.max(this.field_26682.size(), id + 1));
			field_26682.set(id, object);
			field_26683.put(object, id);

			if (nextId <= id) {
				nextId = id + 1;
			}
		}

		//noinspection unchecked
		fabric_getRemapEvent().invoker().onRemap(new RemapStateImpl(registry, oldIdMap, idMap));
	}

	@Override
	public void unmap(String name) throws RemapException {
		if (fabric_prevIndexedEntries != null) {
			List<class_2960> addedIds = new ArrayList<>();

			// Emit AddObject events for previously culled objects.
			for (class_2960 id : fabric_prevEntries.keySet()) {
				if (!entriesById.containsKey(id)) {
					assert fabric_prevIndexedEntries.containsKey(id);
					addedIds.add(id);
				}
			}

			entriesById.clear();
			entriesByKey.clear();

			entriesById.putAll(fabric_prevEntries);

			for (Map.Entry<class_2960, T> entry : fabric_prevEntries.entrySet()) {
				//noinspection unchecked
				entriesByKey.put(class_5321.method_29179(class_5321.method_29180(((class_2378) class_2378.field_11144).method_10221(this)), entry.getKey()), entry.getValue());
			}

			remap(name, fabric_prevIndexedEntries, RemapMode.AUTHORITATIVE);

			for (class_2960 id : addedIds) {
				fabric_getAddObjectEvent().invoker().onEntryAdded(field_26683.getInt(entriesById.get(id)), id, entriesById.get(id));
			}

			fabric_prevIndexedEntries = null;
			fabric_prevEntries = null;
		}
	}
}
