import { CollectionViewer } from '@angular/cdk/collections'
import { DataSource as CdkDataSource } from '@angular/cdk/table'
import { DocumentChangeAction, Query } from '@angular/fire/compat/firestore'
import { MatPaginator } from '@angular/material/paginator'
import { MatSort } from '@angular/material/sort'
import { BehaviorSubject, Observable, of, Subject, Subscription } from 'rxjs'
import {
  filter,
  map,
  startWith,
  switchMap,
  take,
  takeUntil,
  tap
} from 'rxjs/operators'
import firebase from 'firebase/compat/app'
import 'firebase/compat/firestore'

export interface FirestoreRequest {
  // sort
  orderBy?: string
  orderDir?: 'desc' | 'asc'

  // pagination
  pageIndex?: number
  pageSize?: number
  limit?: number
  startAt?: any
  startAfter?: any
  endAt?: any
  endBefore?: any
}

export const applyFirestoreRequest = (
  query: Query,
  request: FirestoreRequest
): Query => {
  if (request.orderBy && request.orderDir) {
    query = query.orderBy(request.orderBy, request.orderDir)
  } else {
    query = query.orderBy(firebase.firestore.FieldPath.documentId(), request.orderDir)
  }

  if (request.startAt) {
    query = query.startAt(request.startAt)
  }
  if (request.startAfter) {
    query = query.startAfter(request.startAfter)
  }
  if (request.endAt) {
    query = query.endAt(request.endAt)
  }
  if (request.endBefore) {
    query = query.endBefore(request.endBefore)
  }

  if (request.limit) {
    if (request.endAt || request.endBefore) {
      query = query.limitToLast(request.limit)
    } else {
      query = query.limit(request.limit)
    }
  }

  return query
}

export abstract class DataSource<
  DataType = any
> extends CdkDataSource<DataType> {
  abstract paginator: MatPaginator
  abstract sort: MatSort

  protected readonly _data$ = new BehaviorSubject<Array<DataType>>([])
  get data$(): Observable<Array<DataType>> {
    return this._data$
  }

  get data() {
    return this._data$.value ? this._data$.value : []
  }

  get isEmpty$(): Observable<boolean> {
    return this._data$.pipe(map((data) => data.length === 0))
  }

  get isNotEmpty$(): Observable<boolean> {
    return this._data$.pipe(map((data) => data.length !== 0))
  }

  abstract refresh(): void
}

export abstract class FirestoreDataSource<
  Request extends FirestoreRequest,
  DataType = any
