import { type LocalId, type NormalizedConnection, type NormalizeLocalIdHandler } from './types';
import { type EditorState } from '@atlaskit/editor-prosemirror/state';
import { type EditorView } from '@atlaskit/editor-prosemirror/view';
import { decodeName, getNodesSupportingFragmentMark } from './utils';
import { type Schema } from '@atlaskit/editor-prosemirror/model';

export class ReferentialityContext {
	/**
	 * @constructor
	 * @param view
	 * @param shallow A flag to help identify the depth of scan required. This default to true. A Shallow context assumes
	 * that nodes with data consumers cannot contain descendants with fragmentMarks or dataConsumers
	 */
	constructor(
		public readonly view: EditorView,
		private normalizeLocalId: NormalizeLocalIdHandler,
		public readonly shallow: boolean = true,
	) {
		const { doc, schema } = this.view.state;
		const { fragment, dataConsumer } = schema.marks;

		const nodesSupportingFragmentMark = getNodesSupportingFragmentMark(schema);
		const sourceToTargetMappings: Map<LocalId, Set<LocalId>> = new Map();

		doc.descendants((node, pos, parent) => {
			if (!nodesSupportingFragmentMark.has(node.type)) {
				return true;
			}

			const fragmentMark = fragment.isInSet(node.marks);

			// If node cannot be referenced by any means then abort.
			if (!fragmentMark && !node.attrs?.localId) {
				return true;
			}

			const normalizedId: LocalId | undefined = this.normalizeLocalId(
				node.attrs?.localId,
				fragmentMark?.attrs?.localId,
			);

			if (!normalizedId || this.connections.has(normalizedId)) {
				return true;
			}

			let decodedName: NormalizedConnection['decodedName'];

			if (!!fragmentMark?.attrs.name) {
				decodedName = decodeName(fragmentMark.attrs.name);
				if (decodedName) {
					const currentMax = this.maxNumberUsedInFragmentMarkName[decodedName.defaultName];
					if (currentMax === undefined || currentMax < decodedName.defaultNameNumber) {
						this.maxNumberUsedInFragmentMarkName[decodedName.defaultName] =
							decodedName.defaultNameNumber;
					}
				}
			} else {
				this.namelessNormalizedIds.add(normalizedId);
			}

			const consumer = dataConsumer.isInSet(node.marks);

			const ref = {
				node,
				pos,
				parent,
				normalizedId,
				name: fragmentMark?.attrs?.name ?? normalizedId,
				ids: new Set<LocalId>(),
				targets: new Set<LocalId>(),
				fragmentMark,
				dataConsumer: consumer,
				decodedName,
			} as NormalizedConnection;

			if (!!fragmentMark?.attrs?.localId) {
				ref.ids.add(fragmentMark.attrs.localId);
				this.connections.set(fragmentMark.attrs.localId, ref);
			}

			if (!!node.attrs?.localId) {
				ref.ids.add(node.attrs.localId);
				this.connections.set(node.attrs.localId, ref);
			}

			this.uniqueNormalizedIds.add(normalizedId);

			if (consumer) {
				// A lookup map of srcId -> targetId[]. This enables quickly finding all nodes
				// which consume a specific source.
				consumer.attrs.sources.forEach((src: LocalId) =>
					sourceToTargetMappings.set(
						src,
						sourceToTargetMappings.get(src)?.add(normalizedId) ?? new Set([normalizedId]),
					),
				);
			}

			// Do not descend into children of a node which has a dataConsumer attached. This assumes all children of the node cannot
			// be referenced by another node. Unless of course this is not a shallow context.
			// if consumer: true & shallow: true -> return false // don't descend
			// if consumer: false & shallow: true -> return true // descend
			// if consumer: true & shallow: false -> return true // descend
			// if consumer: false & shallow: false -> return true // descend
			return !(shallow && consumer);
		});

		// This 2nd-pass is used to normalize all the source ids. We need to have full context before we're able to do this.
		for (const [source, targets] of sourceToTargetMappings) {
			const connection = this.connections.get(source);
			if (connection) {
				this.uniqueSourceNormalizedIds.add(connection.normalizedId);
				connection.targets = new Set([...connection.targets, ...targets]);
			}
		}
	}

	/**
	 * This record is used to keep track of the max number in each fragment name.
	 */
	private maxNumberUsedInFragmentMarkName: Record<string, number> = {};

	/**
	 * This map enables quick lookup of all unnormalized connections
	 */
	private readonly connections: Map<LocalId, NormalizedConnection> = new Map<
		LocalId,
		NormalizedConnection
	>();

	/**
	 * This is a unique set containing only normalized localIds
	 */
	private readonly uniqueNormalizedIds = new Set<LocalId>();

	/**
	 * This is a unique set containing only normalized localIds
	 */
	private readonly uniqueSourceNormalizedIds = new Set<LocalId>();

	/**
	 * This is a unique set of normalized localIds which are nameless
	 */
	private readonly namelessNormalizedIds = new Set<LocalId>();

	/**
	 * Convenience helper to get the EditorState
	 */
	public get state(): EditorState {
		return this.view.state;
	}

	/**
	 * Convenience helper to get the Schema
	 */
	public get schema(): Schema {
		return this.view.state.schema;
	}

	public getById(localId: LocalId): NormalizedConnection | undefined {
		return this.connections.get(localId);
	}

	public hasId(localId: LocalId): boolean {
		return this.connections.has(localId);
	}

	public getMaxNumberUsedInFragmentMarkNames(): Record<string, number> {
		return { ...this.maxNumberUsedInFragmentMarkName };
	}

	public getMaxNumberUsedInFragmentMarkName(name: string): number {
		return this.maxNumberUsedInFragmentMarkName[name] ?? 0;
	}

	public updateMaxNumberUsedInFragmentMarkNames(name: string, value: number) {
		this.maxNumberUsedInFragmentMarkName[name] = value;
	}

	public *uniqueConnections(): Generator<NormalizedConnection> {
		for (const id of this.uniqueNormalizedIds) {
			const connection = this.connections.get(id);
			if (connection) {
				yield connection;
			}
		}
	}

	/**
	 * This generator will only return connections which contain DataConsumer sources.
	 */
	public *consumerOnlyConnections(): Generator<NormalizedConnection> {
		for (const id of this.uniqueNormalizedIds) {
			const connection = this.connections.get(id);
			if (connection && !!connection.dataConsumer?.attrs.sources.length) {
				yield connection;
			}
		}
	}

	/**
	 * This generator will only return connections which are sources
	 */
	public *sourceOnlyConnections(): Generator<NormalizedConnection> {
		for (const id of this.uniqueSourceNormalizedIds) {
			const connection = this.connections.get(id);
			if (connection) {
				yield connection;
			}
		}
	}

	/**
	 * This generator will return connections which do not contain a name or fragment
	 */
	public *namelessOnlyConnections(): Generator<NormalizedConnection> {
		for (const id of this.namelessNormalizedIds) {
			const connection = this.connections.get(id);
			if (connection) {
				yield connection;
			}
		}
	}
}
