Skip to content

Instantly share code, notes, and snippets.

@ericacm
Created July 2, 2011 13:11
Show Gist options
  • Save ericacm/1060126 to your computer and use it in GitHub Desktop.
Save ericacm/1060126 to your computer and use it in GitHub Desktop.
Security OAuth 2 Provider and JPA Domain Classes
class ProviderTokenServices(dataSource: DataSource) extends JdbcOAuth2ProviderTokenServices(dataSource) with InitializingBean {
var log: Logger = LoggerFactory.getLogger(this.getClass.getName)
@Autowired var domainManager: DomainManager = _
@Autowired var schedulerService: SchedulerService = _
@Value("${providerTokenServices.pruneSchedule:0 0 12 * * *}")
var pruneSchedule: String = _
override def storeAccessToken(token: OAuth2AccessToken, authentication: OAuth2Authentication[_ <: Authentication, _ <: Authentication]) {
val refreshToken = if (token.getRefreshToken != null) token.getRefreshToken.getValue else null
authentication.getUserAuthentication.getPrincipal match {
case userDetails: AuthUserDetails =>
domainManager.withTransaction{ txStatus =>
OauthAccessToken.create(token.getValue, token.getExpiration, userDetails.userId, userDetails.ssoGroupId,
SerializationUtils.serialize(token), SerializationUtils.serialize(authentication), refreshToken)
}
case _ => throw new IllegalStateException("User Authentication Details is not an AuthUserDetails")
}
}
override def readAuthentication(token: OAuth2AccessToken): OAuth2Authentication[_ <: Authentication, _ <: Authentication] = {
domainManager.withTransaction({ txStatus => OauthAccessToken.lookup(token.getValue) }, true).asInstanceOf[Option[OauthAccessToken]] match {
case Some(accessToken) =>
SerializationUtils.deserialize(accessToken.authentication)
case None =>
log.info("No OauthAccessToken found for: "+token.getValue)
null
}
}
override def readAccessToken(tokenValue: String): OAuth2AccessToken = {
domainManager.withTransaction({ txStatus => OauthAccessToken.lookup(tokenValue) }, true).asInstanceOf[Option[OauthAccessToken]] match {
case Some(accessToken) =>
SerializationUtils.deserialize(accessToken.token)
case None =>
log.info("No OauthAccessToken found for: "+tokenValue)
null
}
}
override def removeAccessToken(tokenValue: String) = super.removeAccessToken(tokenValue)
def deleteAllAccessTokensByUserId(userId: Long) {
domainManager.withTransaction{ txStatus => OauthAccessToken.deleteAllByUserId(userId) }
}
override def afterPropertiesSet() {
log.debug("Adding token prune job to scheduler")
schedulerService.addJob(new SchedulerJob("pruneTokens", pruneTokens _, pruneSchedule))
}
def pruneTokens() {
log.debug("Pruning expired tokens")
domainManager.withTransaction { txStatus => OauthAccessToken.deleteExpired(new Date) }
}
}
@Entity
@Table(name="oauth_code")
@NamedQueries(Array(
new NamedQuery(name="OauthVerificationCode.findAllByCode", query="from OauthAccessToken where code = :code"),
new NamedQuery(name="OauthVerificationCode.deleteExpired", query="delete from OauthAccessToken where expiration <= :asOf")
))
class OauthVerificationCode {
@Id
var code: String = _
@Index(name="expiration")
var expiration: Date = _
@Lob
var authentication: Array[Byte] = _
}
object OauthVerificationCode {
@Autowired var domainManager: DomainManager = null
AppConfig.inject(this)
def create(code: String, expiration: Date, authBytes: Array[Byte]) = {
val newCode = new OauthVerificationCode
newCode.code = code
newCode.expiration = expiration
newCode.authentication = authBytes.clone
domainManager.save(newCode)
newCode
}
def lookup(code: String): Option[OauthVerificationCode] = {
return domainManager.get[OauthVerificationCode](code)
}
def delete(ovcode: OauthVerificationCode) {
domainManager.delete(ovcode)
}
def deleteExpired(expiration: Date) = {
domainManager.executeUpdate("OauthAccessToken.deleteExpired", Map("asOf"->expiration))
}
}
@Entity
@Table(name="oauth_access_token")
@NamedQueries(Array(
new NamedQuery(name="OauthAccessToken.findAllByUserId", query="from OauthAccessToken where user_id = :userId"),
new NamedQuery(name="OauthAccessToken.deleteAllByUserId", query="delete from OauthAccessToken where user_id = :userId"),
new NamedQuery(name="OauthAccessToken.deleteExpired", query="delete from OauthAccessToken where expiration <= :asOf")
))
class OauthAccessToken {
@Id
var token_id: String = _
@Index(name="expiration")
var expiration: Date = _
@Index(name="user_id")
var user_id: Long = _
var sso_group_id: Long = _
@Lob
var token: Array[Byte] = _
@Lob
var authentication: Array[Byte] = _
var refresh_token: String = _
}
object OauthAccessToken {
@Autowired var domainManager: DomainManager = null
AppConfig.inject(this)
def create(tokenId: String, expiration: Date, userId: Long, ssoGroupId: Long, tokenBytes: Array[Byte], authBytes: Array[Byte], refreshToken: String) = {
val newToken = new OauthAccessToken
newToken.token_id = tokenId
newToken.expiration = expiration
newToken.user_id = userId
newToken.sso_group_id = ssoGroupId
newToken.token = tokenBytes.clone
newToken.authentication = authBytes.clone
domainManager.save(newToken)
newToken
}
def lookup(tokenId: String): Option[OauthAccessToken] = {
return domainManager.get[OauthAccessToken](tokenId)
}
def findAllByUserId(userId: Long): Buffer[OauthAccessToken] =
domainManager.findAll[OauthAccessToken]("OauthAccessToken.findAllByUserId", Map("userId"->userId))
def deleteAllByUserId(userId: Long) {
domainManager.executeUpdate("OauthAccessToken.deleteAllByUserId", Map("userId"->userId))
}
def deleteExpired(expiration: Date) = {
domainManager.executeUpdate("OauthAccessToken.deleteExpired", Map("asOf"->expiration))
}
}
@aschwin
Copy link

aschwin commented May 22, 2014

This is in Scala, but how do I get it in Grails properly? I'm new to Grails and Java and I need to implement a OAuth provider which can handle the 2-legged flow.

Can you help me with that? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment