package pgstore import ( "database/sql" "encoding/base32" "errors" "fmt" "net/http" "strings" "time" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" // Include the pq postgres driver. _ "github.com/lib/pq" ) // PGStore represents the currently configured session store. type PGStore struct { Codecs []securecookie.Codec Options *sessions.Options Path string DbPool *sql.DB } // PGSession type type PGSession struct { ID int64 Key string Data string CreatedOn time.Time ModifiedOn time.Time ExpiresOn time.Time } // NewPGStore creates a new PGStore instance and a new database/sql pool. // This will also create in the database the schema needed by pgstore. func NewPGStore(dbURL string, keyPairs ...[]byte) (*PGStore, error) { db, err := sql.Open("postgres", dbURL) if err != nil { // Ignore and return nil. return nil, err } return NewPGStoreFromPool(db, keyPairs...) } // NewPGStoreFromPool creates a new PGStore instance from an existing // database/sql pool. // This will also create the database schema needed by pgstore. func NewPGStoreFromPool(db *sql.DB, keyPairs ...[]byte) (*PGStore, error) { dbStore := &PGStore{ Codecs: securecookie.CodecsFromPairs(keyPairs...), Options: &sessions.Options{ Path: "/", MaxAge: 86400 * 30, }, DbPool: db, } // Create table if it doesn't exist err := dbStore.createSessionsTable() if err != nil { return nil, err } return dbStore, nil } // Close closes the database connection. func (db *PGStore) Close() { db.DbPool.Close() } // Get Fetches a session for a given name after it has been added to the // registry. func (db *PGStore) Get(r *http.Request, name string) (*sessions.Session, error) { return sessions.GetRegistry(r).Get(db, name) } // New returns a new session for the given name without adding it to the registry. func (db *PGStore) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(db, name) if session == nil { return nil, nil } opts := *db.Options session.Options = &(opts) session.IsNew = true var err error if c, errCookie := r.Cookie(name); errCookie == nil { err = securecookie.DecodeMulti(name, c.Value, &session.ID, db.Codecs...) if err == nil { err = db.load(session) if err == nil { session.IsNew = false } } } db.MaxAge(db.Options.MaxAge) return session, err } // Save saves the given session into the database and deletes cookies if needed func (db *PGStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { // Set delete if max-age is < 0 if session.Options.MaxAge < 0 { if err := db.destroy(session); err != nil { return err } http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) return nil } if session.ID == "" { // Generate a random session ID key suitable for storage in the DB session.ID = strings.TrimRight( base32.StdEncoding.EncodeToString( securecookie.GenerateRandomKey(32), ), "=") } if err := db.save(session); err != nil { return err } // Keep the session ID key in a cookie so it can be looked up in DB later. encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, db.Codecs...) if err != nil { return err } http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options)) return nil } // MaxLength restricts the maximum length of new sessions to l. // If l is 0 there is no limit to the size of a session, use with caution. // The default for a new PGStore is 4096. PostgreSQL allows for max // value sizes of up to 1GB (http://www.postgresql.org/docs/current/interactive/datatype-character.html) func (db *PGStore) MaxLength(l int) { for _, c := range db.Codecs { if codec, ok := c.(*securecookie.SecureCookie); ok { codec.MaxLength(l) } } } // MaxAge sets the maximum age for the store and the underlying cookie // implementation. Individual sessions can be deleted by setting Options.MaxAge // = -1 for that session. func (db *PGStore) MaxAge(age int) { db.Options.MaxAge = age // Set the maxAge for each securecookie instance. for _, codec := range db.Codecs { if sc, ok := codec.(*securecookie.SecureCookie); ok { sc.MaxAge(age) } } } // load fetches a session by ID from the database and decodes its content // into session.Values. func (db *PGStore) load(session *sessions.Session) error { var s PGSession err := db.selectOne(&s, session.ID) if err != nil { return err } return securecookie.DecodeMulti(session.Name(), string(s.Data), &session.Values, db.Codecs...) } // save writes encoded session.Values to a database record. // writes to http_sessions table by default. func (db *PGStore) save(session *sessions.Session) error { encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, db.Codecs...) if err != nil { return err } crOn := session.Values["created_on"] exOn := session.Values["expires_on"] var expiresOn time.Time createdOn, ok := crOn.(time.Time) if !ok { createdOn = time.Now() } if exOn == nil { expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge)) } else { expiresOn = exOn.(time.Time) if expiresOn.Sub(time.Now().Add(time.Second*time.Duration(session.Options.MaxAge))) < 0 { expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge)) } } s := PGSession{ Key: session.ID, Data: encoded, CreatedOn: createdOn, ExpiresOn: expiresOn, ModifiedOn: time.Now(), } if session.IsNew { return db.insert(&s) } return db.update(&s) } // Delete session func (db *PGStore) destroy(session *sessions.Session) error { _, err := db.DbPool.Exec("DELETE FROM http_sessions WHERE key = $1", session.ID) return err } func (db *PGStore) createSessionsTable() error { stmt := `CREATE TABLE IF NOT EXISTS http_sessions ( id BIGSERIAL PRIMARY KEY, key BYTEA, data BYTEA, created_on TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, modified_on TIMESTAMPTZ, expires_on TIMESTAMPTZ);` _, err := db.DbPool.Exec(stmt) if err != nil { msg := fmt.Sprintf("Unable to create http_sessions table in the database: %s\n", err.Error()) return errors.New(msg) } return nil } func (db *PGStore) selectOne(s *PGSession, key string) error { stmt := "SELECT id, key, data, created_on, modified_on, expires_on FROM http_sessions WHERE key = $1" err := db.DbPool.QueryRow(stmt, key).Scan(&s.ID, &s.Key, &s.Data, &s.CreatedOn, &s.ModifiedOn, &s.ExpiresOn) if err != nil { msg := fmt.Sprintf("Unable to find session in the database: %s\n", err.Error()) return errors.New(msg) } return nil } func (db *PGStore) insert(s *PGSession) error { stmt := `INSERT INTO http_sessions (key, data, created_on, modified_on, expires_on) VALUES ($1, $2, $3, $4, $5)` _, err := db.DbPool.Exec(stmt, s.Key, s.Data, s.CreatedOn, s.ModifiedOn, s.ExpiresOn) return err } func (db *PGStore) update(s *PGSession) error { stmt := `UPDATE http_sessions SET data=$1, modified_on=$2, expires_on=$3 WHERE key=$4` _, err := db.DbPool.Exec(stmt, s.Data, s.ModifiedOn, s.ExpiresOn, s.Key) return err }