Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions notion_client/api_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,54 @@ def send(self, file_upload_id: str, **kwargs: Any) -> SyncAsync[Any]:
form_data=pick(kwargs, "file", "part_number"),
auth=kwargs.get("auth"),
)


class OAuthEndpoint(Endpoint):
def token(
self, client_id: str, client_secret: str, **kwargs: Any
) -> SyncAsync[Any]:
"""Get token.

*[🔗 Endpoint documentation](https://developers.notion.com/reference/create-a-token)*
""" # noqa: E501
return self.parent.request(
path="oauth/token",
method="POST",
body=pick(
kwargs,
"grant_type",
"code",
"redirect_uri",
"external_account",
"refresh_token",
),
auth={"client_id": client_id, "client_secret": client_secret},
)

def introspect(
self, client_id: str, client_secret: str, **kwargs: Any
) -> SyncAsync[Any]:
"""Introspect token.

*[🔗 Endpoint documentation](https://developers.notion.com/reference/oauth-introspect)*
""" # noqa: E501
return self.parent.request(
path="oauth/introspect",
method="POST",
body=pick(kwargs, "token"),
auth={"client_id": client_id, "client_secret": client_secret},
)

def revoke(
self, client_id: str, client_secret: str, **kwargs: Any
) -> SyncAsync[Any]:
"""Revoke token.

*[🔗 Endpoint documentation](https://developers.notion.com/reference/oauth-revoke)*
""" # noqa: E501
return self.parent.request(
path="oauth/revoke",
method="POST",
body=pick(kwargs, "token"),
auth={"client_id": client_id, "client_secret": client_secret},
)
20 changes: 15 additions & 5 deletions notion_client/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Synchronous and asynchronous clients for Notion's API."""

import base64
import json
import logging
from abc import abstractmethod
Expand All @@ -19,6 +20,7 @@
SearchEndpoint,
UsersEndpoint,
FileUploadsEndpoint,
OAuthEndpoint,
)
from notion_client.errors import (
APIResponseError,
Expand Down Expand Up @@ -82,6 +84,7 @@ def __init__(
self.search = SearchEndpoint(self)
self.comments = CommentsEndpoint(self)
self.file_uploads = FileUploadsEndpoint(self)
self.oauth = OAuthEndpoint(self)

@property
def client(self) -> Union[httpx.Client, httpx.AsyncClient]:
Expand All @@ -108,11 +111,18 @@ def _build_request(
query: Optional[Dict[Any, Any]] = None,
body: Optional[Dict[Any, Any]] = None,
form_data: Optional[Dict[Any, Any]] = None,
auth: Optional[str] = None,
auth: Optional[Union[str, Dict[str, str]]] = None,
) -> Request:
headers = httpx.Headers()
if auth:
headers["Authorization"] = f"Bearer {auth}"
if isinstance(auth, dict):
client_id = auth.get("client_id", "")
client_secret = auth.get("client_secret", "")
credentials = f"{client_id}:{client_secret}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
else:
headers["Authorization"] = f"Bearer {auth}"
self.logger.info(f"{method} {self.client.base_url}{path}")
self.logger.debug(f"=> {query} -- {body} -- {form_data}")

Expand Down Expand Up @@ -182,7 +192,7 @@ def request(
query: Optional[Dict[Any, Any]] = None,
body: Optional[Dict[Any, Any]] = None,
form_data: Optional[Dict[Any, Any]] = None,
auth: Optional[str] = None,
auth: Optional[Union[str, Dict[str, str]]] = None,
) -> SyncAsync[Any]:
# noqa
pass
Expand Down Expand Up @@ -228,7 +238,7 @@ def request(
query: Optional[Dict[Any, Any]] = None,
body: Optional[Dict[Any, Any]] = None,
form_data: Optional[Dict[Any, Any]] = None,
auth: Optional[str] = None,
auth: Optional[Union[str, Dict[str, str]]] = None,
) -> Any:
"""Send an HTTP request."""
request = self._build_request(method, path, query, body, form_data, auth)
Expand Down Expand Up @@ -279,7 +289,7 @@ async def request(
query: Optional[Dict[Any, Any]] = None,
body: Optional[Dict[Any, Any]] = None,
form_data: Optional[Dict[Any, Any]] = None,
auth: Optional[str] = None,
auth: Optional[Union[str, Dict[str, str]]] = None,
) -> Any:
"""Send an HTTP request asynchronously."""
request = self._build_request(method, path, query, body, form_data, auth)
Expand Down
1 change: 1 addition & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pytest
pytest-asyncio
pytest-cov
pytest-mock
pytest-timeout
pytest-vcr
vcrpy==6.0.2
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ def token() -> Optional[str]:
return os.environ.get("NOTION_TOKEN")


@pytest.fixture(scope="session")
def oauth_client_id() -> Optional[str]:
"""OAuth client ID for testing OAuth endpoints"""
return os.environ.get("NOTION_OAUTH_CLIENT_ID")


@pytest.fixture(scope="session")
def oauth_client_secret() -> Optional[str]:
"""OAuth client secret for testing OAuth endpoints"""
return os.environ.get("NOTION_OAUTH_CLIENT_SECRET")


@pytest.fixture(scope="session")
def oauth_token() -> Optional[str]:
"""OAuth token for testing OAuth introspect and revoke endpoints"""
return os.environ.get("NOTION_OAUTH_TOKEN")


@pytest.fixture(scope="module", autouse=True)
def parent_page_id(vcr) -> str:
"""this is the ID of the Notion page where the tests will be executed
Expand Down
167 changes: 167 additions & 0 deletions tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,170 @@ def test_file_uploads_complete(client, part_uploaded_file_upload_id):
assert response["content_type"] == "text/plain"
assert response["number_of_parts"]["total"] == 3
assert response["number_of_parts"]["sent"] == 3


def test_oauth_introspect(client, mocker):
"""Test OAuth token introspection with mock - tests Basic auth encoding"""
mock_response = {"active": False, "request_id": "test-request-id"}

mock_send = mocker.patch.object(
client.client,
"send",
return_value=mocker.Mock(
json=lambda: mock_response, raise_for_status=lambda: None
),
)

response = client.oauth.introspect(
client_id="test_client_id",
client_secret="test_client_secret",
token="test_token",
)

assert "active" in response
assert isinstance(response["active"], bool)

mock_send.assert_called_once()
request = mock_send.call_args[0][0]
assert "Authorization" in request.headers
assert request.headers["Authorization"].startswith("Basic ")
assert (
request.headers["Authorization"]
== "Basic dGVzdF9jbGllbnRfaWQ6dGVzdF9jbGllbnRfc2VjcmV0"
)


def test_oauth_token_with_basic_auth(client, mocker):
"""Test OAuth token exchange with Basic auth - exercises auth encoding path"""
mock_response = {
"access_token": "secret_test_token",
"token_type": "bearer",
"bot_id": "bot_123",
}

mock_send = mocker.patch.object(
client.client,
"send",
return_value=mocker.Mock(
json=lambda: mock_response, raise_for_status=lambda: None
),
)

response = client.oauth.token(
client_id="test_client_id",
client_secret="test_client_secret",
grant_type="authorization_code",
code="test_code",
redirect_uri="http://localhost:3000/callback",
)

assert response["access_token"] == "secret_test_token"

mock_send.assert_called_once()
request = mock_send.call_args[0][0]
assert "Authorization" in request.headers
assert request.headers["Authorization"].startswith("Basic ")
import base64

expected = base64.b64encode(b"test_client_id:test_client_secret").decode()
assert request.headers["Authorization"] == f"Basic {expected}"


def test_oauth_revoke_with_basic_auth(client, mocker):
"""Test OAuth revoke with Basic auth - exercises auth encoding path"""
mock_response = {}

mock_send = mocker.patch.object(
client.client,
"send",
return_value=mocker.Mock(
json=lambda: mock_response, raise_for_status=lambda: None
),
)

response = client.oauth.revoke(
client_id="test_client_id",
client_secret="test_client_secret",
token="test_token",
)

assert response == {}

mock_send.assert_called_once()
request = mock_send.call_args[0][0]
assert "Authorization" in request.headers
assert request.headers["Authorization"].startswith("Basic ")


def test_oauth_revoke(client, mocker):
"""Test OAuth token revocation with mock (can't use cassette - token becomes invalid)"""
mock_response = {}
mock_request = mocker.patch.object(client, "request", return_value=mock_response)

response = client.oauth.revoke(
client_id="test_client_id",
client_secret="test_client_secret",
token="test_token",
)

assert response == {}
mock_request.assert_called_once_with(
path="oauth/revoke",
method="POST",
body={"token": "test_token"},
auth={"client_id": "test_client_id", "client_secret": "test_client_secret"},
)


def test_oauth_token_authorization_code(client, mocker):
mock_response = {
"access_token": "secret_test_token",
"token_type": "bearer",
"bot_id": "bot_123",
"workspace_id": "ws_456",
"workspace_name": "Test Workspace",
"owner": {"type": "user", "user": {"object": "user", "id": "user_789"}},
}

mock_request = mocker.patch.object(client, "request", return_value=mock_response)

response = client.oauth.token(
client_id="test_client_id",
client_secret="test_client_secret",
grant_type="authorization_code",
code="test_code",
redirect_uri="http://localhost:3000/callback",
)

assert response["access_token"] == "secret_test_token"
assert response["bot_id"] == "bot_123"
mock_request.assert_called_once()
call_kwargs = mock_request.call_args[1]
assert call_kwargs["path"] == "oauth/token"
assert call_kwargs["method"] == "POST"
assert call_kwargs["auth"] == {
"client_id": "test_client_id",
"client_secret": "test_client_secret",
}


def test_oauth_token_refresh_token(client, mocker):
mock_response = {
"access_token": "secret_refreshed_token",
"token_type": "bearer",
"bot_id": "bot_123",
}

mock_request = mocker.patch.object(client, "request", return_value=mock_response)

response = client.oauth.token(
client_id="test_client_id",
client_secret="test_client_secret",
grant_type="refresh_token",
refresh_token="test_refresh_token",
)

assert response["access_token"] == "secret_refreshed_token"
mock_request.assert_called_once()
call_kwargs = mock_request.call_args[1]
assert call_kwargs["path"] == "oauth/token"