package pool
import (
"context"
"errors"
"sync"
"time"
)
var (
ErrPoolClosed = errors.New("pool is closed")
ErrPoolTimeout = errors.New("pool timeout")
)
type Conn interface {
Close() error
IsAlive() bool
}
type Factory func() (Conn, error)
type Pool struct {
factory Factory
maxOpen int
maxIdle int
idleTimeout time.Duration
mu sync.Mutex
idle []poolConn
numOpen int
closed bool
requests chan struct{}
}
type poolConn struct {
conn Conn
createdAt time.Time
}
func New(factory Factory, maxOpen, maxIdle int, idleTimeout time.Duration) *Pool {
p := &Pool{
factory: factory,
maxOpen: maxOpen,
maxIdle: maxIdle,
idleTimeout: idleTimeout,
idle: make([]poolConn, 0, maxIdle),
requests: make(chan struct{}, maxOpen),
}
go p.cleanup()
return p
}
func (p *Pool) Get(ctx context.Context) (Conn, error) {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return nil, ErrPoolClosed
}
// Try to get idle connection
for len(p.idle) > 0 {
pc := p.idle[len(p.idle)-1]
p.idle = p.idle[:len(p.idle)-1]
if pc.conn.IsAlive() {
p.mu.Unlock()
return pc.conn, nil
}
p.numOpen--
pc.conn.Close()
}
// Create new if under limit
if p.numOpen < p.maxOpen {
p.numOpen++
p.mu.Unlock()
conn, err := p.factory()
if err != nil {
p.mu.Lock()
p.numOpen--
p.mu.Unlock()
return nil, err
}
return conn, nil
}
p.mu.Unlock()
// Wait for available connection
select {
case <-p.requests:
return p.Get(ctx) // Retry
case <-ctx.Done():
return nil, ErrPoolTimeout
}
}
func (p *Pool) Put(conn Conn) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed || !conn.IsAlive() {
p.numOpen--
conn.Close()
return
}
if len(p.idle) < p.maxIdle {
p.idle = append(p.idle, poolConn{
conn: conn,
createdAt: time.Now(),
})
} else {
p.numOpen--
conn.Close()
}
// Signal waiting requests
select {
case p.requests <- struct{}{}:
default:
}
}
func (p *Pool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
p.closed = true
for _, pc := range p.idle {
pc.conn.Close()
}
p.idle = nil
}
func (p *Pool) cleanup() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return
}
now := time.Now()
kept := p.idle[:0]
for _, pc := range p.idle {
if now.Sub(pc.createdAt) < p.idleTimeout {
kept = append(kept, pc)
} else {
p.numOpen--
pc.conn.Close()
}
}
p.idle = kept
p.mu.Unlock()
}
}