Skip to content

Commit

Permalink
Wrapping enhancements (hashicorp#1927)
Browse files Browse the repository at this point in the history
  • Loading branch information
jefferai authored Sep 29, 2016
1 parent e65979e commit 60deff1
Show file tree
Hide file tree
Showing 25 changed files with 1,271 additions and 137 deletions.
20 changes: 11 additions & 9 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,17 +327,19 @@ func (c *Client) NewRequest(method, path string) *Request {
Params: make(map[string][]string),
}

var lookupPath string
switch {
case strings.HasPrefix(path, "/v1/"):
lookupPath = strings.TrimPrefix(path, "/v1/")
case strings.HasPrefix(path, "v1/"):
lookupPath = strings.TrimPrefix(path, "v1/")
default:
lookupPath = path
}
if c.wrappingLookupFunc != nil {
var lookupPath string
switch {
case strings.HasPrefix(path, "/v1/"):
lookupPath = strings.TrimPrefix(path, "/v1/")
case strings.HasPrefix(path, "v1/"):
lookupPath = strings.TrimPrefix(path, "v1/")
default:
lookupPath = path
}
req.WrapTTL = c.wrappingLookupFunc(method, lookupPath)
} else {
req.WrapTTL = DefaultWrappingLookupFunc(method, lookupPath)
}

return req
Expand Down
57 changes: 54 additions & 3 deletions api/logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package api
import (
"bytes"
"fmt"
"net/http"
"os"

"github.com/hashicorp/vault/helper/jsonutil"
)
Expand All @@ -11,6 +13,26 @@ const (
wrappedResponseLocation = "cubbyhole/response"
)

var (
// The default TTL that will be used with `sys/wrapping/wrap`, can be
// changed
DefaultWrappingTTL = "5m"

// The default function used if no other function is set, which honors the
// env var and wraps `sys/wrapping/wrap`
DefaultWrappingLookupFunc = func(operation, path string) string {
if os.Getenv(EnvVaultWrapTTL) != "" {
return os.Getenv(EnvVaultWrapTTL)
}

if (operation == "PUT" || operation == "POST") && path == "sys/wrapping/wrap" {
return DefaultWrappingTTL
}

return ""
}
)

// Logical is used to perform logical backend operations on Vault.
type Logical struct {
c *Client
Expand Down Expand Up @@ -96,10 +118,39 @@ func (c *Logical) Delete(path string) (*Secret, error) {
}

func (c *Logical) Unwrap(wrappingToken string) (*Secret, error) {
origToken := c.c.Token()
defer c.c.SetToken(origToken)
var data map[string]interface{}
if wrappingToken != "" {
data = map[string]interface{}{
"token": wrappingToken,
}
}

r := c.c.NewRequest("PUT", "/v1/sys/wrapping/unwrap")
if err := r.SetJSONBody(data); err != nil {
return nil, err
}

c.c.SetToken(wrappingToken)
resp, err := c.c.RawRequest(r)
if resp != nil {
defer resp.Body.Close()
}
if err != nil && resp.StatusCode != 404 {
return nil, err
}

switch resp.StatusCode {
case http.StatusOK: // New method is supported
return ParseSecret(resp.Body)
case http.StatusNotFound: // Fall back to old method
default:
return nil, nil
}

if wrappingToken == "" {
origToken := c.c.Token()
defer c.c.SetToken(origToken)
c.c.SetToken(wrappingToken)
}

secret, err := c.Read(wrappedResponseLocation)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion command/rekey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,10 @@ func TestRekey_init_pgp(t *testing.T) {
MaxLeaseTTLVal: time.Hour * 24 * 32,
},
}
sysBackend := vault.NewSystemBackend(core, bc)
sysBackend, err := vault.NewSystemBackend(core, bc)
if err != nil {
t.Fatal(err)
}

ui := new(cli.MockUi)
c := &RekeyCommand{
Expand Down
24 changes: 14 additions & 10 deletions command/unwrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,25 @@ func (c *UnwrapCommand) Run(args []string) int {
return 1
}

var tokenID string

args = flags.Args()
if len(args) != 1 || len(args[0]) == 0 {
c.Ui.Error("Unwrap expects one argument: the ID of the wrapping token")
switch len(args) {
case 0:
case 1:
tokenID = args[0]
_, err = uuid.ParseUUID(tokenID)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Given token could not be parsed as a UUID: %v", err))
return 1
}
default:
c.Ui.Error("Unwrap expects zero or one argument (the ID of the wrapping token)")
flags.Usage()
return 1
}

tokenID := args[0]
_, err = uuid.ParseUUID(tokenID)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Given token could not be parsed as a UUID: %s", err))
return 1
}

client, err := c.Client()
if err != nil {
c.Ui.Error(fmt.Sprintf(
Expand Down
3 changes: 2 additions & 1 deletion command/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"os"
"reflect"
"time"

"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/token"
Expand Down Expand Up @@ -55,7 +56,7 @@ func PrintRawField(ui cli.Ui, secret *api.Secret, field string) int {
case "wrapping_token_ttl":
val = secret.WrapInfo.TTL
case "wrapping_token_creation_time":
val = secret.WrapInfo.CreationTime.String()
val = secret.WrapInfo.CreationTime.Format(time.RFC3339Nano)
case "wrapped_accessor":
val = secret.WrapInfo.WrappedAccessor
default:
Expand Down
39 changes: 31 additions & 8 deletions http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ func Handler(core *vault.Core) http.Handler {
mux.Handle("/v1/sys/rekey/update", handleRequestForwarding(core, handleSysRekeyUpdate(core, false)))
mux.Handle("/v1/sys/rekey-recovery-key/init", handleRequestForwarding(core, handleSysRekeyInit(core, true)))
mux.Handle("/v1/sys/rekey-recovery-key/update", handleRequestForwarding(core, handleSysRekeyUpdate(core, true)))
mux.Handle("/v1/sys/capabilities-self", handleRequestForwarding(core, handleLogical(core, true, sysCapabilitiesSelfCallback)))
mux.Handle("/v1/sys/wrapping/lookup", handleRequestForwarding(core, handleLogical(core, false, wrappingVerificationFunc)))
mux.Handle("/v1/sys/wrapping/rewrap", handleRequestForwarding(core, handleLogical(core, false, wrappingVerificationFunc)))
mux.Handle("/v1/sys/wrapping/unwrap", handleRequestForwarding(core, handleLogical(core, false, wrappingVerificationFunc)))
mux.Handle("/v1/sys/capabilities-self", handleRequestForwarding(core, handleLogical(core, true, nil)))
mux.Handle("/v1/sys/", handleRequestForwarding(core, handleLogical(core, true, nil)))
mux.Handle("/v1/", handleRequestForwarding(core, handleLogical(core, false, nil)))

Expand All @@ -58,15 +61,35 @@ func Handler(core *vault.Core) http.Handler {
return handler
}

// ClientToken is required in the handler of sys/capabilities-self endpoint in
// system backend. But the ClientToken gets obfuscated before the request gets
// forwarded to any logical backend. So, setting the ClientToken in the data
// field for this request.
func sysCapabilitiesSelfCallback(req *logical.Request) error {
if req == nil || req.Data == nil {
// A lookup on a token that is about to expire returns nil, which means by the
// time we can validate a wrapping token lookup will return nil since it will
// be revoked after the call. So we have to do the validation here.
func wrappingVerificationFunc(core *vault.Core, req *logical.Request) error {
if req == nil {
return fmt.Errorf("invalid request")
}
req.Data["token"] = req.ClientToken

var token string
if req.Data != nil && req.Data["token"] != nil {
if tokenStr, ok := req.Data["token"].(string); !ok {
return fmt.Errorf("could not decode token in request body")
} else if tokenStr == "" {
return fmt.Errorf("empty token in request body")
} else {
token = tokenStr
}
} else {
token = req.ClientToken
}

valid, err := core.ValidateWrappingToken(token)
if err != nil {
return fmt.Errorf("error validating wrapping token: %v", err)
}
if !valid {
return fmt.Errorf("wrapping token is not valid or does not exist")
}

return nil
}

Expand Down
71 changes: 44 additions & 27 deletions http/logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/hashicorp/vault/vault"
)

type PrepareRequestFunc func(req *logical.Request) error
type PrepareRequestFunc func(*vault.Core, *logical.Request) error

func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Request, int, error) {
// Determine the path...
Expand Down Expand Up @@ -99,8 +99,8 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
// will have a callback registered to do the needed operations, so
// invoke it before proceeding.
if prepareRequestCallback != nil {
if err := prepareRequestCallback(req); err != nil {
respondError(w, http.StatusInternalServerError, err)
if err := prepareRequestCallback(core, req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
}
Expand Down Expand Up @@ -160,8 +160,8 @@ func respondLogical(w http.ResponseWriter, r *http.Request, req *logical.Request
}

// Check if this is a raw response
if _, ok := resp.Data[logical.HTTPContentType]; ok {
respondRaw(w, r, req.Path, resp)
if _, ok := resp.Data[logical.HTTPStatusCode]; ok {
respondRaw(w, r, resp)
return
}

Expand Down Expand Up @@ -197,51 +197,68 @@ func respondLogical(w http.ResponseWriter, r *http.Request, req *logical.Request
// respondRaw is used when the response is using HTTPContentType and HTTPRawBody
// to change the default response handling. This is only used for specific things like
// returning the CRL information on the PKI backends.
func respondRaw(w http.ResponseWriter, r *http.Request, path string, resp *logical.Response) {
func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response) {
retErr := func(w http.ResponseWriter, err string) {
w.Header().Set("X-Vault-Raw-Error", err)
w.WriteHeader(http.StatusInternalServerError)
w.Write(nil)
}

// Ensure this is never a secret or auth response
if resp.Secret != nil || resp.Auth != nil {
respondError(w, http.StatusInternalServerError, nil)
retErr(w, "raw responses cannot contain secrets or auth")
return
}

// Get the status code
statusRaw, ok := resp.Data[logical.HTTPStatusCode]
if !ok {
respondError(w, http.StatusInternalServerError, nil)
retErr(w, "no status code given")
return
}
status, ok := statusRaw.(int)
if !ok {
respondError(w, http.StatusInternalServerError, nil)
retErr(w, "cannot decode status code")
return
}

// Get the header
nonEmpty := status != http.StatusNoContent

var contentType string
var body []byte

// Get the content type header; don't require it if the body is empty
contentTypeRaw, ok := resp.Data[logical.HTTPContentType]
if !ok {
respondError(w, http.StatusInternalServerError, nil)
if !ok && !nonEmpty {
retErr(w, "no content type given")
return
}
contentType, ok := contentTypeRaw.(string)
if !ok {
respondError(w, http.StatusInternalServerError, nil)
return
if ok {
contentType, ok = contentTypeRaw.(string)
if !ok {
retErr(w, "cannot decode content type")
return
}
}

// Get the body
bodyRaw, ok := resp.Data[logical.HTTPRawBody]
if !ok {
respondError(w, http.StatusInternalServerError, nil)
return
}
body, ok := bodyRaw.([]byte)
if !ok {
respondError(w, http.StatusInternalServerError, nil)
return
if nonEmpty {
// Get the body
bodyRaw, ok := resp.Data[logical.HTTPRawBody]
if !ok {
retErr(w, "no body given")
return
}
body, ok = bodyRaw.([]byte)
if !ok {
retErr(w, "cannot decode body")
return
}
}

// Write the response
w.Header().Set("Content-Type", contentType)
if contentType != "" {
w.Header().Set("Content-Type", contentType)
}
w.WriteHeader(status)
w.Write(body)
}
Expand Down
Loading

0 comments on commit 60deff1

Please sign in to comment.