/*
 * Copyright (c) 2016, 2017, 2018, 2019 FabricMC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.fabricmc.fabric.impl.registry.sync.packet;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import Id;
import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.objects.Object2IntLinkedOpenHashMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import org.jetbrains.annotations.Nullable;
import net.fabricmc.fabric.api.networking.v1.PacketByteBufs;
import net.minecraft.class_2540;
import net.minecraft.class_2960;
import net.minecraft.class_8710;
import net.minecraft.class_9139;

/**
 * A more optimized method to sync registry ids to client.
 * Produce smaller packet than old nbt-based method.
 *
 * <p>This method optimize the packet in multiple way:
 * <ul>
 *     <li>Directly write into the buffer instead of using an nbt;</li>
 *     <li>Group all {@link class_2960} with same namespace together and only send those unique namespaces once for each group;</li>
 *     <li>Group consecutive rawIds together and only send the difference of the first rawId and the last rawId of the bulk before.
 *     This is based on the assumption that mods generally register all of their object at once,
 *     therefore making the rawIds somewhat densely packed.</li>
 * </ul>
 *
 * <p>This method also split into multiple packets if it exceeds the limit, defaults to 1 MB.
 */
public class DirectRegistryPacketHandler extends RegistryPacketHandler<DirectRegistryPacketHandler.Payload> {
	/**
	 * @see net.minecraft.network.packet.s2c.play.CustomPayloadS2CPacket#MAX_PAYLOAD_SIZE
	 */
	@SuppressWarnings("JavadocReference")
	private static final int MAX_PAYLOAD_SIZE = Integer.getInteger("fabric.registry.direct.maxPayloadSize", 0x100000);

	@Nullable
	private class_2540 combinedBuf;

	@Nullable
	private Map<class_2960, Object2IntMap<class_2960>> syncedRegistryMap;

	private boolean isPacketFinished = false;
	private int totalPacketReceived = 0;

	@Override
	public class_8710.class_9154<DirectRegistryPacketHandler.Payload> getPacketId() {
		return Payload.ID;
	}

	@Override
	public void sendPacket(Consumer<DirectRegistryPacketHandler.Payload> sender, Map<class_2960, Object2IntMap<class_2960>> registryMap) {
		class_2540 buf = PacketByteBufs.create();

		// Group registry ids with same namespace.
		Map<String, List<class_2960>> regNamespaceGroups = registryMap.keySet().stream()
				.collect(Collectors.groupingBy(class_2960::method_12836));

		buf.method_10804(regNamespaceGroups.size());

		regNamespaceGroups.forEach((regNamespace, regIds) -> {
			buf.method_10814(optimizeNamespace(regNamespace));
			buf.method_10804(regIds.size());

			for (class_2960 regId : regIds) {
				buf.method_10814(regId.method_12832());

				Object2IntMap<class_2960> idMap = registryMap.get(regId);

				// Sort object ids by its namespace. We use linked map here to keep the original namespace ordering.
				Map<String, List<Object2IntMap.Entry<class_2960>>> idNamespaceGroups = idMap.object2IntEntrySet().stream()
						.collect(Collectors.groupingBy(e -> e.getKey().method_12836(), LinkedHashMap::new, Collectors.toCollection(ArrayList::new)));

				buf.method_10804(idNamespaceGroups.size());

				int lastBulkLastRawId = 0;

				for (Map.Entry<String, List<Object2IntMap.Entry<class_2960>>> idNamespaceEntry : idNamespaceGroups.entrySet()) {
					// Make sure the ids are sorted by its raw id.
					List<Object2IntMap.Entry<class_2960>> idPairs = idNamespaceEntry.getValue();
					idPairs.sort(Comparator.comparingInt(Object2IntMap.Entry::getIntValue));

					// Group consecutive raw ids together.
					List<List<Object2IntMap.Entry<class_2960>>> bulks = new ArrayList<>();

					Iterator<Object2IntMap.Entry<class_2960>> idPairIter = idPairs.iterator();
					List<Object2IntMap.Entry<class_2960>> currentBulk = new ArrayList<>();
					Object2IntMap.Entry<class_2960> currentPair = idPairIter.next();
					currentBulk.add(currentPair);

					while (idPairIter.hasNext()) {
						currentPair = idPairIter.next();

						if (currentBulk.get(currentBulk.size() - 1).getIntValue() + 1 != currentPair.getIntValue()) {
							bulks.add(currentBulk);
							currentBulk = new ArrayList<>();
						}

						currentBulk.add(currentPair);
					}

					bulks.add(currentBulk);

					buf.method_10814(optimizeNamespace(idNamespaceEntry.getKey()));
					buf.method_10804(bulks.size());

					for (List<Object2IntMap.Entry<class_2960>> bulk : bulks) {
						int firstRawId = bulk.get(0).getIntValue();
						int bulkRawIdStartDiff = firstRawId - lastBulkLastRawId;

						buf.method_10804(bulkRawIdStartDiff);
						buf.method_10804(bulk.size());

						for (Object2IntMap.Entry<class_2960> idPair : bulk) {
							buf.method_10814(idPair.getKey().method_12832());

							lastBulkLastRawId = idPair.getIntValue();
						}
					}
				}
			}
		});

		// Split the packet to multiple MAX_PAYLOAD_SIZEd buffers.
		int readableBytes = buf.readableBytes();
		int sliceIndex = 0;

		while (sliceIndex < readableBytes) {
			int sliceSize = Math.min(readableBytes - sliceIndex, MAX_PAYLOAD_SIZE);
			class_2540 slicedBuf = PacketByteBufs.slice(buf, sliceIndex, sliceSize);
			sender.accept(createPayload(slicedBuf));
			sliceIndex += sliceSize;
		}

		// Send an empty buffer to mark the end of the split.
		sender.accept(createPayload(PacketByteBufs.empty()));
	}

	@Override
	public void receivePayload(Payload payload) {
		Preconditions.checkState(!isPacketFinished);
		totalPacketReceived++;

		if (combinedBuf == null) {
			combinedBuf = PacketByteBufs.create();
		}

		byte[] data = payload.data();

		if (data.length != 0) {
			combinedBuf.method_52983(data);
			return;
		}

		isPacketFinished = true;

		computeBufSize(combinedBuf);
		syncedRegistryMap = new LinkedHashMap<>();
		int regNamespaceGroupAmount = combinedBuf.method_10816();

		for (int i = 0; i < regNamespaceGroupAmount; i++) {
			String regNamespace = unoptimizeNamespace(combinedBuf.method_19772());
			int regNamespaceGroupLength = combinedBuf.method_10816();

			for (int j = 0; j < regNamespaceGroupLength; j++) {
				String regPath = combinedBuf.method_19772();
				Object2IntMap<class_2960> idMap = new Object2IntLinkedOpenHashMap<>();
				int idNamespaceGroupAmount = combinedBuf.method_10816();

				int lastBulkLastRawId = 0;

				for (int k = 0; k < idNamespaceGroupAmount; k++) {
					String idNamespace = unoptimizeNamespace(combinedBuf.method_19772());
					int rawIdBulkAmount = combinedBuf.method_10816();

					for (int l = 0; l < rawIdBulkAmount; l++) {
						int bulkRawIdStartDiff = combinedBuf.method_10816();
						int bulkSize = combinedBuf.method_10816();

						int currentRawId = (lastBulkLastRawId + bulkRawIdStartDiff) - 1;

						for (int m = 0; m < bulkSize; m++) {
							currentRawId++;
							String idPath = combinedBuf.method_19772();
							idMap.put(new class_2960(idNamespace, idPath), currentRawId);
						}

						lastBulkLastRawId = currentRawId;
					}
				}

				syncedRegistryMap.put(new class_2960(regNamespace, regPath), idMap);
			}
		}

		combinedBuf.release();
		combinedBuf = null;
	}

	@Override
	public boolean isPacketFinished() {
		return isPacketFinished;
	}

	@Override
	public int getTotalPacketReceived() {
		Preconditions.checkState(isPacketFinished);
		return totalPacketReceived;
	}

	@Override
	@Nullable
	public Map<class_2960, Object2IntMap<class_2960>> getSyncedRegistryMap() {
		Preconditions.checkState(isPacketFinished);
		Map<class_2960, Object2IntMap<class_2960>> map = syncedRegistryMap;
		isPacketFinished = false;
		totalPacketReceived = 0;
		syncedRegistryMap = null;
		return map;
	}

	private DirectRegistryPacketHandler.Payload createPayload(class_2540 buf) {
		if (buf.readableBytes() == 0) {
			return new Payload(new byte[0]);
		}

		return new Payload(buf.array());
	}

	private static String optimizeNamespace(String namespace) {
		return namespace.equals(class_2960.field_33381) ? "" : namespace;
	}

	private static String unoptimizeNamespace(String namespace) {
		return namespace.isEmpty() ? class_2960.field_33381 : namespace;
	}

	public record Payload(byte[] data) implements RegistrySyncPayload {
		public static class_8710.class_9154<Payload> ID = new Id<>(new class_2960("fabric", "registry/sync/direct"));
		public static class_9139<class_2540, Payload> CODEC = class_8710.method_56484(Payload::write, Payload::new);

		Payload(class_2540 buf) {
			this(readAllBytes(buf));
		}

		private void write(class_2540 buf) {
			buf.method_52983(data);
		}

		private static byte[] readAllBytes(class_2540 buf) {
			byte[] bytes = new byte[buf.readableBytes()];
			buf.method_52979(bytes);
			return bytes;
		}

		@Override
		public Id<? extends class_8710> getId() {
			return ID;
		}
	}
}
