Commit f2a08f6b authored by Clark Lin's avatar Clark Lin
Browse files

Merge branch 'development' of https://linyinghao.cn/gitlab/clark/fastapi-for-apex into development

parents 1204ec96 753bfe8e
# Fastapi for APEX
...@@ -55,4 +55,3 @@ ...@@ -55,4 +55,3 @@
} }
} }
{
"oauth2_secret_key": "<generate random value by openssl rand -hex 32>",
"client_db": {
"client_id: generate random value by openssl rand -hex 32": {
"hashed_client_secret": "<generate by CryptContext(schemes=["bcrypt"], deprecated="auto").hash(password)>",
"name": "fastapi_service_test",
"description": "This client is for fastapi service test"
}
}
}
# -*- coding: utf-8 -*-
# ------------------------------------------------------------------------------
# File Name: fastapi_security_util.py
# Original Author: Clark Lin
# Email: clark_lin@outlook.com
#
# Change History
# Version Date By Description
# 0.01 2024-04-24 Clark Lin Initial version
#
# Main features summary:
# - Implementation of OAuth2 Authentication
#
# Copyright Information:
# Copyright © 2024 Oasis
# Licensed TBD
# ------------------------------------------------------------------------------
from datetime import datetime, timedelta, timezone
from typing import Annotated
from pydantic import BaseModel
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
# Import your custom logging manager
import os
import logging
from common.script.logging_manager import LoggingManager
import json
# Initialize logging manager
curr_module = os.path.basename(__file__)
# lm = LoggingManager()
lm = LoggingManager.get_instance()
# ------------------------------------------------
# Init Global Variables
# ------------------------------------------------
# Credential file
credential_file = './credential/oauth2.json'
# to get a string like this run:
# openssl rand -hex 32
secret_key = ""
algorithm = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# Init crypt context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Init Oauth2 schema, specify token endpoint
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# ------------------------------------------------
# Sample Repository for Client Verification
# ------------------------------------------------
client_db = {}
# ------------------------------------------------
# Get credentials
# ------------------------------------------------
def get_credentials(file: str) -> tuple:
global secret_key
global client_db
try:
# Open and read the JSON file
with open(file, "r") as credential_file:
data = json.load(credential_file)
# Access data from the JSON
secret_key = data['oauth2_secret_key']
client_db = data['client_db']
return secret_key, client_db
except Exception as e:
print(e)
return None, None
# ------------------------------------------------
# Model Definition
# ------------------------------------------------
class Token(BaseModel):
access_token: str
token_type: str
# ------------------------------------------------
# Sub Function - Verify Password
# ------------------------------------------------
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
# ------------------------------------------------
# Sub Function - Use Cript Context to Get Hash
# ------------------------------------------------
def get_password_hash(password):
return pwd_context.hash(password)
# ------------------------------------------------
# Sub Function - Get Client from Dict by Client ID
# ------------------------------------------------
def get_client(db: dict, client_id: str):
if client_id in db:
client_dict = db[client_id]
return client_dict
# ------------------------------------------------
# Sub Function - Authentication Process
# ------------------------------------------------
def authenticate_client(client_db: dict, client_id: str, client_secret: str):
client = get_client(client_db, client_id)
if not client:
return False
if not verify_password(client_secret, client["hashed_client_secret"]):
return False
return client
# ------------------------------------------------
# Sub Function - Generate Access Token
# ------------------------------------------------
def create_access_token(data: dict, expires_delta: timedelta | None = None):
global secret_key
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
return encoded_jwt
# ------------------------------------------------
# Sub Function - Generate Access Token
# ------------------------------------------------
def get_access_token(
form_data: OAuth2PasswordRequestForm,
without_credential: bool = False) -> Token:
global secret_key
global client_db
# Get credentials
secret_key, client_db = get_credentials(credential_file)
if secret_key == None or client_db == None:
lm.log(logging.INFO, curr_module, "Failed to get credentials from file [" + credential_file + "]")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get credentials from file [" + credential_file + "]",
)
# Call authentication process
if (not without_credential):
client = authenticate_client(client_db, form_data.username, form_data.password)
if not client:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
if (not without_credential):
access_token = create_access_token(
data = {
"sub": form_data.username,
"name": client["name"],
"description": client["description"]
},
expires_delta = access_token_expires
)
else:
client_id = list(client_db.keys())[0]
access_token = create_access_token(
data = {
"sub": client_id,
"name": client_db[client_id]["name"],
"description": client_db[client_id]["description"]
},
expires_delta = access_token_expires
)
lm.log(logging.INFO, curr_module, 'get_access_token completed with normal')
return Token(access_token=access_token, token_type="bearer")
# ------------------------------------------------
# Sub Function - Generate Access Token
# ------------------------------------------------
def get_access_token_without_credential() -> Token:
return get_access_token(None, True);
# Import FastAPI Libs
from fastapi import FastAPI, Query
from pydantic import BaseModel
import qcloud_cos_service
import qcloud_ocr_service
import tongyi_genai_service
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import Annotated
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
import fastapi_security_util
from fastapi_security_util import Token
# ------------------------------------------------
# Init Global Variables
# ------------------------------------------------
# Init Oauth2 schema, specify token endpoint
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# ------------------------------------------------
# Define FastAPI
# ------------------------------------------------
app = FastAPI()
# ------------------------------------------------
# Call Token Service
# ------------------------------------------------
@app.post("/token")
def get_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]) -> Token:
return fastapi_security_util.get_access_token(form_data)
# ------------------------------------------------
# Call Token Service without Credential
# ------------------------------------------------
@app.post("/token_without_credential")
def get_access_token_without_credential() -> Token:
return fastapi_security_util.get_access_token_without_credential()
# ------------------------------------------------
# Define CORS Options
# ------------------------------------------------
# origins = [
# "http://localhost",
# "http://localhost:8000",
# "https://yourdomain.com",
# # Add more origins if needed
# ]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # List of allowed origins
allow_credentials=True, # Whether to allow credentials (cookies, authorization headers, etc.)
allow_methods=["*"], # List of allowed methods (GET, POST, PUT, DELETE, etc.)
allow_headers=["*"], # List of allowed headers
)
# ------------------------------------------------
# Call File Upload Service Process
# ------------------------------------------------
@app.post("/upload/")
async def upload(file_uplaod_req: qcloud_cos_service.file_upload_req_basemodel):
return qcloud_cos_service.upload(file_uplaod_req)
# ------------------------------------------------
# Call File Download URL Service Process
# ------------------------------------------------
@app.post("/get_download_url/")
async def get_download_url(file_download_req: qcloud_cos_service.file_download_req_basemodel):
return qcloud_cos_service.get_download_url(file_download_req)
# ------------------------------------------------
# Call OCR Service Process
# ------------------------------------------------
@app.post("/get_detected_text/")
async def get_detected_text(ocr_req: qcloud_ocr_service.ocr_req_basemodel):
if ocr_req.tool == 'Q':
# Call Qcloud API
return qcloud_ocr_service.get_qcloud_detected_text(ocr_req)
elif ocr_req.tool == 'LP':
# Call Local PaddleOCR API
return qcloud_ocr_service.get_paddleocr_detected_text(ocr_req)
# ------------------------------------------------
# Call Gen AI Service Process
# ------------------------------------------------
@app.post("/send_message_to_gen_ai/tongyi/")
async def send_message(send_message_req: tongyi_genai_service.send_message_req_basemodel):
return tongyi_genai_service.send_message(send_message_req)
# ------------------------------------------------
# Call Gen AI Service Process (send question)
# ------------------------------------------------
@app.post("/send_message_to_gen_ai_async/tongyi/send_question")
async def send_question(
token: Annotated[str, Depends(oauth2_scheme)],
send_message_req: tongyi_genai_service.send_message_req_basemodel):
print('token', token)
return tongyi_genai_service.send_question(token, send_message_req)
# ------------------------------------------------
# Call Gen AI Service Process (get answer)
# ------------------------------------------------
@app.get("/send_message_to_gen_ai_async/tongyi/get_answer")
async def get_answer(
uuid: str = Query(None, title="UUID", description="UUID linked to request payload")):
return StreamingResponse(tongyi_genai_service.get_answer(uuid), media_type="text/event-stream")
...@@ -203,4 +203,3 @@ def get_detected_text(ocr_req: ocr_req_basemodel) -> tuple: ...@@ -203,4 +203,3 @@ def get_detected_text(ocr_req: ocr_req_basemodel) -> tuple:
return ocr_res return ocr_res
...@@ -203,4 +203,3 @@ def get_detected_text(asr_req: asr_req_basemodel) -> tuple: ...@@ -203,4 +203,3 @@ def get_detected_text(asr_req: asr_req_basemodel) -> tuple:
return asr_res return asr_res
...@@ -277,3 +277,4 @@ def get_download_url(file_download_req: file_download_req_basemodel): ...@@ -277,3 +277,4 @@ def get_download_url(file_download_req: file_download_req_basemodel):
) )
return file_download_res return file_download_res
...@@ -132,5 +132,3 @@ def get_detected_text(ocr_req: ocr_req_basemodel) -> tuple: ...@@ -132,5 +132,3 @@ def get_detected_text(ocr_req: ocr_req_basemodel) -> tuple:
return ocr_res return ocr_res
# Import FastAPI Libs # Import FastAPI Libs
from pydantic import BaseModel from pydantic import BaseModel
<<<<<<< HEAD
=======
import fastapi_security_util
from jose import JWTError, jwt
from fastapi import HTTPException, status
>>>>>>> 753bfe8e87baa5e77942adc934d18c6febdf34fa
# Import Dashscope Service Libs # Import Dashscope Service Libs
from http import HTTPStatus from http import HTTPStatus
...@@ -8,6 +14,11 @@ from dashscope.api_entities.dashscope_response import Role ...@@ -8,6 +14,11 @@ from dashscope.api_entities.dashscope_response import Role
import dashscope import dashscope
import json import json
import os import os
<<<<<<< HEAD
=======
import asyncio
import uuid
>>>>>>> 753bfe8e87baa5e77942adc934d18c6febdf34fa
# Import your custom logging manager # Import your custom logging manager
import logging import logging
...@@ -47,6 +58,25 @@ def get_credentials(file: str, gen_ai_id: str) -> str: ...@@ -47,6 +58,25 @@ def get_credentials(file: str, gen_ai_id: str) -> str:
return None return None
# ------------------------------------------------
# Sub Function - Verify Access Token
# ------------------------------------------------
def verify_token(token: str):
secret_key, client_db = fastapi_security_util.get_credentials(fastapi_security_util.credential_file)
try:
payload = jwt.decode(token, secret_key, algorithms=[fastapi_security_util.algorithm])
lm.log(logging.ERROR, curr_module, 'payload: ', str(payload))
username: str = payload.get("sub")
if username is None:
return False
return True
except JWTError:
lm.log(logging.ERROR, curr_module, 'JWTError: ', str(JWTError))
return False
# ------------------------------------------------ # ------------------------------------------------
# Define GenAI Message Service Request and Response # Define GenAI Message Service Request and Response
# ------------------------------------------------ # ------------------------------------------------
...@@ -62,6 +92,10 @@ class send_message_res_basemodel(BaseModel): ...@@ -62,6 +92,10 @@ class send_message_res_basemodel(BaseModel):
result: int result: int
result_message: str result_message: str
message_res: str message_res: str
<<<<<<< HEAD
=======
uuid: str
>>>>>>> 753bfe8e87baa5e77942adc934d18c6febdf34fa
# ------------------------------------------------ # ------------------------------------------------
...@@ -89,7 +123,11 @@ def send_message(messages_req: send_message_req_basemodel) -> send_message_res_b ...@@ -89,7 +123,11 @@ def send_message(messages_req: send_message_req_basemodel) -> send_message_res_b
dashscope.api_key = api_key dashscope.api_key = api_key
# Set system prompt # Set system prompt
<<<<<<< HEAD
messages = [{'role': Role.SYSTEM, 'content': 'You are a helpful assistant.'}] messages = [{'role': Role.SYSTEM, 'content': 'You are a helpful assistant.'}]
=======
messages = [{'role': Role.SYSTEM, 'content': 'You are a helpful assistant. If you are not sure the answer, please don''t reply with wrong answer.'}]
>>>>>>> 753bfe8e87baa5e77942adc934d18c6febdf34fa
# Loop input message list # Loop input message list
for message in messages_req.messages: for message in messages_req.messages:
...@@ -109,11 +147,34 @@ def send_message(messages_req: send_message_req_basemodel) -> send_message_res_b ...@@ -109,11 +147,34 @@ def send_message(messages_req: send_message_req_basemodel) -> send_message_res_b
model = Generation.Models.qwen_plus model = Generation.Models.qwen_plus
# Use SDK to get answer # Use SDK to get answer
<<<<<<< HEAD
response = Generation.call( response = Generation.call(
model, # Generation.Models.qwen_turbo, model, # Generation.Models.qwen_turbo,
messages=messages, messages=messages,
result_format='message', # set the result to be 'message' format. result_format='message', # set the result to be 'message' format.
) )
=======
responses = Generation.call(
model, # Generation.Models.qwen_turbo,
messages=messages,
result_format='message', # set the result to be 'message' format.
# Use stream style
stream=True,
incremental_output=True
)
for debug_response in responses:
lm.log(logging.INFO, curr_module, 'Normal response', str(debug_response.output))
response = debug_response
# if response.status_code == HTTPStatus.OK:
# print(response.output.choices[0]['message']['content'],end='')
# else:
# print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
# response.request_id, response.status_code,
# response.code, response.message
# ))
>>>>>>> 753bfe8e87baa5e77942adc934d18c6febdf34fa
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
lm.log(logging.INFO, curr_module, 'Normal response', str(response)) lm.log(logging.INFO, curr_module, 'Normal response', str(response))
...@@ -151,3 +212,150 @@ def send_message(messages_req: send_message_req_basemodel) -> send_message_res_b ...@@ -151,3 +212,150 @@ def send_message(messages_req: send_message_req_basemodel) -> send_message_res_b
return send_message_res return send_message_res
<<<<<<< HEAD
=======
# ------------------------------------------------
# Function to store question (async)
# ------------------------------------------------
def send_question(token: str, send_message_req: send_message_req_basemodel):
if not verify_token(token = token):
raise HTTPException(
status_code = status.HTTP_401_UNAUTHORIZED,
detail = "Authentication Failed",
headers={"WWW-Authenticate": "Bearer"},
)
result = c_ret_code_success
result_message = ""
# Store questions and uuid at somewhere
# Write to file named by UUID
uuid_value = str(uuid.uuid4())
uuid_file = os.path.join('/tmp/genai-stage', uuid_value)
with open(uuid_file, 'w') as file:
file.write(json.dumps(send_message_req.dict()))
send_message_res = send_message_res_basemodel(
result = result,
result_message = result_message,
message_res = '',
uuid = uuid_value
)
return send_message_res
# ------------------------------------------------
# Function to call dashscope SDK (async)
# ------------------------------------------------
async def get_answer(uuid: str):
result = c_ret_code_success
result_message = ""
# Try to read uuid file
uuid_file = os.path.join('/tmp/genai-stage', uuid)
with open(uuid_file, 'r') as file:
content = file.read()
lm.log(logging.INFO, curr_module, 'content: ', content)
file.close()
# Init basemodel
messages = []
messages_req = send_message_req_basemodel(
model = "",
messages = messages
)
# Try to deserialize
try:
data = json.loads(content)
messages_req = send_message_req_basemodel(**data)
lm.log(logging.INFO, curr_module, 'Deserilized: ', str(messages_req))
except json.JSONDecodeError as e:
lm.log(logging.INFO, curr_module, 'Failed to decode JSON: ', str(e))
except Exception as e:
lm.log(logging.INFO, curr_module, 'Failed to create model instance: ', str(e))
try:
# Get API Key
api_key = get_credentials(credential_file, gen_ai_id)
if (api_key == None):
result = c_ret_code_error
# Set return tuple for normal result
send_message_res = send_message_res_basemodel(
result = result,
result_message = 'Failed to get API key',
message_res = ''
)
# return send_message_res
yield send_message_res
return
else:
dashscope.api_key = api_key
# Set system prompt
messages = [{'role': Role.SYSTEM, 'content': 'You are a helpful assistant. If you are not sure the answer, please don''t reply with wrong answer.'}]
# Loop input message list
for message in messages_req.messages:
# Set user and assistant prompt
messages.append({'role': message.role, 'content': message.text})
lm.log(logging.INFO, curr_module, 'message: ', str(messages))
# Get Model
# model = Models()
match messages_req.model:
case 'qwen-max':
model = Generation.Models.qwen_max
case 'qwen-turbo':
model = Generation.Models.qwen_turbo
case 'qwen-plus':
model = Generation.Models.qwen_plus
# Use SDK to get answer
responses = Generation.call(
model, # Generation.Models.qwen_turbo,
messages=messages,
result_format='message', # set the result to be 'message' format.
# Use stream style
stream=True,
incremental_output=True
)
for debug_response in responses:
lm.log(logging.INFO, curr_module, 'Normal response', str(debug_response.output))
# yield f"data: {debug_response.output.choices[0]['message']['content']}\n\n"
raw_data = debug_response.output.choices[0]['message']['content']
# 针对Markdown格式进行优化
raw_data = raw_data.replace('\n\n', '\n')
raw_data = raw_data.replace('\n', '\n\n')
if raw_data.startswith('\n'):
raw_data = raw_data[1:]
lines = raw_data.splitlines()
for line in lines:
yield f"data: {line}\n\n"
# response = debug_response
await asyncio.sleep(0.1)
# if response.status_code == HTTPStatus.OK:
# print(response.output.choices[0]['message']['content'],end='')
# else:
# print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
# response.request_id, response.status_code,
# response.code, response.message
# ))
# yield 'data: null\n\n'
yield 'event: end\ndata: The stream is about to end\n\n'
os.remove(uuid_file)
except Exception as e:
result = c_ret_code_error
result_message = str(e)
lm.log(logging.ERROR, 'Exception', result_message)
yield f"data: Exception - {result_message}\n\n"
>>>>>>> 753bfe8e87baa5e77942adc934d18c6febdf34fa
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment