Skip to content

Commit

Permalink
Merge pull request hashicorp#2031 from hashicorp/transit-helper
Browse files Browse the repository at this point in the history
Transit helper
  • Loading branch information
vishalnayak authored Oct 27, 2016
2 parents 7958b2e + 484f899 commit f2adc02
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 183 deletions.
5 changes: 3 additions & 2 deletions builtin/logical/transit/backend.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package transit

import (
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
Expand Down Expand Up @@ -39,12 +40,12 @@ func Backend(conf *logical.BackendConfig) *backend {
Secrets: []*framework.Secret{},
}

b.lm = newLockManager(conf.System.CachingDisabled())
b.lm = keysutil.NewLockManager(conf.System.CachingDisabled())

return &b
}

type backend struct {
*framework.Backend
lm *lockManager
lm *keysutil.LockManager
}
33 changes: 17 additions & 16 deletions builtin/logical/transit/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
logicaltest "github.com/hashicorp/vault/logical/testing"
Expand Down Expand Up @@ -289,7 +290,7 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool)
if d.Name != name {
return fmt.Errorf("bad name: %#v", d)
}
if d.Type != KeyType(keyType_AES256_GCM96).String() {
if d.Type != keysutil.KeyType(keysutil.KeyType_AES256_GCM96).String() {
return fmt.Errorf("bad key type: %#v", d)
}
// Should NOT get a key back
Expand Down Expand Up @@ -583,13 +584,13 @@ func testAccStepDecryptDatakey(t *testing.T, name string,

func TestKeyUpgrade(t *testing.T) {
key, _ := uuid.GenerateRandomBytes(32)
p := &policy{
p := &keysutil.Policy{
Name: "test",
Key: key,
Type: keyType_AES256_GCM96,
Type: keysutil.KeyType_AES256_GCM96,
}

p.migrateKeyToKeysMap()
p.MigrateKeyToKeysMap()

if p.Key != nil ||
p.Keys == nil ||
Expand All @@ -604,18 +605,18 @@ func TestDerivedKeyUpgrade(t *testing.T) {
key, _ := uuid.GenerateRandomBytes(32)
context, _ := uuid.GenerateRandomBytes(32)

p := &policy{
p := &keysutil.Policy{
Name: "test",
Key: key,
Type: keyType_AES256_GCM96,
Type: keysutil.KeyType_AES256_GCM96,
Derived: true,
}

p.migrateKeyToKeysMap()
p.upgrade(storage) // Need to run the upgrade code to make the migration stick
p.MigrateKeyToKeysMap()
p.Upgrade(storage) // Need to run the upgrade code to make the migration stick

if p.KDF != kdf_hmac_sha256_counter {
t.Fatalf("bad KDF value by default; counter val is %d, KDF val is %d, policy is %#v", kdf_hmac_sha256_counter, p.KDF, *p)
if p.KDF != keysutil.Kdf_hmac_sha256_counter {
t.Fatalf("bad KDF value by default; counter val is %d, KDF val is %d, policy is %#v", keysutil.Kdf_hmac_sha256_counter, p.KDF, *p)
}

derBytesOld, err := p.DeriveKey(context, 1)
Expand All @@ -632,8 +633,8 @@ func TestDerivedKeyUpgrade(t *testing.T) {
t.Fatal("mismatch of same context alg")
}

p.KDF = kdf_hkdf_sha256
if p.needsUpgrade() {
p.KDF = keysutil.Kdf_hkdf_sha256
if p.NeedsUpgrade() {
t.Fatal("expected no upgrade needed")
}

Expand Down Expand Up @@ -692,15 +693,15 @@ func testConvergentEncryptionCommon(t *testing.T, ver int) {
t.Fatalf("bad: expected error response, got %#v", *resp)
}

p := &policy{
p := &keysutil.Policy{
Name: "testkey",
Type: keyType_AES256_GCM96,
Type: keysutil.KeyType_AES256_GCM96,
Derived: true,
ConvergentEncryption: true,
ConvergentVersion: ver,
}

err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -976,7 +977,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) {
resp, err := be.pathDecryptWrite(req, fd)
if err != nil {
// This could well happen since the min version is jumping around
if resp.Data["error"].(string) == ErrTooOld {
if resp.Data["error"].(string) == keysutil.ErrTooOld {
continue
}
t.Fatalf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id)
Expand Down
15 changes: 8 additions & 7 deletions builtin/logical/transit/path_encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"

"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
Expand Down Expand Up @@ -116,7 +117,7 @@ func (b *backend) pathEncryptWrite(
}

// Get the policy
var p *policy
var p *keysutil.Policy
var lock *sync.RWMutex
var upserted bool
if req.Operation == logical.CreateOperation {
Expand All @@ -125,17 +126,17 @@ func (b *backend) pathEncryptWrite(
return logical.ErrorResponse("convergent encryption requires derivation to be enabled, so context is required"), nil
}

polReq := policyRequest{
storage: req.Storage,
name: name,
derived: len(context) != 0,
convergent: convergent,
polReq := keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
Derived: len(context) != 0,
Convergent: convergent,
}

keyType := d.Get("type").(string)
switch keyType {
case "aes256-gcm96":
polReq.keyType = keyType_AES256_GCM96
polReq.KeyType = keysutil.KeyType_AES256_GCM96
case "ecdsa-p256":
return logical.ErrorResponse(fmt.Sprintf("key type %v not supported for this operation", keyType)), logical.ErrInvalidRequest
default:
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func TestTransit_HMAC(t *testing.T) {
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="

// Rotate
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
Expand Down
23 changes: 12 additions & 11 deletions builtin/logical/transit/path_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"strconv"

"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
Expand Down Expand Up @@ -95,17 +96,17 @@ func (b *backend) pathPolicyWrite(
return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil
}

polReq := policyRequest{
storage: req.Storage,
name: name,
derived: derived,
convergent: convergent,
polReq := keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
Derived: derived,
Convergent: convergent,
}
switch keyType {
case "aes256-gcm96":
polReq.keyType = keyType_AES256_GCM96
polReq.KeyType = keysutil.KeyType_AES256_GCM96
case "ecdsa-p256":
polReq.keyType = keyType_ECDSA_P256
polReq.KeyType = keysutil.KeyType_ECDSA_P256
default:
return logical.ErrorResponse(fmt.Sprintf("unknown key type %v", keyType)), logical.ErrInvalidRequest
}
Expand Down Expand Up @@ -158,10 +159,10 @@ func (b *backend) pathPolicyRead(

if p.Derived {
switch p.KDF {
case kdf_hmac_sha256_counter:
case keysutil.Kdf_hmac_sha256_counter:
resp.Data["kdf"] = "hmac-sha256-counter"
resp.Data["kdf_mode"] = "hmac-sha256-counter"
case kdf_hkdf_sha256:
case keysutil.Kdf_hkdf_sha256:
resp.Data["kdf"] = "hkdf_sha256"
}
resp.Data["convergent_encryption"] = p.ConvergentEncryption
Expand All @@ -171,14 +172,14 @@ func (b *backend) pathPolicyRead(
}

switch p.Type {
case keyType_AES256_GCM96:
case keysutil.KeyType_AES256_GCM96:
retKeys := map[string]int64{}
for k, v := range p.Keys {
retKeys[strconv.Itoa(k)] = v.CreationTime
}
resp.Data["keys"] = retKeys

case keyType_ECDSA_P256:
case keysutil.KeyType_ECDSA_P256:
type ecdsaKey struct {
Name string `json:"name"`
PublicKey string `json:"public_key"`
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_rotate.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (b *backend) pathRotateWrite(
}

// Rotate the policy
err = p.rotate(req.Storage)
err = p.Rotate(req.Storage)

return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/transit/path_sign_verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,11 @@ func TestTransit_SignVerify(t *testing.T) {
signRequest(req, true, "")

// Rotate and set min decryption version
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading

0 comments on commit f2adc02

Please sign in to comment.