useScrollRestoration

All about useScrollRestoration custom hook


Restores the position of the scroll of an element after the component re-mounts.

Usage

function ScrollAreaWithScrollRestoration({ children }: Props) {
  // scroll restoration
  const { ref } = useScrollRestoration(currentCourseFromParams, {
    debounceTime: 200,
    persist: "localStorage",
  })
 
  return (
    <div className="h-[500px] overflow-y-auto" ref={ref} >
      {children}
    </div>
  )
}

Hook

import { useCallback, useEffect, useState } from "react"
import { useScrollStore } from "./use-scroll-store"
import { debounce } from "./use-debounce"
 
interface ScrollRestorationOptions {
  debounceTime?: number
  persist?: false | "localStorage" | "sessionStorage"
}
 
export function useScrollRestoration<U extends HTMLElement>(
  key: string,
  { debounceTime = 100, persist = false }: ScrollRestorationOptions = {},
) {
  const { scrollRestoration, setScrollRestoration } = useScrollStore()
  const [element, setElement] = useState<U | null>(null)
  const ref = useCallback((element: U | null) => {
    if (element) {
      setElement(element)
    }
  }, [])
 
  const currentScrollRestoration = scrollRestoration[key]
  const hasRestoration = key in scrollRestoration
 
  // Add event listener
  useEffect(() => {
    if (!element) return
 
    const handleScroll = debounce(() => {
      const scrollTop = element.scrollTop
      const scrollLeft = element.scrollLeft
 
      setScrollRestoration(key, { scrollTop, scrollLeft })
    }, debounceTime)
 
    element.addEventListener("scroll", handleScroll)
    return () => {
      element.removeEventListener("scroll", handleScroll)
    }
  }, [debounceTime, key, element, persist, setScrollRestoration])
 
  // Restore or initialize scroll
  useEffect(() => {
    if (!element) return
 
    if (hasRestoration) {
      element.scrollTo(
        currentScrollRestoration.scrollLeft,
        currentScrollRestoration.scrollTop,
      )
    } else {
      let initialScrollRestoration = {
        scrollTop: element.scrollTop,
        scrollLeft: element.scrollLeft,
      }
 
      if (persist === "localStorage") {
        const savedScrollRestoration = localStorage.getItem(
          `scrollRestoration-${key}`,
        )
        if (savedScrollRestoration) {
          initialScrollRestoration = JSON.parse(savedScrollRestoration)
        }
      }
 
      if (persist === "sessionStorage") {
        const savedScrollRestoration = sessionStorage.getItem(
          `scrollRestoration-${key}`,
        )
        if (savedScrollRestoration) {
          initialScrollRestoration = JSON.parse(savedScrollRestoration)
        }
      }
 
      setScrollRestoration(key, initialScrollRestoration)
    }
  }, [
    currentScrollRestoration,
    element,
    key,
    persist,
    hasRestoration,
    setScrollRestoration,
  ])
 
  // Persist scroll restoration
  useEffect(() => {
    if (!persist || !currentScrollRestoration) return
 
    if (persist === "localStorage") {
      localStorage.setItem(
        `scrollRestoration-${key}`,
        JSON.stringify(currentScrollRestoration),
      )
    } else if (persist === "sessionStorage") {
      sessionStorage.setItem(
        `scrollRestoration-${key}`,
        JSON.stringify(currentScrollRestoration),
      )
    }
  }, [key, persist, currentScrollRestoration])
 
  const setScroll = ({ x, y }: { x?: number; y?: number }) => {
    setScrollRestoration(key, {
      scrollLeft: x !== undefined ? x : scrollRestoration[key].scrollLeft,
      scrollTop: y !== undefined ? y : scrollRestoration[key].scrollTop,
    })
  }
 
  return { ref, setScroll }
}

A state management, Zustand, is used to reuse the scroll position across re-renders

hooks/use-debounce
import { create } from "zustand"
 
interface ScrollState {
  scrollRestoration: Record<string, { scrollTop: number; scrollLeft: number }>
  setScrollRestoration: (
    key: string,
    value: { scrollTop: number; scrollLeft: number },
  ) => void
}
 
export const useScrollStore = create<ScrollState>((set) => ({
  scrollRestoration: {},
  setScrollRestoration: (key, value) =>
    set((state) => ({
      scrollRestoration: {
        ...state.scrollRestoration,
        [key]: value,
      },
    })),
}))

A debounce function was also used.