package main import ( "bufio" "database/sql" "encoding/json" "errors" "flag" "fmt" "github.com/elliotchance/phpserialize" _ "github.com/go-sql-driver/mysql" "log" "net" "net/textproto" "os" "sort" "sync" "time" ) // Globals... var sockpath *string var dsn *string var db *sql.DB type metadataValues struct { Values []string `php:"v"` } type metadataObject struct { lastHit time.Time entries map[string]metadataValues lock sync.RWMutex loading sync.Once } type MetadataCache struct { objects map[int64]metadataObject lock sync.RWMutex objectType string sizeLimit int } type age struct { key int64 lastHit time.Time } func (c *MetadataCache) Purge() { if len(c.objects) < c.sizeLimit { // We aren't full so let's not run return } var ages []age for x := range c.objects { var a age a.key = x a.lastHit = c.objects[x].lastHit ages = append(ages, a) } // Sort by age sort.Slice(ages, func(a, b int) bool { return ages[a].lastHit.After(ages[b].lastHit) }) // Now we remove all entries from the 90% age mark up purgeto := int(float64(c.sizeLimit) * 0.9) for x := purgeto; x < len(ages); x++ { delete(c.objects, ages[x].key) } } func (c *MetadataCache) Delete(o int64) { // We could be selective about what we purge but let's see if it matters before we do that c.lock.Lock() delete(c.objects, o) c.lock.Unlock() } func (c *MetadataCache) Get(o int64, k string) ([]string, error) { // Check cache for entry c.lock.RLock() object, ok := c.objects[o] if !ok { // Object is not in the cache so let's load it up var tmpobject metadataObject c.objects[o] = tmpobject tmpobject.loading.Do(func() { // Only do this once even if concurrent requests come in tmpobject.entries = loadDbEntries(c.objectType, o) }) object = tmpobject } c.lock.RUnlock() object.lock.RLock() entries, ok := object.entries[k] object.lock.RUnlock() object.lock.Lock() object.lastHit = time.Now() object.lock.Unlock() c.lock.Lock() c.objects[o] = object c.lock.Unlock() if !ok { // Value does not exist! Send it back. return nil, errors.New("Value not found") } return entries.Values, nil } func loadDbEntries(ot string, id int64) map[string]metadataValues { log.Println("loadDbEntries") var entries map[string]metadataValues entries = make(map[string]metadataValues) var table string var column string if ot == "u" { table = "wp_usermeta" column = "user_id" } else if ot == "p" { table = "wp_postmeta" column = "post_id" } else { log.Printf("Invalid object type: %s", ot) return entries } query := fmt.Sprintf("select meta_key, meta_value from %s where %s = ?", table, column) rows, err := db.Query(query, id) if err != nil { log.Printf("db.Query: %s\n", err.Error()) return entries } for rows.Next() { var key string var value string rows.Scan(&key, &value) values, _ := entries[key] values.Values = append(values.Values, value) entries[key] = values } return entries } func init() { // Get database and socket running dsn = flag.String("dsn", "", "Database connection string") sockpath = flag.String("sock", "", "Unix socket path") flag.Parse() var err error db, err = sql.Open("mysql", *dsn) if err != nil { log.Fatalf("sql.Open: %s", err.Error()) } err = db.Ping() if err != nil { log.Fatalf("sql.Ping: %s", err.Error()) } } func main() { var UC MetadataCache UC.objects = make(map[int64]metadataObject) UC.sizeLimit = 100 UC.objectType = "u" var PC MetadataCache PC.objects = make(map[int64]metadataObject) PC.sizeLimit = 100 PC.objectType = "p" if err := os.RemoveAll(*sockpath); err != nil { log.Fatal(err) } unixListener, err := net.Listen("unix", *sockpath) if err != nil { log.Fatal(err) } os.Chmod(*sockpath, 0777) for { conn, err := unixListener.Accept() if err != nil { // handle error } go handleConnection(conn, UC, PC) } } type cachecommand struct { ObjectType string `json:"t"` ObjectId int64 `json:"i"` Key string `json:"k"` Command string `json:"c"` } func handleConnection(conn net.Conn, UC MetadataCache, PC MetadataCache) { log.Println("handleConnection started") var m *MetadataCache reader := bufio.NewReader(conn) tp := textproto.NewReader(reader) num:=0 for { buf, err := tp.ReadLine() if err != nil { conn.Close() return } num++ var c cachecommand err = json.Unmarshal([]byte(buf), &c) if err != nil { log.Printf("json.Unmarshal: %s returned %s", buf, err.Error()) } var values []string if c.ObjectType == "u" { m = &UC } else if c.ObjectType == "p" { m = &PC } if c.Command == "g" { values, err = m.Get(c.ObjectId, c.Key) if err != nil { conn.Write([]byte("404")) } else { p, err := phpserialize.Marshal(values, nil) if err != nil { log.Fatalf("phpserialize.Marshal: %s", err.Error()) } conn.Write(p) } } else if c.Command == "d" { m.Delete(c.ObjectId) conn.Write([]byte("200")) } conn.Write([]byte("\n")) } log.Printf("Ended connection after %d queries", num) }