import { createContext, useContext, useMemo } from "react"
import Table from "@mui/material/Table"
import TableBody from "@mui/material/TableBody"
import TableCell from "@mui/material/TableCell"
import TableHead from "@mui/material/TableHead"
import TableRow from "@mui/material/TableRow"
import TableSortLabel from "@mui/material/TableSortLabel"
import TableContainer from "@mui/material/TableContainer"
import TablePagination from "@mui/material/TablePagination"
import Paper from "@mui/material/Paper"
import Typography from "@mui/material/Typography"
import CircularProgress from "@mui/material/CircularProgress"
import Box from "@mui/material/Box"
import { makeStyles } from "tss-react/mui"

import EmptyState from "./EmptyState"
import { PaginatedTableProps, TableEntity } from "./types"
import { once } from "lodash"

type TableContextType<T extends TableEntity> = Pick<
  PaginatedTableProps<T>,
  | "columns"
  | "userRole"
  | "sortBy"
  | "sortDirection"
  | "onSortClick"
  | "records"
  | "recordsLoading"
  | "emptyMessage"
  | "loadingQueryKey"
  | "RenderRow"
>

const useStyles = makeStyles()(theme => ({
  table: {
    marginTop: theme.spacing(2),
  },
  tableHeader: {
    backgroundColor: "#EBEBEB",
  },
  loadingContainer: {
    display: "flex",
    flexDirection: "column",
    padding: theme.spacing(4, "20%"),
  },
  loadingContent: {
    margin: theme.spacing(1, "auto"),
  },
}))

// using once() to memoize the context creation, this allows us to use generics properly
const createTableContext = once(<T extends TableEntity>() => createContext<TableContextType<T> | null>(null))
const useTableContext = <T extends TableEntity>() => {
  const context = useContext(createTableContext<T>())
  if (context === null) {
    throw new Error("useTableContext must be used within a TableContextProvider")
  }
  return context
}

const TableHeadRow = <T extends TableEntity>() => {
  const { columns, userRole, sortBy, sortDirection, onSortClick } = useTableContext<T>()
  return (
    <TableRow>
      {columns?.map(column => {
        if (column.restrictView && !column.restrictView(userRole)) {
          return
        }

        const cellKey = `table-header-cell-${column.id}`
        return (
          <TableCell
            key={cellKey}
            data-test={cellKey}
            align={column.align ?? "left"}
            padding={column.disablePadding ? "none" : "normal"}
            style={column.style ?? {}}
          >
            {column.sortable === true ? (
              <TableSortLabel
                active={sortBy === column.id ? true : false}
                direction={sortDirection}
                onClick={() => {
                  onSortClick(column.id)
                }}
              >
                {column.text}
              </TableSortLabel>
            ) : (
              column.text
            )}
          </TableCell>
        )
      })}
    </TableRow>
  )
}

const LoadingCaption = () => {
  const { classes } = useStyles()

  return (
    <caption>
      <Box className={classes.loadingContainer} data-test="table-loading-indicator">
        <CircularProgress className={classes.loadingContent} />
        <Typography variant="h5" className={classes.loadingContent}>
          Loading...
        </Typography>
      </Box>
    </caption>
  )
}

const TableBodyContent = <T extends TableEntity>({ ...props }) => {
  const { records, columns, userRole, emptyMessage } = useTableContext<T>()

  if (records?.length) {
    return (
      <TableBody>
        <TableBodyRows<T> records={records} columns={columns} userRole={userRole} {...props} />
      </TableBody>
    )
  } else {
    return (
      <caption>
        <EmptyState message={emptyMessage} />
      </caption>
    )
  }
}

const TableBodyRows = <T extends TableEntity>({ ...props }) => {
  const { records } = useTableContext<T>()
  return (
    records?.map((record, recordIndex) => {
      return <TableBodyRow<T> key={`table-row-${recordIndex}`} record={record} {...props} />
    }) || <></>
  )
}

const TableBodyRow = <T extends TableEntity>({ record, ...props }: { record: T; [key: string]: unknown }) => {
  const { columns, userRole, RenderRow } = useTableContext<T>()

  const renderCell = () => {
    return columns?.map((column, columnIndex) => {
      if (typeof column?.restrictView === "function" && !column.restrictView(userRole)) {
        return
      }

      const CustomTableCell = column.tableCellComponent ?? TableCell
      const CellContentComponent = column.cellContentComponent
      return (
        <CustomTableCell
          data-test={`table-cell-${column.id}`}
          key={`table-column-${columnIndex}`}
          align={column?.align || "left"}
          record={record}
        >
          <CellContentComponent record={record} {...props} />
        </CustomTableCell>
      )
    })
  }

  return RenderRow ? (
    <RenderRow record={record}>{renderCell()}</RenderRow>
  ) : (
    <TableRow hover data-test={record?.pk ? `table-row-${record.pk}` : `table-row`}>
      {renderCell()}
    </TableRow>
  )
}

const PaginatedTable = <T extends TableEntity>({
  records,
  columns,
  onSortClick,
  sortDirection,
  sortBy,
  page,
  totalCount,
  onPageChange,
  onRowsPerPageChange,
  userRole,
  emptyMessage = "No Data",
  rowsPerPageOptions = [5, 10, 50],
  loadingQueryKey,
  pageSize,
  recordsLoading = false,
  RenderRow,
  tableLayout = "auto",
  ...props
}: PaginatedTableProps<T>) => {
  const { classes } = useStyles()
  const context = useMemo<TableContextType<T>>(
    () => ({
      columns,
      userRole,
      sortBy,
      sortDirection,
      onSortClick,
      records,
      recordsLoading,
      emptyMessage,
      loadingQueryKey,
      RenderRow,
    }),
    [
      columns,
      userRole,
      sortBy,
      sortDirection,
      onSortClick,
      records,
      recordsLoading,
      emptyMessage,
      loadingQueryKey,
      RenderRow,
    ]
  )
  const TableContext = createTableContext<T>()

  return (
    <TableContext.Provider value={context}>
      <TableContainer component={Paper} className={classes.table} data-test="table">
        <Table sx={{ tableLayout }}>
          <TableHead className={classes.tableHeader}>
            <TableHeadRow<T> />
          </TableHead>
          {recordsLoading ? <LoadingCaption /> : <TableBodyContent<T> {...props} />}
        </Table>
      </TableContainer>
      <TablePagination
        onRowsPerPageChange={e => {
          onRowsPerPageChange(Number(e.target.value))
        }}
        rowsPerPageOptions={rowsPerPageOptions}
        component="div"
        count={totalCount ?? 0}
        rowsPerPage={pageSize ?? 50}
        // since paging is done on backend, when we are loading next page there is technically
        // "no page", setting it to 0 when loading prevents dev error in console
        page={recordsLoading ? 0 : page}
        onPageChange={(_, newPage) => {
          onPageChange(newPage)
        }}
      ></TablePagination>
    </TableContext.Provider>
  )
}

export { PaginatedTable as default }
