@@ -41,7 +41,9 @@ type Balancer struct {
4141 localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error)
4242
4343 mu xsync.RWMutex
44- connectionsState *connectionsState
44+ connectionsState *connectionsState[conn.Conn]
45+
46+ closed chan struct{}
4547
4648 onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info)
4749}
@@ -133,7 +135,7 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
133135 return nil
134136}
135137
136- func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Conn ) (
138+ func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Info ) (
137139 nodes []trace.EndpointInfo,
138140 added []trace.EndpointInfo,
139141 dropped []trace.EndpointInfo,
@@ -178,7 +180,7 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
178180 "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"),
179181 b.config.DetectLocalDC,
180182 )
181- previousConns []conn.Conn
183+ previousConns []conn.Info
182184 )
183185 defer func() {
184186 nodes, added, dropped := endpointsDiff(endpoints, previousConns)
@@ -187,7 +189,9 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
187189
188190 connections := endpointsToConnections(b.pool, endpoints)
189191 for _, c := range connections {
190- b.pool.Allow(ctx, c)
192+ if c.State() == conn.Banned {
193+ b.pool.Unban(ctx, c)
194+ }
191195 c.Endpoint().Touch()
192196 }
193197
@@ -201,7 +205,10 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
201205
202206 b.mu.WithLock(func() {
203207 if b.connectionsState != nil {
204- previousConns = b.connectionsState.all
208+ previousConns = make([]conn.Info, len(b.connectionsState.all))
209+ for i := range b.connectionsState.all {
210+ previousConns[i] = b.connectionsState.all[i]
211+ }
205212 }
206213 b.connectionsState = state
207214 for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints {
@@ -211,6 +218,8 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
211218}
212219
213220func (b *Balancer) Close(ctx context.Context) (err error) {
221+ close(b.closed)
222+
214223 onDone := trace.DriverOnBalancerClose(
215224 b.driverConfig.Trace(), &ctx,
216225 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).Close"),
@@ -223,6 +232,8 @@ func (b *Balancer) Close(ctx context.Context) (err error) {
223232 b.discoveryRepeater.Stop()
224233 }
225234
235+ b.applyDiscoveredEndpoints(ctx, nil, "")
236+
226237 if err = b.discoveryClient.Close(ctx); err != nil {
227238 return xerrors.WithStackTrace(err)
228239 }
@@ -258,6 +269,7 @@ func New(
258269 driverConfig: driverConfig,
259270 pool: pool,
260271 localDCDetector: detectLocalDC,
272+ closed: make(chan struct{}),
261273 }
262274 d := internalDiscovery.New(ctx, pool.Get(
263275 endpoint.New(driverConfig.Endpoint()),
@@ -300,9 +312,14 @@ func (b *Balancer) Invoke(
300312 reply interface{},
301313 opts ...grpc.CallOption,
302314) error {
303- return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
304- return cc.Invoke(ctx, method, args, reply, opts...)
305- })
315+ select {
316+ case <-b.closed:
317+ return xerrors.WithStackTrace(errBalancerClosed)
318+ default:
319+ return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
320+ return cc.Invoke(ctx, method, args, reply, opts...)
321+ })
322+ }
306323}
307324
308325func (b *Balancer) NewStream(
@@ -311,17 +328,22 @@ func (b *Balancer) NewStream(
311328 method string,
312329 opts ...grpc.CallOption,
313330) (_ grpc.ClientStream, err error) {
314- var client grpc.ClientStream
315- err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
316- client, err = cc.NewStream(ctx, desc, method, opts...)
331+ select {
332+ case <-b.closed:
333+ return nil, xerrors.WithStackTrace(errBalancerClosed)
334+ default:
335+ var client grpc.ClientStream
336+ err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
337+ client, err = cc.NewStream(ctx, desc, method, opts...)
338+
339+ return err
340+ })
341+ if err == nil {
342+ return client, nil
343+ }
317344
318- return err
319- })
320- if err == nil {
321- return client, nil
345+ return nil, err
322346 }
323-
324- return nil, err
325347}
326348
327349func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc conn.Conn) error) (err error) {
@@ -332,10 +354,8 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
332354
333355 defer func() {
334356 if err == nil {
335- if cc.GetState() == conn.Banned {
336- b.pool.Allow(ctx, cc)
337- }
338- } else if xerrors.MustPessimizeEndpoint(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
357+ b.pool.Unban(ctx, cc)
358+ } else if xerrors.MustBanConn(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
339359 b.pool.Ban(ctx, cc, err)
340360 }
341361 }()
@@ -363,7 +383,7 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
363383 return nil
364384}
365385
366- func (b *Balancer) connections() *connectionsState {
386+ func (b *Balancer) connections() *connectionsState[conn.Conn] {
367387 b.mu.RLock()
368388 defer b.mu.RUnlock()
369389
@@ -401,7 +421,7 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
401421 c, failedCount = state.GetConnection(ctx)
402422 if c == nil {
403423 return nil, xerrors.WithStackTrace(
404- fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount ),
424+ fmt.Errorf("cannot get connection from Balancer after %d attempts: %w ", failedCount, ErrNoEndpoints ),
405425 )
406426 }
407427
0 commit comments