# -*- coding: utf-8 -*- # ------------------------------------------------------------------------------ # File Name: fastapi_security_util.py # Original Author: Clark Lin # Email: clark_lin@outlook.com # # Change History # Version Date By Description # 0.01 2024-04-24 Clark Lin Initial version # # Main features summary: # - Implementation of OAuth2 Authentication # # Copyright Information: # Copyright © 2024 Oasis # Licensed TBD # ------------------------------------------------------------------------------ from datetime import datetime, timedelta, timezone from typing import Annotated from pydantic import BaseModel from fastapi import Depends, FastAPI, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt from passlib.context import CryptContext # Import your custom logging manager import os import logging from common.script.logging_manager import LoggingManager import json # Initialize logging manager curr_module = os.path.basename(__file__) # lm = LoggingManager() lm = LoggingManager.get_instance() # ------------------------------------------------ # Init Global Variables # ------------------------------------------------ # Credential file credential_file = './credential/oauth2.json' # to get a string like this run: # openssl rand -hex 32 secret_key = "" algorithm = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 # Init crypt context pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # Init Oauth2 schema, specify token endpoint oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # ------------------------------------------------ # Sample Repository for Client Verification # ------------------------------------------------ client_db = {} # ------------------------------------------------ # Get credentials # ------------------------------------------------ def get_credentials(file: str) -> tuple: global secret_key global client_db try: # Open and read the JSON file with open(file, "r") as credential_file: data = json.load(credential_file) # Access data from the JSON secret_key = data['oauth2_secret_key'] client_db = data['client_db'] return secret_key, client_db except Exception as e: print(e) return None, None # ------------------------------------------------ # Model Definition # ------------------------------------------------ class Token(BaseModel): access_token: str token_type: str # ------------------------------------------------ # Sub Function - Verify Password # ------------------------------------------------ def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) # ------------------------------------------------ # Sub Function - Use Cript Context to Get Hash # ------------------------------------------------ def get_password_hash(password): return pwd_context.hash(password) # ------------------------------------------------ # Sub Function - Get Client from Dict by Client ID # ------------------------------------------------ def get_client(db: dict, client_id: str): if client_id in db: client_dict = db[client_id] return client_dict # ------------------------------------------------ # Sub Function - Authentication Process # ------------------------------------------------ def authenticate_client(client_db: dict, client_id: str, client_secret: str): client = get_client(client_db, client_id) if not client: return False if not verify_password(client_secret, client["hashed_client_secret"]): return False return client # ------------------------------------------------ # Sub Function - Generate Access Token # ------------------------------------------------ def create_access_token(data: dict, expires_delta: timedelta | None = None): global secret_key to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(minutes=15) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) return encoded_jwt # ------------------------------------------------ # Sub Function - Generate Access Token # ------------------------------------------------ def get_access_token( form_data: OAuth2PasswordRequestForm, without_credential: bool = False) -> Token: global secret_key global client_db # Get credentials secret_key, client_db = get_credentials(credential_file) if secret_key == None or client_db == None: lm.log(logging.INFO, curr_module, "Failed to get credentials from file [" + credential_file + "]") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get credentials from file [" + credential_file + "]", ) # Call authentication process if (not without_credential): client = authenticate_client(client_db, form_data.username, form_data.password) if not client: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if (not without_credential): access_token = create_access_token( data = { "sub": form_data.username, "name": client["name"], "description": client["description"] }, expires_delta = access_token_expires ) else: client_id = list(client_db.keys())[0] access_token = create_access_token( data = { "sub": client_id, "name": client_db[client_id]["name"], "description": client_db[client_id]["description"] }, expires_delta = access_token_expires ) lm.log(logging.INFO, curr_module, 'get_access_token completed with normal') return Token(access_token=access_token, token_type="bearer") # ------------------------------------------------ # Sub Function - Generate Access Token # ------------------------------------------------ def get_access_token_without_credential() -> Token: return get_access_token(None, True);