import time
import requests
from authlib.integrations.flask_client import OAuth
from flask import session, redirect, url_for
from typing import Optional
from mdvtools.auth.auth_provider import AuthProvider
import logging
from jose import jwt
from jose.exceptions import ExpiredSignatureError, JWTError, JWTClaimsError
[docs]
class Auth0Provider(AuthProvider):
def __init__(self, app, oauth: OAuth, client_id: str, client_secret: str, domain: str):
"""
Initializes the Auth0Provider class with application details.
:param app: Flask app instance
:param oauth: Authlib OAuth instance
:param client_id: Auth0 Client ID
:param client_secret: Auth0 Client Secret
:param domain: Auth0 Domain
"""
try:
if not all([client_id, client_secret, domain]):
raise ValueError("Missing required Auth0 configuration parameters.")
self.app = app
self.oauth = oauth
self.client_id = client_id
self.client_secret = client_secret
self.domain = domain
self._initialize_oauth()
logging.info("Auth0Provider initialized successfully.")
except Exception as e:
logging.critical(f"Failed to initialize Auth0Provider: {e}")
raise
[docs]
def _initialize_oauth(self):
"""
Registers the Auth0 OAuth provider and validates OpenID Connect metadata.
"""
try:
# Construct the server metadata URL for OpenID Connect discovery
server_metadata_url = f'https://{self.domain}/.well-known/openid-configuration'
# Attempt to fetch metadata to ensure it's accessible
response = requests.get(server_metadata_url)
if response.status_code != 200:
logging.error(f"Failed to fetch OpenID configuration from {server_metadata_url}: {response.text}")
raise RuntimeError(f"Unable to fetch OpenID Connect metadata from {server_metadata_url}")
# Parse and check the existence of jwks_uri in the metadata
metadata = response.json()
print("@@@@@@")
#print(metadata)
jwks_uri = metadata.get('jwks_uri')
if not jwks_uri:
logging.error(f"The OpenID configuration is missing 'jwks_uri': {metadata}")
raise RuntimeError("'jwks_uri' is missing in OpenID Connect metadata.")
# Register the OAuth provider with server_metadata_url for dynamic metadata fetching
self.oauth.register(
'auth0',
client_id=self.client_id,
client_secret=self.client_secret,
server_metadata_url=server_metadata_url,
client_kwargs={'scope': 'openid profile email'},
)
logging.info("Auth0 OAuth provider registered successfully with OpenID Connect metadata.")
except Exception as e:
logging.error(f"Error while registering OAuth provider: {e}")
raise RuntimeError("Failed to initialize OAuth.") from e
[docs]
def login(self) -> str:
"""
Initiates the login process by redirecting to Auth0's authorization page.
"""
try:
print("$$$$$$$$$$$$$$$ -login-1")
logging.info("Initiating login process.")
#redirect_uri = url_for('callback', _external=True)
redirect_uri = self.app.config["AUTH0_CALLBACK_URL"]
print(redirect_uri)
print(self.oauth.auth0.authorize_redirect(redirect_uri))
return self.oauth.auth0.authorize_redirect(redirect_uri=redirect_uri)
except Exception as e:
logging.error(f"Error during login process: {e}")
raise RuntimeError("Login failed.") from e
[docs]
def logout(self) -> None:
"""
Logs the user out by clearing the session and redirecting to Auth0's logout endpoint.
"""
try:
logging.info("Logging out user from Auth0.")
# Clear the server-side session to remove any stored tokens and user data
session.clear()
# Prepare the redirect URL after logout (i.e., where the user is sent after logging out of Auth0)
redirect_url = self.app.config["LOGIN_REDIRECT_URL"] # The URL to redirect after logout
# Redirect the user to Auth0's logout URL, which will handle the Auth0-side logout
# This will log the user out of Auth0 and redirect them to the provided URL
logout_url = f"https://{self.app.config['AUTH0_DOMAIN']}/v2/logout?returnTo={redirect_url}&client_id={self.app.config['AUTH0_CLIENT_ID']}"
logging.info(f"Redirecting to Auth0 logout URL: {logout_url}")
return redirect(logout_url)
except Exception as e:
logging.error(f"Error during logout process: {e}")
raise RuntimeError("Auth0 logout failed.") from e
[docs]
def get_user(self, token: dict) -> Optional[dict]:
"""
Retrieves the user information using the provided token.
:param token: Dictionary containing access token and user details
:return: User information dictionary or None
"""
try:
logging.info("Fetching user information.")
# Extract access token
access_token = token.get("access_token")
if not access_token:
logging.error("Access token is missing.")
return None
# Correct Authorization Header
headers = {"Authorization": f"Bearer {access_token}"}
user_info_url = f"https://{self.domain}/userinfo"
response = requests.get(user_info_url, headers=headers)
if response.status_code == 200:
logging.debug("User information retrieved successfully.")
raw_data = response.json()
# Transform response to match frontend expectations
user_data = {
"name": raw_data.get("name", "Unknown User"),
"email": raw_data.get("email", ""),
"association": "Example Corp", # Static or extract from another source
"avatarUrl": raw_data.get("picture", ""),
}
print("!!!!!!")
print(user_data)
return user_data
else:
logging.warning(f"Failed to fetch user information: {response.status_code} {response.text}")
return None
except requests.RequestException as e:
logging.error(f"Error while fetching user information: {e}")
return None
[docs]
def get_token(self) -> Optional[str]:
"""
Retrieves the token from the session.
:return: Token string or None
"""
try:
logging.info("Retrieving token from session.")
return session.get('token', {}).get('access_token')
except Exception as e:
logging.error(f"Error while retrieving token: {e}")
return None
[docs]
def handle_callback(self) -> Optional[str]:
"""
Handles the Auth0 callback and retrieves the access token.
:return: Access token string
"""
try:
logging.info("Handling callback from Auth0.")
token = self.oauth.auth0.authorize_access_token()
if 'access_token' not in token:
raise ValueError("Access token not found in the response.")
session['token'] = token
session["auth_method"] = "auth0"
logging.info("Access token retrieved and stored in session.")
return token['access_token']
except Exception as e:
logging.error(f"Error during callback handling: {e}")
session.clear() # Clear session in case of failure
raise RuntimeError("Callback handling failed.") from e
[docs]
def is_authenticated(self, token: str) -> bool:
"""
Checks if the user is authenticated by verifying the token.
:param token: Access token
:return: True if authenticated, False otherwise
"""
try:
logging.info("Checking authentication status.")
user_info = self.get_user(token)
return user_info is not None
except Exception as e:
logging.error(f"Error while checking authentication: {e}")
return False
[docs]
def is_token_valid(self, token):
"""
Validates the provided token by verifying its signature using Auth0's public keys
and ensuring it's not expired.
"""
try:
# Step 1: Decode the token header without verification to extract the 'kid'
unverified_header = jwt.get_unverified_header(token)
if unverified_header is None:
print("++++++++1")
logging.error("Invalid token header.")
return False
# Step 2: Get the public key from Auth0's JWKS (JSON Web Key Set) endpoint
rsa_key = {}
if 'kid' in unverified_header:
try:
# Fetch Auth0 public keys from jwks_uri
response = requests.get(self.app.config['AUTH0_PUBLIC_KEY_URI'])
if response.status_code != 200:
logging.error(f"Failed to fetch JWKS: {response.status_code}")
print("++++++++2")
return False
jwks = response.json()
# Find the key in the JWKS that matches the 'kid' in the token header
for key in jwks['keys']:
if key['kid'] == unverified_header['kid']:
rsa_key = {
'kty': key['kty'],
'kid': key['kid'],
'use': key['use'],
'n': key['n'],
'e': key['e']
}
break
except Exception as e:
logging.error(f"Error getting public keys from Auth0: {e}")
return False
if not rsa_key:
logging.error("No valid key found in JWKS for token verification.")
print("++++++++3")
return False
# Step 3: Verify the JWT token using the public key
payload = jwt.decode(
token,
rsa_key,
algorithms=["RS256"],
audience=self.app.config["AUTH0_AUDIENCE"], # Your API audience
issuer=f"https://{self.app.config['AUTH0_DOMAIN']}/"
)
print("++++++++4")
# Step 4: Check the expiration of the token
if payload['exp'] > time.time():
print("++++++++5")
return True
else:
logging.error("Token is expired.")
print("++++++++6")
return False
except ExpiredSignatureError:
logging.error("Token is expired.")
return False
except JWTClaimsError:
logging.error("Invalid claims in token.")
return False
except JWTError as e:
logging.error(f"Error decoding token: {e}")
return False
except Exception as e:
logging.error(f"Error during token validation: {e}")
return False