Last active
December 29, 2015 05:48
-
-
Save jharmn/7624044 to your computer and use it in GitHub Desktop.
First shot at setting up a flask-oauthlib server for OAuth2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| from flask import Flask | |
| from flask import session, request | |
| from flask import render_template, redirect, jsonify | |
| from flask_sqlalchemy import SQLAlchemy | |
| from flask_oauthlib.provider import OAuth2Provider | |
| from sqlalchemy.orm import relationship | |
| from werkzeug.security import gen_salt | |
| from datetime import datetime, timedelta | |
| app = Flask(__name__, template_folder='templates') | |
| oauth = OAuth2Provider(app) | |
| app.debug = True | |
| app.secret_key = 'secret' | |
| app.config.update({ | |
| 'SQLALCHEMY_DATABASE_URI': 'sqlite:///db.sqlite', | |
| }) | |
| logger = logging.getLogger('flask_oauthlib') | |
| fh = logging.FileHandler('flask.log') | |
| fh.setLevel(logging.DEBUG) | |
| ch = logging.StreamHandler() | |
| ch.setLevel(logging.DEBUG) | |
| logger.addHandler(fh) | |
| logger.addHandler(ch) | |
| app.logger.addHandler(fh) | |
| app.logger.addHandler(ch) | |
| db = SQLAlchemy(app) | |
| class Client(db.Model): | |
| # human readable name, not required | |
| name = db.Column(db.Unicode(40)) | |
| # human readable description, not required | |
| description = db.Column(db.Unicode(400)) | |
| # creator of the client, not required | |
| user_id = db.Column(db.ForeignKey('user.id')) | |
| # required if you need to support client credential | |
| user = relationship('User') | |
| client_id = db.Column(db.Unicode(40), primary_key=True) | |
| client_secret = db.Column(db.Unicode(55), unique=True, index=True, | |
| nullable=False) | |
| # public or confidential | |
| is_confidential = db.Column(db.Boolean) | |
| _redirect_uris = db.Column(db.UnicodeText) | |
| _default_scopes = db.Column(db.UnicodeText) | |
| @property | |
| def client_type(self): | |
| if self.is_confidential: | |
| return 'confidential' | |
| return 'public' | |
| @property | |
| def redirect_uris(self): | |
| if self._redirect_uris: | |
| return self._redirect_uris.split() | |
| return [] | |
| @property | |
| def default_redirect_uri(self): | |
| return self.redirect_uris[0] | |
| @property | |
| def default_scopes(self): | |
| if self._default_scopes: | |
| return self._default_scopes.split() | |
| return [] | |
| class Grant(db.Model): | |
| id = db.Column(db.Integer, primary_key=True) | |
| user_id = db.Column( | |
| db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') | |
| ) | |
| user = relationship('User') | |
| client_id = db.Column( | |
| db.Unicode(40), db.ForeignKey('client.client_id'), | |
| nullable=False, | |
| ) | |
| client = relationship('Client') | |
| code = db.Column(db.Unicode(255), index=True, nullable=False) | |
| redirect_uri = db.Column(db.Unicode(255)) | |
| expires = db.Column(db.DateTime) | |
| _scopes = db.Column(db.UnicodeText) | |
| def delete(self): | |
| db.session.delete(self) | |
| db.session.commit() | |
| return self | |
| @property | |
| def scopes(self): | |
| if self._scopes: | |
| return self._scopes.split() | |
| return [] | |
| class Token(db.Model): | |
| id = db.Column(db.Integer, primary_key=True) | |
| client_id = db.Column( | |
| db.Unicode(40), db.ForeignKey('client.client_id'), | |
| nullable=False, | |
| ) | |
| client = relationship('Client') | |
| user_id = db.Column( | |
| db.Integer, db.ForeignKey('user.id') | |
| ) | |
| user = relationship('User') | |
| # currently only bearer is supported | |
| token_type = db.Column(db.Unicode(40)) | |
| access_token = db.Column(db.Unicode(255), unique=True) | |
| refresh_token = db.Column(db.Unicode(255), unique=True) | |
| expires = db.Column(db.DateTime) | |
| _scopes = db.Column(db.UnicodeText) | |
| @property | |
| def scopes(self): | |
| if self._scopes: | |
| return self._scopes.split() | |
| return [] | |
| class User(db.Model): | |
| id = db.Column(db.Integer, primary_key=True) | |
| username = db.Column(db.String(40), unique=True) | |
| password = db.Column(db.String(40), unique=True) | |
| def check_password(username, password): | |
| if True: # self.password == password: | |
| return True | |
| else: | |
| return False | |
| def current_user(): | |
| if 'id' in session: | |
| uid = session['id'] | |
| return User.query.get(uid) | |
| return None | |
| @app.route('/client') | |
| def client(): | |
| user = current_user() | |
| if not user: | |
| return redirect('/') | |
| item = Client( | |
| client_id=gen_salt(40), | |
| client_secret=gen_salt(50), | |
| user_id=user.id, | |
| ) | |
| db.session.add(item) | |
| db.session.commit() | |
| return jsonify( | |
| client_id=item.client_id, | |
| client_secret=item.client_secret | |
| ) | |
| @oauth.clientgetter | |
| def load_client(client_id): | |
| client_selected = Client.query.filter_by(client_id=client_id).first() | |
| return client_selected | |
| @oauth.grantgetter | |
| def load_grant(client_id, code): | |
| grant_selected = Grant.query.filter_by(client_id=client_id, code=code).first() | |
| print "Grant is %s" % (grant_selected.id) | |
| return grant_selected | |
| @oauth.grantsetter | |
| def save_grant(client_id, code, request, *args, **kwargs): | |
| # decide the expires time yourself | |
| expires = datetime.utcnow() + timedelta(seconds=100) | |
| grant = Grant( | |
| client_id=client_id, | |
| code=code['code'], | |
| redirect_uri=request.redirect_uri, | |
| _scopes=' '.join(request.scopes), | |
| user=get_current_user(), | |
| expires=expires | |
| ) | |
| db.session.add(grant) | |
| db.session.commit() | |
| return grant | |
| @oauth.tokengetter | |
| def load_token(access_token=None, refresh_token=None): | |
| if access_token: | |
| return Token.query.filter_by(access_token=access_token).first() | |
| elif refresh_token: | |
| return Token.query.filter_by(refresh_token=refresh_token).first() | |
| from datetime import datetime, timedelta | |
| @oauth.tokensetter | |
| def save_token(token, request, *args, **kwargs): | |
| toks = Token.query.filter_by(client_id=request.client.client_id, | |
| user_id=request.user.id) | |
| # make sure that every client has only one token connected to a user | |
| # db.session.delete(toks) | |
| expires_in = token.pop('expires_in') | |
| expires = datetime.utcnow() + timedelta(seconds=expires_in) | |
| tok = Token( | |
| access_token=token['access_token'], | |
| refresh_token=token['refresh_token'], | |
| token_type=token['token_type'], | |
| _scopes=token['scope'], | |
| expires=expires, | |
| client_id=request.client.client_id, | |
| user_id=request.user.id, | |
| ) | |
| db.session.add(tok) | |
| db.session.commit() | |
| return tok | |
| @oauth.usergetter | |
| def get_user(username, password, *args, **kwargs): | |
| user = User.query.filter_by(username=username).first() | |
| if user.check_password(password): | |
| return user | |
| return None | |
| @app.route('/oauth/token', methods=['POST']) | |
| @oauth.token_handler | |
| def access_token(): | |
| return None | |
| @app.route('/api/me') | |
| @oauth.require_oauth('standard,advanced') | |
| def me(req): | |
| return jsonify(username=req.user.username) | |
| @app.route('/', methods=('GET', 'POST')) | |
| def home(): | |
| if request.method == 'POST': | |
| username = request.form.get('username') | |
| user = User.query.filter_by(username=username).first() | |
| default_password="password" | |
| if not user: | |
| user = User(username=username, password=default_password) | |
| db.session.add(user) | |
| db.session.commit() | |
| session['id'] = user.id | |
| return redirect('/') | |
| user = current_user() | |
| return render_template('home.html', user=user) | |
| @app.route('/oauth/authorize', methods=['GET', 'POST']) | |
| #@require_login | |
| @oauth.authorize_handler | |
| def authorize(*args, **kwargs): | |
| if request.method == 'GET': | |
| client_id = kwargs.get('client_id') | |
| client = Client.query.filter_by(client_id=client_id).first() | |
| kwargs['client'] = client | |
| return render_template('oauthorize.html', **kwargs) | |
| confirm = request.form.get('confirm', 'no') | |
| return confirm == 'yes' | |
| if __name__ == '__main__': | |
| db.create_all() | |
| app.run(host="0.0.0.0") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment