From 550f29bae4789e12442b93a9f923cf9cd10de5fc Mon Sep 17 00:00:00 2001 From: Joseph Anttila Hall Date: Wed, 10 Apr 2024 14:59:53 -0700 Subject: [PATCH] Refactor BackendManager / BackendStorage. This also fixes https://github.com/kubernetes-sigs/apiserver-network-proxy/issues/294 --- pkg/server/backend_manager.go | 363 ++++------------ pkg/server/backend_manager_test.go | 435 ++++++++------------ pkg/server/default_route_backend_manager.go | 59 --- pkg/server/desthost_backend_manager.go | 86 ---- pkg/server/readiness_manager.go | 9 +- pkg/server/server.go | 66 +-- pkg/server/server_test.go | 231 +---------- pkg/server/storage.go | 264 ++++++++++++ pkg/server/storage_test.go | 314 ++++++++++++++ tests/framework/proxy_server.go | 10 +- tests/proxy_test.go | 11 +- 11 files changed, 865 insertions(+), 983 deletions(-) delete mode 100644 pkg/server/default_route_backend_manager.go delete mode 100644 pkg/server/desthost_backend_manager.go create mode 100644 pkg/server/storage.go create mode 100644 pkg/server/storage_test.go diff --git a/pkg/server/backend_manager.go b/pkg/server/backend_manager.go index c19485642..6f9813305 100644 --- a/pkg/server/backend_manager.go +++ b/pkg/server/backend_manager.go @@ -17,23 +17,11 @@ limitations under the License. package server import ( - "context" "fmt" - "io" - "math/rand" "slices" "strings" - "sync" - "time" - "google.golang.org/grpc/metadata" - "k8s.io/klog/v2" - - commonmetrics "sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/common/metrics" - client "sigs.k8s.io/apiserver-network-proxy/konnectivity-client/proto/client" "sigs.k8s.io/apiserver-network-proxy/pkg/server/metrics" - "sigs.k8s.io/apiserver-network-proxy/proto/agent" - "sigs.k8s.io/apiserver-network-proxy/proto/header" ) type ProxyStrategy int @@ -98,284 +86,104 @@ func ParseProxyStrategies(proxyStrategies string) ([]ProxyStrategy, error) { return result, nil } -// Backend abstracts a connected Konnectivity agent. -// -// In the only currently supported case (gRPC), it wraps an -// agent.AgentService_ConnectServer, provides synchronization and -// emits common stream metrics. -type Backend interface { - Send(p *client.Packet) error - Recv() (*client.Packet, error) - Context() context.Context - GetAgentID() string - GetAgentIdentifiers() header.Identifiers -} - -var _ Backend = &backend{} - -type backend struct { - sendLock sync.Mutex - recvLock sync.Mutex - conn agent.AgentService_ConnectServer - - // cached from conn.Context() - id string - idents header.Identifiers -} - -func (b *backend) Send(p *client.Packet) error { - b.sendLock.Lock() - defer b.sendLock.Unlock() - - const segment = commonmetrics.SegmentToAgent - metrics.Metrics.ObservePacket(segment, p.Type) - err := b.conn.Send(p) - if err != nil && err != io.EOF { - metrics.Metrics.ObserveStreamError(segment, err, p.Type) - } - return err -} - -func (b *backend) Recv() (*client.Packet, error) { - b.recvLock.Lock() - defer b.recvLock.Unlock() - - const segment = commonmetrics.SegmentFromAgent - pkt, err := b.conn.Recv() - if err != nil { - if err != io.EOF { - metrics.Metrics.ObserveStreamErrorNoPacket(segment, err) - } - return nil, err - } - metrics.Metrics.ObservePacket(segment, pkt.Type) - return pkt, nil -} - -func (b *backend) Context() context.Context { - // TODO: does Context require lock protection? - return b.conn.Context() -} - -func (b *backend) GetAgentID() string { - return b.id -} - -func (b *backend) GetAgentIdentifiers() header.Identifiers { - return b.idents -} - -func getAgentID(stream agent.AgentService_ConnectServer) (string, error) { - md, ok := metadata.FromIncomingContext(stream.Context()) - if !ok { - return "", fmt.Errorf("failed to get context") - } - agentIDs := md.Get(header.AgentID) - if len(agentIDs) != 1 { - return "", fmt.Errorf("expected one agent ID in the context, got %v", agentIDs) - } - return agentIDs[0], nil -} - -func getAgentIdentifiers(conn agent.AgentService_ConnectServer) (header.Identifiers, error) { - var agentIdentifiers header.Identifiers - md, ok := metadata.FromIncomingContext(conn.Context()) - if !ok { - return agentIdentifiers, fmt.Errorf("failed to get metadata from context") - } - agentIdent := md.Get(header.AgentIdentifiers) - if len(agentIdent) > 1 { - return agentIdentifiers, fmt.Errorf("expected at most one set of agent identifiers in the context, got %v", agentIdent) - } - if len(agentIdent) == 0 { - return agentIdentifiers, nil - } - - return header.GenAgentIdentifiers(agentIdent[0]) -} - -func NewBackend(conn agent.AgentService_ConnectServer) (Backend, error) { - agentID, err := getAgentID(conn) - if err != nil { - return nil, err - } - agentIdentifiers, err := getAgentIdentifiers(conn) - if err != nil { - return nil, err - } - return &backend{conn: conn, id: agentID, idents: agentIdentifiers}, nil -} - -// BackendStorage is an interface to manage the storage of the backend -// connections, i.e., get, add and remove -type BackendStorage interface { - // addBackend adds a backend. - addBackend(identifier string, idType header.IdentifierType, backend Backend) - // removeBackend removes a backend. - removeBackend(identifier string, idType header.IdentifierType, backend Backend) - // NumBackends returns the number of backends. - NumBackends() int -} - // BackendManager is an interface to manage backend connections, i.e., // connection to the proxy agents. type BackendManager interface { - // Backend returns a single backend. - // WARNING: the context passed to the function should be a session-scoped - // context instead of a request-scoped context, as the backend manager will - // pick a backend for every tunnel session and each tunnel session may - // contains multiple requests. - Backend(ctx context.Context) (Backend, error) + // Backend returns a backend connection according to proxy strategies. + Backend(addr string) (Backend, error) // AddBackend adds a backend. AddBackend(backend Backend) // RemoveBackend adds a backend. RemoveBackend(backend Backend) - BackendStorage + // NumBackends returns the number of backends. + NumBackends() int ReadinessManager } -var _ BackendManager = &DefaultBackendManager{} - -// DefaultBackendManager is the default backend manager. type DefaultBackendManager struct { - *DefaultBackendStorage -} + proxyStrategies []ProxyStrategy -func (dbm *DefaultBackendManager) Backend(_ context.Context) (Backend, error) { - klog.V(5).InfoS("Get a random backend through the DefaultBackendManager") - return dbm.DefaultBackendStorage.GetRandomBackend() -} - -func (dbm *DefaultBackendManager) AddBackend(backend Backend) { - agentID := backend.GetAgentID() - klog.V(5).InfoS("Add the agent to DefaultBackendManager", "agentID", agentID) - dbm.addBackend(agentID, header.UID, backend) + // All backends by agentID. + all BackendStorage + // All backends by host identifier(s). Only used with ProxyStrategyDestHost. + byHost BackendStorage + // All default-route backends, by agentID. Only used with ProxyStrategyDefaultRoute. + byDefaultRoute BackendStorage } -func (dbm *DefaultBackendManager) RemoveBackend(backend Backend) { - agentID := backend.GetAgentID() - klog.V(5).InfoS("Remove the agent from the DefaultBackendManager", "agentID", agentID) - dbm.removeBackend(agentID, header.UID, backend) -} - -// DefaultBackendStorage is the default backend storage. -type DefaultBackendStorage struct { - mu sync.RWMutex //protects the following - // A map between agentID and its grpc connections. - // For a given agent, ProxyServer prefers backends[agentID][0] to send - // traffic, because backends[agentID][1:] are more likely to be closed - // by the agent to deduplicate connections to the same server. - // - // TODO: fix documentation. This is not always agentID, e.g. in - // the case of DestHostBackendManager. - backends map[string][]Backend - // agentID is tracked in this slice to enable randomly picking an - // agentID in the Backend() method. There is no reliable way to - // randomly pick a key from a map (in this case, the backends) in - // Golang. - agentIDs []string - random *rand.Rand - // idTypes contains the valid identifier types for this - // DefaultBackendStorage. The DefaultBackendStorage may only tolerate certain - // types of identifiers when associating to a specific BackendManager, - // e.g., when associating to the DestHostBackendManager, it can only use the - // identifiers of types, IPv4, IPv6 and Host. - idTypes []header.IdentifierType -} - -// NewDefaultBackendManager returns a DefaultBackendManager. -func NewDefaultBackendManager() *DefaultBackendManager { - return &DefaultBackendManager{ - DefaultBackendStorage: NewDefaultBackendStorage( - []header.IdentifierType{header.UID})} -} +var _ BackendManager = &DefaultBackendManager{} -// NewDefaultBackendStorage returns a DefaultBackendStorage -func NewDefaultBackendStorage(idTypes []header.IdentifierType) *DefaultBackendStorage { - // Set an explicit value, so that the metric is emitted even when - // no agent ever successfully connects. +// NewDefaultBackendManager returns a DefaultBackendStorage +func NewDefaultBackendManager(proxyStrategies []ProxyStrategy) *DefaultBackendManager { metrics.Metrics.SetBackendCount(0) - return &DefaultBackendStorage{ - backends: make(map[string][]Backend), - random: rand.New(rand.NewSource(time.Now().UnixNano())), - idTypes: idTypes, - } /* #nosec G404 */ + return &DefaultBackendManager{ + proxyStrategies: proxyStrategies, + all: NewDefaultBackendStorage(), + byHost: NewDefaultBackendStorage(), + byDefaultRoute: NewDefaultBackendStorage(), + } +} + +func (s *DefaultBackendManager) Backend(addr string) (Backend, error) { + for _, strategy := range s.proxyStrategies { + var b Backend + var e error + e = &ErrNotFound{} + switch strategy { + case ProxyStrategyDefault: + b, e = s.all.RandomBackend() + case ProxyStrategyDestHost: + b, e = s.byHost.Backend(addr) + case ProxyStrategyDefaultRoute: + b, e = s.byDefaultRoute.RandomBackend() + } + if e == nil { + return b, nil + } + } + return nil, &ErrNotFound{} } -func containIDType(idTypes []header.IdentifierType, idType header.IdentifierType) bool { - return slices.Contains(idTypes, idType) +func hostIdentifiers(backend Backend) []string { + hosts := []string{} + hosts = append(hosts, backend.GetAgentIdentifiers().IPv4...) + hosts = append(hosts, backend.GetAgentIdentifiers().IPv6...) + hosts = append(hosts, backend.GetAgentIdentifiers().Host...) + return hosts } -// addBackend adds a backend. -func (s *DefaultBackendStorage) addBackend(identifier string, idType header.IdentifierType, backend Backend) { - if !containIDType(s.idTypes, idType) { - klog.V(4).InfoS("fail to add backend", "backend", identifier, "error", &ErrWrongIDType{idType, s.idTypes}) - return - } - klog.V(5).InfoS("Register backend for agent", "agentID", identifier) - s.mu.Lock() - defer s.mu.Unlock() - _, ok := s.backends[identifier] - if ok { - for _, b := range s.backends[identifier] { - if b == backend { - klog.V(1).InfoS("This should not happen. Adding existing backend for agent", "agentID", identifier) - return - } +func (s *DefaultBackendManager) AddBackend(backend Backend) { + agentID := backend.GetAgentID() + count := s.all.AddBackend([]string{agentID}, backend) + if slices.Contains(s.proxyStrategies, ProxyStrategyDestHost) { + idents := hostIdentifiers(backend) + s.byHost.AddBackend(idents, backend) + } + if slices.Contains(s.proxyStrategies, ProxyStrategyDefaultRoute) { + if backend.GetAgentIdentifiers().DefaultRoute { + s.byDefaultRoute.AddBackend([]string{agentID}, backend) } - s.backends[identifier] = append(s.backends[identifier], backend) - return } - s.backends[identifier] = []Backend{backend} - metrics.Metrics.SetBackendCount(len(s.backends)) - s.agentIDs = append(s.agentIDs, identifier) + metrics.Metrics.SetBackendCount(count) } -// removeBackend removes a backend. -func (s *DefaultBackendStorage) removeBackend(identifier string, idType header.IdentifierType, backend Backend) { - if !containIDType(s.idTypes, idType) { - klog.ErrorS(&ErrWrongIDType{idType, s.idTypes}, "fail to remove backend") - return - } - klog.V(5).InfoS("Remove connection for agent", "agentID", identifier) - s.mu.Lock() - defer s.mu.Unlock() - backends, ok := s.backends[identifier] - if !ok { - klog.V(1).InfoS("Cannot find agent in backends", "identifier", identifier) - return - } - var found bool - for i, b := range backends { - if b == backend { - s.backends[identifier] = append(s.backends[identifier][:i], s.backends[identifier][i+1:]...) - if i == 0 && len(s.backends[identifier]) != 0 { - klog.V(1).InfoS("This should not happen. Removed connection that is not the first connection", "agentID", identifier) - } - found = true - } - } - if len(s.backends[identifier]) == 0 { - delete(s.backends, identifier) - for i := range s.agentIDs { - if s.agentIDs[i] == identifier { - s.agentIDs[i] = s.agentIDs[len(s.agentIDs)-1] - s.agentIDs = s.agentIDs[:len(s.agentIDs)-1] - break - } +func (s *DefaultBackendManager) RemoveBackend(backend Backend) { + agentID := backend.GetAgentID() + count := s.all.RemoveBackend([]string{agentID}, backend) + if slices.Contains(s.proxyStrategies, ProxyStrategyDestHost) { + idents := hostIdentifiers(backend) + s.byHost.RemoveBackend(idents, backend) + } + if slices.Contains(s.proxyStrategies, ProxyStrategyDefaultRoute) { + if backend.GetAgentIdentifiers().DefaultRoute { + s.byDefaultRoute.RemoveBackend([]string{agentID}, backend) } } - if !found { - klog.V(1).InfoS("Could not find connection matching identifier to remove", "agentID", identifier, "idType", idType) - } - metrics.Metrics.SetBackendCount(len(s.backends)) + metrics.Metrics.SetBackendCount(count) } -// NumBackends resturns the number of available backends -func (s *DefaultBackendStorage) NumBackends() int { - s.mu.RLock() - defer s.mu.RUnlock() - return len(s.backends) +func (s *DefaultBackendManager) NumBackends() int { + return s.all.NumKeys() } // ErrNotFound indicates that no backend can be found. @@ -386,32 +194,9 @@ func (e *ErrNotFound) Error() string { return "No agent available" } -type ErrWrongIDType struct { - got header.IdentifierType - expect []header.IdentifierType -} - -func (e *ErrWrongIDType) Error() string { - return fmt.Sprintf("incorrect id type: got %s, expect %s", e.got, e.expect) -} - -func ignoreNotFound(err error) error { - if _, ok := err.(*ErrNotFound); ok { - return nil - } - return err -} - -// GetRandomBackend returns a random backend connection from all connected agents. -func (s *DefaultBackendStorage) GetRandomBackend() (Backend, error) { - s.mu.Lock() - defer s.mu.Unlock() - if len(s.backends) == 0 { - return nil, &ErrNotFound{} +func (s *DefaultBackendManager) Ready() (bool, string) { + if s.NumBackends() == 0 { + return false, "no connection to any proxy agent" } - agentID := s.agentIDs[s.random.Intn(len(s.agentIDs))] - klog.V(5).InfoS("Pick agent as backend", "agentID", agentID) - // always return the first connection to an agent, because the agent - // will close later connections if there are multiple. - return s.backends[agentID][0], nil + return true, "" } diff --git a/pkg/server/backend_manager_test.go b/pkg/server/backend_manager_test.go index e786f0d24..6e489ccb5 100644 --- a/pkg/server/backend_manager_test.go +++ b/pkg/server/backend_manager_test.go @@ -43,346 +43,273 @@ func mockAgentConn(ctrl *gomock.Controller, agentID string, agentIdentifiers []s return agentConn } -func TestNewBackend(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - testCases := []struct { - desc string - ids []string - idents []string - wantErr bool - }{ - { - desc: "no agentID", - wantErr: true, - }, - { - desc: "multiple agentID", - ids: []string{"agent-id", "agent-id"}, - wantErr: true, - }, - { - desc: "multiple identifiers", - ids: []string{"agent-id"}, - idents: []string{"host=localhost", "host=localhost"}, - wantErr: true, - }, - { - desc: "invalid identifiers", - ids: []string{"agent-id"}, - idents: []string{";"}, - wantErr: true, - }, - { - desc: "success", - ids: []string{"agent-id"}, - }, - { - desc: "success with identifiers", - ids: []string{"agent-id"}, - idents: []string{"host=localhost&host=node1.mydomain.com&cidr=127.0.0.1/16&ipv4=1.2.3.4&ipv4=5.6.7.8&ipv6=:::::&default-route=true"}, - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - - agentConn := agentmock.NewMockAgentService_ConnectServer(ctrl) - agentConnMD := metadata.MD{ - ":authority": []string{"127.0.0.1:8091"}, - "agentid": tc.ids, - "agentidentifiers": tc.idents, - "content-type": []string{"application/grpc"}, - "user-agent": []string{"grpc-go/1.42.0"}, - } - agentConnCtx := metadata.NewIncomingContext(context.Background(), agentConnMD) - agentConn.EXPECT().Context().Return(agentConnCtx).AnyTimes() - - _, err := NewBackend(agentConn) - if gotErr := (err != nil); gotErr != tc.wantErr { - t.Errorf("NewBackend got err %q; wantErr = %t", err, tc.wantErr) - } - }) - } -} - -func TestDefaultBackendManager_AddRemoveBackends(t *testing.T) { +func TestBackendManagerAddRemoveBackends(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) - backend12, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{})) - backend22, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{})) backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{})) - p := NewDefaultBackendManager() + p := NewDefaultBackendManager([]ProxyStrategy{ProxyStrategyDefault}) p.AddBackend(backend1) - p.RemoveBackend(backend1) - expectedBackends := make(map[string][]Backend) - expectedAgentIDs := []string{} - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) - } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) - } - p = NewDefaultBackendManager() - p.AddBackend(backend1) - p.AddBackend(backend12) - // Adding the same connection again should be a no-op. - p.AddBackend(backend12) + input := "127.0.0.1" + if got, _ := p.Backend(input); got != backend1 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend1) + } p.AddBackend(backend2) - p.AddBackend(backend22) p.AddBackend(backend3) - p.RemoveBackend(backend22) - p.RemoveBackend(backend2) p.RemoveBackend(backend1) - expectedBackends = map[string][]Backend{ - "agent1": {backend12}, - "agent3": {backend3}, - } - expectedAgentIDs = []string{"agent1", "agent3"} - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + p.RemoveBackend(backend2) + + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + + p.RemoveBackend(backend3) + + if got, _ := p.Backend(input); got != nil { + t.Errorf("Backend(%v) = %v, want nil", input, got) } } -func TestDefaultRouteBackendManager_AddRemoveBackends(t *testing.T) { +func TestBackendManagerAddRemoveBackends_DefaultRouteStrategy(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"default-route=true"})) - backend12, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"default-route=true"})) - backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"default-route=true"})) - backend22, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"default-route=true"})) + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"default-route=false"})) backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"default-route=true"})) - p := NewDefaultRouteBackendManager() + p := NewDefaultBackendManager([]ProxyStrategy{ProxyStrategyDefaultRoute}) p.AddBackend(backend1) - p.RemoveBackend(backend1) - expectedBackends := make(map[string][]Backend) - expectedAgentIDs := []string{} - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) - } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + p.AddBackend(backend2) + + input := "127.0.0.1" + if got, _ := p.Backend(input); got != nil { + t.Errorf("Backend(%v) = %v, want nil", input, got) } - p = NewDefaultRouteBackendManager() - p.AddBackend(backend1) - p.AddBackend(backend12) - // Adding the same connection again should be a no-op. - p.AddBackend(backend12) - p.AddBackend(backend2) - p.AddBackend(backend22) p.AddBackend(backend3) - p.RemoveBackend(backend22) - p.RemoveBackend(backend2) - p.RemoveBackend(backend1) - expectedBackends = map[string][]Backend{ - "agent1": {backend12}, - "agent3": {backend3}, + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) } - expectedAgentIDs = []string{"agent1", "agent3"} - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) - } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + p.RemoveBackend(backend3) + + if got, _ := p.Backend(input); got != nil { + t.Errorf("Backend(%v) = %v, want nil", input, got) } } -func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) { +func TestBackendManagerAddRemoveBackends_DestHostStrategy(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"host=localhost&host=node1.mydomain.com&ipv4=1.2.3.4&ipv6=9878::7675:1292:9183:7562"})) - // backend2 has no desthost relevant identifiers backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"default-route=true"})) - // TODO: if backend3 is given conflicting identifiers with backend1, the wrong thing happens in RemoveBackend. backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"host=node2.mydomain.com&ipv4=5.6.7.8&ipv6=::"})) - p := NewDestHostBackendManager() + p := NewDefaultBackendManager([]ProxyStrategy{ProxyStrategyDestHost}) p.AddBackend(backend1) - p.RemoveBackend(backend1) - expectedBackends := make(map[string][]Backend) - expectedAgentIDs := []string{} - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + p.AddBackend(backend2) + p.AddBackend(backend3) + + input := "127.0.0.1" + if got, _ := p.Backend(input); got != nil { + t.Errorf("Backend(%v) = %v, want nil", input, got) } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + input = "localhost" + if got, _ := p.Backend(input); got != backend1 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend1) } - - p = NewDestHostBackendManager() - p.AddBackend(backend1) - - expectedBackends = map[string][]Backend{ - "localhost": {backend1}, - "1.2.3.4": {backend1}, - "9878::7675:1292:9183:7562": {backend1}, - "node1.mydomain.com": {backend1}, + input = "node1.mydomain.com" + if got, _ := p.Backend(input); got != backend1 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend1) } - expectedAgentIDs = []string{ - "1.2.3.4", - "9878::7675:1292:9183:7562", - "localhost", - "node1.mydomain.com", + input = "1.2.3.4" + if got, _ := p.Backend(input); got != backend1 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend1) } - - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + input = "9878::7675:1292:9183:7562" + if got, _ := p.Backend(input); got != backend1 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend1) + } + input = "node2.mydomain.com" + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) + } + input = "5.6.7.8" + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + input = "::" + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) } + p.RemoveBackend(backend1) + p.RemoveBackend(backend2) + p.RemoveBackend(backend3) + + input = "127.0.0.1" + if got, _ := p.Backend(input); got != nil { + t.Errorf("Backend(%v) = %v, want nil", input, got) + } +} + +func TestBackendManagerAddRemoveBackends_DestHostWithDefault(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"host=localhost&host=node1.mydomain.com&ipv4=1.2.3.4&ipv6=9878::7675:1292:9183:7562"})) + backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"default-route=false"})) + backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"host=node2.mydomain.com&ipv4=5.6.7.8&ipv6=::"})) + + p := NewDefaultBackendManager([]ProxyStrategy{ProxyStrategyDestHost, ProxyStrategyDefault}) + + p.AddBackend(backend1) p.AddBackend(backend2) p.AddBackend(backend3) - expectedBackends = map[string][]Backend{ - "localhost": {backend1}, - "node1.mydomain.com": {backend1}, - "node2.mydomain.com": {backend3}, - "1.2.3.4": {backend1}, - "5.6.7.8": {backend3}, - "9878::7675:1292:9183:7562": {backend1}, - "::": {backend3}, + input := "127.0.0.1" + if got, _ := p.Backend(input); got == nil { + t.Errorf("expected random fallback, got nil") } - expectedAgentIDs = []string{ - "1.2.3.4", - "9878::7675:1292:9183:7562", - "localhost", - "node1.mydomain.com", - "5.6.7.8", - "::", - "node2.mydomain.com", + input = "localhost" + if got, _ := p.Backend(input); got != backend1 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend1) } - - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + input = "node1.mydomain.com" + if got, _ := p.Backend(input); got != backend1 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend1) } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + input = "1.2.3.4" + if got, _ := p.Backend(input); got != backend1 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend1) } - - p.RemoveBackend(backend2) - p.RemoveBackend(backend1) - - expectedBackends = map[string][]Backend{ - "node2.mydomain.com": {backend3}, - "5.6.7.8": {backend3}, - "::": {backend3}, + input = "9878::7675:1292:9183:7562" + if got, _ := p.Backend(input); got != backend1 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend1) } - expectedAgentIDs = []string{ - "node2.mydomain.com", - "::", - "5.6.7.8", + input = "node2.mydomain.com" + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) } - - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + input = "5.6.7.8" + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + input = "::" + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) } - p.RemoveBackend(backend3) - expectedBackends = map[string][]Backend{} - expectedAgentIDs = []string{} + p.RemoveBackend(backend1) + p.RemoveBackend(backend2) - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + input = "127.0.0.1" + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + + p.RemoveBackend(backend3) + + input = "127.0.0.1" + if got, _ := p.Backend(input); got != nil { + t.Errorf("Backend(%v) = %v, want nil", input, got) } } -func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) { +func TestBackendManagerAddRemoveBackends_DuplicateIdents(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"host=localhost&host=node1.mydomain.com&ipv4=1.2.3.4&ipv6=9878::7675:1292:9183:7562"})) - backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"host=localhost&host=node1.mydomain.com&ipv4=1.2.3.4&ipv6=9878::7675:1292:9183:7562"})) - backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"host=localhost&host=node2.mydomain.com&ipv4=5.6.7.8&ipv6=::"})) + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"host=node1.mydomain.com&ipv4=1.2.3.4&ipv6=9878::7675:1292:9183:7562"})) + backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"host=node1.mydomain.com&ipv4=1.2.3.4&ipv6=9878::7675:1292:9183:7562"})) + backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"default-route=true"})) - p := NewDestHostBackendManager() + p := NewDefaultBackendManager([]ProxyStrategy{ProxyStrategyDestHost, ProxyStrategyDefaultRoute}) p.AddBackend(backend1) p.AddBackend(backend2) p.AddBackend(backend3) - expectedBackends := map[string][]Backend{ - "localhost": {backend1, backend2, backend3}, - "1.2.3.4": {backend1, backend2}, - "5.6.7.8": {backend3}, - "9878::7675:1292:9183:7562": {backend1, backend2}, - "::": {backend3}, - "node1.mydomain.com": {backend1, backend2}, - "node2.mydomain.com": {backend3}, + input := "node1.mydomain.com" + if got, _ := p.Backend(input); got == nil || got == backend3 { + t.Errorf("Backend(%v) = %v, want any other backend", got, backend3) } - expectedAgentIDs := []string{ - "1.2.3.4", - "9878::7675:1292:9183:7562", - "localhost", - "node1.mydomain.com", - "5.6.7.8", - "::", - "node2.mydomain.com", + input = "1.2.3.4" + if got, _ := p.Backend(input); got == nil || got == backend3 { + t.Errorf("Backend(%v) = %v, want any other backend", got, backend3) } - - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + input = "9878::7675:1292:9183:7562" + if got, _ := p.Backend(input); got == nil || got == backend3 { + t.Errorf("Backend(%v) = %v, want any other backend", got, backend3) + } + input = "node2.mydomain.com" + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + input = "5.6.7.8" + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) + } + input = "::" + if got, _ := p.Backend(input); got != backend3 { + t.Errorf("Backend(%v) = %v, want %v", input, got, backend3) } p.RemoveBackend(backend1) + p.RemoveBackend(backend2) p.RemoveBackend(backend3) - expectedBackends = map[string][]Backend{ - "localhost": {backend2}, - "1.2.3.4": {backend2}, - "9878::7675:1292:9183:7562": {backend2}, - "node1.mydomain.com": {backend2}, - } - expectedAgentIDs = []string{ - "1.2.3.4", - "9878::7675:1292:9183:7562", - "localhost", - "node1.mydomain.com", + input = "127.0.0.1" + if got, _ := p.Backend(input); got != nil { + t.Errorf("Backend(%v) = %v, want nil", input, got) } +} + +func TestBackendManagerNumBackends(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"host=node1.mydomain.com&ipv4=1.2.3.4&ipv6=9878::7675:1292:9183:7562"})) + backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{""})) + backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"default-route=true"})) - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + p := NewDefaultBackendManager([]ProxyStrategy{ProxyStrategyDestHost, ProxyStrategyDefaultRoute, ProxyStrategyDefault}) + + p.AddBackend(backend1) + p.AddBackend(backend2) + p.AddBackend(backend3) + if got := p.NumBackends(); got != 3 { + t.Errorf("NumBackends() = %v, want 3", got) } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + wantReady := true + wantMsg := "" + ready, msg := p.Ready() + if ready != wantReady || msg != wantMsg { + t.Errorf("Ready() = %t / %q, want %t / %q", ready, msg, wantReady, wantMsg) } + p.RemoveBackend(backend1) p.RemoveBackend(backend2) - expectedBackends = map[string][]Backend{} - expectedAgentIDs = []string{} + p.RemoveBackend(backend3) - if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + if got := p.NumBackends(); got != 0 { + t.Errorf("NumBackends() = %v, want 0", got) } - if e, a := expectedAgentIDs, p.agentIDs; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, got %v", e, a) + wantReady = false + wantMsg = "no connection to any proxy agent" + ready, msg = p.Ready() + if ready != wantReady || msg != wantMsg { + t.Errorf("Ready() = %t / %q, want %t / %q", ready, msg, wantReady, wantMsg) } } @@ -417,7 +344,7 @@ func TestProxyStrategy(t *testing.T) { } else { got := tc.input.String() if got != tc.want { - t.Errorf("ProxyStrategy.String(): got %v, want %v", got, tc.want) + t.Errorf("ProxyStrategy.String() = %v, want %v", got, tc.want) } } }) @@ -453,9 +380,9 @@ func TestParseProxyStrategy(t *testing.T) { } { t.Run(desc, func(t *testing.T) { got, err := ParseProxyStrategy(tc.input) - assert.Equal(t, tc.wantErr, err, "ParseProxyStrategy(%s): got error %q, want %v", tc.input, err, tc.wantErr) + assert.Equal(t, err, tc.wantErr, "ParseProxyStrategy(%s) = error %q, want %v", tc.input, err, tc.wantErr) if got != tc.want { - t.Errorf("ParseProxyStrategy(%s): got %v, want %v", tc.input, got, tc.want) + t.Errorf("ParseProxyStrategy(%s) = %v, want %v", tc.input, got, tc.want) } }) } @@ -498,9 +425,9 @@ func TestParseProxyStrategies(t *testing.T) { } { t.Run(desc, func(t *testing.T) { got, err := ParseProxyStrategies(tc.input) - assert.Equal(t, tc.wantErr, err, "ParseProxyStrategies(%s): got error %q, want %v", tc.input, err, tc.wantErr) + assert.Equal(t, err, tc.wantErr, "ParseProxyStrategies(%s) = error %q, want %v", tc.input, err, tc.wantErr) if !reflect.DeepEqual(got, tc.want) { - t.Errorf("ParseProxyStrategies(%s): got %v, want %v", tc.input, got, tc.want) + t.Errorf("ParseProxyStrategies(%s) = %v, want %v", tc.input, got, tc.want) } }) } diff --git a/pkg/server/default_route_backend_manager.go b/pkg/server/default_route_backend_manager.go deleted file mode 100644 index 4bd18749d..000000000 --- a/pkg/server/default_route_backend_manager.go +++ /dev/null @@ -1,59 +0,0 @@ -/* -Copyright 2021 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package server - -import ( - "context" - - "k8s.io/klog/v2" - "sigs.k8s.io/apiserver-network-proxy/proto/header" -) - -type DefaultRouteBackendManager struct { - *DefaultBackendStorage -} - -var _ BackendManager = &DefaultRouteBackendManager{} - -func NewDefaultRouteBackendManager() *DefaultRouteBackendManager { - return &DefaultRouteBackendManager{ - DefaultBackendStorage: NewDefaultBackendStorage( - []header.IdentifierType{header.DefaultRoute})} -} - -// Backend tries to get a backend that advertises default route, with random selection. -func (dibm *DefaultRouteBackendManager) Backend(_ context.Context) (Backend, error) { - return dibm.GetRandomBackend() -} - -func (dibm *DefaultRouteBackendManager) AddBackend(backend Backend) { - agentID := backend.GetAgentID() - agentIdentifiers := backend.GetAgentIdentifiers() - if agentIdentifiers.DefaultRoute { - klog.V(5).InfoS("Add the agent to DefaultRouteBackendManager", "agentID", agentID) - dibm.addBackend(agentID, header.DefaultRoute, backend) - } -} - -func (dibm *DefaultRouteBackendManager) RemoveBackend(backend Backend) { - agentID := backend.GetAgentID() - agentIdentifiers := backend.GetAgentIdentifiers() - if agentIdentifiers.DefaultRoute { - klog.V(5).InfoS("Remove the agent from the DefaultRouteBackendManager", "agentID", agentID) - dibm.removeBackend(agentID, header.DefaultRoute, backend) - } -} diff --git a/pkg/server/desthost_backend_manager.go b/pkg/server/desthost_backend_manager.go deleted file mode 100644 index 2857659c5..000000000 --- a/pkg/server/desthost_backend_manager.go +++ /dev/null @@ -1,86 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package server - -import ( - "context" - - "k8s.io/klog/v2" - "sigs.k8s.io/apiserver-network-proxy/proto/header" -) - -type DestHostBackendManager struct { - *DefaultBackendStorage -} - -var _ BackendManager = &DestHostBackendManager{} - -func NewDestHostBackendManager() *DestHostBackendManager { - return &DestHostBackendManager{ - DefaultBackendStorage: NewDefaultBackendStorage( - []header.IdentifierType{header.IPv4, header.IPv6, header.Host})} -} - -func (dibm *DestHostBackendManager) AddBackend(backend Backend) { - agentIdentifiers := backend.GetAgentIdentifiers() - for _, ipv4 := range agentIdentifiers.IPv4 { - klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", ipv4) - dibm.addBackend(ipv4, header.IPv4, backend) - } - for _, ipv6 := range agentIdentifiers.IPv6 { - klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", ipv6) - dibm.addBackend(ipv6, header.IPv6, backend) - } - for _, host := range agentIdentifiers.Host { - klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", host) - dibm.addBackend(host, header.Host, backend) - } -} - -func (dibm *DestHostBackendManager) RemoveBackend(backend Backend) { - agentIdentifiers := backend.GetAgentIdentifiers() - for _, ipv4 := range agentIdentifiers.IPv4 { - klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", ipv4) - dibm.removeBackend(ipv4, header.IPv4, backend) - } - for _, ipv6 := range agentIdentifiers.IPv6 { - klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", ipv6) - dibm.removeBackend(ipv6, header.IPv6, backend) - } - for _, host := range agentIdentifiers.Host { - klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", host) - dibm.removeBackend(host, header.Host, backend) - } -} - -// Backend tries to get a backend associating to the request destination host. -func (dibm *DestHostBackendManager) Backend(ctx context.Context) (Backend, error) { - dibm.mu.RLock() - defer dibm.mu.RUnlock() - if len(dibm.backends) == 0 { - return nil, &ErrNotFound{} - } - destHost := ctx.Value(destHostKey).(string) - if destHost != "" { - bes, exist := dibm.backends[destHost] - if exist && len(bes) > 0 { - klog.V(5).InfoS("Get the backend through the DestHostBackendManager", "destHost", destHost) - return dibm.backends[destHost][0], nil - } - } - return nil, &ErrNotFound{} -} diff --git a/pkg/server/readiness_manager.go b/pkg/server/readiness_manager.go index 6cc5c58cf..cc1560ee6 100644 --- a/pkg/server/readiness_manager.go +++ b/pkg/server/readiness_manager.go @@ -23,11 +23,4 @@ type ReadinessManager interface { Ready() (bool, string) } -var _ ReadinessManager = &DefaultBackendStorage{} - -func (s *DefaultBackendStorage) Ready() (bool, string) { - if s.NumBackends() == 0 { - return false, "no connection to any proxy agent" - } - return true, "" -} +var _ ReadinessManager = &DefaultBackendManager{} diff --git a/pkg/server/server.go b/pkg/server/server.go index 94a76feec..2f8a4ef46 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -106,10 +106,6 @@ type ProxyClientConnection struct { dialAddress string // cached for logging } -const ( - destHostKey key = iota -) - func (c *ProxyClientConnection) send(pkt *client.Packet) error { defer func(start time.Time) { metrics.Metrics.ObserveFrontendWriteLatency(time.Since(start)) }(time.Now()) if c.Mode == ModeGRPC { @@ -193,8 +189,7 @@ func (pm *PendingDialManager) removeForStream(streamUID string) []*ProxyClientCo // ProxyServer type ProxyServer struct { - // BackendManagers contains a list of BackendManagers - BackendManagers []BackendManager + BackendManager BackendManager // Readiness reports if the proxy server is ready, i.e., if the proxy // server has connections to proxy agents (backends). Note that the @@ -215,9 +210,6 @@ type ProxyServer struct { // agent authentication AgentAuthenticationOptions *AgentTokenAuthenticationOptions - - // TODO: move strategies into BackendStorage - proxyStrategies []ProxyStrategy } // AgentTokenAuthenticationOptions contains list of parameters required for agent token based authentication @@ -233,45 +225,17 @@ var _ agent.AgentServiceServer = &ProxyServer{} var _ client.ProxyServiceServer = &ProxyServer{} -func genContext(proxyStrategies []ProxyStrategy, reqHost string) context.Context { - ctx := context.Background() - for _, ps := range proxyStrategies { - switch ps { - case ProxyStrategyDestHost: - addr := util.RemovePortFromHost(reqHost) - ctx = context.WithValue(ctx, destHostKey, addr) - } - } - return ctx -} - func (s *ProxyServer) getBackend(reqHost string) (Backend, error) { - ctx := genContext(s.proxyStrategies, reqHost) - for _, bm := range s.BackendManagers { - be, err := bm.Backend(ctx) - if err == nil { - return be, nil - } - if ignoreNotFound(err) != nil { - // if can't find a backend through current BackendManager, move on - // to the next one - return nil, err - } - } - return nil, &ErrNotFound{} + addr := util.RemovePortFromHost(reqHost) + return s.BackendManager.Backend(addr) } func (s *ProxyServer) addBackend(backend Backend) { - // TODO: refactor BackendStorage to acquire lock once, not up to 3 times. - for _, bm := range s.BackendManagers { - bm.AddBackend(backend) - } + s.BackendManager.AddBackend(backend) } func (s *ProxyServer) removeBackend(backend Backend) { - for _, bm := range s.BackendManagers { - bm.RemoveBackend(backend) - } + s.BackendManager.RemoveBackend(backend) } func (s *ProxyServer) addEstablished(agentID string, connID int64, p *ProxyClientConnection) { @@ -377,30 +341,16 @@ func (s *ProxyServer) removeEstablishedForStream(streamUID string) []*ProxyClien // NewProxyServer creates a new ProxyServer instance func NewProxyServer(serverID string, proxyStrategies []ProxyStrategy, serverCount int, agentAuthenticationOptions *AgentTokenAuthenticationOptions) *ProxyServer { - var bms []BackendManager - for _, ps := range proxyStrategies { - switch ps { - case ProxyStrategyDestHost: - bms = append(bms, NewDestHostBackendManager()) - case ProxyStrategyDefault: - bms = append(bms, NewDefaultBackendManager()) - case ProxyStrategyDefaultRoute: - bms = append(bms, NewDefaultRouteBackendManager()) - default: - klog.ErrorS(nil, "Unknown proxy strategy", "strategy", ps) - } - } + bm := NewDefaultBackendManager(proxyStrategies) return &ProxyServer{ established: make(map[string](map[int64]*ProxyClientConnection)), PendingDial: NewPendingDialManager(), serverID: serverID, serverCount: serverCount, - BackendManagers: bms, + BackendManager: bm, AgentAuthenticationOptions: agentAuthenticationOptions, - // use the first backend-manager as the Readiness Manager - Readiness: bms[0], - proxyStrategies: proxyStrategies, + Readiness: bm, } } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index b96fdfe13..2c4f2cdb2 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -253,129 +253,7 @@ func TestAddRemoveFrontends(t *testing.T) { } } -func TestAddRemoveBackends_DefaultStrategy(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) - backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{})) - backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{})) - - p := NewProxyServer("", []ProxyStrategy{ProxyStrategyDefault}, 1, nil) - - p.addBackend(backend1) - - if got, _ := p.getBackend("127.0.0.1"); got != backend1 { - t.Errorf("expected %v, got %v", backend1, got) - } - - p.addBackend(backend2) - p.addBackend(backend3) - p.removeBackend(backend1) - p.removeBackend(backend2) - - if got, _ := p.getBackend("127.0.0.1"); got != backend3 { - t.Errorf("expected %v, got %v", backend3, got) - } - - p.removeBackend(backend3) - - if got, _ := p.getBackend("127.0.0.1"); got != nil { - t.Errorf("expected nil, got %v", got) - } -} - -func TestAddRemoveBackends_DefaultRouteStrategy(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) - backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"default-route=false"})) - backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"default-route=true"})) - - p := NewProxyServer("", []ProxyStrategy{ProxyStrategyDefaultRoute}, 1, nil) - - p.addBackend(backend1) - - if got, _ := p.getBackend("127.0.0.1"); got != nil { - t.Errorf("expected nil, got %v", got) - } - - p.addBackend(backend2) - - if got, _ := p.getBackend("127.0.0.1"); got != nil { - t.Errorf("expected nil, got %v", got) - } - - p.addBackend(backend3) - - if got, _ := p.getBackend("127.0.0.1"); got != backend3 { - t.Errorf("expected %v, got %v", backend3, got) - } - - p.removeBackend(backend1) - p.removeBackend(backend2) - - if got, _ := p.getBackend("127.0.0.1"); got != backend3 { - t.Errorf("expected %v, got %v", backend3, got) - } - - p.removeBackend(backend3) - - if got, _ := p.getBackend("127.0.0.1"); got != nil { - t.Errorf("expected nil, got %v", got) - } -} - -func TestAddRemoveBackends_DestHostStrategy(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"host=localhost&host=node1.mydomain.com&ipv4=1.2.3.4&ipv6=9878::7675:1292:9183:7562"})) - backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"default-route=true"})) - backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"host=node2.mydomain.com&ipv4=5.6.7.8&ipv6=::"})) - - p := NewProxyServer("", []ProxyStrategy{ProxyStrategyDestHost}, 1, nil) - - p.addBackend(backend1) - p.addBackend(backend2) - p.addBackend(backend3) - - if got, _ := p.getBackend("127.0.0.1"); got != nil { - t.Errorf("expected nil, got %v", got) - } - if got, _ := p.getBackend("localhost"); got != backend1 { - t.Errorf("expected %v, got %v", backend1, got) - } - if got, _ := p.getBackend("node1.mydomain.com"); got != backend1 { - t.Errorf("expected %v, got %v", backend1, got) - } - if got, _ := p.getBackend("1.2.3.4"); got != backend1 { - t.Errorf("expected %v, got %v", backend1, got) - } - if got, _ := p.getBackend("9878::7675:1292:9183:7562"); got != backend1 { - t.Errorf("expected %v, got %v", backend1, got) - } - if got, _ := p.getBackend("node2.mydomain.com"); got != backend3 { - t.Errorf("expected %v, got %v", backend3, got) - } - if got, _ := p.getBackend("5.6.7.8"); got != backend3 { - t.Errorf("expected %v, got %v", backend3, got) - } - if got, _ := p.getBackend("::"); got != backend3 { - t.Errorf("expected %v, got %v", backend3, got) - } - - p.removeBackend(backend1) - p.removeBackend(backend2) - p.removeBackend(backend3) - - if got, _ := p.getBackend("127.0.0.1"); got != nil { - t.Errorf("expected nil, got %v", got) - } -} - -func TestAddRemoveBackends_DestHostSanitizeRequest(t *testing.T) { +func TestAddRemoveBackends_SanitizeRequest(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -390,120 +268,45 @@ func TestAddRemoveBackends_DestHostSanitizeRequest(t *testing.T) { if got, _ := p.getBackend("127.0.0.1:443"); got != nil { t.Errorf("expected nil, got %v", got) } + if got, _ := p.getBackend("5.6.7.8:443"); got != backend2 { + t.Errorf("expected %v, got %v", backend2, got) + } if got, _ := p.getBackend("node1.mydomain.com:443"); got != backend1 { t.Errorf("expected %v, got %v", backend1, got) } if got, _ := p.getBackend("node2.mydomain.com:443"); got != backend2 { t.Errorf("expected %v, got %v", backend2, got) } -} - -func TestAddRemoveBackends_DestHostWithDefault(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"host=localhost&host=node1.mydomain.com&ipv4=1.2.3.4&ipv6=9878::7675:1292:9183:7562"})) - backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"default-route=false"})) - backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"host=node2.mydomain.com&ipv4=5.6.7.8&ipv6=::"})) - - p := NewProxyServer("", []ProxyStrategy{ProxyStrategyDestHost, ProxyStrategyDefault}, 1, nil) - - p.addBackend(backend1) - p.addBackend(backend2) - p.addBackend(backend3) - - if got, _ := p.getBackend("127.0.0.1"); got == nil { - t.Errorf("expected random fallback, got nil") - } - if got, _ := p.getBackend("localhost"); got != backend1 { + if got, _ := p.getBackend("[9878::7675:1292:9183:7562]:443"); got != backend1 { t.Errorf("expected %v, got %v", backend1, got) } - if got, _ := p.getBackend("node1.mydomain.com"); got != backend1 { - t.Errorf("expected %v, got %v", backend1, got) - } - if got, _ := p.getBackend("1.2.3.4"); got != backend1 { - t.Errorf("expected %v, got %v", backend1, got) - } - if got, _ := p.getBackend("9878::7675:1292:9183:7562"); got != backend1 { - t.Errorf("expected %v, got %v", backend1, got) - } - if got, _ := p.getBackend("node2.mydomain.com"); got != backend3 { - t.Errorf("expected %v, got %v", backend3, got) - } - if got, _ := p.getBackend("5.6.7.8"); got != backend3 { - t.Errorf("expected %v, got %v", backend3, got) - } - if got, _ := p.getBackend("::"); got != backend3 { - t.Errorf("expected %v, got %v", backend3, got) - } - - p.removeBackend(backend1) - p.removeBackend(backend2) - - if got, _ := p.getBackend("127.0.0.1"); got != backend3 { - t.Errorf("expected %v, got %v", backend3, got) - } - - p.removeBackend(backend3) - - if got, _ := p.getBackend("127.0.0.1"); got != nil { - t.Errorf("expected nil, got %v", got) + if got, _ := p.getBackend("[::]:443"); got != backend2 { + t.Errorf("expected %v, got %v", backend2, got) } } -func TestAddRemoveBackends_DestHostWithDuplicateIdents(t *testing.T) { +func TestAddRemoveBackends_Readiness(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"host=localhost&host=node1.mydomain.com&ipv4=1.2.3.4&ipv6=9878::7675:1292:9183:7562"})) - backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"host=localhost&host=node1.mydomain.com&ipv4=1.2.3.4&ipv6=9878::7675:1292:9183:7562"})) - backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"host=localhost&host=node2.mydomain.com&ipv4=5.6.7.8&ipv6=::"})) - p := NewProxyServer("", []ProxyStrategy{ProxyStrategyDestHost, ProxyStrategyDefault}, 1, nil) - - p.addBackend(backend1) - p.addBackend(backend2) - p.addBackend(backend3) + p := NewProxyServer("", []ProxyStrategy{ProxyStrategyDestHost}, 1, nil) - if got, _ := p.getBackend("127.0.0.1"); got == nil { - t.Errorf("expected random fallback, got nil") - } - if got, _ := p.getBackend("localhost"); got == nil { - t.Errorf("expected any backend, got nil") + if got, _ := p.Readiness.Ready(); got != false { + t.Errorf("Ready() = got %t, want false", got) } - p.removeBackend(backend1) - p.removeBackend(backend3) + p.addBackend(backend1) - if got, _ := p.getBackend("127.0.0.1"); got != backend2 { - t.Errorf("expected %v, got %v", backend2, got) - } - if got, _ := p.getBackend("localhost"); got != backend2 { - t.Errorf("expected %v, got %v", backend2, got) - } - if got, _ := p.getBackend("node1.mydomain.com"); got != backend2 { - t.Errorf("expected %v, got %v", backend2, got) - } - if got, _ := p.getBackend("1.2.3.4"); got != backend2 { - t.Errorf("expected %v, got %v", backend2, got) - } - if got, _ := p.getBackend("9878::7675:1292:9183:7562"); got != backend2 { - t.Errorf("expected %v, got %v", backend2, got) - } - if got, _ := p.getBackend("node2.mydomain.com"); got != backend2 { - t.Errorf("expected %v, got %v", backend2, got) - } - if got, _ := p.getBackend("5.6.7.8"); got != backend2 { - t.Errorf("expected %v, got %v", backend2, got) - } - if got, _ := p.getBackend("::"); got != backend2 { - t.Errorf("expected %v, got %v", backend2, got) + if got, _ := p.Readiness.Ready(); got != true { + t.Errorf("Ready() = got %t, want true", got) } - p.removeBackend(backend2) + p.removeBackend(backend1) - if got, _ := p.getBackend("127.0.0.1"); got != nil { - t.Errorf("expected nil, got %v", got) + if got, _ := p.Readiness.Ready(); got != false { + t.Errorf("Ready() = got %t, want false", got) } } diff --git a/pkg/server/storage.go b/pkg/server/storage.go new file mode 100644 index 000000000..36ab28466 --- /dev/null +++ b/pkg/server/storage.go @@ -0,0 +1,264 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package server + +import ( + "context" + "fmt" + "io" + "math/rand" + "slices" + "sync" + "time" + + "google.golang.org/grpc/metadata" + commonmetrics "sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/common/metrics" + "sigs.k8s.io/apiserver-network-proxy/konnectivity-client/proto/client" + "sigs.k8s.io/apiserver-network-proxy/pkg/server/metrics" + "sigs.k8s.io/apiserver-network-proxy/proto/agent" + "sigs.k8s.io/apiserver-network-proxy/proto/header" + // "k8s.io/klog/v2" +) + +// Backend abstracts a connected Konnectivity agent stream. +// +// In the only currently supported case (gRPC), it wraps an +// agent.AgentService_ConnectServer, provides synchronization and +// emits common stream metrics. +type Backend interface { + Send(p *client.Packet) error + Recv() (*client.Packet, error) + Context() context.Context + GetAgentID() string + GetAgentIdentifiers() header.Identifiers +} + +var _ Backend = &backend{} + +type backend struct { + sendLock sync.Mutex + recvLock sync.Mutex + conn agent.AgentService_ConnectServer + + // cached from conn.Context() + id string + idents header.Identifiers +} + +func (b *backend) Send(p *client.Packet) error { + b.sendLock.Lock() + defer b.sendLock.Unlock() + + const segment = commonmetrics.SegmentToAgent + metrics.Metrics.ObservePacket(segment, p.Type) + err := b.conn.Send(p) + if err != nil && err != io.EOF { + metrics.Metrics.ObserveStreamError(segment, err, p.Type) + } + return err +} + +func (b *backend) Recv() (*client.Packet, error) { + b.recvLock.Lock() + defer b.recvLock.Unlock() + + const segment = commonmetrics.SegmentFromAgent + pkt, err := b.conn.Recv() + if err != nil { + if err != io.EOF { + metrics.Metrics.ObserveStreamErrorNoPacket(segment, err) + } + return nil, err + } + metrics.Metrics.ObservePacket(segment, pkt.Type) + return pkt, nil +} + +func (b *backend) Context() context.Context { + return b.conn.Context() +} + +func (b *backend) GetAgentID() string { + return b.id +} + +func (b *backend) GetAgentIdentifiers() header.Identifiers { + return b.idents +} + +func getAgentID(stream agent.AgentService_ConnectServer) (string, error) { + md, ok := metadata.FromIncomingContext(stream.Context()) + if !ok { + return "", fmt.Errorf("failed to get context") + } + agentIDs := md.Get(header.AgentID) + if len(agentIDs) != 1 { + return "", fmt.Errorf("expected one agent ID in the context, got %v", agentIDs) + } + return agentIDs[0], nil +} + +func getAgentIdentifiers(conn agent.AgentService_ConnectServer) (header.Identifiers, error) { + var agentIdentifiers header.Identifiers + md, ok := metadata.FromIncomingContext(conn.Context()) + if !ok { + return agentIdentifiers, fmt.Errorf("failed to get metadata from context") + } + agentIdent := md.Get(header.AgentIdentifiers) + if len(agentIdent) > 1 { + return agentIdentifiers, fmt.Errorf("expected at most one set of agent identifiers in the context, got %v", agentIdent) + } + if len(agentIdent) == 0 { + return agentIdentifiers, nil + } + + return header.GenAgentIdentifiers(agentIdent[0]) +} + +func NewBackend(conn agent.AgentService_ConnectServer) (Backend, error) { + agentID, err := getAgentID(conn) + if err != nil { + return nil, err + } + agentIdentifiers, err := getAgentIdentifiers(conn) + if err != nil { + return nil, err + } + return &backend{conn: conn, id: agentID, idents: agentIdentifiers}, nil +} + +// BackendStorage is an interface for an in-memory storage of backend +// connections. +// +// A key may be associated with multiple Backend objects. For example, +// a given agent can have multiple re-connects in flight, and multiple +// agents could share a common host identifier. +type BackendStorage interface { + // AddBackend registers a backend, and returns the new number of backends in this storage. + AddBackend(keys []string, backend Backend) int + // RemoveBackend removes a backend, and returns the new number of backends in this storage. + RemoveBackend(keys []string, backend Backend) int + // Backend selects a backend by key. + Backend(key string) (Backend, error) + // RandomBackend selects a random backend. + RandomBackend() (Backend, error) + // NumKeys returns the distinct count of backend keys in this storage. + NumKeys() int +} + +// DefaultBackendStorage is the default BackendStorage +type DefaultBackendStorage struct { + mu sync.RWMutex //protects the following + // A map from key to grpc connections. + // For a given "backends []Backend", ProxyServer prefers backends[0] to send + // traffic, because backends[1:] are more likely to be closed + // by the agent to deduplicate connections to the same server. + backends map[string][]Backend + // Cache of backends keys, to efficiently select random. + backendKeys []string + + random *rand.Rand +} + +var _ BackendStorage = &DefaultBackendStorage{} + +// NewDefaultBackendStorage returns a DefaultBackendStorage +func NewDefaultBackendStorage() *DefaultBackendStorage { + // Set an explicit value, so that the metric is emitted even when + // no agent ever successfully connects. + return &DefaultBackendStorage{ + backends: make(map[string][]Backend), + random: rand.New(rand.NewSource(time.Now().UnixNano())), + } /* #nosec G404 */ +} + +func (s *DefaultBackendStorage) AddBackend(keys []string, backend Backend) int { + s.mu.Lock() + defer s.mu.Unlock() + for _, key := range keys { + if key == "" { + continue + } + _, ok := s.backends[key] + if ok { + if !slices.Contains(s.backends[key], backend) { + s.backends[key] = append(s.backends[key], backend) + } + continue + } + s.backends[key] = []Backend{backend} + s.backendKeys = append(s.backendKeys, key) + } + return len(s.backends) +} + +func (s *DefaultBackendStorage) RemoveBackend(keys []string, backend Backend) int { + s.mu.Lock() + defer s.mu.Unlock() + for _, key := range keys { + if key == "" { + continue + } + backends, ok := s.backends[key] + if !ok { + continue + } + for i, b := range backends { + if b == backend { + s.backends[key] = slices.Delete(backends, i, i+1) + } + } + if len(s.backends[key]) == 0 { + delete(s.backends, key) + s.backendKeys = slices.DeleteFunc(s.backendKeys, func(k string) bool { + return k == key + }) + } + } + return len(s.backends) +} + +func (s *DefaultBackendStorage) Backend(key string) (Backend, error) { + s.mu.RLock() + defer s.mu.RUnlock() + bes, exist := s.backends[key] + if exist && len(bes) > 0 { + // always return the first connection to an agent, because the agent + // will close later connections if there are multiple. + return bes[0], nil + } + return nil, &ErrNotFound{} +} + +func (s *DefaultBackendStorage) RandomBackend() (Backend, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if len(s.backends) == 0 { + return nil, &ErrNotFound{} + } + key := s.backendKeys[s.random.Intn(len(s.backendKeys))] + // always return the first connection to an agent, because the agent + // will close later connections if there are multiple. + return s.backends[key][0], nil +} + +func (s *DefaultBackendStorage) NumKeys() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.backends) +} diff --git a/pkg/server/storage_test.go b/pkg/server/storage_test.go new file mode 100644 index 000000000..1b51076f0 --- /dev/null +++ b/pkg/server/storage_test.go @@ -0,0 +1,314 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package server + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "google.golang.org/grpc/metadata" + + agentmock "sigs.k8s.io/apiserver-network-proxy/proto/agent/mocks" +) + +func TestNewBackend(t *testing.T) { + testCases := []struct { + desc string + ids []string + idents []string + wantErr error + }{ + { + desc: "no agentID", + wantErr: fmt.Errorf("expected one agent ID in the context, got []"), + }, + { + desc: "multiple agentID", + ids: []string{"agent-id", "agent-id"}, + wantErr: fmt.Errorf("expected one agent ID in the context, got [agent-id agent-id]"), + }, + { + desc: "multiple identifiers", + ids: []string{"agent-id"}, + idents: []string{"host=localhost", "host=localhost"}, + wantErr: fmt.Errorf("expected at most one set of agent identifiers in the context, got [host=localhost host=localhost]"), + }, + { + desc: "invalid identifiers", + ids: []string{"agent-id"}, + idents: []string{";"}, + wantErr: fmt.Errorf("fail to parse url encoded string: invalid semicolon separator in query"), + }, + { + desc: "success", + ids: []string{"agent-id"}, + }, + { + desc: "success with identifiers", + ids: []string{"agent-id"}, + idents: []string{"host=localhost&host=node1.mydomain.com&cidr=127.0.0.1/16&ipv4=1.2.3.4&ipv4=5.6.7.8&ipv6=:::::&default-route=true"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + agentConn := agentmock.NewMockAgentService_ConnectServer(ctrl) + agentConnMD := metadata.MD{ + ":authority": []string{"127.0.0.1:8091"}, + "agentid": tc.ids, + "agentidentifiers": tc.idents, + "content-type": []string{"application/grpc"}, + "user-agent": []string{"grpc-go/1.42.0"}, + } + agentConnCtx := metadata.NewIncomingContext(context.Background(), agentConnMD) + agentConn.EXPECT().Context().Return(agentConnCtx).AnyTimes() + + _, got := NewBackend(agentConn) + assert.Equal(t, got, tc.wantErr, "NewBackend() error %q, want %v", got, tc.wantErr) + }) + } +} + +func TestBackendStorage_AddBackend_Empty(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + s := NewDefaultBackendStorage() + + gotCount := s.AddBackend([]string{""}, backend1) + if gotCount != 0 { + t.Errorf("AddBackend() = %d, want 0", gotCount) + } +} + +func TestBackendStorage_GetBackend_Empty(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + s := NewDefaultBackendStorage() + + s.AddBackend([]string{"ident1"}, backend1) + + got, err := s.Backend("") + if got != nil { + t.Errorf("Backend() = %v, want nil", got) + } + if err == nil { + t.Errorf("Backend() = nil, want error") + } +} + +func TestBackendStorage_RemoveBackend_Empty(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + s := NewDefaultBackendStorage() + + s.AddBackend([]string{"ident1"}, backend1) + + gotCount := s.RemoveBackend([]string{""}, backend1) + if gotCount != 1 { + t.Errorf("RemoveBackend() = %v, want 1", gotCount) + } +} + +func TestBackendStorage_RemoveBackend_Unrecognized(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + s := NewDefaultBackendStorage() + + s.AddBackend([]string{"ident1"}, backend1) + + gotCount := s.RemoveBackend([]string{"ident2"}, backend1) + if gotCount != 1 { + t.Errorf("RemoveBackend() = %v, want 1", gotCount) + } +} + +func TestBackendStorage_RemoveBackend(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{})) + backend22, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{})) + backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{})) + s := NewDefaultBackendStorage() + + s.AddBackend([]string{"ident1"}, backend1) + s.AddBackend([]string{"ident1", "ident2"}, backend2) + s.AddBackend([]string{"ident1", "ident2"}, backend22) + s.AddBackend([]string{"ident2"}, backend3) + + wantBackends := map[string][]Backend{ + "ident1": {backend1, backend2, backend22}, + "ident2": {backend2, backend22, backend3}, + } + wantBackendKeys := []string{"ident1", "ident2"} + if !reflect.DeepEqual(s.backends, wantBackends) { + t.Errorf("s.backends = %v, want %v", s.backends, wantBackends) + } + if !reflect.DeepEqual(s.backendKeys, wantBackendKeys) { + t.Errorf("s.backendKeys = %v, want %v", s.backendKeys, wantBackendKeys) + } + + gotCount := s.RemoveBackend([]string{"ident1", "ident2"}, backend22) + if gotCount != 2 { + t.Errorf("RemoveBackend() = %v, want 2", gotCount) + } + wantBackends = map[string][]Backend{ + "ident1": {backend1, backend2}, + "ident2": {backend2, backend3}, + } + if !reflect.DeepEqual(s.backends, wantBackends) { + t.Errorf("s.backends = %v, want %v", s.backends, wantBackends) + } + if !reflect.DeepEqual(s.backendKeys, wantBackendKeys) { + t.Errorf("s.backendKeys = %v, want %v", s.backendKeys, wantBackendKeys) + } +} + +func TestBackendStorage_AddBackend(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + s := NewDefaultBackendStorage() + + gotCount := s.AddBackend([]string{"ident1"}, backend1) + if gotCount != 1 { + t.Errorf("AddBackend() = %v, want 1", gotCount) + } +} + +func TestBackendStorage_AddBackend_DuplicateAgentID(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + backend12, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + s := NewDefaultBackendStorage() + + s.AddBackend([]string{"ident1"}, backend1) + gotCount := s.AddBackend([]string{"ident1"}, backend12) + if gotCount != 1 { + t.Errorf("AddBackend() = %v, want 1", gotCount) + } + + got, _ := s.Backend("ident1") + if got != backend1 { + t.Errorf("Backend() = %v, want %v", got, backend1) + } +} + +func TestBackendStorage_AddBackend_DuplicateKey(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{})) + s := NewDefaultBackendStorage() + + s.AddBackend([]string{"hostname1"}, backend1) + gotCount := s.AddBackend([]string{"hostname1"}, backend2) + if gotCount != 1 { + t.Errorf("AddBackend() = %v, want 1", gotCount) + } +} + +func TestBackendStorage_AddBackend_SameBackend(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + s := NewDefaultBackendStorage() + + s.AddBackend([]string{"ident1"}, backend1) + gotCount := s.AddBackend([]string{"ident1"}, backend1) + if gotCount != 1 { + t.Errorf("AddBackend() = %v, want 1", gotCount) + } + gotLen := len(s.backends["ident1"]) + if gotLen != 1 { + t.Errorf("backends list = %v, want length 1", s.backends["ident1"]) + } +} + +func TestBackendStorage_NumKeys(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + s := NewDefaultBackendStorage() + + gotCount := s.NumKeys() + if gotCount != 0 { + t.Errorf("NumKeys() = %d, want 0", gotCount) + } + s.AddBackend([]string{"ident1"}, backend1) + gotCount = s.NumKeys() + if gotCount != 1 { + t.Errorf("NumKeys() = %d, want 1", gotCount) + } +} + +func TestBackendStorage_RandomBackend(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{})) + s := NewDefaultBackendStorage() + + got, err := s.RandomBackend() + if got != nil { + t.Errorf("RandomBackend() = %v, want nil", got) + } + if err == nil { + t.Errorf("RandomBackend() = error nil, want error") + } + + s.AddBackend([]string{"ident1"}, backend1) + got, err = s.RandomBackend() + if got != backend1 { + t.Errorf("RandomBackend() = %v, want %v", got, backend1) + } + if err != nil { + t.Errorf("RandomBackend() = error %v, want nil", err) + } + + s.AddBackend([]string{"ident1"}, backend2) + got, err = s.RandomBackend() + if got != backend1 { + t.Errorf("RandomBackend() = %v, want %v", got, backend1) + } + if err != nil { + t.Errorf("RandomBackend() = error %v, want nil", err) + } +} diff --git a/tests/framework/proxy_server.go b/tests/framework/proxy_server.go index 4ad740b68..4f9491ad2 100644 --- a/tests/framework/proxy_server.go +++ b/tests/framework/proxy_server.go @@ -46,7 +46,7 @@ type ProxyServerRunner interface { } type ProxyServer interface { - ConnectedBackends() (int, error) + ConnectedBackends() int AgentAddr() string FrontAddr() string Ready() bool @@ -119,12 +119,8 @@ func (ps *inProcessProxyServer) FrontAddr() string { return ps.frontAddr } -func (ps *inProcessProxyServer) ConnectedBackends() (int, error) { - numBackends := 0 - for _, bm := range ps.proxyServer.BackendManagers { - numBackends += bm.NumBackends() - } - return numBackends, nil +func (ps *inProcessProxyServer) ConnectedBackends() int { + return ps.proxyServer.BackendManager.NumBackends() } func (ps *inProcessProxyServer) Ready() bool { diff --git a/tests/proxy_test.go b/tests/proxy_test.go index 39fd363a3..3d8229cc4 100644 --- a/tests/proxy_test.go +++ b/tests/proxy_test.go @@ -901,20 +901,15 @@ func waitForConnectedServerCount(t testing.TB, expectedServerCount int, a framew func waitForConnectedAgentCount(t testing.TB, expectedAgentCount int, ps framework.ProxyServer) { t.Helper() err := wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) { - count, err := ps.ConnectedBackends() - if err != nil { - return false, err - } + count := ps.ConnectedBackends() if count == expectedAgentCount { return true, nil } return false, nil }) if err != nil { - if count, err := ps.ConnectedBackends(); err == nil { - t.Logf("got %d backends; expected %d", count, expectedAgentCount) - } - t.Fatalf("Error waiting for backend count: %v", err) + count := ps.ConnectedBackends() + t.Logf("got %d backends; expected %d", count, expectedAgentCount) } }