| // Copyright 2012 The Gorilla Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package sessions |
| |
| import ( |
| "context" |
| "encoding/gob" |
| "fmt" |
| "net/http" |
| "time" |
| ) |
| |
| // Default flashes key. |
| const flashesKey = "_flash" |
| |
| // Session -------------------------------------------------------------------- |
| |
| // NewSession is called by session stores to create a new session instance. |
| func NewSession(store Store, name string) *Session { |
| return &Session{ |
| Values: make(map[interface{}]interface{}), |
| store: store, |
| name: name, |
| Options: new(Options), |
| } |
| } |
| |
| // Session stores the values and optional configuration for a session. |
| type Session struct { |
| // The ID of the session, generated by stores. It should not be used for |
| // user data. |
| ID string |
| // Values contains the user-data for the session. |
| Values map[interface{}]interface{} |
| Options *Options |
| IsNew bool |
| store Store |
| name string |
| } |
| |
| // Flashes returns a slice of flash messages from the session. |
| // |
| // A single variadic argument is accepted, and it is optional: it defines |
| // the flash key. If not defined "_flash" is used by default. |
| func (s *Session) Flashes(vars ...string) []interface{} { |
| var flashes []interface{} |
| key := flashesKey |
| if len(vars) > 0 { |
| key = vars[0] |
| } |
| if v, ok := s.Values[key]; ok { |
| // Drop the flashes and return it. |
| delete(s.Values, key) |
| flashes = v.([]interface{}) |
| } |
| return flashes |
| } |
| |
| // AddFlash adds a flash message to the session. |
| // |
| // A single variadic argument is accepted, and it is optional: it defines |
| // the flash key. If not defined "_flash" is used by default. |
| func (s *Session) AddFlash(value interface{}, vars ...string) { |
| key := flashesKey |
| if len(vars) > 0 { |
| key = vars[0] |
| } |
| var flashes []interface{} |
| if v, ok := s.Values[key]; ok { |
| flashes = v.([]interface{}) |
| } |
| s.Values[key] = append(flashes, value) |
| } |
| |
| // Save is a convenience method to save this session. It is the same as calling |
| // store.Save(request, response, session). You should call Save before writing to |
| // the response or returning from the handler. |
| func (s *Session) Save(r *http.Request, w http.ResponseWriter) error { |
| return s.store.Save(r, w, s) |
| } |
| |
| // Name returns the name used to register the session. |
| func (s *Session) Name() string { |
| return s.name |
| } |
| |
| // Store returns the session store used to register the session. |
| func (s *Session) Store() Store { |
| return s.store |
| } |
| |
| // Registry ------------------------------------------------------------------- |
| |
| // sessionInfo stores a session tracked by the registry. |
| type sessionInfo struct { |
| s *Session |
| e error |
| } |
| |
| // contextKey is the type used to store the registry in the context. |
| type contextKey int |
| |
| // registryKey is the key used to store the registry in the context. |
| const registryKey contextKey = 0 |
| |
| // GetRegistry returns a registry instance for the current request. |
| func GetRegistry(r *http.Request) *Registry { |
| var ctx = r.Context() |
| registry := ctx.Value(registryKey) |
| if registry != nil { |
| return registry.(*Registry) |
| } |
| newRegistry := &Registry{ |
| request: r, |
| sessions: make(map[string]sessionInfo), |
| } |
| *r = *r.WithContext(context.WithValue(ctx, registryKey, newRegistry)) |
| return newRegistry |
| } |
| |
| // Registry stores sessions used during a request. |
| type Registry struct { |
| request *http.Request |
| sessions map[string]sessionInfo |
| } |
| |
| // Get registers and returns a session for the given name and session store. |
| // |
| // It returns a new session if there are no sessions registered for the name. |
| func (s *Registry) Get(store Store, name string) (session *Session, err error) { |
| if !isCookieNameValid(name) { |
| return nil, fmt.Errorf("sessions: invalid character in cookie name: %s", name) |
| } |
| if info, ok := s.sessions[name]; ok { |
| session, err = info.s, info.e |
| } else { |
| session, err = store.New(s.request, name) |
| session.name = name |
| s.sessions[name] = sessionInfo{s: session, e: err} |
| } |
| session.store = store |
| return |
| } |
| |
| // Save saves all sessions registered for the current request. |
| func (s *Registry) Save(w http.ResponseWriter) error { |
| var errMulti MultiError |
| for name, info := range s.sessions { |
| session := info.s |
| if session.store == nil { |
| errMulti = append(errMulti, fmt.Errorf( |
| "sessions: missing store for session %q", name)) |
| } else if err := session.store.Save(s.request, w, session); err != nil { |
| errMulti = append(errMulti, fmt.Errorf( |
| "sessions: error saving session %q -- %v", name, err)) |
| } |
| } |
| if errMulti != nil { |
| return errMulti |
| } |
| return nil |
| } |
| |
| // Helpers -------------------------------------------------------------------- |
| |
| func init() { |
| gob.Register([]interface{}{}) |
| } |
| |
| // Save saves all sessions used during the current request. |
| func Save(r *http.Request, w http.ResponseWriter) error { |
| return GetRegistry(r).Save(w) |
| } |
| |
| // NewCookie returns an http.Cookie with the options set. It also sets |
| // the Expires field calculated based on the MaxAge value, for Internet |
| // Explorer compatibility. |
| func NewCookie(name, value string, options *Options) *http.Cookie { |
| cookie := newCookieFromOptions(name, value, options) |
| if options.MaxAge > 0 { |
| d := time.Duration(options.MaxAge) * time.Second |
| cookie.Expires = time.Now().Add(d) |
| } else if options.MaxAge < 0 { |
| // Set it to the past to expire now. |
| cookie.Expires = time.Unix(1, 0) |
| } |
| return cookie |
| } |
| |
| // Error ---------------------------------------------------------------------- |
| |
| // MultiError stores multiple errors. |
| // |
| // Borrowed from the App Engine SDK. |
| type MultiError []error |
| |
| func (m MultiError) Error() string { |
| s, n := "", 0 |
| for _, e := range m { |
| if e != nil { |
| if n == 0 { |
| s = e.Error() |
| } |
| n++ |
| } |
| } |
| switch n { |
| case 0: |
| return "(0 errors)" |
| case 1: |
| return s |
| case 2: |
| return s + " (and 1 other error)" |
| } |
| return fmt.Sprintf("%s (and %d other errors)", s, n-1) |
| } |