diff --git a/src/main/java/org/springframework/data/solr/server/support/HttpSolrClientFactory.java b/src/main/java/org/springframework/data/solr/server/support/HttpSolrClientFactory.java index ca9cda29f..4dfe82ed6 100644 --- a/src/main/java/org/springframework/data/solr/server/support/HttpSolrClientFactory.java +++ b/src/main/java/org/springframework/data/solr/server/support/HttpSolrClientFactory.java @@ -18,14 +18,20 @@ import org.apache.commons.lang3.StringUtils; import org.apache.http.auth.AuthScope; import org.apache.http.auth.Credentials; +import org.apache.http.auth.params.AuthPNames; import org.apache.http.client.CredentialsProvider; import org.apache.http.client.HttpClient; import org.apache.solr.client.solrj.SolrClient; import org.apache.solr.client.solrj.impl.HttpSolrClient; +import org.apache.solr.client.solrj.impl.CloudSolrClient; +import org.apache.solr.client.solrj.impl.LBHttpSolrClient; import org.springframework.beans.DirectFieldAccessor; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import java.util.Arrays; +import java.util.Set; + /** * The {@link HttpSolrClientFactory} replaces HttpSolrServerFactory from version 1.x and configures an * {@link HttpSolrClient} to work with the provided core. If provided Credentials eg. (@link @@ -67,11 +73,22 @@ private void appendAuthentication(Credentials credentials, String authPolicy, So HttpSolrClient httpSolrClient = (HttpSolrClient) solrClient; - if (credentials != null && StringUtils.isNotBlank(authPolicy) - && assertHttpClientInstance(httpSolrClient.getHttpClient())) { + if (credentials != null && StringUtils.isNotBlank(authPolicy) && assertHttpClientInstance( + httpSolrClient.getHttpClient())) { HttpClient httpClient = httpSolrClient.getHttpClient(); + DirectFieldAccessor df = new DirectFieldAccessor(httpClient); + CredentialsProvider provider = (CredentialsProvider) df.getPropertyValue("credentialsProvider"); + + provider.setCredentials(new AuthScope(AuthScope.ANY), credentials); + } + } else if (isCloudSolrClient(solrClient)) { + + CloudSolrClient cloudSolrClient = (CloudSolrClient) solrClient; + if (credentials != null && StringUtils.isNotBlank(authPolicy) && assertHttpClientInstance( + cloudSolrClient.getHttpClient())) { + HttpClient httpClient = cloudSolrClient.getHttpClient(); DirectFieldAccessor df = new DirectFieldAccessor(httpClient); CredentialsProvider provider = (CredentialsProvider) df.getPropertyValue("credentialsProvider"); diff --git a/src/main/java/org/springframework/data/solr/server/support/SolrClientFactoryBase.java b/src/main/java/org/springframework/data/solr/server/support/SolrClientFactoryBase.java index 8ffe2a065..b02b1acc1 100644 --- a/src/main/java/org/springframework/data/solr/server/support/SolrClientFactoryBase.java +++ b/src/main/java/org/springframework/data/solr/server/support/SolrClientFactoryBase.java @@ -16,6 +16,7 @@ package org.springframework.data.solr.server.support; import org.apache.solr.client.solrj.SolrClient; +import org.apache.solr.client.solrj.impl.CloudSolrClient; import org.apache.solr.client.solrj.impl.HttpSolrClient; import org.springframework.beans.factory.DisposableBean; import org.springframework.data.solr.server.SolrClientFactory; @@ -44,6 +45,10 @@ protected final boolean isHttpSolrClient(SolrClient solrClient) { return (solrClient instanceof HttpSolrClient); } + protected final boolean isCloudSolrClient(SolrClient solrClient) { + return (solrClient instanceof CloudSolrClient); + } + @Override public SolrClient getSolrClient() { diff --git a/src/test/java/org/springframework/data/solr/HttpSolrClientFactoryTests.java b/src/test/java/org/springframework/data/solr/HttpSolrClientFactoryTests.java index 77a49a0c7..d7adc4653 100644 --- a/src/test/java/org/springframework/data/solr/HttpSolrClientFactoryTests.java +++ b/src/test/java/org/springframework/data/solr/HttpSolrClientFactoryTests.java @@ -22,6 +22,7 @@ import org.apache.http.client.CredentialsProvider; import org.apache.http.client.HttpClient; import org.apache.solr.client.solrj.SolrClient; +import org.apache.solr.client.solrj.impl.CloudSolrClient; import org.apache.solr.client.solrj.impl.HttpSolrClient; import org.junit.After; import org.junit.Before; @@ -35,39 +36,37 @@ public class HttpSolrClientFactoryTests { private static final String URL = "https://solr.server.url"; + private static final String ZK = "localhost:9983"; private SolrClient solrClient; + private CloudSolrClient cloudSolrClient; - @Before - public void setUp() { + @Before public void setUp() { solrClient = new HttpSolrClient.Builder().withBaseSolrUrl(URL).build(); + cloudSolrClient = new CloudSolrClient.Builder().withZkHost(ZK).build(); } - @After - public void tearDown() { + @After public void tearDown() { solrClient = null; } - @Test - public void testInitFactory() { + @Test public void testInitFactory() { HttpSolrClientFactory factory = new HttpSolrClientFactory(solrClient); assertThat(factory.getSolrClient()).isEqualTo(solrClient); assertThat(((HttpSolrClient) factory.getSolrClient()).getBaseURL()).isEqualTo(URL); } - @Test(expected = IllegalArgumentException.class) - public void testInitFactoryWithNullServer() { + @Test(expected = IllegalArgumentException.class) public void testInitFactoryWithNullServer() { new HttpSolrClientFactory(null); } - @Test - public void testInitFactoryWithAuthentication() { + @Test public void testInitFactoryWithAuthentication() { HttpSolrClientFactory factory = new HttpSolrClientFactory(solrClient, new UsernamePasswordCredentials("username", "password"), "BASIC"); HttpClient solrHttpClient = ((HttpSolrClient) factory.getSolrClient()).getHttpClient(); - CredentialsProvider provider = (CredentialsProvider) ReflectionTestUtils.getField(solrHttpClient, - "credentialsProvider"); + CredentialsProvider provider = (CredentialsProvider) ReflectionTestUtils + .getField(solrHttpClient, "credentialsProvider"); assertThat(provider.getCredentials(AuthScope.ANY)).isNotNull(); @@ -78,8 +77,25 @@ public void testInitFactoryWithAuthentication() { .isEqualTo("password"); } - @Test(expected = IllegalArgumentException.class) - public void testInitFactoryWithoutAuthenticationSchema() { + @Test public void testInitCloudFactoryWithAuthentication() { + HttpSolrClientFactory factory = new HttpSolrClientFactory(cloudSolrClient, + new UsernamePasswordCredentials("username", "password"), "BASIC"); + + HttpClient solrCloudClient = ((CloudSolrClient) factory.getSolrClient()).getHttpClient(); + + CredentialsProvider provider = (CredentialsProvider) ReflectionTestUtils + .getField(solrCloudClient, "credentialsProvider"); + + assertThat(provider.getCredentials(AuthScope.ANY)).isNotNull(); + + assertThat(((UsernamePasswordCredentials) provider.getCredentials(AuthScope.ANY)).getUserName()) + .isEqualTo("username"); + + assertThat(((UsernamePasswordCredentials) provider.getCredentials(AuthScope.ANY)).getPassword()) + .isEqualTo("password"); + } + + @Test(expected = IllegalArgumentException.class) public void testInitFactoryWithoutAuthenticationSchema() { new HttpSolrClientFactory(solrClient, new UsernamePasswordCredentials("username", "password"), ""); }