Skip to content

Commit 1ad978d

Browse files
authored
Move SigV4 config and middleware here (#132)
1 parent a93b9c5 commit 1ad978d

File tree

6 files changed

+238
-0
lines changed

6 files changed

+238
-0
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## v0.25.0
6+
7+
- Add SigV4 middleware from Grafana core.
8+
59
## v0.24.0
610

711
- Sessions: Use STS regional endpoint in assume role for opt-in regions in [#129](https://github.com/grafana/grafana-aws-sdk/pull/129)

pkg/awsds/authSettings.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ const (
2828
// ListMetricsPageLimitKeyName is the string literal for the cloudwatch list metrics page limit key name
2929
ListMetricsPageLimitKeyName = "AWS_CW_LIST_METRICS_PAGE_LIMIT"
3030

31+
// SigV4AuthEnabledEnvVarKeyName is the string literal for the sigv4 auth enabled environment variable key name
32+
SigV4AuthEnabledEnvVarKeyName = "AWS_SIGV4_AUTH_ENABLED"
33+
34+
// SigV4VerboseLoggingEnvVarKeyName is the string literal for the sigv4 verbose logging environment variable key name
35+
SigV4VerboseLoggingEnvVarKeyName = "AWS_SIGV4_VERBOSE_LOGGING"
36+
3137
defaultAssumeRoleEnabled = true
3238
defaultListMetricsPageLimit = 500
3339
defaultSecureSocksDSProxyEnabled = false
@@ -193,3 +199,12 @@ func ReadAuthSettingsFromEnvironmentVariables() *AuthSettings {
193199

194200
return authSettings
195201
}
202+
203+
// ReadSigV4Settings gets the SigV4 settings from the context if its available
204+
func ReadSigV4Settings(ctx context.Context) *SigV4Settings {
205+
cfg := backend.GrafanaConfigFromContext(ctx)
206+
return &SigV4Settings{
207+
Enabled: cfg.Get(SigV4AuthEnabledEnvVarKeyName) == "true",
208+
VerboseLogging: cfg.Get(SigV4VerboseLoggingEnvVarKeyName) == "true",
209+
}
210+
}

pkg/awsds/authSettings_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,39 @@ func TestReadAuthSettings(t *testing.T) {
155155
}
156156
}
157157

158+
func TestReadSigV4Settings(t *testing.T) {
159+
tcs := []struct {
160+
name string
161+
cfg *backend.GrafanaCfg
162+
expectedSettings *SigV4Settings
163+
}{
164+
{
165+
name: "empty config map",
166+
cfg: backend.NewGrafanaCfg(make(map[string]string)),
167+
expectedSettings: &SigV4Settings{},
168+
},
169+
{
170+
name: "aws settings in config",
171+
cfg: backend.NewGrafanaCfg(map[string]string{
172+
SigV4AuthEnabledEnvVarKeyName: "true",
173+
SigV4VerboseLoggingEnvVarKeyName: "true",
174+
}),
175+
expectedSettings: &SigV4Settings{
176+
Enabled: true,
177+
VerboseLogging: true,
178+
},
179+
},
180+
}
181+
for _, tc := range tcs {
182+
t.Run(tc.name, func(t *testing.T) {
183+
ctx := backend.WithGrafanaConfig(context.Background(), tc.cfg)
184+
settings := ReadSigV4Settings(ctx)
185+
186+
require.Equal(t, tc.expectedSettings, settings)
187+
})
188+
}
189+
}
190+
158191
func unsetEnvironmentVariables() {
159192
os.Unsetenv(AllowedAuthProvidersEnvVarKeyName)
160193
os.Unsetenv(AssumeRoleEnabledEnvVarKeyName)

pkg/awsds/types.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ type AuthSettings struct {
2929
SecureSocksDSProxyEnabled bool
3030
}
3131

32+
// SigV4Settings stores the settings for SigV4 authentication
33+
type SigV4Settings struct {
34+
Enabled bool
35+
VerboseLogging bool
36+
}
37+
3238
// QueryStatus represents the status of an async query
3339
type QueryStatus uint32
3440

pkg/sigv4/sigv4_middleware.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package sigv4
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
7+
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
8+
)
9+
10+
// SigV4MiddlewareName the middleware name used by SigV4Middleware.
11+
const SigV4MiddlewareName = "sigv4"
12+
13+
var newSigV4Func = New
14+
15+
// SigV4Middleware applies AWS Signature Version 4 request signing for the outgoing request.
16+
func SigV4Middleware(verboseLogging bool) httpclient.Middleware {
17+
return httpclient.NamedMiddlewareFunc(SigV4MiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
18+
if opts.SigV4 == nil {
19+
return next
20+
}
21+
22+
conf := &Config{
23+
Service: opts.SigV4.Service,
24+
AccessKey: opts.SigV4.AccessKey,
25+
SecretKey: opts.SigV4.SecretKey,
26+
Region: opts.SigV4.Region,
27+
AssumeRoleARN: opts.SigV4.AssumeRoleARN,
28+
AuthType: opts.SigV4.AuthType,
29+
ExternalID: opts.SigV4.ExternalID,
30+
Profile: opts.SigV4.Profile,
31+
}
32+
33+
rt, err := newSigV4Func(conf, next, Opts{VerboseMode: verboseLogging})
34+
if err != nil {
35+
return invalidSigV4Config(err)
36+
}
37+
38+
return rt
39+
})
40+
}
41+
42+
func invalidSigV4Config(err error) http.RoundTripper {
43+
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
44+
return nil, fmt.Errorf("invalid SigV4 configuration: %w", err)
45+
})
46+
}

