Skip to content

Commit 552cacc

Browse files
committed
Properly sign query string parameters included in request URI.
1 parent 24f894e commit 552cacc

File tree

2 files changed

+87
-40
lines changed

2 files changed

+87
-40
lines changed

oauth2/__init__.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -250,35 +250,33 @@ class Request(dict):
250250
251251
"""
252252

253-
http_method = HTTP_METHOD
254-
http_url = None
255253
version = VERSION
256254

257255
def __init__(self, method=HTTP_METHOD, url=None, parameters=None):
258-
if method is not None:
259-
self.method = method
260-
261-
if url is not None:
262-
self.url = url
263-
256+
self.method = method
257+
self.url = url
264258
if parameters is not None:
265259
self.update(parameters)
266260

267261
@setter
268262
def url(self, value):
269-
scheme, netloc, path, params, query, fragment = urlparse.urlparse(value)
270-
271-
# Exclude default port numbers.
272-
if scheme == 'http' and netloc[-3:] == ':80':
273-
netloc = netloc[:-3]
274-
elif scheme == 'https' and netloc[-4:] == ':443':
275-
netloc = netloc[:-4]
276-
277-
if scheme != 'http' and scheme != 'https':
278-
raise ValueError("Unsupported URL %s (%s)." % (value, scheme))
279-
280-
value = urlparse.urlunparse((scheme, netloc, path, params, query, fragment))
281263
self.__dict__['url'] = value
264+
if value is not None:
265+
scheme, netloc, path, params, query, fragment = urlparse.urlparse(value)
266+
267+
# Exclude default port numbers.
268+
if scheme == 'http' and netloc[-3:] == ':80':
269+
netloc = netloc[:-3]
270+
elif scheme == 'https' and netloc[-4:] == ':443':
271+
netloc = netloc[:-4]
272+
if scheme not in ('http', 'https'):
273+
raise ValueError("Unsupported URL %s (%s)." % (value, scheme))
274+
275+
# Normalized URL excludes params, query, and fragment.
276+
self.normalized_url = urlparse.urlunparse((scheme, netloc, path, None, None, None))
277+
else:
278+
self.normalized_url = None
279+
self.__dict__['url'] = None
282280

283281
@setter
284282
def method(self, value):
@@ -342,6 +340,11 @@ def get_normalized_parameters(self):
342340
items.extend((key, item) for item in value)
343341
else:
344342
items.append((key, value))
343+
344+
# Include any query string parameters from the provided URL
345+
query = urlparse.urlparse(self.url)[4]
346+
items.extend(self._split_url_string(query).items())
347+
345348
encoded_str = urllib.urlencode(sorted(items))
346349
# Encode signature parameters per Oauth Core 1.0 protocol
347350
# spec draft 7, section 3.6
@@ -600,15 +603,6 @@ def request(self, uri, method="GET", body=None, headers=None,
600603

601604
if body and method == "POST" and not is_multipart:
602605
parameters = dict(parse_qsl(body))
603-
elif method == "GET":
604-
parsed = urlparse.urlparse(uri)
605-
606-
try:
607-
query = parsed.query
608-
except AttributeError:
609-
query = parsed[4]
610-
611-
parameters = parse_qsl(query)
612606
else:
613607
parameters = None
614608

@@ -676,14 +670,15 @@ class SignatureMethod_HMAC_SHA1(SignatureMethod):
676670
def signing_base(self, request, consumer, token):
677671
sig = (
678672
escape(request.method),
679-
escape(request.url),
673+
escape(request.normalized_url),
680674
escape(request.get_normalized_parameters()),
681675
)
682676

683677
key = '%s&' % escape(consumer.secret)
684678
if token:
685679
key += escape(token.secret)
686680
raw = '&'.join(sig)
681+
print key, raw
687682
return key, raw
688683

689684
def sign(self, request, consumer, token):

tests/test_oauth.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,8 @@ def test_setter(self):
222222
url = "http://example.com"
223223
method = "GET"
224224
req = oauth.Request(method)
225-
226-
try:
227-
url = req.url
228-
self.fail("AttributeError should have been raised on empty url.")
229-
except AttributeError:
230-
pass
231-
except Exception, e:
232-
self.fail(str(e))
225+
self.assertTrue(req.url is None)
226+
self.assertTrue(req.normalized_url is None)
233227

234228
def test_deleter(self):
235229
url = "http://example.com"
@@ -253,17 +247,21 @@ def test_url(self):
253247
method = "GET"
254248

255249
req = oauth.Request(method, url1)
256-
self.assertEquals(req.url, exp1)
250+
self.assertEquals(req.normalized_url, exp1)
251+
self.assertEquals(req.url, url1)
257252

258253
req = oauth.Request(method, url2)
259-
self.assertEquals(req.url, exp2)
254+
self.assertEquals(req.normalized_url, exp2)
255+
self.assertEquals(req.url, url2)
260256

261257
def test_url_query(self):
262258
url = "https://www.google.com/m8/feeds/contacts/default/full/?alt=json&max-contacts=10"
259+
normalized_url = urlparse.urlunparse(urlparse.urlparse(url)[:3] + (None, None, None))
263260
method = "GET"
264261

265262
req = oauth.Request(method, url)
266263
self.assertEquals(req.url, url)
264+
self.assertEquals(req.normalized_url, normalized_url)
267265

268266
def test_get_parameter(self):
269267
url = "http://example.com"
@@ -400,6 +398,30 @@ def test_to_url_with_query(self):
400398
self.assertEquals(b['max-contacts'], ['10'])
401399
self.assertEquals(a, b)
402400

401+
def test_signature_base_string_with_query(self):
402+
url = "https://www.google.com/m8/feeds/contacts/default/full/?alt=json&max-contacts=10"
403+
params = {
404+
'oauth_version': "1.0",
405+
'oauth_nonce': "4572616e48616d6d65724c61686176",
406+
'oauth_timestamp': "137131200",
407+
'oauth_consumer_key': "0685bd9184jfhq22",
408+
'oauth_signature_method': "HMAC-SHA1",
409+
'oauth_token': "ad180jjd733klru7",
410+
'oauth_signature': "wOJIO9A2W5mFwDgiDvZbTSMK%2FPY%3D",
411+
}
412+
req = oauth.Request("GET", url, params)
413+
self.assertEquals(req.normalized_url, 'https://www.google.com/m8/feeds/contacts/default/full/')
414+
self.assertEquals(req.url, 'https://www.google.com/m8/feeds/contacts/default/full/?alt=json&max-contacts=10')
415+
normalized_params = parse_qsl(req.get_normalized_parameters())
416+
self.assertTrue(len(normalized_params), len(params) + 2)
417+
normalized_params = dict(normalized_params)
418+
for key, value in params.iteritems():
419+
if key == 'oauth_signature':
420+
continue
421+
self.assertEquals(value, normalized_params[key])
422+
self.assertEquals(normalized_params['alt'], 'json')
423+
self.assertEquals(normalized_params['max-contacts'], '10')
424+
403425
def test_get_normalized_parameters(self):
404426
url = "http://sp.example.com/"
405427

@@ -871,6 +893,36 @@ def test_multipart_post_does_not_alter_body(self):
871893
self.assertEqual(result, random_result)
872894
self.mox.VerifyAll()
873895

896+
def test_url_with_query_string(self):
897+
self.mox.StubOutWithMock(httplib2.Http, 'request')
898+
uri = 'http://example.com/foo/bar/?show=thundercats&character=snarf'
899+
client = oauth.Client(self.consumer, None)
900+
expected_kwargs = {
901+
'method': 'GET',
902+
'body': None,
903+
'redirections': httplib2.DEFAULT_MAX_REDIRECTS,
904+
'connection_type': None,
905+
'headers': mox.IsA(dict),
906+
}
907+
def oauth_verifier(url):
908+
req = oauth.Request.from_consumer_and_token(self.consumer, None,
909+
http_method='GET', http_url=uri, parameters={})
910+
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), self.consumer, None)
911+
expected = parse_qsl(urlparse.urlparse(req.to_url()).query)
912+
actual = parse_qsl(urlparse.urlparse(url).query)
913+
if len(expected) != len(actual):
914+
return False
915+
actual = dict(actual)
916+
for key, value in expected:
917+
if key not in ('oauth_signature', 'oauth_nonce', 'oauth_timestamp'):
918+
if actual[key] != value:
919+
return False
920+
return True
921+
httplib2.Http.request(client, mox.Func(oauth_verifier), **expected_kwargs)
922+
self.mox.ReplayAll()
923+
client.request(uri, 'GET')
924+
self.mox.VerifyAll()
925+
874926
if __name__ == "__main__":
875927
unittest.main()
876928

0 commit comments

Comments
 (0)