Unverified Commit 9c4279da authored by Bryan Liles's avatar Bryan Liles Committed by GitHub
Browse files

Merge pull request #392 from bryanl/fetch-swagger-json

use provided tls configuration if possible when fetching api spec
parents 34db32e0 ff5aa404
......@@ -16,6 +16,8 @@
package client
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io/ioutil"
......@@ -29,6 +31,7 @@ import (
"github.com/ksonnet/ksonnet/metadata"
str "github.com/ksonnet/ksonnet/strings"
"github.com/ksonnet/ksonnet/utils"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"k8s.io/client-go/discovery"
......@@ -92,6 +95,27 @@ func (c *Config) GetAPISpec(server string) string {
Timeout: time.Second * 2,
}
restConfig, err := c.Config.ClientConfig()
if err != nil {
log.Debugf("Failed to retrieve REST config:\n%v", err)
}
if len(restConfig.TLSClientConfig.CAData) > 0 {
log.Info("Configuring TLS (from data) for retrieving cluster swagger.json")
client.Transport = buildTransportFromData(restConfig.TLSClientConfig.CAData)
}
if restConfig.TLSClientConfig.CAFile != "" {
log.Info("Configuring TLS (from file) for retrieving cluster swagger.json")
transport, err := buildTransportFromFile(restConfig.TLSClientConfig.CAFile)
if err != nil {
log.Debugf("Failed to read CA file: %v", err)
return defaultVersion
}
client.Transport = transport
}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
log.Debugf("Failed to create request at %s\n%s", url, err.Error())
......@@ -120,6 +144,21 @@ func (c *Config) GetAPISpec(server string) string {
return fmt.Sprintf("version:%s", spec.Info.Version)
}
func buildTransportFromData(data []byte) *http.Transport {
tlsConfig := &tls.Config{RootCAs: x509.NewCertPool()}
tlsConfig.RootCAs.AppendCertsFromPEM(data)
return &http.Transport{TLSClientConfig: tlsConfig}
}
func buildTransportFromFile(file string) (*http.Transport, error) {
data, err := ioutil.ReadFile(file)
if err != nil {
return nil, errors.Wrap(err, "unable to ready CA file")
}
return buildTransportFromData(data), nil
}
// Namespace returns the namespace for the provided ClientConfig.
func (c *Config) Namespace() (string, error) {
ns, _, err := c.Config.Namespace()
......@@ -179,7 +218,7 @@ func (c *Config) ResolveContext(context string) (server, namespace string, err e
if context == "" {
if rawConfig.CurrentContext == "" && len(rawConfig.Clusters) == 0 {
// User likely does not have a kubeconfig file.
return "", "", fmt.Errorf("No current context found. Make sure a kubeconfig file is present")
return "", "", errors.Errorf("No current context found. Make sure a kubeconfig file is present")
}
// Note: "" is a valid rawConfig.CurrentContext
context = rawConfig.CurrentContext
......@@ -187,13 +226,13 @@ func (c *Config) ResolveContext(context string) (server, namespace string, err e
ctx := rawConfig.Contexts[context]
if ctx == nil {
return "", "", fmt.Errorf("context '%s' does not exist in the kubeconfig file", context)
return "", "", errors.Errorf("context '%s' does not exist in the kubeconfig file", context)
}
log.Infof("Using context '%s' from the kubeconfig file specified at the environment variable $KUBECONFIG", context)
cluster, exists := rawConfig.Clusters[ctx.Cluster]
if !exists {
return "", "", fmt.Errorf("No cluster with name '%s' exists", ctx.Cluster)
return "", "", errors.Errorf("No cluster with name '%s' exists", ctx.Cluster)
}
return cluster.Server, ctx.Namespace, nil
......@@ -261,6 +300,6 @@ func (c *Config) overrideCluster(envName string) error {
return nil
}
return fmt.Errorf("Attempting to deploy to environment '%s' at '%s', but cannot locate a server at that address",
return errors.Errorf("Attempting to deploy to environment '%s' at '%s', but cannot locate a server at that address",
envName, destination.Server())
}
......@@ -16,28 +16,56 @@
package client
import (
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
restclient "k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
)
func TestConfig_GetAPISpec(t *testing.T) {
b, err := ioutil.ReadFile("testdata/swagger.json")
require.NoError(t, err)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, string(b))
}))
})
ts := httptest.NewServer(handler)
defer ts.Close()
tsTLS := httptest.NewTLSServer(handler)
defer tsTLS.Close()
tmpfile, err := ioutil.TempFile("", "")
require.NoError(t, err)
defer os.Remove(tmpfile.Name())
certPEM := buildPEM(tsTLS.Certificate())
fmt.Println(string(certPEM))
_, err = tmpfile.Write(certPEM)
require.NoError(t, err)
err = tmpfile.Close()
require.NoError(t, err)
cases := []struct {
name string
serverURL string
expected string
caData []byte
caFile string
}{
{
name: "invalid server URL",
......@@ -49,14 +77,67 @@ func TestConfig_GetAPISpec(t *testing.T) {
serverURL: ts.URL,
expected: "version:v1.9.3",
},
{
name: "TLS with file cert",
serverURL: tsTLS.URL,
expected: "version:v1.9.3",
caFile: tmpfile.Name(),
},
{
name: "TLS with data cert",
serverURL: tsTLS.URL,
expected: "version:v1.9.3",
caData: certPEM,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c := Config{}
c := Config{
Config: &clientConfig{
caFile: tc.caFile,
caData: tc.caData,
},
}
got := c.GetAPISpec(tc.serverURL)
require.Equal(t, tc.expected, got)
})
}
}
func buildPEM(cert *x509.Certificate) []byte {
b := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
return pem.EncodeToMemory(b)
}
type clientConfig struct {
caFile string
caData []byte
}
var _ clientcmd.ClientConfig = (*clientConfig)(nil)
func (c *clientConfig) RawConfig() (clientcmdapi.Config, error) {
return clientcmdapi.Config{}, errors.Errorf("not implemented")
}
func (c *clientConfig) ClientConfig() (*restclient.Config, error) {
rc := &restclient.Config{
TLSClientConfig: restclient.TLSClientConfig{
CAData: c.caData,
CAFile: c.caFile,
},
}
return rc, nil
}
func (c *clientConfig) Namespace() (string, bool, error) {
return "", false, errors.Errorf("not implemented")
}
func (c *clientConfig) ConfigAccess() clientcmd.ConfigAccess {
var ca clientcmd.ConfigAccess
return ca
}
......@@ -135,7 +135,7 @@ func (cs *clusterSpecVersion) OpenAPI() ([]byte, error) {
if resp.StatusCode != 200 {
return nil, fmt.Errorf(
"Recieved status code '%d' when trying to retrieve OpenAPI schema for cluster version '%s' from URL '%s'",
"Received status code '%d' when trying to retrieve OpenAPI schema for cluster version '%s' from URL '%s'",
resp.StatusCode, cs.k8sVersion, versionURL)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment