Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Backend interface with a struct. #608

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 22 additions & 32 deletions pkg/server/backend_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,7 @@ func GenProxyStrategiesFromStr(proxyStrategies string) ([]ProxyStrategy, error)
// In the only currently supported case (gRPC), it wraps an
// agent.AgentService_ConnectServer, provides synchronization and
// emits common stream metrics.
type Backend interface {
Send(p *client.Packet) error
Recv() (*client.Packet, error)
Context() context.Context
GetAgentID() string
GetAgentIdentifiers() header.Identifiers
}

var _ Backend = &backend{}

type backend struct {
type Backend struct {
sendLock sync.Mutex
recvLock sync.Mutex
conn agent.AgentService_ConnectServer
Expand All @@ -97,7 +87,7 @@ type backend struct {
idents header.Identifiers
}

func (b *backend) Send(p *client.Packet) error {
func (b *Backend) Send(p *client.Packet) error {
b.sendLock.Lock()
defer b.sendLock.Unlock()

Expand All @@ -110,7 +100,7 @@ func (b *backend) Send(p *client.Packet) error {
return err
}

func (b *backend) Recv() (*client.Packet, error) {
func (b *Backend) Recv() (*client.Packet, error) {
b.recvLock.Lock()
defer b.recvLock.Unlock()

Expand All @@ -126,16 +116,16 @@ func (b *backend) Recv() (*client.Packet, error) {
return pkt, nil
}

func (b *backend) Context() context.Context {
func (b *Backend) Context() context.Context {
// TODO: does Context require lock protection?
return b.conn.Context()
}

func (b *backend) GetAgentID() string {
func (b *Backend) GetAgentID() string {
return b.id
}

func (b *backend) GetAgentIdentifiers() header.Identifiers {
func (b *Backend) GetAgentIdentifiers() header.Identifiers {
return b.idents
}

Expand Down Expand Up @@ -168,7 +158,7 @@ func getAgentIdentifiers(conn agent.AgentService_ConnectServer) (header.Identifi
return header.GenAgentIdentifiers(agentIdent[0])
}

func NewBackend(conn agent.AgentService_ConnectServer) (Backend, error) {
func NewBackend(conn agent.AgentService_ConnectServer) (*Backend, error) {
agentID, err := getAgentID(conn)
if err != nil {
return nil, err
Expand All @@ -177,16 +167,16 @@ func NewBackend(conn agent.AgentService_ConnectServer) (Backend, error) {
if err != nil {
return nil, err
}
return &backend{conn: conn, id: agentID, idents: agentIdentifiers}, nil
return &Backend{conn: conn, id: agentID, idents: agentIdentifiers}, nil
}

// BackendStorage is an interface to manage the storage of the backend
// connections, i.e., get, add and remove
type BackendStorage interface {
// addBackend adds a backend.
addBackend(identifier string, idType header.IdentifierType, backend Backend)
addBackend(identifier string, idType header.IdentifierType, backend *Backend)
// removeBackend removes a backend.
removeBackend(identifier string, idType header.IdentifierType, backend Backend)
removeBackend(identifier string, idType header.IdentifierType, backend *Backend)
// NumBackends returns the number of backends.
NumBackends() int
}
Expand All @@ -199,11 +189,11 @@ type BackendManager interface {
// context instead of a request-scoped context, as the backend manager will
// pick a backend for every tunnel session and each tunnel session may
// contains multiple requests.
Backend(ctx context.Context) (Backend, error)
Backend(ctx context.Context) (*Backend, error)
// AddBackend adds a backend.
AddBackend(backend Backend)
AddBackend(backend *Backend)
// RemoveBackend adds a backend.
RemoveBackend(backend Backend)
RemoveBackend(backend *Backend)
BackendStorage
ReadinessManager
}
Expand All @@ -215,18 +205,18 @@ type DefaultBackendManager struct {
*DefaultBackendStorage
}

func (dbm *DefaultBackendManager) Backend(_ context.Context) (Backend, error) {
func (dbm *DefaultBackendManager) Backend(_ context.Context) (*Backend, error) {
klog.V(5).InfoS("Get a random backend through the DefaultBackendManager")
return dbm.DefaultBackendStorage.GetRandomBackend()
}

func (dbm *DefaultBackendManager) AddBackend(backend Backend) {
func (dbm *DefaultBackendManager) AddBackend(backend *Backend) {
agentID := backend.GetAgentID()
klog.V(5).InfoS("Add the agent to DefaultBackendManager", "agentID", agentID)
dbm.addBackend(agentID, header.UID, backend)
}

func (dbm *DefaultBackendManager) RemoveBackend(backend Backend) {
func (dbm *DefaultBackendManager) RemoveBackend(backend *Backend) {
agentID := backend.GetAgentID()
klog.V(5).InfoS("Remove the agent from the DefaultBackendManager", "agentID", agentID)
dbm.removeBackend(agentID, header.UID, backend)
Expand All @@ -242,7 +232,7 @@ type DefaultBackendStorage struct {
//
// TODO: fix documentation. This is not always agentID, e.g. in
// the case of DestHostBackendManager.
backends map[string][]Backend
backends map[string][]*Backend
// agentID is tracked in this slice to enable randomly picking an
// agentID in the Backend() method. There is no reliable way to
// randomly pick a key from a map (in this case, the backends) in
Expand Down Expand Up @@ -272,7 +262,7 @@ func NewDefaultBackendStorage(idTypes []header.IdentifierType) *DefaultBackendSt
// no agent ever successfully connects.
metrics.Metrics.SetBackendCount(0)
return &DefaultBackendStorage{
backends: make(map[string][]Backend),
backends: make(map[string][]*Backend),
random: rand.New(rand.NewSource(time.Now().UnixNano())),
idTypes: idTypes,
} /* #nosec G404 */
Expand All @@ -283,7 +273,7 @@ func containIDType(idTypes []header.IdentifierType, idType header.IdentifierType
}

// addBackend adds a backend.
func (s *DefaultBackendStorage) addBackend(identifier string, idType header.IdentifierType, backend Backend) {
func (s *DefaultBackendStorage) addBackend(identifier string, idType header.IdentifierType, backend *Backend) {
if !containIDType(s.idTypes, idType) {
klog.V(4).InfoS("fail to add backend", "backend", identifier, "error", &ErrWrongIDType{idType, s.idTypes})
return
Expand All @@ -302,7 +292,7 @@ func (s *DefaultBackendStorage) addBackend(identifier string, idType header.Iden
s.backends[identifier] = append(s.backends[identifier], backend)
return
}
s.backends[identifier] = []Backend{backend}
s.backends[identifier] = []*Backend{backend}
metrics.Metrics.SetBackendCount(len(s.backends))
s.agentIDs = append(s.agentIDs, identifier)
if idType == header.DefaultRoute {
Expand All @@ -311,7 +301,7 @@ func (s *DefaultBackendStorage) addBackend(identifier string, idType header.Iden
}

// removeBackend removes a backend.
func (s *DefaultBackendStorage) removeBackend(identifier string, idType header.IdentifierType, backend Backend) {
func (s *DefaultBackendStorage) removeBackend(identifier string, idType header.IdentifierType, backend *Backend) {
if !containIDType(s.idTypes, idType) {
klog.ErrorS(&ErrWrongIDType{idType, s.idTypes}, "fail to remove backend")
return
Expand Down Expand Up @@ -390,7 +380,7 @@ func ignoreNotFound(err error) error {
}

// GetRandomBackend returns a random backend connection from all connected agents.
func (s *DefaultBackendStorage) GetRandomBackend() (Backend, error) {
func (s *DefaultBackendStorage) GetRandomBackend() (*Backend, error) {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.backends) == 0 {
Expand Down
24 changes: 12 additions & 12 deletions pkg/server/backend_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func TestDefaultBackendManager_AddRemoveBackends(t *testing.T) {

p.AddBackend(backend1)
p.RemoveBackend(backend1)
expectedBackends := make(map[string][]Backend)
expectedBackends := make(map[string][]*Backend)
expectedAgentIDs := []string{}
expectedDefaultRouteAgentIDs := []string(nil)
if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
Expand All @@ -143,7 +143,7 @@ func TestDefaultBackendManager_AddRemoveBackends(t *testing.T) {
p.RemoveBackend(backend22)
p.RemoveBackend(backend2)
p.RemoveBackend(backend1)
expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"agent1": {backend12},
"agent3": {backend3},
}
Expand Down Expand Up @@ -174,7 +174,7 @@ func TestDefaultRouteBackendManager_AddRemoveBackends(t *testing.T) {

p.AddBackend(backend1)
p.RemoveBackend(backend1)
expectedBackends := make(map[string][]Backend)
expectedBackends := make(map[string][]*Backend)
expectedAgentIDs := []string{}
expectedDefaultRouteAgentIDs := []string{}
if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
Expand All @@ -199,7 +199,7 @@ func TestDefaultRouteBackendManager_AddRemoveBackends(t *testing.T) {
p.RemoveBackend(backend2)
p.RemoveBackend(backend1)

expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"agent1": {backend12},
"agent3": {backend3},
}
Expand Down Expand Up @@ -231,7 +231,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) {

p.AddBackend(backend1)
p.RemoveBackend(backend1)
expectedBackends := make(map[string][]Backend)
expectedBackends := make(map[string][]*Backend)
expectedAgentIDs := []string{}
expectedDefaultRouteAgentIDs := []string(nil)
if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
Expand All @@ -247,7 +247,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) {
p = NewDestHostBackendManager()
p.AddBackend(backend1)

expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"localhost": {backend1},
"1.2.3.4": {backend1},
"9878::7675:1292:9183:7562": {backend1},
Expand All @@ -273,7 +273,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) {
p.AddBackend(backend2)
p.AddBackend(backend3)

expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"localhost": {backend1},
"node1.mydomain.com": {backend1},
"node2.mydomain.com": {backend3},
Expand Down Expand Up @@ -306,7 +306,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) {
p.RemoveBackend(backend2)
p.RemoveBackend(backend1)

expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"node2.mydomain.com": {backend3},
"5.6.7.8": {backend3},
"::": {backend3},
Expand All @@ -328,7 +328,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) {
}

p.RemoveBackend(backend3)
expectedBackends = map[string][]Backend{}
expectedBackends = map[string][]*Backend{}
expectedAgentIDs = []string{}

if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
Expand Down Expand Up @@ -356,7 +356,7 @@ func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) {
p.AddBackend(backend2)
p.AddBackend(backend3)

expectedBackends := map[string][]Backend{
expectedBackends := map[string][]*Backend{
"localhost": {backend1, backend2, backend3},
"1.2.3.4": {backend1, backend2},
"5.6.7.8": {backend3},
Expand Down Expand Up @@ -389,7 +389,7 @@ func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) {
p.RemoveBackend(backend1)
p.RemoveBackend(backend3)

expectedBackends = map[string][]Backend{
expectedBackends = map[string][]*Backend{
"localhost": {backend2},
"1.2.3.4": {backend2},
"9878::7675:1292:9183:7562": {backend2},
Expand All @@ -413,7 +413,7 @@ func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) {
}

p.RemoveBackend(backend2)
expectedBackends = map[string][]Backend{}
expectedBackends = map[string][]*Backend{}
expectedAgentIDs = []string{}

if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/server/default_route_backend_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ func NewDefaultRouteBackendManager() *DefaultRouteBackendManager {
}

// Backend tries to get a backend that advertises default route, with random selection.
func (dibm *DefaultRouteBackendManager) Backend(_ context.Context) (Backend, error) {
func (dibm *DefaultRouteBackendManager) Backend(_ context.Context) (*Backend, error) {
return dibm.GetRandomBackend()
}

func (dibm *DefaultRouteBackendManager) AddBackend(backend Backend) {
func (dibm *DefaultRouteBackendManager) AddBackend(backend *Backend) {
agentID := backend.GetAgentID()
agentIdentifiers := backend.GetAgentIdentifiers()
if agentIdentifiers.DefaultRoute {
Expand All @@ -49,7 +49,7 @@ func (dibm *DefaultRouteBackendManager) AddBackend(backend Backend) {
}
}

func (dibm *DefaultRouteBackendManager) RemoveBackend(backend Backend) {
func (dibm *DefaultRouteBackendManager) RemoveBackend(backend *Backend) {
agentID := backend.GetAgentID()
agentIdentifiers := backend.GetAgentIdentifiers()
if agentIdentifiers.DefaultRoute {
Expand Down
6 changes: 3 additions & 3 deletions pkg/server/desthost_backend_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func NewDestHostBackendManager() *DestHostBackendManager {
[]header.IdentifierType{header.IPv4, header.IPv6, header.Host})}
}

func (dibm *DestHostBackendManager) AddBackend(backend Backend) {
func (dibm *DestHostBackendManager) AddBackend(backend *Backend) {
agentIdentifiers := backend.GetAgentIdentifiers()
for _, ipv4 := range agentIdentifiers.IPv4 {
klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", ipv4)
Expand All @@ -51,7 +51,7 @@ func (dibm *DestHostBackendManager) AddBackend(backend Backend) {
}
}

func (dibm *DestHostBackendManager) RemoveBackend(backend Backend) {
func (dibm *DestHostBackendManager) RemoveBackend(backend *Backend) {
agentIdentifiers := backend.GetAgentIdentifiers()
for _, ipv4 := range agentIdentifiers.IPv4 {
klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", ipv4)
Expand All @@ -68,7 +68,7 @@ func (dibm *DestHostBackendManager) RemoveBackend(backend Backend) {
}

// Backend tries to get a backend associating to the request destination host.
func (dibm *DestHostBackendManager) Backend(ctx context.Context) (Backend, error) {
func (dibm *DestHostBackendManager) Backend(ctx context.Context) (*Backend, error) {
dibm.mu.RLock()
defer dibm.mu.RUnlock()
if len(dibm.backends) == 0 {
Expand Down
Loading
Loading