diff --git a/dotnet/src/Plugins/Plugins.Core/HttpPlugin.cs b/dotnet/src/Plugins/Plugins.Core/HttpPlugin.cs
index 5bd6ce3c21d8..ed3d9635f8a1 100644
--- a/dotnet/src/Plugins/Plugins.Core/HttpPlugin.cs
+++ b/dotnet/src/Plugins/Plugins.Core/HttpPlugin.cs
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
+using System;
+using System.Collections.Generic;
using System.ComponentModel;
using System.Net.Http;
using System.Threading;
@@ -36,6 +38,15 @@ public HttpPlugin() : this(null)
public HttpPlugin(HttpClient? client = null) =>
this._client = client ?? HttpClientProvider.GetHttpClient();
+ ///
+ /// List of allowed domains to download from.
+ ///
+ public IEnumerable? AllowedDomains
+ {
+ get => this._allowedDomains;
+ set => this._allowedDomains = value is null ? null : new HashSet(value, StringComparer.OrdinalIgnoreCase);
+ }
+
///
/// Sends an HTTP GET request to the specified URI and returns the response body as a string.
///
@@ -88,17 +99,38 @@ public Task DeleteAsync(
CancellationToken cancellationToken = default) =>
this.SendRequestAsync(uri, HttpMethod.Delete, requestContent: null, cancellationToken);
+ #region private
+ private HashSet? _allowedDomains;
+
+ ///
+ /// If a list of allowed domains has been provided, the host of the provided uri is checked
+ /// to verify it is in the allowed domain list.
+ ///
+ private bool IsUriAllowed(Uri uri)
+ {
+ Verify.NotNull(uri);
+
+ return this._allowedDomains is null || this._allowedDomains.Contains(uri.Host);
+ }
+
/// Sends an HTTP request and returns the response content as a string.
- /// The URI of the request.
+ /// The URI of the request.
/// The HTTP method for the request.
/// Optional request content.
/// The token to use to request cancellation.
- private async Task SendRequestAsync(string uri, HttpMethod method, HttpContent? requestContent, CancellationToken cancellationToken)
+ private async Task SendRequestAsync(string uriStr, HttpMethod method, HttpContent? requestContent, CancellationToken cancellationToken)
{
+ var uri = new Uri(uriStr);
+ if (!this.IsUriAllowed(uri))
+ {
+ throw new InvalidOperationException("Sending requests to the provided location is not allowed.");
+ }
+
using var request = new HttpRequestMessage(method, uri) { Content = requestContent };
request.Headers.Add("User-Agent", HttpHeaderConstant.Values.UserAgent);
request.Headers.Add(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(HttpPlugin)));
using var response = await this._client.SendWithSuccessCheckAsync(request, cancellationToken).ConfigureAwait(false);
return await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);
}
+ #endregion
}
diff --git a/dotnet/src/Plugins/Plugins.UnitTests/Core/HttpPluginTests.cs b/dotnet/src/Plugins/Plugins.UnitTests/Core/HttpPluginTests.cs
index 02e776761b43..b64d87671bb0 100644
--- a/dotnet/src/Plugins/Plugins.UnitTests/Core/HttpPluginTests.cs
+++ b/dotnet/src/Plugins/Plugins.UnitTests/Core/HttpPluginTests.cs
@@ -102,6 +102,25 @@ public async Task ItCanDeleteAsync()
this.VerifyMock(mockHandler, HttpMethod.Delete);
}
+ [Fact]
+ public async Task ItThrowsInvalidOperationExceptionForInvalidDomainAsync()
+ {
+ // Arrange
+ var mockHandler = this.CreateMock();
+ using var client = new HttpClient(mockHandler.Object);
+ var plugin = new HttpPlugin(client)
+ {
+ AllowedDomains = ["www.example.com"]
+ };
+ var invalidUri = "http://www.notexample.com";
+
+ // Act & Assert
+ await Assert.ThrowsAsync(async () => await plugin.GetAsync(invalidUri));
+ await Assert.ThrowsAsync(async () => await plugin.PostAsync(invalidUri, this._content));
+ await Assert.ThrowsAsync(async () => await plugin.PutAsync(invalidUri, this._content));
+ await Assert.ThrowsAsync(async () => await plugin.DeleteAsync(invalidUri));
+ }
+
private Mock CreateMock()
{
var mockHandler = new Mock();