wp-metacache/main.go

229 lines
4.7 KiB
Go

package main
import (
"database/sql"
"encoding/json"
"errors"
"flag"
"github.com/elliotchance/phpserialize"
_ "github.com/go-sql-driver/mysql"
"log"
"net"
"os"
"sort"
"sync"
"time"
)
// Globals...
var sockpath *string
var dsn *string
var db *sql.DB
type metadataValues struct {
Values []string `php:"v"`
}
type metadataObject struct {
lastHit time.Time
entries map[string]metadataValues
lock sync.RWMutex
loading sync.Once
}
type MetadataCache struct {
objects map[int64]metadataObject
lock sync.RWMutex
objectType string
sizeLimit int
}
type age struct {
key int64
lastHit time.Time
}
func (c *MetadataCache) Purge() {
if len(c.objects) < c.sizeLimit {
// We aren't full so let's not run
return
}
var ages []age
for x := range c.objects {
var a age
a.key = x
a.lastHit = c.objects[x].lastHit
ages = append(ages, a)
}
// Sort by age
sort.Slice(ages, func(a, b int) bool { return ages[a].lastHit.After(ages[b].lastHit) })
// Now we remove all entries from the 90% age mark up
purgeto := int(float64(c.sizeLimit) * 0.9)
for x := purgeto; x < len(ages); x++ {
delete(c.objects, ages[x].key)
}
}
func (c *MetadataCache) Delete(o int64) {
// We could be selective about what we purge but let's see if it matters before we do that
c.lock.Lock()
delete(c.objects, o)
c.lock.Unlock()
}
func (c *MetadataCache) Get(o int64, k string) ([]string, error) {
log.Printf("Get %d, %s", o, k)
// Check cache for entry
c.lock.RLock()
defer c.lock.Unlock()
object, ok := c.objects[o]
if !ok {
// Object is not in the cache so let's load it up
var tmpobject metadataObject
c.objects[o] = tmpobject
c.lock.Unlock()
tmpobject.loading.Do(func() {
// Only do this once even if concurrent requests come in
tmpobject.entries = loadDbEntries(c.objectType, o)
})
object = tmpobject
}
object.lock.RLock()
defer object.lock.Unlock()
// Load entries from db
entries, ok := object.entries[k]
object.lock.Unlock()
object.lock.Lock()
object.lastHit=time.Now()
object.lock.Unlock()
c.lock.Lock()
c.objects[o]=object
c.lock.Unlock()
if !ok {
// Value does not exist! Send it back.
return nil, errors.New("Value not found")
}
return entries.Values, nil
}
func loadDbEntries(ot string, id int64) map[string]metadataValues {
log.Println("loadDbEntries")
var entries map[string]metadataValues
entries=make (map[string]metadataValues)
var table string
var column string
if ot=="u" {
table="wp_usermeta"
column="user_id"
} else if ot=="p" {
table="wp_postmeta"
column="post_id"
} else {
log.Printf("Invalid object type: %s", ot)
return entries
}
rows,err:=db.Query("select meta_key, meta_value from ? where ? = ?", table, column, id)
if err != nil {
log.Printf("db.Query: %s\n", err.Error())
return entries
}
for rows.Next() {
var key string
var value string
rows.Scan(&key, &value)
values,_:=entries[key]
values.Values=append(values.Values, value)
entries[key]=values
}
return entries
}
func init() {
// Get database and socket running
dsn = flag.String("dsn", "", "Database connection string")
sockpath = flag.String("sock", "", "Unix socket path")
flag.Parse()
db, err := sql.Open("mysql", *dsn)
if err != nil {
log.Fatalf("sql.Open: %s", err.Error())
}
err = db.Ping()
if err != nil {
log.Fatalf("sql.Ping: %s", err.Error())
}
}
func main() {
var UC MetadataCache
UC.objects = make(map[int64]metadataObject)
UC.sizeLimit = 100
UC.objectType="u"
var PC MetadataCache
PC.objects = make(map[int64]metadataObject)
PC.sizeLimit = 100
PC.objectType="p"
if err := os.RemoveAll(*sockpath); err != nil {
log.Fatal(err)
}
unixListener, err := net.Listen("unix", *sockpath)
if err!=nil {
log.Fatal(err)
}
os.Chmod(*sockpath, 0777)
for {
conn, err := unixListener.Accept()
if err != nil {
// handle error
}
go handleConnection(conn, UC, PC)
}
}
type cachecommand struct {
ObjectType string `json:"t"`
ObjectId int64 `json:"i"`
Key string `json:"k"`
Command string `json:"c"`
}
func handleConnection(conn net.Conn, UC MetadataCache, PC MetadataCache) {
log.Println("handleConnection started")
var buf []byte
var m *MetadataCache
_, err := conn.Read(buf)
var c cachecommand
err = json.Unmarshal(buf, &c)
log.Printf("JSON got: %#v", c)
var values []string
if c.ObjectType == "u" {
m = &UC
} else if c.ObjectType == "p" {
m = &PC
}
if c.Command == "g" {
values, err = m.Get(c.ObjectId, c.Key)
if err != nil {
conn.Write([]byte("404"))
conn.Close()
return
}
p, err := phpserialize.Marshal(values, nil)
if err != nil {
log.Fatalf("phpserialize.Marshal: %s", err.Error())
}
conn.Write(p)
conn.Close()
return
}
if c.Command=="d" {
m.Delete(c.ObjectId)
conn.Write([]byte("200"))
conn.Close()
return
}
}