import type {
  Table,
  Row,
  RowData,
  RowModel,
} from '@tanstack/table-core';

import {
  createRow,
  getMemoOptions,
  memo,
} from '@tanstack/table-core';

export interface CustomOrderedGroup {
  id: string;
  label: string;
  orderedKeys: string[];
}

export interface CustomOrderedGroupsOptions<TData extends RowData> {
  groups: CustomOrderedGroup[];
  showUngroupedEvenIfEmpty?: boolean;
  uniqueKeyColumnId: keyof TData;
}

export const UNGROUPED_GROUP_ID = '__ungrouped__';

/**
 * Extension of the TanStack Table API to allow for custom-ordered groups.
 */
export default function getCustomOrderedGroupedRowModel<
  TData extends RowData,
>(
): (
  table: Table<TData>,
) => () => RowModel<TData> {
  return (table) =>
    memo(
      () => [table.getState().grouping, table.getPreGroupedRowModel()],
      (grouping, rowModel) => {
        if (!rowModel.flatRows.length || !grouping.length) return rowModel;

        // Parse options.
        const options = JSON.parse(grouping[0]) as CustomOrderedGroupsOptions<TData>;
        if (!options.groups.length || !options.uniqueKeyColumnId) return rowModel;

        const { groups, uniqueKeyColumnId } = options;

        // Get flat list of rows to work with.
        const rows = rowModel.rows.reduce((acc, row) => {
          if (row.subRows) return acc.concat(row, row.subRows);
          return acc.concat(row);
        }, [] as Row<TData>[])
          // Remove all group rows.
          .filter(row => row.id !== UNGROUPED_GROUP_ID && !groups.find(x => x.id === row.id));

        const groupUpRecursively = (
          rows: Row<TData>[],
          depth = 0,
          parentId?: string
        ) => {
          if (depth === 0) {
            // Create an array to store the grouped rows
            const groupedRows: Row<TData>[] = [];

            // Iterate over each group and create a row for it
            groups.forEach((group, index) => {
              const id = group.id;
              const groupRows = rows.filter(row => {
                const rowKey = (row.original as any)[uniqueKeyColumnId];

                // If a key appears in multiple groups, it should only be included in the group
                // with the highest alphabetical ID.
                return group.orderedKeys.includes(rowKey) &&
                  groups.filter(g => g.orderedKeys.includes(rowKey))
                    .sort((a, b) => b.id.localeCompare(a.id))[0].id === id;
              });
              const subRows = groupUpRecursively(groupRows, depth + 1, id);
              const row = createRow(table, id, {} as TData, index, depth, undefined, parentId);

              Object.assign(row, {
                depth,
                groupingColumnId: uniqueKeyColumnId,
                groupingValue: group.label,
                subRows,
                leafRows: subRows,
                getValue: () => undefined,
              });

              groupedRows.push(row);
            });

            // Handle ungrouped rows
            const ungroupedRows = rows.filter(
              (row) => !groups.some((group) => group.orderedKeys.includes((row.original as any)[uniqueKeyColumnId]))
            );

            if (ungroupedRows.length > 0 || options.showUngroupedEvenIfEmpty) {
              const ungroupedGroup = createRow(
                table,
                UNGROUPED_GROUP_ID,
                {} as TData,
                groups.length,
                0,
                undefined,
                undefined
              );

              Object.assign(ungroupedGroup, {
                groupingColumnId: uniqueKeyColumnId,
                groupingValue: 'Ungrouped',
                subRows: ungroupedRows,
                leafRows: ungroupedRows,
                getValue: () => undefined,
              });

              groupedRows.push(ungroupedGroup);
            }

            return groupedRows;
          }

          // For all other levels, apply custom ordering based on the parent group
          const parentGroup = groups.find((group) => group.id === parentId);
          if (parentGroup) {
            return parentGroup.orderedKeys.map((key) => {
              const row = rows.find((row) => (row.original as any)[uniqueKeyColumnId] === key);
              if (row) {
                row.depth = depth;
                const subRows = groupUpRecursively(row.subRows || [], depth + 1, row.id);
                row.subRows = subRows;
                return row;
              }
              return undefined;
            }).filter((row) => row !== undefined) as Row<TData>[];
          }

          // If no parent group is found, return an empty array
          return [];
        };

        const groupedRows = groupUpRecursively(rows, 0);

        const groupedFlatRows = groupedRows.reduce((acc, row) => {
          if (row.subRows) {
            return acc.concat(row, row.subRows);
          }
          return acc.concat(row);
        }, [] as Row<TData>[]);

        const groupedRowsById = groupedFlatRows.reduce((acc, row) => {
          acc[row.id] = row;
          return acc;
        }, {} as Record<string, Row<TData>>);

        rowModel.rows = groupedRows;
        rowModel.flatRows = groupedFlatRows;
        rowModel.rowsById = groupedRowsById;

        return rowModel;
      },
      getMemoOptions(table.options, 'debugTable', 'getCustomOrderedGroupedRowModel', () => {})
    );
}
