Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions flagd/pkg/service/flag-sync/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
"github.com/open-feature/flagd/core/pkg/model"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"buf.build/gen/go/open-feature/flagd/grpc/go/flagd/sync/v1/syncv1grpc"
syncv1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/sync/v1"
"github.com/open-feature/flagd/core/pkg/logger"
"github.com/open-feature/flagd/core/pkg/store"
flagdService "github.com/open-feature/flagd/flagd/pkg/service"
"google.golang.org/protobuf/types/known/structpb"
)

Expand All @@ -32,7 +34,7 @@

func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.FlagSyncService_SyncFlagsServer) error {
watcher := make(chan store.FlagQueryResult, 1)
selectorExpression := req.GetSelector()
selectorExpression := s.getSelectorExpression(server.Context(), req)
selector := store.NewSelector(selectorExpression)
ctx := server.Context()

Expand Down Expand Up @@ -85,6 +87,39 @@
}
}

// getSelectorExpression extracts the selector expression from the request.
// It first checks the Flagd-Selector header (metadata), then falls back to the request body selector.
// A deprecation warning is logged when the request body selector is used.
//
// The req parameter accepts *syncv1.SyncFlagsRequest or *syncv1.FetchAllFlagsRequest.
// Using interface{} here is intentional as both protobuf-generated types implement GetSelector()
// but do not share a common interface.
func (s syncHandler) getSelectorExpression(ctx context.Context, req interface{}) string {
// Try to get selector from metadata (header)
if md, ok := metadata.FromIncomingContext(ctx); ok {
if values := md.Get(flagdService.FLAGD_SELECTOR_HEADER); len(values) > 0 {
return values[0]
}
}

// Fall back to request body selector for backward compatibility
var bodySelector string
switch r := req.(type) {
case *syncv1.SyncFlagsRequest:
bodySelector = r.GetSelector()
case *syncv1.FetchAllFlagsRequest:
bodySelector = r.GetSelector()
}

// Log deprecation warning if using request body selector
if bodySelector != "" {
s.log.Warn("Using selector from request body is deprecated. Please use the 'Flagd-Selector' header instead. " +
"Request body selector support will be removed in a future major version.")
}

return bodySelector
}

