Archive

Posts Tagged ‘FastAPI’

FastAPI 使用JWT认证的中间件

May 25th, 2020 No comments

FastAPI 使用JWT认证的中间件
fastapi的中间件还是太少,单独开发JWT需要,starlette本身提供认证相关实现,只需要自定义一个AuthenticationBackend即可,本次我们实现使用中间价方式拆包JWT的令牌,获取payload里面的用户信息

私有定义的payload内容格式如下

{
"usid": "SkDQBhEjUfygRSeEBech", //UUID Short
"uname": "test user name",  //Username
"mid":"700010001" // Member ID
}

调用代码

app = FastAPI()
app.add_middleware(AuthenticationMiddleware,backend=JWTAuthenticationBackend(secret_key="YOUR_SECRET_KEY"))

完整的代码

import jwt
 
from starlette.authentication import (
    AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
    UnauthenticatedUser )
 
 
class JWTUser(BaseUser):
    def __init__(self, user_id_short: str, member_number: str, user_name: str,token: str, payload: dict) -> None:
        self.user_name = user_name
        self.user_id_short = user_id_short
        self.member_number = member_number
        self.token = token
        self.payload = payload
 
    @property
    def is_authenticated(self) -> bool:
        return True
 
    @property
    def display_name(self) -> str:
        return self.user_name   ## 会员名处理,加*?
 
    @property
    def display_user_id_short(self) -> str:
        return self.user_id_short   ## 会员id处理
 
    @property
    def display_member_number(self) -> str:
        return self.member_number  ## 会员号处理,加*?
 
 
class JWTAuthenticationBackend(AuthenticationBackend):
 
    def __init__(self, secret_key: str, algorithm: str = 'HS256', prefix: str = 'JWT', user_name_field:str = 'uname' , user_id_field: str = 'usid',member_number_field: str = 'mid'):
        self.secret_key = secret_key
        self.algorithm = algorithm
        self.prefix = prefix
        self.user_name_field = user_name_field
        self.user_id_field = user_id_field
        self.member_number_field = member_number_field
 
 
    @classmethod
    def get_token_from_header(cls, authorization: str, prefix: str):
        """
        Parses the Authorization header and returns only the token
        :param authorization:
        :return:
        """
        try:
            scheme, token = authorization.split()
        except ValueError:
            raise AuthenticationError('Could not separate Authorization scheme and token')
        if scheme.lower() != prefix.lower():
            raise AuthenticationError(f'Authorization scheme {scheme} is not supported')
        return token
 
    async def authenticate(self, request):
        if "Authorization" not in request.headers:
            return None
 
        auth = request.headers["Authorization"]
        token = self.get_token_from_header(authorization=auth, prefix=self.prefix)
        try:
            payload = jwt.decode(token, key=self.secret_key, algorithms=self.algorithm)
        except jwt.InvalidTokenError as e:
            raise AuthenticationError(str(e))
 
        return AuthCredentials(["jwt_authenticated"]), JWTUser(user_id_short=payload[self.user_id_field], member_number=payload[self.member_number_field],user_name=payload[self.user_name_field],token=token,
                                                           payload=payload)
 
 
class JWTWebSocketAuthenticationBackend(AuthenticationBackend):
 
    def __init__(self, secret_key: str, algorithm: str = 'HS256', query_param_name: str = 'jwt',
                 user_name_field:str = 'uname' , user_id_field: str = 'usid',member_number_field: str = 'mid'):
        self.secret_key = secret_key
        self.algorithm = algorithm
        self.query_param_name = query_param_name
        self.user_name_field = user_name_field
        self.user_id_field = user_id_field
        self.member_number_field = member_number_field
 
    async def authenticate(self, request):
        if self.query_param_name not in request.query_params:
            return AuthCredentials(), UnauthenticatedUser()
 
        token = request.query_params[self.query_param_name]
 
        try:
            payload = jwt.decode(token, key=self.secret_key, algorithms=self.algorithm)
        except jwt.InvalidTokenError as e:
            raise AuthenticationError(str(e))
 
        return AuthCredentials(["jwt_authenticated"]), JWTUser(user_id_short=payload[self.user_id_field], member_number=payload[self.member_number_field],user_name=payload[self.user_name_field],token=token,
                                                           payload=payload)
Categories: 语言编程 Tags: , ,