package main import ( "database/sql" "encoding/json" "errors" "flag" "github.com/elliotchance/phpserialize" _ "github.com/go-sql-driver/mysql" "log" "net" "os" "sort" "sync" "time" ) // Globals... var sockpath *string var dsn *string var db *sql.DB type metadataValues struct { Values []string `php:"v"` } type metadataObject struct { lastHit time.Time entries map[string]metadataValues lock sync.RWMutex loading sync.Once } type MetadataCache struct { objects map[int64]metadataObject lock sync.RWMutex objectType string sizeLimit int } type age struct { key int64 lastHit time.Time } func (c *MetadataCache) Purge() { if len(c.objects) < c.sizeLimit { // We aren't full so let's not run return } var ages []age for x := range c.objects { var a age a.key = x a.lastHit = c.objects[x].lastHit ages = append(ages, a) } // Sort by age sort.Slice(ages, func(a, b int) bool { return ages[a].lastHit.After(ages[b].lastHit) }) // Now we remove all entries from the 90% age mark up purgeto := int(float64(c.sizeLimit) * 0.9) for x := purgeto; x < len(ages); x++ { delete(c.objects, ages[x].key) } } func (c *MetadataCache) Delete(o int64) { // We could be selective about what we purge but let's see if it matters before we do that c.lock.Lock() delete(c.objects, o) c.lock.Unlock() } func (c *MetadataCache) Get(o int64, k string) ([]string, error) { // Check cache for entry c.lock.RLock() defer c.lock.Unlock() object, ok := c.objects[o] if !ok { // Object is not in the cache so let's load it up var tmpobject metadataObject c.objects[o] = tmpobject c.lock.Unlock() tmpobject.loading.Do(func() { // Only do this once even if concurrent requests come in tmpobject.entries = loadDbEntries(c.objectType, o) }) object = tmpobject } object.lock.RLock() defer object.lock.Unlock() // Load entries from db entries, ok := object.entries[k] object.lock.Unlock() object.lock.Lock() object.lastHit=time.Now() object.lock.Unlock() c.lock.Lock() c.objects[o]=object c.lock.Unlock() if !ok { // Value does not exist! Send it back. return nil, errors.New("Value not found") } return entries.Values, nil } func loadDbEntries(ot string, id int64) map[string]metadataValues { var entries map[string]metadataValues entries=make (map[string]metadataValues) var table string var column string if ot=="u" { table="wp_usermeta" column="user_id" } else if ot=="p" { table="wp_postmeta" column="post_id" } else { log.Printf("Invalid object type: %s", ot) return entries } rows,err:=db.Query("select meta_key, meta_value from ? where ? = ?", table, column, id) if err != nil { log.Printf("db.Query: %s\n", err.Error()) return entries } for rows.Next() { var key string var value string rows.Scan(&key, &value) values,_:=entries[key] values.Values=append(values.Values, value) entries[key]=values } return entries } func init() { // Get database and socket running dsn = flag.String("dsn", "", "Database connection string") sockpath = flag.String("sock", "", "Unix socket path") flag.Parse() db, err := sql.Open("mysql", *dsn) if err != nil { log.Fatalf("sql.Open: %s", err.Error()) } err = db.Ping() if err != nil { log.Fatalf("sql.Ping: %s", err.Error()) } } func main() { var UC MetadataCache UC.objects = make(map[int64]metadataObject) UC.sizeLimit = 100 UC.objectType="u" var PC MetadataCache PC.objects = make(map[int64]metadataObject) PC.sizeLimit = 100 PC.objectType="p" if err := os.RemoveAll(*sockpath); err != nil { log.Fatal(err) } unixListener, err := net.Listen("unix", *sockpath) if err!=nil { log.Fatal(err) } os.Chmod(*sockpath, 0777) for { conn, err := unixListener.Accept() if err != nil { // handle error } go handleConnection(conn, UC, PC) } } type cachecommand struct { ObjectType string `json:"t"` ObjectId int64 `json:"i"` Key string `json:"k"` Command string `json:"c"` } func handleConnection(conn net.Conn, UC MetadataCache, PC MetadataCache) { var buf []byte var m *MetadataCache _, err := conn.Read(buf) var c cachecommand err = json.Unmarshal(buf, &c) var values []string if c.ObjectType == "u" { m = &UC } else if c.ObjectType == "p" { m = &PC } if c.Command == "g" { values, err = m.Get(c.ObjectId, c.Key) if err != nil { conn.Write([]byte("404")) conn.Close() return } p, err := phpserialize.Marshal(values, nil) if err != nil { log.Fatalf("phpserialize.Marshal: %s", err.Error()) } conn.Write(p) conn.Close() return } if c.Command=="d" { m.Delete(c.ObjectId) conn.Write([]byte("200")) conn.Close() return } }