pkg/sigv4/sigv4_middleware_test.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package sigv4
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"io"
7+
"net/http"
8+
"testing"
9+
10+
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
type testContext struct {
15+
callChain []string
16+
}
17+
18+
func (c *testContext) createRoundTripper(name string) http.RoundTripper {
19+
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
20+
c.callChain = append(c.callChain, name)
21+
return &http.Response{
22+
StatusCode: http.StatusOK,
23+
Request: req,
24+
Body: io.NopCloser(bytes.NewBufferString("")),
25+
}, nil
26+
})
27+
}
28+
29+
func TestSigV4Middleware(t *testing.T) {
30+
t.Run("Without sigv4 options set should return next http.RoundTripper", func(t *testing.T) {
31+
origSigV4Func := newSigV4Func
32+
newSigV4Called := false
33+
middlewareCalled := false
34+
newSigV4Func = func(config *Config, next http.RoundTripper, opts ...Opts) (http.RoundTripper, error) {
35+
newSigV4Called = true
36+
return httpclient.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
37+
middlewareCalled = true
38+
return next.RoundTrip(r)
39+
}), nil
40+
}
41+
t.Cleanup(func() {
42+
newSigV4Func = origSigV4Func
43+
})
44+
45+
ctx := &testContext{}
46+
finalRoundTripper := ctx.createRoundTripper("finalrt")
47+
mw := SigV4Middleware(false)
48+
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
49+
require.NotNil(t, rt)
50+
middlewareName, ok := mw.(httpclient.MiddlewareName)
51+
require.True(t, ok)
52+
require.Equal(t, SigV4MiddlewareName, middlewareName.MiddlewareName())
53+
54+
req, err := http.NewRequest(http.MethodGet, "http://", nil)
55+
require.NoError(t, err)
56+
res, err := rt.RoundTrip(req)
57+
require.NoError(t, err)
58+
require.NotNil(t, res)
59+
if res.Body != nil {
60+
require.NoError(t, res.Body.Close())
61+
}
62+
require.Len(t, ctx.callChain, 1)
63+
require.ElementsMatch(t, []string{"finalrt"}, ctx.callChain)
64+
require.False(t, newSigV4Called)
65+
require.False(t, middlewareCalled)
66+
})
67+
68+
t.Run("With sigv4 options set should call sigv4 http.RoundTripper", func(t *testing.T) {
69+
origSigV4Func := newSigV4Func
70+
newSigV4Called := false
71+
middlewareCalled := false
72+
newSigV4Func = func(config *Config, next http.RoundTripper, opts ...Opts) (http.RoundTripper, error) {
73+
newSigV4Called = true
74+
return httpclient.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
75+
middlewareCalled = true
76+
return next.RoundTrip(r)
77+
}), nil
78+
}
79+
t.Cleanup(func() {
80+
newSigV4Func = origSigV4Func
81+
})
82+
83+
ctx := &testContext{}
84+
finalRoundTripper := ctx.createRoundTripper("final")
85+
mw := SigV4Middleware(false)
86+
rt := mw.CreateMiddleware(httpclient.Options{SigV4: &httpclient.SigV4Config{}}, finalRoundTripper)
87+
require.NotNil(t, rt)
88+
middlewareName, ok := mw.(httpclient.MiddlewareName)
89+
require.True(t, ok)
90+
require.Equal(t, SigV4MiddlewareName, middlewareName.MiddlewareName())
91+
92+
req, err := http.NewRequest(http.MethodGet, "http://", nil)
93+
require.NoError(t, err)
94+
res, err := rt.RoundTrip(req)
95+
require.NoError(t, err)
96+
require.NotNil(t, res)
97+
if res.Body != nil {
98+
require.NoError(t, res.Body.Close())
99+
}
100+
require.Len(t, ctx.callChain, 1)
101+
require.ElementsMatch(t, []string{"final"}, ctx.callChain)
102+
103+
require.True(t, newSigV4Called)
104+
require.True(t, middlewareCalled)
105+
})
106+
107+
t.Run("With sigv4 error returned", func(t *testing.T) {
108+
origSigV4Func := newSigV4Func
109+
newSigV4Func = func(config *Config, next http.RoundTripper, opts ...Opts) (http.RoundTripper, error) {
110+
return nil, fmt.Errorf("problem")
111+
}
112+
t.Cleanup(func() {
113+
newSigV4Func = origSigV4Func
114+
})
115+
116+
ctx := &testContext{}
117+
finalRoundTripper := ctx.createRoundTripper("final")
118+
mw := SigV4Middleware(false)
119+
rt := mw.CreateMiddleware(httpclient.Options{SigV4: &httpclient.SigV4Config{}}, finalRoundTripper)
120+
require.NotNil(t, rt)
121+
middlewareName, ok := mw.(httpclient.MiddlewareName)
122+
require.True(t, ok)
123+
require.Equal(t, SigV4MiddlewareName, middlewareName.MiddlewareName())
124+
125+
req, err := http.NewRequest(http.MethodGet, "http://", nil)
126+
require.NoError(t, err)
127+
// response is nil
128+
// nolint:bodyclose
129+
res, err := rt.RoundTrip(req)
130+
require.Error(t, err)
131+
require.Nil(t, res)
132+
require.Empty(t, ctx.callChain)
133+
})
134+
}

0 commit comments

Comments
 (0)