Skip to content

Commit

Permalink
refactoring and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mhenriks committed Sep 19, 2018
1 parent 36e73be commit b806786
Show file tree
Hide file tree
Showing 14 changed files with 1,495 additions and 182 deletions.
28 changes: 11 additions & 17 deletions cmd/cdi-uploadserver/uploadserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ import (
)

const (
defaultListenPort = uint16(8443)
defaultListenPort = 8443
defaultListenAddress = "0.0.0.0"

defaultPVCDir = common.IMPORTER_WRITE_DIR
defaultDestination = common.IMPORTER_WRITE_PATH
)

Expand All @@ -46,19 +45,18 @@ func main() {

listenAddress, listenPort := getListenAddressAndPort()

pvcDir, destination := getPVCDirAndDestination()
destination := getDestination()

server := uploadserver.NewUploadServer(
listenAddress,
listenPort,
pvcDir,
destination,
os.Getenv("TLS_KEY_FILE"),
os.Getenv("TLS_CERT_FILE"),
os.Getenv("TLS_CA_FILE"),
os.Getenv("TLS_KEY"),
os.Getenv("TLS_CERT"),
os.Getenv("CLIENT_CERT"),
)

glog.Infof("PVC dir: %s, destination: %s", pvcDir, destination)
glog.Infof("Upload destination: %s", destination)

glog.Infof("Running server on %s:%d", listenAddress, listenPort)

Expand All @@ -71,7 +69,7 @@ func main() {
glog.Info("UploadServer successfully exited")
}

func getListenAddressAndPort() (string, uint16) {
func getListenAddressAndPort() (string, int) {
addr, port := defaultListenAddress, defaultListenPort

// empty value okay here
Expand All @@ -83,23 +81,19 @@ func getListenAddressAndPort() (string, uint16) {
if val := os.Getenv("LISTEN_PORT"); len(val) > 0 {
n, err := strconv.ParseUint(val, 10, 16)
if err == nil {
port = uint16(n)
port = int(n)
}
}

return addr, port
}

func getPVCDirAndDestination() (string, string) {
pvcDir, destination := defaultPVCDir, defaultDestination

if val := os.Getenv("PVC_DIR"); len(val) > 0 {
pvcDir = val
}
func getDestination() string {
destination := defaultDestination

if val := os.Getenv("DESTINATION"); len(val) > 0 {
destination = val
}

return pvcDir, destination
return destination
}
56 changes: 20 additions & 36 deletions pkg/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"reflect"
"strings"

"k8s.io/client-go/util/cert/triple"

"github.com/golang/glog"
"github.com/pkg/errors"

Expand All @@ -20,7 +22,6 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/util/cert"
"k8s.io/client-go/util/cert/triple"
apiregistrationv1beta1 "k8s.io/kube-aggregator/pkg/apis/apiregistration/v1beta1"
aggregatorclient "k8s.io/kube-aggregator/pkg/client/clientset_generated/clientset"

Expand Down Expand Up @@ -188,44 +189,27 @@ func (app *uploadAPIApp) getClientCert() error {

func (app *uploadAPIApp) getSelfSignedCert() error {
namespace := util.GetNamespace()
keyPairAndCertBytes, err := keys.GetKeyPairAndCertBytes(app.client, namespace, apiCertSecretName)
caKeyPair, err := triple.NewCA("api.cdi.kubevirt.io")
if err != nil {
return errors.Wrap(err, "Error getting secret")
}

if keyPairAndCertBytes == nil {
caKeyPair, err := triple.NewCA("api.cdi.kubevirt.io")
if err != nil {
return errors.Wrap(err, "Error creating CA")
}

err = keys.CreateServerKeyPairAndCert(app.client,
namespace,
apiCertSecretName,
caKeyPair,
caKeyPair.Cert,
apiServiceName+"."+namespace,
apiServiceName,
false,
nil,
)
if err != nil {
return errors.Wrap(err, "Error creating secret")
}

keyPairAndCertBytes, err = keys.GetKeyPairAndCertBytes(app.client, namespace, apiCertSecretName)
if err != nil {
return errors.Wrap(err, "Error getting secret")
}

if keyPairAndCertBytes == nil {
return errors.Wrap(err, "Error getting secret the second time")
}
return errors.Wrap(err, "Error creating CA")
}

keyPairAndCert, err := keys.GetOrCreateServerKeyPairAndCert(app.client,
namespace,
apiCertSecretName,
caKeyPair,
caKeyPair.Cert,
apiServiceName+"."+namespace,
apiServiceName,
nil,
)
if err != nil {
return errors.Wrapf(err, "Error getting/creating secret %s", apiCertSecretName)
}

app.keyBytes = keyPairAndCertBytes.PrivateKey
app.certBytes = keyPairAndCertBytes.Cert
app.signingCertBytes = keyPairAndCertBytes.CACert
app.keyBytes = cert.EncodePrivateKeyPEM(keyPairAndCert.KeyPair.Key)
app.certBytes = cert.EncodeCertPEM(keyPairAndCert.KeyPair.Cert)
app.signingCertBytes = cert.EncodeCertPEM(keyPairAndCert.CACert)

privateKey, err := keys.GetOrCreatePrivateKey(app.client, namespace, apiSigningKeySecretName)
if err != nil {
Expand Down
25 changes: 16 additions & 9 deletions pkg/apiserver/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import (
"reflect"
"testing"

"k8s.io/client-go/util/cert/triple"

"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/diff"
k8sfake "k8s.io/client-go/kubernetes/fake"
core "k8s.io/client-go/testing"
"kubevirt.io/containerized-data-importer/pkg/keys"
"kubevirt.io/containerized-data-importer/pkg/keys/keystest"
)

func signingKeySecretGetAction() core.Action {
Expand All @@ -25,7 +27,7 @@ func signingKeySecretGetAction() core.Action {
}

func signingKeySecretCreateAction(privateKey *rsa.PrivateKey) core.Action {
secret, _ := keys.NewPrivateKeySecret("kube-system", apiSigningKeySecretName, privateKey)
secret, _ := keystest.NewPrivateKeySecret("kube-system", apiSigningKeySecretName, privateKey)
return core.NewCreateAction(
schema.GroupVersionResource{
Resource: "secrets",
Expand All @@ -52,7 +54,7 @@ func tlsSecretCreateAction(privateKeyBytes, certBytes, caCertBytes []byte) core.
Version: "v1",
},
"kube-system",
keys.NewTLSSecretFromBytes("kube-system", apiCertSecretName, privateKeyBytes, certBytes, caCertBytes, nil))
keystest.NewTLSSecretFromBytes("kube-system", apiCertSecretName, privateKeyBytes, certBytes, caCertBytes, nil))
}

func checkActions(expected []core.Action, actual []core.Action, t *testing.T) {
Expand Down Expand Up @@ -151,16 +153,22 @@ func TestKeyRetrieval(t *testing.T) {
t.Errorf("error generating keys: %v", err)
}

keyBytes := []byte("madeup")
certBytes := []byte("madeup")
signingCertBytes := []byte("madeup")
caKeyPair, err := triple.NewCA("myca")
if err != nil {
t.Errorf("Error creating CA key pair")
}

signingKeySecret, err := keys.NewPrivateKeySecret("kube-system", apiSigningKeySecretName, signingKey)
serverKeyPair, err := triple.NewServerKeyPair(caKeyPair, "commonname", "service", "kube-system", "cluster.local", []string{}, []string{})
if err != nil {
t.Errorf("Error creating server key pair")
}

signingKeySecret, err := keystest.NewPrivateKeySecret("kube-system", apiSigningKeySecretName, signingKey)
if err != nil {
t.Errorf("error creating secret: %v", err)
}

tlsSecret := keys.NewTLSSecretFromBytes("kube-system", apiCertSecretName, keyBytes, certBytes, signingCertBytes, nil)
tlsSecret := keystest.NewTLSSecret("kube-system", apiCertSecretName, serverKeyPair, caKeyPair.Cert, nil)

kubeobjects := []runtime.Object{}
kubeobjects = append(kubeobjects, tlsSecret)
Expand Down Expand Up @@ -201,7 +209,6 @@ func TestShouldGenerateCertsAndKeyFirstRun(t *testing.T) {
actions := []core.Action{}
actions = append(actions, tlsSecretGetAction())
actions = append(actions, tlsSecretCreateAction(app.keyBytes, app.certBytes, app.signingCertBytes))
actions = append(actions, tlsSecretGetAction())
actions = append(actions, signingKeySecretGetAction())
actions = append(actions, signingKeySecretCreateAction(app.privateSigningKey))

Expand Down
54 changes: 26 additions & 28 deletions pkg/controller/upload-controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,20 @@ const (

// cert/key annotations

// ServerCASecret is the secret containing the server CA
ServerCASecret = "cdi-upload-server-ca-key"
// ServerCAName is the name of the server CA
ServerCAName = "server.upload.cdi.kubevirt.io"
// uploadServerCASecret is the secret containing the server CA
uploadServerCASecret = "cdi-upload-server-ca-key"
// uploadServerCAName is the name of the server CA
uploadServerCAName = "server.upload.cdi.kubevirt.io"

// ClientCASecret is the secret containing the clent CA
ClientCASecret = "cdi-upload-client-ca-key"
// ClientCAName is the name of the client CA
ClientCAName = "client.upload.cdi.kubevirt.io"
// uploadServerClientCASecret is the secret containing the clent CA
uploadServerClientCASecret = "cdi-upload-server-client-ca-key"
// uploadServerClientCAName is the name of the client CA
uploadServerClientCAName = "client.upload-server.cdi.kubevirt.io"

// ClientKeySecret is the secret containing the client key/cert
ClientKeySecret = "cdi-upload-client-key"
// UploadProxyClientName is the CN for client cert
UploadProxyClientName = "uploadproxy.client.upload.cdi.kebevirt.io"
// uploadServerClientKeySecret is the secret containing the client key/cert
uploadServerClientKeySecret = "cdi-upload-server-client-key"
// uploadProxyClientName is the CN for client cert
uploadProxyClientName = "uploadproxy.client.upload-server.cdi.kebevirt.io"

uploadProxyCAName = "proxy.upload.cdi.kubevirt.io"
uploadProxyCASecret = "cdi-upload-proxy-ca-key"
Expand All @@ -74,7 +74,7 @@ const (

// UploadController members
type UploadController struct {
clientset kubernetes.Interface
client kubernetes.Interface
queue workqueue.RateLimitingInterface
pvcInformer, podInformer, serviceInformer cache.SharedIndexInformer
pvcLister corelisters.PersistentVolumeClaimLister
Expand Down Expand Up @@ -120,7 +120,7 @@ func NewUploadController(client kubernetes.Interface,
pullPolicy string,
verbose string) *UploadController {
c := &UploadController{
clientset: client,
client: client,
queue: workqueue.NewRateLimitingQueue(workqueue.DefaultControllerRateLimiter()),
pvcInformer: pvcInformer.Informer(),
podInformer: podInformer.Informer(),
Expand Down Expand Up @@ -221,45 +221,43 @@ func (c *UploadController) initCerts() error {
var err error

// CA for Upload Servers
c.serverCAKeyPair, err = keys.GetOrCreateCA(c.clientset, util.GetNamespace(), ServerCASecret, ServerCAName)
c.serverCAKeyPair, err = keys.GetOrCreateCA(c.client, util.GetNamespace(), uploadServerCASecret, uploadServerCAName)
if err != nil {
return errors.Wrap(err, "Couldn't get/create server CA")
}

// CA for Upload Client
c.clientCAKeyPair, err = keys.GetOrCreateCA(c.clientset, util.GetNamespace(), ClientCASecret, ClientCAName)
c.clientCAKeyPair, err = keys.GetOrCreateCA(c.client, util.GetNamespace(), uploadServerClientCASecret, uploadServerClientCAName)
if err != nil {
return errors.Wrap(err, "Couldn't get/create client CA")
}

// Upload Server Client Cert
err = keys.CreateClientKeyPairAndCert(c.clientset,
_, err = keys.GetOrCreateClientKeyPairAndCert(c.client,
util.GetNamespace(),
ClientKeySecret,
uploadServerClientKeySecret,
c.clientCAKeyPair,
c.serverCAKeyPair.Cert,
UploadProxyClientName,
uploadProxyClientName,
[]string{},
false, // okay if already exists
nil,
)
if err != nil {
return errors.Wrap(err, "Couldn't get/create client cert")
}

uploadProxyCAKeyPair, err := keys.GetOrCreateCA(c.clientset, util.GetNamespace(), uploadProxyCASecret, uploadProxyCAName)
uploadProxyCAKeyPair, err := keys.GetOrCreateCA(c.client, util.GetNamespace(), uploadProxyCASecret, uploadProxyCAName)
if err != nil {
return errors.Wrap(err, "Couldn't create upload proxy server cert")
}

err = keys.CreateServerKeyPairAndCert(c.clientset,
_, err = keys.GetOrCreateServerKeyPairAndCert(c.client,
util.GetNamespace(),
uploadProxyServerSecret,
uploadProxyCAKeyPair,
nil,
c.uploadProxyServiceName+"."+util.GetNamespace(),
c.uploadProxyServiceName,
false, // okay if already exists
nil,
)
if err != nil {
Expand Down Expand Up @@ -413,7 +411,7 @@ func (c *UploadController) syncHandler(key string) error {
if podPhase != podPhaseFromPVC(pvc) {
var labels map[string]string
annotations := map[string]string{AnnUploadPodPhase: string(podPhase)}
pvc, err = updatePVC(c.clientset, pvc, annotations, labels)
pvc, err = updatePVC(c.client, pvc, annotations, labels)
if err != nil {
return errors.Wrapf(err, "Error updating pvc %s, pod phase %s", key, podPhase)
}
Expand All @@ -439,7 +437,7 @@ func (c *UploadController) getOrCreateUploadPod(pvc *v1.PersistentVolumeClaim, n
pod, err := c.podLister.Pods(pvc.Namespace).Get(name)

if k8serrors.IsNotFound(err) {
pod, err = CreateUploadPod(c.clientset, c.serverCAKeyPair, c.clientCAKeyPair.Cert, c.uploadServiceImage, c.verbose, c.pullPolicy, name, pvc)
pod, err = CreateUploadPod(c.client, c.serverCAKeyPair, c.clientCAKeyPair.Cert, c.uploadServiceImage, c.verbose, c.pullPolicy, name, pvc)
}

if pod != nil && !metav1.IsControlledBy(pod, pvc) {
Expand All @@ -453,7 +451,7 @@ func (c *UploadController) getOrCreateUploadService(pvc *v1.PersistentVolumeClai
service, err := c.serviceLister.Services(pvc.Namespace).Get(name)

if k8serrors.IsNotFound(err) {
service, err = CreateUploadService(c.clientset, name, pvc)
service, err = CreateUploadService(c.client, name, pvc)
}

if service != nil && !metav1.IsControlledBy(service, pvc) {
Expand All @@ -469,7 +467,7 @@ func (c *UploadController) deletePod(namespace, name string) error {
return nil
}
if err == nil && pod.DeletionTimestamp == nil {
err = c.clientset.CoreV1().Pods(namespace).Delete(name, &metav1.DeleteOptions{})
err = c.client.CoreV1().Pods(namespace).Delete(name, &metav1.DeleteOptions{})
if k8serrors.IsNotFound(err) {
return nil
}
Expand All @@ -483,7 +481,7 @@ func (c *UploadController) deleteService(namespace, name string) error {
return nil
}
if err == nil && service.DeletionTimestamp == nil {
err = c.clientset.CoreV1().Services(namespace).Delete(name, &metav1.DeleteOptions{})
err = c.client.CoreV1().Services(namespace).Delete(name, &metav1.DeleteOptions{})
if k8serrors.IsNotFound(err) {
return nil
}
Expand Down
Loading

0 comments on commit b806786

Please sign in to comment.