blob: 859b4e0dccb8149c3be5a9565b5db2b89f2f81cf [file] [log] [blame]
// Copyright 2020 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controllers
import (
"context"
"fmt"
"math/rand"
"net/http"
"os"
"time"
secretmanager "cloud.google.com/go/secretmanager/apiv1"
"github.com/gorilla/sessions"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
secretmanagerpb "google.golang.org/genproto/googleapis/cloud/secretmanager/v1"
)
const (
// Session variables
sessionName = "changelog"
sessionKeyLength = 32
// Oauth state generation variables
oauthStateCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890"
oauthStateLength = 16
)
var config = &oauth2.Config{
ClientID: "",
ClientSecret: "",
Endpoint: google.Endpoint,
RedirectURL: "",
Scopes: []string{"https://www.googleapis.com/auth/gerritcodereview"},
}
var store *sessions.CookieStore
var projectID = os.Getenv("COS_CHANGELOG_PROJECT_ID")
var clientIDName = os.Getenv("COS_CHANGELOG_OAUTH_CLIENT_ID_NAME")
var clientSecretName = os.Getenv("COS_CHANGELOG_CLIENT_SECRET_NAME")
var sessionSecretName = os.Getenv("COS_CHANGELOG_SESSION_SECRET_NAME")
var redirectURL = os.Getenv("COS_CHANGELOG_OAUTH_CALLBACK_NAME")
func init() {
var err error
client, err := secretmanager.NewClient(context.Background())
if err != nil {
log.Fatalf("failed to setup client: %v", err)
}
config.ClientID, err = getSecret(client, clientIDName)
if err != nil {
log.Fatalf("failed to retrieve client id: %s\n%v", clientIDName, err)
}
config.ClientSecret, err = getSecret(client, clientSecretName)
if err != nil {
log.Fatalf("failed to retrieve secret: %s\n%v", clientSecretName, err)
}
sessionSecret, err := getSecret(client, sessionSecretName)
if err != nil {
log.Fatalf("failed to retrieve secret :%s\n%v", sessionSecretName, err)
}
store = sessions.NewCookieStore([]byte(sessionSecret))
config.RedirectURL, err = getSecret(client, redirectURL)
if err != nil {
log.Fatalf("failed to retrieve secret :%s\n%v", redirectURL, err)
}
}
// Retrieve secrets stored in Gcloud Secret Manager
func getSecret(client *secretmanager.Client, secretName string) (string, error) {
accessRequest := &secretmanagerpb.AccessSecretVersionRequest{
Name: fmt.Sprintf("projects/%s/secrets/%s/versions/latest", projectID, secretName),
}
result, err := client.AccessSecretVersion(context.Background(), accessRequest)
if err != nil {
return "", fmt.Errorf("failed to access secret at %s: %v", accessRequest.Name, err)
}
return string(result.Payload.Data), nil
}
func randomString(stringSize int, suffix string) string {
randWithSeed := rand.New(rand.NewSource(time.Now().UnixNano()))
stateArr := make([]byte, stringSize)
for i := range stateArr {
stateArr[i] = oauthStateCharset[randWithSeed.Intn(len(oauthStateCharset))]
}
return string(stateArr) + suffix
}
func returnURLFromState(state string) string {
return state[oauthStateLength:]
}
// tokenExpired indicates whether the Oauth token associated with a request is expired
func tokenExpired(r *http.Request) bool {
var parsedExpiry time.Time
session, _ := store.Get(r, sessionName)
parsedExpiry, err := time.Parse(time.RFC3339, session.Values["expiry"].(string))
return err != nil || parsedExpiry.Before(time.Now())
}
// GetLoginURL returns a login URL to redirect the user to
func GetLoginURL(redirect string, auto bool) string {
return fmt.Sprintf("/login/?redirect=%s&auto=%v", redirect, auto)
}
// SignedIn returns a bool indicating if the current request is signed in
func SignedIn(r *http.Request) bool {
session, err := store.Get(r, sessionName)
if err != nil || session.IsNew {
return false
}
for _, key := range []string{"accessToken", "refreshToken", "tokenType", "expiry"} {
if val, ok := session.Values[key]; !ok || val == nil {
return false
}
}
return true
}
// RequireToken will check if the user has a valid, unexpired Oauth token.
// If not, it will initiate the Oauth flow.
// Returns a bool indicating if the user was redirected
func RequireToken(w http.ResponseWriter, r *http.Request, activePage string) bool {
if !SignedIn(r) {
err := promptLoginTemplate.Execute(w, &statusPage{ActivePage: activePage})
if err != nil {
log.Errorf("error executing promptLogin template: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return true
}
return false
}
// HTTPClient creates an authorized HTTP Client using stored token credentials.
func HTTPClient(w http.ResponseWriter, r *http.Request) (*http.Client, error) {
session, _ := store.Get(r, sessionName)
parsedExpiry, err := time.Parse(time.RFC3339, session.Values["expiry"].(string))
if err != nil {
return nil, err
}
token := &oauth2.Token{
AccessToken: session.Values["accessToken"].(string),
RefreshToken: session.Values["refreshToken"].(string),
TokenType: session.Values["tokenType"].(string),
Expiry: parsedExpiry,
}
return config.Client(context.Background(), token), nil
}
// HandleLogin initiates the Oauth flow.
func HandleLogin(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
log.Errorf("could not parse request: %v\n", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
autoAuth := r.FormValue("auto") == "true"
redirect := r.FormValue("redirect")
if redirect == "" {
redirect = "/"
}
state := randomString(oauthStateLength, redirect)
// Ignore store.Get() errors because an error indicates the
// old session could not be deciphered. It returns a new session
// regardless.
session, _ := store.Get(r, sessionName)
session.Values["oauthState"] = state
err := session.Save(r, w)
if err != nil {
log.Errorf("Error saving key: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var authURL string
if autoAuth {
authURL = config.AuthCodeURL(state, oauth2.AccessTypeOffline)
} else {
authURL = config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
}
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
}
// HandleCallback handles the response from the Oauth callback URL.
// It verifies response state and populates session with callback values.
// Redirects to URL stored in the callback state on completion.
func HandleCallback(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
log.Errorf("could not parse query: %v\n", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
authCode := r.FormValue("code")
callbackState := r.FormValue("state")
session, err := store.Get(r, sessionName)
if err != nil {
log.Errorf("error retrieving session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if val, ok := session.Values["oauthState"]; !ok || val == nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
sessionState := session.Values["oauthState"].(string)
if callbackState != sessionState {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
token, err := config.Exchange(context.Background(), authCode)
if err != nil {
log.Errorf("error exchanging token: %v", token)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
session.Values["accessToken"] = token.AccessToken
session.Values["refreshToken"] = token.RefreshToken
session.Values["tokenType"] = token.TokenType
session.Values["expiry"] = token.Expiry.Format(time.RFC3339)
err = session.Save(r, w)
if err != nil {
log.Errorf("error saving session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, r, returnURLFromState(sessionState), http.StatusTemporaryRedirect)
}
// HandleSignOut signs out the user by removing token information from the
// session
func HandleSignOut(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
log.Errorf("could not parse request: %v\n", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
redirect := r.FormValue("redirect")
if redirect == "" {
redirect = "/"
}
session, _ := store.Get(r, sessionName)
session.Values["accessToken"] = nil
session.Values["refreshToken"] = nil
session.Values["tokenType"] = nil
session.Values["expiry"] = nil
err := session.Save(r, w)
if err != nil {
log.Errorf("error saving session: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}