> extends DataSource<DataType> {
  // This simply allows observers to be re-triggered when
  //  internal state changes, like a paginator or sorter
  //  is changed.
  private internalStateSubject = new BehaviorSubject<void>(null)

  private currentPage = 0
  private currentPageSize = 0

  private _rawData: Array<DocumentChangeAction<unknown>> = []

  private readonly _request$ = new BehaviorSubject<Request | undefined>(
    undefined
  )
  get request$(): Observable<Request> {
    return this._request$
      .asObservable()
      .pipe(filter((req) => req !== undefined && req !== null))
  }

  get paginator() {
    return this._paginator
  }
  set paginator(value: MatPaginator) {
    if (this._paginator) {
      this._paginator.page.unsubscribe()
    }
    this._totalSubscription?.unsubscribe()

    this._paginator = value
    this.internalStateSubject.next()
    this._totalSubscription = this.totalItems().subscribe(
      (count) => (value.length = count)
    )
    this._paginator.page.subscribe(this.internalStateSubject)
  }
  private _totalSubscription?: Subscription
  private _paginator?: MatPaginator

  get sort() {
    return this._sort
  }
  set sort(value: MatSort) {
    if (this._sort) {
      this._sort.sortChange.unsubscribe()
    }
    this._sort = value
    this._sort.sortChange.subscribe(this.internalStateSubject)
    this.internalStateSubject.next()
  }
  private _sort?: MatSort

  get isFollowing(): boolean {
    return this._follow.value
  }
  set isFollowing(value: boolean) {
    this._follow.next(value)
  }
  private _follow = new BehaviorSubject<boolean>(false)

  protected readonly _disconnect$ = new Subject<void>()

  constructor() {
    super()
  }

  refresh() {
    this.internalStateSubject.next()
  }

  abstract defaultArgs(): Observable<Request>

  private createRequest(): Observable<Request> {
    return this.defaultArgs().pipe(
      // Add in the sort settings
      switchMap((req) => {
        const sort = this._sort
        if (!sort) {
          return of(req)
        }

        return sort.sortChange.pipe(
          startWith(sort),
          map(({ active, direction }) => {
            // TODO If the sort changes, then the paging is no longer accurate. In
            //  that case, we should run multiple queries until we end up at our
            //  desired page.

            if (direction && active) {
              req.orderDir = direction === 'desc' ? 'desc' : 'asc'
              req.orderBy = active
            }
            return req
          })
        )
      }),
      // Add in the pagination settings
      switchMap((req) => {
        const paginator = this._paginator
        if (!paginator) {
          return of(req)
        }

        return paginator.page.pipe(
          startWith(paginator),
          map(({ pageIndex, pageSize }) => {
            if (pageIndex !== null && typeof(pageIndex) !== 'undefined') {
              req.pageIndex = pageIndex
            }
            if (pageSize !== null && pageSize !== undefined) {
              req.pageSize = req.limit = pageSize
            } else {
              // We need a default page size so we don't end
              // up querying an entire collection.
              req.pageSize = this.currentPageSize || 5
            }
            return req
          })
        )
      }),
      // Figure out the start/end points.
      map((req) => {
        if (req.pageIndex) {
          const targetStartIndex = req.pageIndex * req.pageSize
          const firstDataIndex = this.currentPage * this.currentPageSize
          const lastDataIndex = firstDataIndex + this._rawData.length - 1

          if (
            firstDataIndex <= targetStartIndex &&
            targetStartIndex <= lastDataIndex
          ) {
            // We have the target item in the current data set
            const targetItem = this._rawData[targetStartIndex - firstDataIndex]
            req.startAt = targetItem.payload.doc
          } else if (targetStartIndex > lastDataIndex) {
            // We're moving to the next page
            if (targetStartIndex >= lastDataIndex) {
              // We don't have the desired item in the raw data set,
              // so we'll need to fetch more items.
              req.limit = targetStartIndex + req.pageSize
            } else {
              const targetItem = this._rawData[this._rawData.length - 1]
              req.startAfter = targetItem.payload.doc
            }
          } else {
            const targetItem = this._rawData[0]
            req.endBefore = targetItem.payload.doc
          }
        }

        return req
      })
    )
  }

  public mapRequest(request: Request): Observable<Request> {
    return of(request)
  }

  abstract executeQuery(
    request: Request
  ): Observable<Array<DocumentChangeAction<unknown>>>

  abstract mapResponse(
    response: Array<DocumentChangeAction<unknown>>
  ): Array<DataType>

  public totalItems(): Observable<number> {
    return of(-1)
  }

  connect(_collectionViewer: CollectionViewer): Observable<Array<DataType>> {
    return this.internalStateSubject.pipe(
      takeUntil(this._disconnect$),
      switchMap(() => this.createRequest()),
      switchMap((req) => this.mapRequest(req)),

      // Provide the request for listeners
      tap((req) => this._request$.next(req)),

      // This is not a switchMap because we want the actual observable
      //  next in the stream, so that we can pipe it through take(1) if
      //  we only want the first result.
      map((req) =>
        this.executeQuery(req).pipe(
          map((raw) => {
            let result: DocumentChangeAction<unknown>[]
            if (req.limit > req.pageSize) {
              // We had to request more items than we want so we could
              //  get to the start point, so let's trim the results.
              result = raw.slice(-req.pageSize)
            } else {
              result = raw
            }
            return result
          }),
          tap((raw) => (this._rawData = raw)),
          map((raw) => this.mapResponse(raw)),
          map((res) => ({ request: req, response: res }))
        )
      ),
      // Stop observing after the first result if we're not following results.
      // Use take(2) because the first result is the initial data set, and 
      //  the second result is the updated data set.
      switchMap((reqObservable) =>
        this._follow.pipe(
          switchMap((isFollowing) =>
            isFollowing ? reqObservable : reqObservable.pipe(take(2))
          )
        )
      ),
      map(({ request, response }) => {
        // Update our page tracking controls
        this.currentPage = request.pageIndex
        this.currentPageSize = request.pageSize
        this._data$.next(response)

        return response
      })
    )
  }

  disconnect(_collectionViewer: CollectionViewer) {
    this._disconnect$.next()
    this._disconnect$.complete()
  }
}
