@@ -206,12 +206,20 @@ func (p *Pool) TxPipelined(ctx context.Context, fn func(redis.Pipeliner) error)
206
206
}
207
207
208
208
func (p * Pool ) Ping (ctx context.Context ) * redis.StatusCmd {
209
- // FIXME: use config to determine whether no key would access the master
210
- conn , err := p .connFactory .getMasterConn ()
211
- if err != nil {
212
- return newErrorStatusCmd (err )
209
+ if _ , ok := p .connFactory .(* HAConnFactory ); ok {
210
+ conn , _ := p .connFactory .getMasterConn ()
211
+ return conn .Ping (ctx )
212
+ }
213
+ var result * redis.StatusCmd
214
+ factory := p .connFactory .(* ShardConnFactory )
215
+ for _ , shard := range factory .shards {
216
+ conn , _ := shard .getMasterConn ()
217
+ result = conn .Ping (ctx )
218
+ if result .Err () != nil {
219
+ return result
220
+ }
213
221
}
214
- return conn . Ping ( ctx )
222
+ return result
215
223
}
216
224
217
225
func (p * Pool ) Get (ctx context.Context , key string ) * redis.StringCmd {
@@ -425,18 +433,17 @@ func (p *Pool) MSetWithGD(ctx context.Context, values ...interface{}) []*redis.S
425
433
}
426
434
427
435
var wg sync.WaitGroup
428
- var mu sync. Mutex
429
- var result [] * redis. StatusCmd
436
+ result := make ([] * redis. StatusCmd , len ( index2Values ))
437
+ var i int
430
438
for ind , vals := range index2Values {
431
439
wg .Add (1 )
432
440
conn , _ := factory .shards [ind ].getMasterConn ()
433
- go func (conn * redis.Client , vals ... interface {}) {
434
- defer wg .Done ()
441
+ go func (i int , conn * redis.Client , vals ... interface {}) {
435
442
status := conn .MSet (ctx , vals ... )
436
- mu . Lock ()
437
- result = append ( result , status )
438
- mu . Unlock ( )
439
- }( conn , vals ... )
443
+ result [ i ] = status
444
+ wg . Done ( )
445
+ }( i , conn , vals ... )
446
+ i ++
440
447
}
441
448
wg .Wait ()
442
449
return result
@@ -519,6 +526,79 @@ func (p *Pool) Expire(ctx context.Context, key string, expiration time.Duration)
519
526
return conn .Expire (ctx , key , expiration )
520
527
}
521
528
529
+ // MExpire gives the result for each group of keys
530
+ func (p * Pool ) MExpire (ctx context.Context , expiration time.Duration , keys ... string ) map [string ]error {
531
+ keyErrorsMap := func (results []redis.Cmder ) map [string ]error {
532
+ if len (results ) == 0 {
533
+ return nil
534
+ }
535
+ keyErrors := make (map [string ]error , 0 )
536
+ for _ , result := range results {
537
+ if result .Err () != nil {
538
+ args := result .Args ()
539
+ for i , arg := range args {
540
+ if i == 0 || i == 2 {
541
+ continue
542
+ }
543
+ keyErrors [arg .(string )] = result .Err ()
544
+ }
545
+ }
546
+ }
547
+ return keyErrors
548
+ }
549
+
550
+ if _ , ok := p .connFactory .(* HAConnFactory ); ok {
551
+ conn , _ := p .connFactory .getMasterConn ()
552
+ pipe := conn .Pipeline ()
553
+ for _ , key := range keys {
554
+ pipe .Expire (ctx , key , expiration )
555
+ }
556
+ results , err := pipe .Exec (ctx )
557
+ _ = pipe .Close ()
558
+ if err != nil {
559
+ return keyErrorsMap (results )
560
+ }
561
+ return nil
562
+ }
563
+
564
+ factory := p .connFactory .(* ShardConnFactory )
565
+ index2Keys := make (map [uint32 ][]string )
566
+ for _ , key := range keys {
567
+ ind := factory .getShardIndex (key )
568
+ if _ , ok := index2Keys [ind ]; ! ok {
569
+ index2Keys [ind ] = make ([]string , 0 )
570
+ }
571
+ index2Keys [ind ] = append (index2Keys [ind ], key )
572
+ }
573
+
574
+ var wg sync.WaitGroup
575
+ var mu sync.Mutex
576
+ var results []redis.Cmder
577
+ var i int
578
+ for ind , keys := range index2Keys {
579
+ wg .Add (1 )
580
+ conn , _ := factory .shards [ind ].getMasterConn ()
581
+ go func (i int , conn * redis.Client , keys ... string ) {
582
+ pipe := conn .Pipeline ()
583
+ for _ , key := range keys {
584
+ pipe .Expire (ctx , key , expiration )
585
+ }
586
+ result , err := pipe .Exec (ctx )
587
+ _ = pipe .Close ()
588
+ if err != nil {
589
+ mu .Lock ()
590
+ results = append (results , result ... )
591
+ mu .Unlock ()
592
+ }
593
+ wg .Done ()
594
+ }(i , conn , keys ... )
595
+ i ++
596
+ }
597
+ wg .Wait ()
598
+
599
+ return keyErrorsMap (results )
600
+ }
601
+
522
602
func (p * Pool ) ExpireAt (ctx context.Context , key string , tm time.Time ) * redis.BoolCmd {
523
603
conn , err := p .connFactory .getMasterConn (key )
524
604
if err != nil {
@@ -527,6 +607,79 @@ func (p *Pool) ExpireAt(ctx context.Context, key string, tm time.Time) *redis.Bo
527
607
return conn .ExpireAt (ctx , key , tm )
528
608
}
529
609
610
+ // MExpireAt gives the result for each group of keys
611
+ func (p * Pool ) MExpireAt (ctx context.Context , tm time.Time , keys ... string ) map [string ]error {
612
+ keyErrorsMap := func (results []redis.Cmder ) map [string ]error {
613
+ if len (results ) == 0 {
614
+ return nil
615
+ }
616
+ keyErrors := make (map [string ]error , 0 )
617
+ for _ , result := range results {
618
+ if result .Err () != nil {
619
+ args := result .Args ()
620
+ for i , arg := range args {
621
+ if i == 0 || i == 2 {
622
+ continue
623
+ }
624
+ keyErrors [arg .(string )] = result .Err ()
625
+ }
626
+ }
627
+ }
628
+ return keyErrors
629
+ }
630
+
631
+ if _ , ok := p .connFactory .(* HAConnFactory ); ok {
632
+ conn , _ := p .connFactory .getMasterConn ()
633
+ pipe := conn .Pipeline ()
634
+ for _ , key := range keys {
635
+ pipe .ExpireAt (ctx , key , tm )
636
+ }
637
+ results , err := pipe .Exec (ctx )
638
+ _ = pipe .Close ()
639
+ if err != nil {
640
+ return keyErrorsMap (results )
641
+ }
642
+ return nil
643
+ }
644
+
645
+ factory := p .connFactory .(* ShardConnFactory )
646
+ index2Keys := make (map [uint32 ][]string )
647
+ for _ , key := range keys {
648
+ ind := factory .getShardIndex (key )
649
+ if _ , ok := index2Keys [ind ]; ! ok {
650
+ index2Keys [ind ] = make ([]string , 0 )
651
+ }
652
+ index2Keys [ind ] = append (index2Keys [ind ], key )
653
+ }
654
+
655
+ var wg sync.WaitGroup
656
+ var mu sync.Mutex
657
+ var results []redis.Cmder
658
+ var i int
659
+ for ind , keys := range index2Keys {
660
+ wg .Add (1 )
661
+ conn , _ := factory .shards [ind ].getMasterConn ()
662
+ go func (i int , conn * redis.Client , keys ... string ) {
663
+ pipe := conn .Pipeline ()
664
+ for _ , key := range keys {
665
+ pipe .ExpireAt (ctx , key , tm )
666
+ }
667
+ result , err := pipe .Exec (ctx )
668
+ _ = pipe .Close ()
669
+ if err != nil {
670
+ mu .Lock ()
671
+ results = append (results , result ... )
672
+ mu .Unlock ()
673
+ }
674
+ wg .Done ()
675
+ }(i , conn , keys ... )
676
+ i ++
677
+ }
678
+ wg .Wait ()
679
+
680
+ return keyErrorsMap (results )
681
+ }
682
+
530
683
func (p * Pool ) TTL (ctx context.Context , key string ) * redis.DurationCmd {
531
684
conn , err := p .connFactory .getSlaveConn (key )
532
685
if err != nil {
0 commit comments