From ad7b0ee0eb81501859f5b0a7bb3972c732fb4fee Mon Sep 17 00:00:00 2001 From: Josh Holloway Date: Wed, 8 Jul 2015 12:31:02 +0100 Subject: [PATCH] Added MaxLength support port of: https://github.com/boj/redistore/pull/5 --- pgstore.go | 19 ++++++++++++++++--- pgstore_test.go | 19 +++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/pgstore.go b/pgstore.go index a91789a..dcef38f 100644 --- a/pgstore.go +++ b/pgstore.go @@ -3,13 +3,14 @@ package pgstore import ( "database/sql" "encoding/base32" + "net/http" + "strings" + "time" + "github.com/coopernurse/gorp" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" _ "github.com/lib/pq" - "net/http" - "strings" - "time" ) type PGStore struct { @@ -123,6 +124,18 @@ func (db *PGStore) Save(r *http.Request, w http.ResponseWriter, session *session return nil } +// MaxLength restricts the maximum length of new sessions to l. +// If l is 0 there is no limit to the size of a session, use with caution. +// The default for a new PGStore is 4096. PostgreSQL allows for max +// value sizes of up to 1GB (http://www.postgresql.org/docs/current/interactive/datatype-character.html) +func (s *PGStore) MaxLength(l int) { + for _, c := range s.Codecs { + if codec, ok := c.(*securecookie.SecureCookie); ok { + codec.MaxLength(l) + } + } +} + //load fetches a session by ID from the database and decodes its content into session.Values func (db *PGStore) load(session *sessions.Session) error { var s Session diff --git a/pgstore_test.go b/pgstore_test.go index 2adbb7a..2cafe06 100644 --- a/pgstore_test.go +++ b/pgstore_test.go @@ -1,6 +1,7 @@ package pgstore import ( + "encoding/base64" "net/http" "os" "testing" @@ -99,6 +100,24 @@ func TestPGStore(t *testing.T) { t.Fatal("Retrieved session had wrong value in round 3:", session.Values["counter"]) } + // ROUND 3 - Increase max length + req, err = http.NewRequest("GET", "http://www.example.com", nil) + if err != nil { + t.Fatal("failed to create round 3 request", err) + } + + req.AddCookie(sessions.NewCookie(session.Name(), encoded, session.Options)) + session, err = ss.New(req, "my session") + session.Values["big"] = make([]byte, base64.StdEncoding.DecodedLen(4096*2)) + + if err = ss.Save(req, headerOnlyResponseWriter(m), session); err == nil { + t.Fatal("expected an error, got nil") + } + + ss.MaxLength(4096 * 3) // A bit more than the value size to account for encoding overhead. + if err = ss.Save(req, headerOnlyResponseWriter(m), session); err != nil { + t.Fatal("Failed to save session:", err.Error()) + } } func TestSessionOptionsAreUniquePerSession(t *testing.T) {