import React, {
  useEffect,
  useMemo,
  useState,
  useImperativeHandle,
  forwardRef,
} from "react";
import {makeStyles} from "@material-ui/core/styles";
import TableContainer from "@material-ui/core/TableContainer";
import Table from "@material-ui/core/Table";
import TablePagination from "@material-ui/core/TablePagination";
import TableHead from "@material-ui/core/TableHead";
import TableBody from "@material-ui/core/TableBody";
import TableRow from "@material-ui/core/TableRow";
import TableCell from "@material-ui/core/TableCell";
import TableSortLabel from "@material-ui/core/TableSortLabel";
import Checkbox from "@material-ui/core/Checkbox";
import CircularProgress from "@material-ui/core/CircularProgress";
import styles from "./styles";

const useStyles = makeStyles(styles);

const DataGrid = (props, ref) => {
  const {
    rows,
    columns,
    disableSelectionOnClick = false,
    checkboxSelection = false,
    onSelectionChange = () => {},
    onPageChange = () => {},
    noPagination = false,
    rowCount = 0,
    pageSize = 0,
    rowsPerPageOptions = [20],
    windowHeight,
    loading = false,
  } = props;

  const classes = useStyles();
  const [selection, setSelection] = useState([]);
  const [page, setPage] = useState(0);
  const [order, setOrder] = useState("desc");
  const [orderBy, setOrderBy] = useState("");

  const columnsMap = useMemo(() => {
    const map = {};

    columns.forEach(column => (map[column.field] = column));
    return map;
  }, [columns]);

  const rowModels = useMemo(() => {
    const rowModels = rows.map((row, i) => ({
      id: row.id || `${i}`,
      data: row,
      selected: selection.indexOf(row) > -1,
    }));

    if (orderBy) {
      const columnDef = columnsMap[orderBy];
      rowModels.sort((a, b) => {
        const valA = columnDef.valueGetter
          ? columnDef.valueGetter(a.data)
          : a.data[columnDef.field];
        const valB = columnDef.valueGetter
          ? columnDef.valueGetter(b.data)
          : b.data[columnDef.field];

        if (order === "desc") return valA > valB ? -1 : valA === valB ? 0 : 1;
        else return valA > valB ? 1 : valA === valB ? 0 : -1;
      });
    }

    return rowModels;
  }, [rows, selection, order, orderBy, columnsMap]);

  // Flags
  const allRowsSelected = selection.length === rows.length;
  const someRowsSelected =
    selection.length > 0 && selection.length < rows.length;
  const dropPage =
    page > 0 && (page + 1) * pageSize === rowCount && rows.length === 0;

  function handleRowClick(rowModel) {
    if (disableSelectionOnClick && !checkboxSelection) return;

    if (rowModel.selected)
      setSelection(selection =>
        selection.filter(row => row.id !== rowModel.id)
      );
    else setSelection(selection => [...selection, rowModel.data]);
  }

  function handleCheckboxSelection() {
    if (someRowsSelected || selection.length === 0) setSelection([...rows]);
    else if (allRowsSelected) setSelection([]);
  }

  function sortHandler(fieldName) {
    return () => {
      const isAsc = fieldName === orderBy && order === "asc";
      setOrder(isAsc ? "desc" : "asc");
      setOrderBy(fieldName);
    };
  }

  useImperativeHandle(ref, () => ({setSelection}), []);

  useEffect(() => {
    onSelectionChange(selection);
  }, [selection, onSelectionChange]);

  useEffect(() => {
    if (dropPage) {
      setPage(page => {
        onPageChange(page); // Since Datagrid pagination is 0 index based.
        return page - 1;
      });
    }
  }, [dropPage, onPageChange]);

  return (
    <>
      <div className={classes.root}>
        {loading && (
          <div className={classes.loadingOverlay}>
            <div className={classes.circularProgressContainer}>
              <CircularProgress />
            </div>
          </div>
        )}

        <TableContainer style={{height: windowHeight}}>
          <Table stickyHeader className={classes.table}>
            <TableHead>
              <TableRow>
                {checkboxSelection && (
                  <TableCell className={classes.tableCell} padding="checkbox">
                    <Checkbox
                      checked={allRowsSelected}
                      indeterminate={someRowsSelected}
                      onChange={handleCheckboxSelection}
                    />
                  </TableCell>
                )}
                {columns.map(column => (
                  <TableCell
                    className={classes.tableCell}
                    key={column.field}
                    style={{minWidth: column.width}}
                    align={column.align}
                    sortDirection={orderBy === column.field ? order : false}
                  >
                    {column.sortable ? (
                      <TableSortLabel
                        active={orderBy === column.field}
                        direction={orderBy === column.field ? order : "asc"}
                        onClick={sortHandler(column.field)}
                      >
                        {column.headerName || column.field}
                      </TableSortLabel>
                    ) : (
                      column.headerName || column.field
                    )}
                  </TableCell>
                ))}
              </TableRow>
            </TableHead>
            <TableBody>
              {rowModels.map(rowModel => (
                <TableRow
                  key={"" + rowModel.id}
                  selected={rowModel.selected}
                  onClick={ev => {
                    ev.stopPropagation();
                    if (disableSelectionOnClick) return;
                    handleRowClick(rowModel);
                  }}
                >
                  {checkboxSelection && (
                    <TableCell className={classes.tableCell} padding="checkbox">
                      <Checkbox
                        checked={rowModel.selected}
                        onClick={ev => {
                          ev.stopPropagation();
                          handleRowClick(rowModel);
                        }}
                      />
                    </TableCell>
                  )}
                  {columns.map(column => (
                    <TableCell
                      className={classes.tableCell}
                      key={column.field}
                      align={column.align}
                      padding={column.padding}
                    >
                      {column.renderCell
                        ? column.renderCell(rowModel.data)
                        : column.valueGetter
                        ? column.valueGetter(rowModel.data)
                        : rowModel.data[column.field]}
                    </TableCell>
                  ))}
                </TableRow>
              ))}
            </TableBody>
          </Table>
        </TableContainer>
      </div>

      {!noPagination && (
        <TablePagination
          component="div"
          count={rowCount}
          rowsPerPage={pageSize}
          rowsPerPageOptions={rowsPerPageOptions}
          page={page}
          onChangePage={(ev, page) => {
            setPage(page);
            onPageChange(page + 1);
            setSelection([]);
          }}
        />
      )}
    </>
  );
};

export default forwardRef(DataGrid);
