Skip to content

Commit 4a49db9

Browse files
authored
CLOUDP-329787: Make Service Account transport available in httpClient() (#4090)
1 parent eb23c5c commit 4a49db9

File tree

7 files changed

+119
-63
lines changed

7 files changed

+119
-63
lines changed

docs/command/atlas-setup.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@ atlas setup
1414

1515
Login, authenticate, create, and access an Atlas cluster.
1616

17-
Public Preview: The atlas api sub-command, automatically generated from the MongoDB Atlas Admin API, offers full coverage of the Admin API and is currently in Public Preview (please provide feedback at https://feedback.mongodb.com/forums/930808-atlas-cli).
18-
Admin API capabilities have their own release lifecycle, which you can check via the provided API endpoint documentation link.
19-
20-
21-
2217
This command takes you through login, default profile creation, creating your first free tier cluster and connecting to it using MongoDB Shell.
2318

2419
Syntax

internal/cli/auth/whoami.go

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,16 @@ func (opts *whoOpts) Run() error {
3636
return nil
3737
}
3838

39-
var ErrUnauthenticated = errors.New("not logged in with an Atlas account or API key")
40-
41-
func AccountWithAccessToken() (string, error) {
42-
if config.AccessToken() == "" {
43-
return "", ErrUnauthenticated
44-
}
45-
46-
return config.AccessTokenSubject()
47-
}
39+
var ErrUnauthenticated = errors.New("not logged in with an Atlas account, Service Account or API key")
4840

4941
func authTypeAndSubject() (string, string, error) {
50-
if config.PublicAPIKey() != "" {
42+
switch config.AuthType() {
43+
case config.APIKeys:
5144
return "key", config.PublicAPIKey(), nil
52-
}
53-
54-
if subject, err := AccountWithAccessToken(); err == nil {
45+
case config.ServiceAccount:
46+
return "service account", config.ClientID(), nil
47+
case config.UserAccount:
48+
subject, _ := config.AccessTokenSubject()
5549
return "account", subject, nil
5650
}
5751

internal/cli/commonerrors/errors.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
atlasClustersPinned "go.mongodb.org/atlas-sdk/v20240530005/admin"
2222
atlasv2 "go.mongodb.org/atlas-sdk/v20250312005/admin"
2323
atlas "go.mongodb.org/atlas/mongodbatlas"
24+
"golang.org/x/oauth2"
2425
)
2526

2627
var (
@@ -44,6 +45,7 @@ const (
4445
globalUserOutsideSubnetErrorCode = "GLOBAL_USER_OUTSIDE_SUBNET"
4546
unauthorizedErrorCode = "UNAUTHORIZED"
4647
invalidRefreshTokenErrorCode = "INVALID_REFRESH_TOKEN"
48+
invalidServiceAccountClient = "invalid_client"
4749
)
4850

4951
// Check checks the error and returns a more user-friendly error message if applicable.
@@ -65,6 +67,8 @@ func Check(err error) error {
6567
return errOutsideVPN
6668
case asymmetricShardUnsupportedErrorCode:
6769
return errAsymmetricShardUnsupported
70+
case invalidServiceAccountClient: // oauth2 error
71+
return ErrUnauthorized
6872
}
6973

7074
apiError := getError(err) // some `Unauthorized` errors do not have an error code, so we check the HTTP status code
@@ -77,8 +81,9 @@ func Check(err error) error {
7781
}
7882

7983
// getErrorCode extracts the error code from the error if it is an Atlas error.
80-
// This function checks for v2 SDK, the pinned clusters SDK and the old SDK errors.
81-
// If the error is not any of these Atlas errors, it returns "UNKNOWN_ERROR".
84+
// This function checks for v2 SDK, the pinned clusters SDK, the old SDK errors
85+
// and oauth2 errors.
86+
// If the error is not any of these errors, it returns "UNKNOWN_ERROR".
8287
func getErrorCode(err error) string {
8388
if err == nil {
8489
return unknownErrorCode
@@ -94,6 +99,10 @@ func getErrorCode(err error) string {
9499
if sdkPinnedError, ok := atlasClustersPinned.AsError(err); ok {
95100
return sdkPinnedError.GetErrorCode()
96101
}
102+
var oauth2Err *oauth2.RetrieveError
103+
if errors.As(err, &oauth2Err) {
104+
return oauth2Err.ErrorCode
105+
}
97106

98107
return unknownErrorCode
99108
}

internal/cli/setup/setup_cmd.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -579,16 +579,10 @@ func (opts *Opts) promptConnect() error {
579579
func (opts *Opts) PreRun(ctx context.Context) error {
580580
opts.skipLogin = true
581581

582-
if err := validate.NoAPIKeys(); err != nil {
583-
// Why are we ignoring the error?
584-
// Because if the user has API keys, we just want to proceed with the flow
585-
// Then why not remove the error?
586-
// The error is useful in other components that call `validate.NoAPIKeys()`
582+
switch config.AuthType() {
583+
case config.APIKeys, config.ServiceAccount:
587584
return nil
588-
}
589-
590-
// if profile has access token and refresh token is valid, we can skip login
591-
if _, err := auth.AccountWithAccessToken(); err == nil {
585+
case config.UserAccount:
592586
if err := opts.login.RefreshAccessToken(ctx); !commonerrors.IsInvalidRefreshToken(err) {
593587
return nil
594588
}

internal/mocks/mock_store.go

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/store/store.go

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,16 @@ const (
4040
var errUnsupportedService = errors.New("unsupported service")
4141

4242
type Store struct {
43-
service string
44-
baseURL string
45-
telemetry bool
46-
authType config.AuthMechanism
47-
username string
48-
password string
49-
accessToken *atlasauth.Token
50-
client *atlas.Client
43+
service string
44+
baseURL string
45+
telemetry bool
46+
authType config.AuthMechanism
47+
username string
48+
password string
49+
accessToken *atlasauth.Token
50+
clientID string
51+
clientSecret string
52+
client *atlas.Client
5153
// Latest release of the autogenerated Atlas V2 API Client
5254
clientv2 *atlasv2.APIClient
5355
// Pinnned version to the most recent version that's working for clusters
@@ -72,13 +74,14 @@ func (s *Store) httpClient(httpTransport http.RoundTripper) (*http.Client, error
7274

7375
return &http.Client{Transport: tr}, nil
7476
case config.ServiceAccount:
75-
// TODO: serviceAccount will be implemented in CLOUDP-329787
76-
return &http.Client{Transport: httpTransport}, nil
77+
tr, err := transport.NewServiceAccountTransport(s.clientID, s.clientSecret, httpTransport)
78+
if err != nil {
79+
return nil, err
80+
}
81+
return &http.Client{Transport: tr}, nil
7782
default:
7883
return &http.Client{Transport: httpTransport}, nil
7984
}
80-
81-
return &http.Client{Transport: httpTransport}, nil
8285
}
8386

8487
func (s *Store) transport() *http.Transport {
@@ -137,6 +140,8 @@ func Telemetry() Option {
137140
type CredentialsGetter interface {
138141
PublicAPIKey() string
139142
PrivateAPIKey() string
143+
ClientID() string
144+
ClientSecret() string
140145
Token() (*atlasauth.Token, error)
141146
AuthType() config.AuthMechanism
142147
}
@@ -145,10 +150,16 @@ type CredentialsGetter interface {
145150
func WithAuthentication(c CredentialsGetter) Option {
146151
return func(s *Store) error {
147152
s.authType = c.AuthType()
148-
if s.authType == config.APIKeys {
153+
switch s.authType {
154+
case config.APIKeys:
149155
s.username = c.PublicAPIKey()
150156
s.password = c.PrivateAPIKey()
151-
} else {
157+
case config.ServiceAccount:
158+
s.clientID = c.ClientID()
159+
s.clientSecret = c.ClientSecret()
160+
case config.UserAccount:
161+
fallthrough
162+
default:
152163
t, err := c.Token()
153164
if err != nil {
154165
return err

internal/store/store_test.go

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,32 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
//go:build unit
16-
1715
package store
1816

1917
import (
2018
"context"
2119
"testing"
2220

2321
"github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config"
22+
"github.com/stretchr/testify/require"
2423
atlasauth "go.mongodb.org/atlas/auth"
2524
)
2625

2726
type auth struct {
2827
username string
2928
password string
30-
token string
29+
refreshToken string
3130
clientID string
3231
clientSecret string
32+
accessToken *atlasauth.Token
3333
}
3434

35-
func (auth) Token() (*atlasauth.Token, error) {
36-
return nil, nil
35+
func (a auth) Token() (*atlasauth.Token, error) {
36+
return a.accessToken, nil
3737
}
3838

3939
func (a auth) RefreshToken() string {
40-
return a.token
40+
return a.refreshToken
4141
}
4242

4343
func (a auth) PublicAPIKey() string {
@@ -60,7 +60,7 @@ func (a auth) AuthType() config.AuthMechanism {
6060
if a.username != "" {
6161
return config.APIKeys
6262
}
63-
if a.token != "" {
63+
if a.accessToken != nil {
6464
return config.UserAccount
6565
}
6666
if a.clientID != "" {
@@ -117,21 +117,46 @@ func (c testConfig) OpsManagerURL() string {
117117
var _ AuthenticatedConfig = &testConfig{}
118118

119119
func TestWithAuthentication(t *testing.T) {
120-
a := auth{
121-
username: "username",
122-
password: "password",
123-
}
124-
c, err := New(Service("cloud"), WithAuthentication(a))
125-
126-
if err != nil {
127-
t.Fatalf("New() unexpected error: %v", err)
120+
tests := []struct {
121+
name string
122+
a auth
123+
}{
124+
{
125+
name: "api keys",
126+
a: auth{
127+
username: "username",
128+
password: "password",
129+
},
130+
},
131+
{
132+
name: "service account",
133+
a: auth{
134+
clientID: "id",
135+
clientSecret: "secret",
136+
},
137+
},
138+
{
139+
name: "user account",
140+
a: auth{
141+
refreshToken: "token",
142+
accessToken: &atlasauth.Token{
143+
AccessToken: "access",
144+
RefreshToken: "refresh",
145+
},
146+
},
147+
},
128148
}
129149

130-
if c.username != a.username {
131-
t.Errorf("New() username = %s; expected %s", c.username, a.username)
132-
}
133-
if c.password != a.password {
134-
t.Errorf("New() password = %s; expected %s", c.password, a.password)
150+
for _, tt := range tests {
151+
t.Run(tt.name, func(t *testing.T) {
152+
c, err := New(Service("cloud"), WithAuthentication(tt.a))
153+
require.NoError(t, err)
154+
require.Equal(t, c.username, tt.a.username)
155+
require.Equal(t, c.password, tt.a.password)
156+
require.Equal(t, c.clientID, tt.a.clientID)
157+
require.Equal(t, c.clientSecret, tt.a.clientSecret)
158+
require.Equal(t, c.accessToken, tt.a.accessToken)
159+
})
135160
}
136161
}
137162

0 commit comments

Comments
 (0)