1
+ using System . Collections . Concurrent ;
2
+ using System . Diagnostics ;
3
+ using System . Net ;
4
+ using System . Net . Sockets ;
5
+
6
+ // <DnsRoundRobinConnector>
7
+ // This is available as NuGet Package: https://www.nuget.org/packages/DnsRoundRobin/
8
+ // The original source code can be found also here: https://github.com/MihaZupan/DnsRoundRobin
9
+ public sealed class DnsRoundRobinConnector : IDisposable
10
+ // </DnsRoundRobinConnector>
11
+ {
12
+ private const int DefaultDnsRefreshIntervalSeconds = 2 * 60 ;
13
+ private const int MaxCleanupIntervalSeconds = 60 ;
14
+
15
+ public static DnsRoundRobinConnector Shared { get ; } = new ( ) ;
16
+
17
+ private readonly ConcurrentDictionary < string , HostRoundRobinState > _states = new ( StringComparer . Ordinal ) ;
18
+ private readonly Timer _cleanupTimer ;
19
+ private readonly TimeSpan _cleanupInterval ;
20
+ private readonly long _cleanupIntervalTicks ;
21
+ private readonly long _dnsRefreshTimeoutTicks ;
22
+ private readonly TimeSpan _endpointConnectTimeout ;
23
+
24
+ /// <summary>
25
+ /// Creates a new <see cref="DnsRoundRobinConnector"/>.
26
+ /// </summary>
27
+ /// <param name="dnsRefreshInterval">Maximum amount of time a Dns resolution is cached for. Default to 2 minutes.</param>
28
+ /// <param name="endpointConnectTimeout">Maximum amount of time allowed for a connection attempt to any individual endpoint. Defaults to infinite.</param>
29
+ public DnsRoundRobinConnector ( TimeSpan ? dnsRefreshInterval = null , TimeSpan ? endpointConnectTimeout = null )
30
+ {
31
+ dnsRefreshInterval = TimeSpan . FromSeconds ( Math . Max ( 1 , dnsRefreshInterval ? . TotalSeconds ?? DefaultDnsRefreshIntervalSeconds ) ) ;
32
+ _cleanupInterval = TimeSpan . FromSeconds ( Math . Clamp ( dnsRefreshInterval . Value . TotalSeconds / 2 , 1 , MaxCleanupIntervalSeconds ) ) ;
33
+ _cleanupIntervalTicks = ( long ) ( _cleanupInterval . TotalSeconds * Stopwatch . Frequency ) ;
34
+ _dnsRefreshTimeoutTicks = ( long ) ( dnsRefreshInterval . Value . TotalSeconds * Stopwatch . Frequency ) ;
35
+ _endpointConnectTimeout = endpointConnectTimeout is null || endpointConnectTimeout . Value . Ticks < 1 ? Timeout . InfiniteTimeSpan : endpointConnectTimeout . Value ;
36
+
37
+ bool restoreFlow = false ;
38
+ try
39
+ {
40
+ // Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever
41
+ if ( ! ExecutionContext . IsFlowSuppressed ( ) )
42
+ {
43
+ ExecutionContext . SuppressFlow ( ) ;
44
+ restoreFlow = true ;
45
+ }
46
+
47
+ // Ensure the Timer has a weak reference to the connector; otherwise, it
48
+ // can introduce a cycle that keeps the connector rooted by the Timer
49
+ _cleanupTimer = new Timer ( static state =>
50
+ {
51
+ var thisWeakRef = ( WeakReference < DnsRoundRobinConnector > ) state ! ;
52
+ if ( thisWeakRef . TryGetTarget ( out DnsRoundRobinConnector ? thisRef ) )
53
+ {
54
+ thisRef . Cleanup ( ) ;
55
+ thisRef . _cleanupTimer . Change ( thisRef . _cleanupInterval , Timeout . InfiniteTimeSpan ) ;
56
+ }
57
+ } , new WeakReference < DnsRoundRobinConnector > ( this ) , Timeout . Infinite , Timeout . Infinite ) ;
58
+
59
+ _cleanupTimer . Change ( _cleanupInterval , Timeout . InfiniteTimeSpan ) ;
60
+ }
61
+ finally
62
+ {
63
+ if ( restoreFlow )
64
+ {
65
+ ExecutionContext . RestoreFlow ( ) ;
66
+ }
67
+ }
68
+ }
69
+
70
+ private void Cleanup ( )
71
+ {
72
+ long minTimestamp = Stopwatch . GetTimestamp ( ) - _cleanupIntervalTicks ;
73
+
74
+ foreach ( KeyValuePair < string , HostRoundRobinState > state in _states )
75
+ {
76
+ if ( state . Value . LastAccessTimestamp < minTimestamp )
77
+ {
78
+ _states . TryRemove ( state ) ;
79
+ }
80
+ }
81
+ }
82
+
83
+ public void Dispose ( )
84
+ {
85
+ _states . Clear ( ) ;
86
+ }
87
+
88
+ public Task < Socket > ConnectAsync ( DnsEndPoint endPoint , CancellationToken cancellationToken )
89
+ {
90
+ if ( cancellationToken . IsCancellationRequested )
91
+ {
92
+ return Task . FromCanceled < Socket > ( cancellationToken ) ;
93
+ }
94
+
95
+ if ( IPAddress . TryParse ( endPoint . Host , out IPAddress ? address ) )
96
+ {
97
+ // Avoid the overhead of HostRoundRobinState if we're dealing with a single endpoint
98
+ return ConnectToIPAddressAsync ( address , endPoint . Port , cancellationToken ) ;
99
+ }
100
+
101
+ HostRoundRobinState state = _states . GetOrAdd (
102
+ endPoint . Host ,
103
+ static ( _ , thisRef ) => new HostRoundRobinState ( thisRef . _dnsRefreshTimeoutTicks , thisRef . _endpointConnectTimeout ) ,
104
+ this ) ;
105
+
106
+ return state . ConnectAsync ( endPoint , cancellationToken ) ;
107
+ }
108
+
109
+ private static async Task < Socket > ConnectToIPAddressAsync ( IPAddress address , int port , CancellationToken cancellationToken )
110
+ {
111
+ var socket = new Socket ( SocketType . Stream , ProtocolType . Tcp ) { NoDelay = true } ;
112
+ try
113
+ {
114
+ await socket . ConnectAsync ( address , port , cancellationToken ) ;
115
+ return socket ;
116
+ }
117
+ catch
118
+ {
119
+ socket . Dispose ( ) ;
120
+ throw ;
121
+ }
122
+ }
123
+
124
+ private sealed class HostRoundRobinState
125
+ {
126
+ private readonly long _dnsRefreshTimeoutTicks ;
127
+ private readonly TimeSpan _endpointConnectTimeout ;
128
+ private long _lastAccessTimestamp ;
129
+ private long _lastDnsTimestamp ;
130
+ private IPAddress [ ] ? _addresses ;
131
+ private uint _roundRobinIndex ;
132
+
133
+ public long LastAccessTimestamp => Volatile . Read ( ref _lastAccessTimestamp ) ;
134
+
135
+ private bool AddressesAreStale => Stopwatch . GetTimestamp ( ) - Volatile . Read ( ref _lastDnsTimestamp ) > _dnsRefreshTimeoutTicks ;
136
+
137
+ public HostRoundRobinState ( long dnsRefreshTimeoutTicks , TimeSpan endpointConnectTimeout )
138
+ {
139
+ _dnsRefreshTimeoutTicks = dnsRefreshTimeoutTicks ;
140
+ _endpointConnectTimeout = endpointConnectTimeout ;
141
+
142
+ _roundRobinIndex -- ; // Offset the first Increment to ensure we start with the first address in the list
143
+
144
+ RefreshLastAccessTimestamp ( ) ;
145
+ }
146
+
147
+ private void RefreshLastAccessTimestamp ( ) => Volatile . Write ( ref _lastAccessTimestamp , Stopwatch . GetTimestamp ( ) ) ;
148
+
149
+ public async Task < Socket > ConnectAsync ( DnsEndPoint endPoint , CancellationToken cancellationToken )
150
+ {
151
+ RefreshLastAccessTimestamp ( ) ;
152
+
153
+ uint sharedIndex = Interlocked . Increment ( ref _roundRobinIndex ) ;
154
+ IPAddress [ ] ? attemptedAddresses = null ;
155
+ IPAddress [ ] ? addresses = null ;
156
+ Exception ? lastException = null ;
157
+
158
+ while ( attemptedAddresses is null )
159
+ {
160
+ if ( addresses is null )
161
+ {
162
+ addresses = _addresses ;
163
+ }
164
+ else
165
+ {
166
+ attemptedAddresses = addresses ;
167
+
168
+ // Give each connection attempt a chance to do its own Dns call.
169
+ addresses = null ;
170
+ }
171
+
172
+ if ( addresses is null || AddressesAreStale )
173
+ {
174
+ // It's possible that multiple connection attempts are resolving the same host concurrently - that's okay.
175
+ _addresses = addresses = await Dns . GetHostAddressesAsync ( endPoint . Host , cancellationToken ) ;
176
+ Volatile . Write ( ref _lastDnsTimestamp , Stopwatch . GetTimestamp ( ) ) ;
177
+
178
+ if ( attemptedAddresses is not null && AddressListsAreEquivalent ( attemptedAddresses , addresses ) )
179
+ {
180
+ // We've already tried to connect to every address in the list, and a new Dns resolution returned the same list.
181
+ // Instead of attempting every address again, give up early.
182
+ break ;
183
+ }
184
+ }
185
+
186
+ for ( int i = 0 ; i < addresses . Length ; i ++ )
187
+ {
188
+ Socket ? attemptSocket = null ;
189
+ CancellationTokenSource ? endpointConnectTimeoutCts = null ;
190
+ try
191
+ {
192
+ IPAddress address = addresses [ ( int ) ( ( sharedIndex + i ) % addresses . Length ) ] ;
193
+
194
+ if ( Socket . OSSupportsIPv6 && address . AddressFamily == AddressFamily . InterNetworkV6 )
195
+ {
196
+ attemptSocket = new Socket ( AddressFamily . InterNetworkV6 , SocketType . Stream , ProtocolType . Tcp ) ;
197
+ if ( address . IsIPv4MappedToIPv6 )
198
+ {
199
+ attemptSocket . DualMode = true ;
200
+ }
201
+ }
202
+ else if ( Socket . OSSupportsIPv4 && address . AddressFamily == AddressFamily . InterNetwork )
203
+ {
204
+ attemptSocket = new Socket ( AddressFamily . InterNetwork , SocketType . Stream , ProtocolType . Tcp ) ;
205
+ }
206
+
207
+ if ( attemptSocket is not null )
208
+ {
209
+ attemptSocket . NoDelay = true ;
210
+
211
+ if ( _endpointConnectTimeout != Timeout . InfiniteTimeSpan )
212
+ {
213
+ endpointConnectTimeoutCts = CancellationTokenSource . CreateLinkedTokenSource ( cancellationToken ) ;
214
+ endpointConnectTimeoutCts . CancelAfter ( _endpointConnectTimeout ) ;
215
+ }
216
+
217
+ await attemptSocket . ConnectAsync ( address , endPoint . Port , endpointConnectTimeoutCts ? . Token ?? cancellationToken ) ;
218
+
219
+ RefreshLastAccessTimestamp ( ) ;
220
+ return attemptSocket ;
221
+ }
222
+ }
223
+ catch ( Exception ex )
224
+ {
225
+ attemptSocket ? . Dispose ( ) ;
226
+
227
+ if ( cancellationToken . IsCancellationRequested )
228
+ {
229
+ throw ;
230
+ }
231
+
232
+ if ( endpointConnectTimeoutCts ? . IsCancellationRequested == true )
233
+ {
234
+ ex = new TimeoutException ( $ "Failed to connect to any endpoint within the specified endpoint connect timeout of { _endpointConnectTimeout . TotalSeconds : N2} seconds.", ex ) ;
235
+ }
236
+
237
+ lastException = ex ;
238
+ }
239
+ finally
240
+ {
241
+ endpointConnectTimeoutCts ? . Dispose ( ) ;
242
+ }
243
+ }
244
+ }
245
+
246
+ throw lastException ?? new SocketException ( ( int ) SocketError . NoData ) ;
247
+ }
248
+
249
+ private static bool AddressListsAreEquivalent ( IPAddress [ ] left , IPAddress [ ] right )
250
+ {
251
+ if ( ReferenceEquals ( left , right ) )
252
+ {
253
+ return true ;
254
+ }
255
+
256
+ if ( left . Length != right . Length )
257
+ {
258
+ return false ;
259
+ }
260
+
261
+ for ( int i = 0 ; i < left . Length ; i ++ )
262
+ {
263
+ if ( ! left [ i ] . Equals ( right [ i ] ) )
264
+ {
265
+ return false ;
266
+ }
267
+ }
268
+
269
+ return true ;
270
+ }
271
+ }
272
+ }
0 commit comments