@@ -61,7 +61,7 @@ static DEFINE_MUTEX(tcpv6_prot_mutex);
61
61
static const struct proto * saved_tcpv4_prot ;
62
62
static DEFINE_MUTEX (tcpv4_prot_mutex );
63
63
static struct proto tls_prots [TLS_NUM_PROTS ][TLS_NUM_CONFIG ][TLS_NUM_CONFIG ];
64
- static struct proto_ops tls_sw_proto_ops ;
64
+ static struct proto_ops tls_proto_ops [ TLS_NUM_PROTS ][ TLS_NUM_CONFIG ][ TLS_NUM_CONFIG ] ;
65
65
static void build_protos (struct proto prot [TLS_NUM_CONFIG ][TLS_NUM_CONFIG ],
66
66
const struct proto * base );
67
67
@@ -71,6 +71,8 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx)
71
71
72
72
WRITE_ONCE (sk -> sk_prot ,
73
73
& tls_prots [ip_ver ][ctx -> tx_conf ][ctx -> rx_conf ]);
74
+ WRITE_ONCE (sk -> sk_socket -> ops ,
75
+ & tls_proto_ops [ip_ver ][ctx -> tx_conf ][ctx -> rx_conf ]);
74
76
}
75
77
76
78
int wait_on_pending_writer (struct sock * sk , long * timeo )
@@ -669,8 +671,6 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
669
671
if (tx ) {
670
672
ctx -> sk_write_space = sk -> sk_write_space ;
671
673
sk -> sk_write_space = tls_write_space ;
672
- } else {
673
- sk -> sk_socket -> ops = & tls_sw_proto_ops ;
674
674
}
675
675
goto out ;
676
676
@@ -728,6 +728,39 @@ struct tls_context *tls_ctx_create(struct sock *sk)
728
728
return ctx ;
729
729
}
730
730
731
+ static void build_proto_ops (struct proto_ops ops [TLS_NUM_CONFIG ][TLS_NUM_CONFIG ],
732
+ const struct proto_ops * base )
733
+ {
734
+ ops [TLS_BASE ][TLS_BASE ] = * base ;
735
+
736
+ ops [TLS_SW ][TLS_BASE ] = ops [TLS_BASE ][TLS_BASE ];
737
+ ops [TLS_SW ][TLS_BASE ].sendpage_locked = tls_sw_sendpage_locked ;
738
+
739
+ ops [TLS_BASE ][TLS_SW ] = ops [TLS_BASE ][TLS_BASE ];
740
+ ops [TLS_BASE ][TLS_SW ].splice_read = tls_sw_splice_read ;
741
+
742
+ ops [TLS_SW ][TLS_SW ] = ops [TLS_SW ][TLS_BASE ];
743
+ ops [TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read ;
744
+
745
+ #ifdef CONFIG_TLS_DEVICE
746
+ ops [TLS_HW ][TLS_BASE ] = ops [TLS_BASE ][TLS_BASE ];
747
+ ops [TLS_HW ][TLS_BASE ].sendpage_locked = NULL ;
748
+
749
+ ops [TLS_HW ][TLS_SW ] = ops [TLS_BASE ][TLS_SW ];
750
+ ops [TLS_HW ][TLS_SW ].sendpage_locked = NULL ;
751
+
752
+ ops [TLS_BASE ][TLS_HW ] = ops [TLS_BASE ][TLS_SW ];
753
+
754
+ ops [TLS_SW ][TLS_HW ] = ops [TLS_SW ][TLS_SW ];
755
+
756
+ ops [TLS_HW ][TLS_HW ] = ops [TLS_HW ][TLS_SW ];
757
+ ops [TLS_HW ][TLS_HW ].sendpage_locked = NULL ;
758
+ #endif
759
+ #ifdef CONFIG_TLS_TOE
760
+ ops [TLS_HW_RECORD ][TLS_HW_RECORD ] = * base ;
761
+ #endif
762
+ }
763
+
731
764
static void tls_build_proto (struct sock * sk )
732
765
{
733
766
int ip_ver = sk -> sk_family == AF_INET6 ? TLSV6 : TLSV4 ;
@@ -739,6 +772,8 @@ static void tls_build_proto(struct sock *sk)
739
772
mutex_lock (& tcpv6_prot_mutex );
740
773
if (likely (prot != saved_tcpv6_prot )) {
741
774
build_protos (tls_prots [TLSV6 ], prot );
775
+ build_proto_ops (tls_proto_ops [TLSV6 ],
776
+ sk -> sk_socket -> ops );
742
777
smp_store_release (& saved_tcpv6_prot , prot );
743
778
}
744
779
mutex_unlock (& tcpv6_prot_mutex );
@@ -749,6 +784,8 @@ static void tls_build_proto(struct sock *sk)
749
784
mutex_lock (& tcpv4_prot_mutex );
750
785
if (likely (prot != saved_tcpv4_prot )) {
751
786
build_protos (tls_prots [TLSV4 ], prot );
787
+ build_proto_ops (tls_proto_ops [TLSV4 ],
788
+ sk -> sk_socket -> ops );
752
789
smp_store_release (& saved_tcpv4_prot , prot );
753
790
}
754
791
mutex_unlock (& tcpv4_prot_mutex );
@@ -959,10 +996,6 @@ static int __init tls_register(void)
959
996
if (err )
960
997
return err ;
961
998
962
- tls_sw_proto_ops = inet_stream_ops ;
963
- tls_sw_proto_ops .splice_read = tls_sw_splice_read ;
964
- tls_sw_proto_ops .sendpage_locked = tls_sw_sendpage_locked ;
965
-
966
999
tls_device_init ();
967
1000
tcp_register_ulp (& tcp_tls_ulp_ops );
968
1001
0 commit comments