Cache stampede solution — Probabilistic Early Expiration
Please familiarize with the cache stampede problem here. This post talks about probabilistic early expiration technique to avoid this problem
Probabilistic Early Expiration
Fundamentally, the idea is to distribute the expiration times to reduce the likely hood of a cache stampede.
We achieve this by allowing each request generate a small probability by adding a randomized element to the expiration of cache entries, instead of waiting for the cache to expire completely and then refresh it. Probability function should be defined in such a way that as the cache expiration approaches, the probability increases gradually.
Note: In some cases, entries might be refreshed too early or too late, leading to suboptimal cache utilization and potentially higher latency if a cache miss occurs close to the normal expiration time.
How this works? 🤔
Determine the cache refresh using the following —
regular_expiration_time
— this is the expiration time set for cache entry and this must be refreshed after expiry
early_expiration_window
— this must be earlier than the actual expiration time and cache entry may be refreshed based on probability
Random check — during each access within early_expiration_window
, app decides whether to refresh or not based on the probability function
Here’s an example with code using golang, redis and simulated database with map —
package main
import (
"context"
"fmt"
"math/rand"
"time"
"github.com/go-redis/redis/v8"
)
// Simulate a database with a map
var database = map[string]string{
"key1": "value1",
"key2": "value2",
"key3": "value3",
}
var ctx = context.Background()
// Initialize Redis client
var rdb = redis.NewClient(&redis.Options{
Addr: "localhost:6377",
})
const (
defaultExpiration = 600 // seconds (10 minutes)
earlyExpirationWindow = 480 // seconds (8 minutes)
)
// Fetching data from simulated db
func fetchFromDB(key string) (string, error) {
time.Sleep(100 * time.Millisecond) // Simulate database delay
return database[key], nil
}
// Check if the cache should be refreshed using probability function
func shouldRefresh(lastUpdated int64) bool {
age := time.Now().Unix() - lastUpdated
if age >= normalExpiration {
return true // we must refresh
}
if age < earlyExpirationWindow {
return false // no need to refresh
}
// determine if we need to refresh
earlyAge := age - earlyExpirationWindow
maxEarlyAge := normalExpiration - earlyExpirationWindow
probability := float64(earlyAge) / float64(maxEarlyAge)
return rand.Float64() < probability
}
func getValue(key string) (string, error) {
// get value from Redis
val, err := rdb.Get(ctx, key).Result()
if err == redis.Nil {
// Key does not exist, now get from db
fmt.Println("Cache miss. Fetching from db...")
value, dbErr := fetchFromDB(key)
if dbErr != nil {
return "", dbErr
}
// Store in Redis with timestamp as we use timestamp to determine expiration
err = rdb.Set(ctx, key, value, 0).Err()
if err != nil {
return "", err
}
err = rdb.Set(ctx, key+":timestamp", time.Now().Unix(), 0).Err()
if err != nil {
return "", err
}
return value, nil
} else if err != nil {
return "", err
}
// Check if the cache entry should be refreshed
timestamp, err := rdb.Get(ctx, key+":timestamp").Int64()
if err == nil && shouldRefresh(timestamp) {
fmt.Println("Probabilistically refreshing cache entry...")
value, dbErr := fetchFromDatabase(key)
if dbErr != nil {
return "", dbErr
}
// Update Redis with new value and timestamp
err = rdb.Set(ctx, key, value, 0).Err()
if err != nil {
return "", err
}
err = rdb.Set(ctx, key+":timestamp", time.Now().Unix(), 0).Err()
if err != nil {
return "", err
}
return value, nil
}
fmt.Println("Cache hit. Using cached data.")
return val, nil
}
func main() {
rand.Seed(time.Now().UnixNano()) // Seed random number generator
key := "key1"
// fetch immediately
value, err := getValue(key)
if err != nil {
fmt.Printf("Error: %v\n", err)
return
}
fmt.Printf("Value for %s: %s\n", key, value)
// Simulate time passing for 5 mintues
time.Sleep(500 * time.Second)
value, err = getValue(key)
if err != nil {
fmt.Printf("Error: %v\n", err)
return
}
fmt.Printf("Value for %s: %s\n", key, value)
// time passing some more
time.Sleep(150 * time.Second)
value, err = getValue(key)
if err != nil {
fmt.Printf("Error: %v\n", err)
return
}
fmt.Printf("Value for %s: %s\n", key, value)
}
comments, claps and feedback are welcome. Happy coding 👨💻
References
https://cseweb.ucsd.edu//~avattani/papers/cache_stampede.pdf