func (s syncHandler) convertMap(flags []model.Flag) map[string]model.Flag {
flagMap := make(map[string]model.Flag, len(flags))
for _, flag := range flags {
Expand All @@ -96,7 +131,7 @@
func (s syncHandler) FetchAllFlags(ctx context.Context, req *syncv1.FetchAllFlagsRequest) (
*syncv1.FetchAllFlagsResponse, error,
) {
selectorExpression := req.GetSelector()
selectorExpression := s.getSelectorExpression(ctx, req)
selector := store.NewSelector(selectorExpression)
flags, _, err := s.store.GetAll(ctx, &selector)
if err != nil {
Expand All @@ -117,8 +152,8 @@

// Deprecated - GetMetadata is deprecated and will be removed in a future release.
// Use the sync_context field in syncv1.SyncFlagsResponse, providing same info.
func (s syncHandler) GetMetadata(_ context.Context, _ *syncv1.GetMetadataRequest) (

Check failure on line 155 in flagd/pkg/service/flag-sync/handler.go

View workflow job for this annotation

GitHub Actions / lint

SA1019: syncv1.GetMetadataRequest is deprecated: Marked as deprecated in flagd/sync/v1/sync.proto. (staticcheck)

Check failure on line 155 in flagd/pkg/service/flag-sync/handler.go

View workflow job for this annotation

GitHub Actions / lint

SA1019: syncv1.GetMetadataRequest is deprecated: Marked as deprecated in flagd/sync/v1/sync.proto. (staticcheck)
*syncv1.GetMetadataResponse, error,

Check failure on line 156 in flagd/pkg/service/flag-sync/handler.go

View workflow job for this annotation

GitHub Actions / lint

SA1019: syncv1.GetMetadataResponse is deprecated: Marked as deprecated in flagd/sync/v1/sync.proto. (staticcheck)

Check failure on line 156 in flagd/pkg/service/flag-sync/handler.go

View workflow job for this annotation

GitHub Actions / lint

SA1019: syncv1.GetMetadataResponse is deprecated: Marked as deprecated in flagd/sync/v1/sync.proto. (staticcheck)
) {
if s.disableSyncMetadata {
return nil, status.Error(codes.Unimplemented, "metadata endpoint disabled")
Expand All @@ -134,7 +169,7 @@
return nil, fmt.Errorf("error constructing metadata response")
}

return &syncv1.GetMetadataResponse{

Check failure on line 172 in flagd/pkg/service/flag-sync/handler.go

View workflow job for this annotation

GitHub Actions / lint

SA1019: syncv1.GetMetadataResponse is deprecated: Marked as deprecated in flagd/sync/v1/sync.proto. (staticcheck)

Check failure on line 172 in flagd/pkg/service/flag-sync/handler.go

View workflow job for this annotation

GitHub Actions / lint

SA1019: syncv1.GetMetadataResponse is deprecated: Marked as deprecated in flagd/sync/v1/sync.proto. (staticcheck)
Metadata: metadata,
},
nil
Expand Down
222 changes: 222 additions & 0 deletions flagd/pkg/service/flag-sync/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@ import (
syncv1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/sync/v1"
"github.com/open-feature/flagd/core/pkg/logger"
"github.com/open-feature/flagd/core/pkg/store"
flagdService "github.com/open-feature/flagd/flagd/pkg/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"google.golang.org/grpc/metadata"
)

func TestSyncHandler_SyncFlags(t *testing.T) {
Expand Down Expand Up @@ -128,3 +133,220 @@ func (m *mockSyncFlagsServer) GetLastResponse() *syncv1.SyncFlagsResponse {
defer m.mu.Unlock()
return m.lastResp
}

// TestSyncHandler_SelectorFromHeader tests that the selector is correctly extracted from the header
func TestSyncHandler_SelectorFromHeader(t *testing.T) {
flagStore, err := store.NewStore(logger.NewLogger(nil, false), []string{})
require.NoError(t, err)

// Create a logger with observer to capture log messages
observedZapCore, observedLogs := observer.New(zapcore.WarnLevel)
observedLogger := zap.New(observedZapCore)
log := logger.NewLogger(observedLogger, false)

handler := syncHandler{
store: flagStore,
log: log,
contextValues: map[string]any{},
}

// Create context with metadata containing the selector header
md := metadata.New(map[string]string{
flagdService.FLAGD_SELECTOR_HEADER: "source:my-source",
})
ctx := metadata.NewIncomingContext(context.Background(), md)

// Test with SyncFlags
stream := &mockSyncFlagsServer{
ctx: ctx,
mu: sync.Mutex{},
respReady: make(chan struct{}, 1),
}

go func() {
// Use empty request body selector to verify header is used
err := handler.SyncFlags(&syncv1.SyncFlagsRequest{Selector: ""}, stream)
assert.NoError(t, err)
}()

select {
case <-stream.respReady:
// Verify no deprecation warning was logged
logs := observedLogs.All()
for _, log := range logs {
assert.NotContains(t, log.Message, "deprecated", "Should not log deprecation warning when using header")
}
Comment on lines +175 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to verify that no deprecation warning was logged is duplicated across TestSyncHandler_SelectorHeaderTakesPrecedence and TestSyncHandler_FetchAllFlags_SelectorFromHeader. Consider extracting it into a shared helper function to improve code reuse and maintainability.

For example:

func assertNoDeprecationWarning(t *testing.T, logs *observer.ObservedLogs) {
	t.Helper()
	for _, entry := range logs.All() {
		assert.NotContains(t, entry.Message, "deprecated", "Should not log deprecation warning")
	}
}

You could then call assertNoDeprecationWarning(t, observedLogs) in all relevant tests.

case <-time.After(time.Second):
t.Fatal("timeout waiting for response")
}
}

// TestSyncHandler_SelectorFromRequestBody tests backward compatibility with request body selector
func TestSyncHandler_SelectorFromRequestBody(t *testing.T) {
flagStore, err := store.NewStore(logger.NewLogger(nil, false), []string{})
require.NoError(t, err)

// Create a logger with observer to capture log messages
observedZapCore, observedLogs := observer.New(zapcore.WarnLevel)
observedLogger := zap.New(observedZapCore)
log := logger.NewLogger(observedLogger, false)

handler := syncHandler{
store: flagStore,
log: log,
contextValues: map[string]any{},
}

// Create context without metadata (no header)
ctx := context.Background()

// Test with SyncFlags
stream := &mockSyncFlagsServer{
ctx: ctx,
mu: sync.Mutex{},
respReady: make(chan struct{}, 1),
}

go func() {
// Use request body selector
err := handler.SyncFlags(&syncv1.SyncFlagsRequest{Selector: "source:legacy-source"}, stream)
assert.NoError(t, err)
}()

select {
case <-stream.respReady:
// Verify deprecation warning was logged
logs := observedLogs.All()
require.Greater(t, len(logs), 0, "Expected at least one log entry")
found := false
for _, log := range logs {
if log.Level == zapcore.WarnLevel {
assert.Contains(t, log.Message, "deprecated", "Should log deprecation warning when using request body selector")
assert.Contains(t, log.Message, "Flagd-Selector", "Deprecation message should mention the header name")
found = true
break
}
}
assert.True(t, found, "Expected to find deprecation warning in logs")
Comment on lines +219 to +230
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to verify the deprecation warning is duplicated in TestSyncHandler_FetchAllFlags_SelectorFromRequestBody. To improve maintainability and reduce code duplication, consider extracting this block into a shared helper function.

For example:

func assertDeprecationWarning(t *testing.T, logs *observer.ObservedLogs) {
	t.Helper()
	
	entries := logs.All()
	require.NotEmpty(t, entries, "Expected at least one log entry")

	found := false
	for _, entry := range entries {
		if entry.Level == zapcore.WarnLevel {
			assert.Contains(t, entry.Message, "deprecated")
			assert.Contains(t, entry.Message, "Flagd-Selector")
			found = true
			break
		}
	}
	assert.True(t, found, "Expected to find deprecation warning in logs")
}

You could then call assertDeprecationWarning(t, observedLogs) here and in the other test.

case <-time.After(time.Second):
t.Fatal("timeout waiting for response")
}
}

// TestSyncHandler_SelectorHeaderTakesPrecedence tests that header takes precedence over request body
func TestSyncHandler_SelectorHeaderTakesPrecedence(t *testing.T) {
flagStore, err := store.NewStore(logger.NewLogger(nil, false), []string{})
require.NoError(t, err)

// Create a logger with observer to capture log messages
observedZapCore, observedLogs := observer.New(zapcore.WarnLevel)
observedLogger := zap.New(observedZapCore)
log := logger.NewLogger(observedLogger, false)

handler := syncHandler{
store: flagStore,
log: log,
contextValues: map[string]any{},
}

// Create context with metadata containing the selector header
md := metadata.New(map[string]string{
flagdService.FLAGD_SELECTOR_HEADER: "source:header-source",
})
ctx := metadata.NewIncomingContext(context.Background(), md)

// Test with SyncFlags
stream := &mockSyncFlagsServer{
ctx: ctx,
mu: sync.Mutex{},
respReady: make(chan struct{}, 1),
}

go func() {
// Provide both header and request body selector
err := handler.SyncFlags(&syncv1.SyncFlagsRequest{Selector: "source:body-source"}, stream)
assert.NoError(t, err)
}()

select {
case <-stream.respReady:
// Verify no deprecation warning was logged (header was used)
logs := observedLogs.All()
for _, log := range logs {
assert.NotContains(t, log.Message, "deprecated", "Should not log deprecation warning when header is present")
}
case <-time.After(time.Second):
t.Fatal("timeout waiting for response")
}
}

// TestSyncHandler_FetchAllFlags_SelectorFromHeader tests FetchAllFlags with header selector
func TestSyncHandler_FetchAllFlags_SelectorFromHeader(t *testing.T) {
flagStore, err := store.NewStore(logger.NewLogger(nil, false), []string{})
require.NoError(t, err)

// Create a logger with observer to capture log messages
observedZapCore, observedLogs := observer.New(zapcore.WarnLevel)
observedLogger := zap.New(observedZapCore)
log := logger.NewLogger(observedLogger, false)

handler := syncHandler{
store: flagStore,
log: log,
contextValues: map[string]any{},
}

// Create context with metadata containing the selector header
md := metadata.New(map[string]string{
flagdService.FLAGD_SELECTOR_HEADER: "source:my-source",
})
ctx := metadata.NewIncomingContext(context.Background(), md)

// Call FetchAllFlags with empty request body selector
_, err = handler.FetchAllFlags(ctx, &syncv1.FetchAllFlagsRequest{Selector: ""})
require.NoError(t, err)

// Verify no deprecation warning was logged
logs := observedLogs.All()
for _, log := range logs {
assert.NotContains(t, log.Message, "deprecated", "Should not log deprecation warning when using header")
}
}

// TestSyncHandler_FetchAllFlags_SelectorFromRequestBody tests FetchAllFlags with request body selector
func TestSyncHandler_FetchAllFlags_SelectorFromRequestBody(t *testing.T) {
flagStore, err := store.NewStore(logger.NewLogger(nil, false), []string{})
require.NoError(t, err)

// Create a logger with observer to capture log messages
observedZapCore, observedLogs := observer.New(zapcore.WarnLevel)
observedLogger := zap.New(observedZapCore)
log := logger.NewLogger(observedLogger, false)

handler := syncHandler{
store: flagStore,
log: log,
contextValues: map[string]any{},
}

// Create context without metadata (no header)
ctx := context.Background()

// Call FetchAllFlags with request body selector
_, err = handler.FetchAllFlags(ctx, &syncv1.FetchAllFlagsRequest{Selector: "source:legacy-source"})
require.NoError(t, err)

// Verify deprecation warning was logged
logs := observedLogs.All()
require.Greater(t, len(logs), 0, "Expected at least one log entry")
found := false
for _, log := range logs {
if log.Level == zapcore.WarnLevel {
assert.Contains(t, log.Message, "deprecated", "Should log deprecation warning when using request body selector")
assert.Contains(t, log.Message, "Flagd-Selector", "Deprecation message should mention the header name")
found = true
break
}
}
assert.True(t, found, "Expected to find deprecation warning in logs")
}
Loading