package consistenthash
import (
"hash/crc32"
"sort"
"strconv"
"sync"
)
type Ring struct {
mu sync.RWMutex
nodes map[uint32]string // hash -> node ID
sortedHashes []uint32
vnodes int // virtual nodes per physical node
}
func New(vnodes int) *Ring {
if vnodes < 1 {
vnodes = 100 // default
}
return &Ring{
nodes: make(map[uint32]string),
vnodes: vnodes,
}
}
func (r *Ring) hash(key string) uint32 {
return crc32.ChecksumIEEE([]byte(key))
}
func (r *Ring) AddNode(nodeID string) {
r.mu.Lock()
defer r.mu.Unlock()
for i := 0; i < r.vnodes; i++ {
vkey := nodeID + "#" + strconv.Itoa(i)
h := r.hash(vkey)
r.nodes[h] = nodeID
r.sortedHashes = append(r.sortedHashes, h)
}
sort.Slice(r.sortedHashes, func(i, j int) bool {
return r.sortedHashes[i] < r.sortedHashes[j]
})
}
func (r *Ring) RemoveNode(nodeID string) {
r.mu.Lock()
defer r.mu.Unlock()
for i := 0; i < r.vnodes; i++ {
vkey := nodeID + "#" + strconv.Itoa(i)
h := r.hash(vkey)
delete(r.nodes, h)
}
// Rebuild sorted hashes
r.sortedHashes = r.sortedHashes[:0]
for h := range r.nodes {
r.sortedHashes = append(r.sortedHashes, h)
}
sort.Slice(r.sortedHashes, func(i, j int) bool {
return r.sortedHashes[i] < r.sortedHashes[j]
})
}
func (r *Ring) GetNode(key string) string {
r.mu.RLock()
defer r.mu.RUnlock()
if len(r.sortedHashes) == 0 {
return ""
}
h := r.hash(key)
// Binary search for first node >= hash
idx := sort.Search(len(r.sortedHashes), func(i int) bool {
return r.sortedHashes[i] >= h
})
// Wrap around to first node
if idx >= len(r.sortedHashes) {
idx = 0
}
return r.nodes[r.sortedHashes[idx]]
}
func (r *Ring) GetNodes(key string, n int) []string {
r.mu.RLock()
defer r.mu.RUnlock()
if len(r.sortedHashes) == 0 {
return nil
}
h := r.hash(key)
idx := sort.Search(len(r.sortedHashes), func(i int) bool {
return r.sortedHashes[i] >= h
})
seen := make(map[string]bool)
result := make([]string, 0, n)
for i := 0; i < len(r.sortedHashes) && len(result) < n; i++ {
nodeIdx := (idx + i) % len(r.sortedHashes)
nodeID := r.nodes[r.sortedHashes[nodeIdx]]
if !seen[nodeID] {
seen[nodeID] = true
result = append(result, nodeID)
}
}
return result
}