Skip to content

Commit

Permalink
[Auth][enhance] Select Grammar Runtime Columns Auth Mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
cfmcgrady committed Mar 14, 2019
1 parent 33b5c94 commit 6fc689a
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class AuthProcessListener(val listener: ScriptSQLExecListener) extends BaseParse
case "load" =>
new LoadAuth(this).auth(ctx)

case "select" if ENABLE_RUNTIME_SELECT_AUTH =>
case "select" if !ENABLE_RUNTIME_SELECT_AUTH =>
new SelectAuth(this).auth(ctx)

case "save" =>
Expand Down
80 changes: 44 additions & 36 deletions streamingpro-mlsql/src/main/java/streaming/dsl/SelectAdaptor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ package streaming.dsl
import scala.collection.mutable

import org.antlr.v4.runtime.misc.Interval
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import streaming.dsl.auth.{MLSQLTable, OperateType, TableType}
import streaming.dsl.parser.DSLSQLLexer
import streaming.dsl.parser.DSLSQLParser.SqlContext
import streaming.dsl.template.TemplateMerge
Expand Down Expand Up @@ -54,45 +55,52 @@ class SelectAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAda

val df = scriptSQLExecListener.sparkSession.sql(sql)

// ScriptSQLExec._tableAuth.auth(authListener.tables().tables.toList)
// scriptSQLExecListener.getTableAuth.foreach(tableAuth => {
var r = Array.empty[String]
df.queryExecution.logical.map {
case sp: UnresolvedRelation =>
r +:= sp.tableIdentifier.unquotedString.toLowerCase
case h: HiveTableRelation =>
println(s"xxxxxxxx: ${h.tableMeta.identifier}")
case _ =>
}
println(r.mkString(","))
var tableAndCols = mutable.HashMap.empty[String, mutable.HashSet[String]]
df.queryExecution.analyzed.map(lp => {
lp.output.map(o => {
val qualifier = o.qualifier.mkString(".")
if (r.contains(o.qualifier.mkString("."))) {
val value = tableAndCols.getOrElse(qualifier, mutable.HashSet.empty[String])
value.add(o.name)
tableAndCols.update(qualifier, value)
}
})
// })
tableAndCols.foreach(println)
runtimeTableAuth(df)

df.createOrReplaceTempView(tableName)
scriptSQLExecListener.setLastSelectTable(tableName)
}

println("===================")
r.distinct.map(t => {
val tt = scriptSQLExecListener.sparkSession.catalog.getTable(t)
println(tt)
println(tt.tableType)
println(tt.isTemporary)
def runtimeTableAuth(df: DataFrame): Unit = {
// enable runtime select auth
if (ENABLE_RUNTIME_SELECT_AUTH) {
scriptSQLExecListener.getTableAuth.foreach(tableAuth => {
var r = Array.empty[String]
df.queryExecution.logical.map {
case sp: UnresolvedRelation =>
r +:= sp.tableIdentifier.unquotedString.toLowerCase
case _ =>
}
var tableAndCols = mutable.HashMap.empty[String, mutable.HashSet[String]]
df.queryExecution.analyzed.map(lp => {
lp.output.map(o => {
val qualifier = o.qualifier.mkString(".")
if (r.contains(o.qualifier.mkString("."))) {
val value = tableAndCols.getOrElse(qualifier, mutable.HashSet.empty[String])
value.add(o.name)
tableAndCols.update(qualifier, value)
}
})
})
println("===================")

var mlsqlTables = List.empty[MLSQLTable]

tableAndCols.foreach {
case (table, cols) =>
val stable = scriptSQLExecListener.sparkSession.catalog.getTable(table)
val db = Option(stable.database)
val tableStr = Option(stable.name)
val ttpe = if (stable.isTemporary) {
TableType.TEMP
} else {
TableType.HIVE
}
mlsqlTables ::= MLSQLTable(db, tableStr, Option(cols.toSet), OperateType.SELECT, None, ttpe)
}

tableAuth.auth(mlsqlTables)
})
}

})

// tableAuth.auth(authListener.tables().tables.toList)
df.createOrReplaceTempView(tableName)
scriptSQLExecListener.setLastSelectTable(tableName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@

package streaming.dsl.auth

import scala.collection.mutable

import org.antlr.v4.runtime.misc.Interval
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.execution.MLSQLAuthParser
import streaming.dsl.parser.DSLSQLLexer
import streaming.dsl.parser.DSLSQLParser._
import streaming.dsl.template.TemplateMerge
import streaming.dsl.{AuthProcessListener, DslTool, ScriptSQLExecListener}
import streaming.dsl.{AuthProcessListener, DslTool}


/**
Expand Down Expand Up @@ -57,59 +53,26 @@ class SelectAuth(authProcessListener: AuthProcessListener) extends MLSQLAuth wit

val tableRefs = MLSQLAuthParser.filterTables(sql, authProcessListener.listener.sparkSession)

var tables = Array.empty[MLSQLTable]
tableRefs.foreach { f =>
val tables = tableRefs.foreach { f =>
f.database match {
case Some(db) =>
val exists = authProcessListener.withDBs.filter(m => f.table == m.table.get && db == m.db.get).size > 0
if (!exists) {
tables +:= MLSQLTable(Some(db), Some(f.table) ,OperateType.SELECT , None, TableType.HIVE)
authProcessListener.addTable(MLSQLTable(Some(db), Some(f.table) ,OperateType.SELECT , None, TableType.HIVE))
}
case None =>
val exists = authProcessListener.withoutDBs.filter(m => f.table == m.table.get).size > 0
if (!exists) {
tables +:= MLSQLTable(Some("default"), Some(f.table) ,OperateType.SELECT , None, TableType.HIVE)
authProcessListener.addTable(MLSQLTable(Some("default"), Some(f.table) ,OperateType.SELECT , None, TableType.HIVE))
}
}
}

val exists = authProcessListener.withoutDBs.filter(m => tableName == m.table.get).size > 0
if (!exists) {
tables +:= MLSQLTable(None, Some(tableName) ,OperateType.SELECT , None, TableType.TEMP)
}

val df = authProcessListener.listener.sparkSession.sql(sql)
println("----final---")
var r = Array.empty[String]
df.queryExecution.logical.map {
case sp: UnresolvedRelation =>
r +:= sp.tableIdentifier.unquotedString.toLowerCase
case h: HiveTableRelation =>
println(s"xxxxxxxx: ${h.tableMeta.identifier}")

case _ =>
authProcessListener.addTable(MLSQLTable(None, Some(tableName) ,OperateType.SELECT , None, TableType.TEMP))
}
println(r.mkString(","))
var tableAndCols = mutable.HashMap.empty[String, mutable.HashSet[String]]
df.queryExecution.analyzed.map(lp => {
lp.output.map(o => {
val qualifier = o.qualifier.mkString(".")
if (r.contains(o.qualifier.mkString("."))) {
val value = tableAndCols.getOrElse(qualifier, mutable.HashSet.empty[String])
value.add(o.name)
tableAndCols.update(qualifier, value)
}
})
})
tableAndCols.foreach(println)
tables.foreach(table => {
tableAndCols.get(table.tableIdentifier)
.foreach(cols => {
authProcessListener.addTable(table.copy(columns = Option(cols.toSet)))
})
})

df.explain()

TableAuthResult.empty()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ object MLSQLAuthParser {
lazy val parserInstance = new WowSparkSqlParser(session.sqlContext.conf)
parser.compareAndSet(null, parserInstance)
parser.get().tables(sql, t)
parser.get().columns(sql)
t
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.{TableIdentifier}
import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.internal.{SQLConf, VariableSubstitution}
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types.StructType

import scala.collection.mutable.ArrayBuffer

/**
* Concrete parser for Spark SQL statements.
Expand Down Expand Up @@ -56,20 +54,6 @@ class WowSparkSqlParser(conf: SQLConf) extends AbstractSqlParser {
res
}

def columns(sqlText: String): Unit = {
println(s"sql: [ $sqlText ]")
val res = parse(sqlText) { parser =>
astBuilder.visitExpression(parser.expression()) match {
case ua: UnresolvedAttribute =>
// println("xxxxxxx")
// println(ua.name)
// ua
case _ =>
// throw new RuntimeException
}
}
}

}

/**
Expand All @@ -81,22 +65,6 @@ class WowSparkSqlAstBuilder(conf: SQLConf) extends SparkSqlAstBuilder(conf) {
TableHolder.tables.get() += ti
ti
}

override def visitExpression(ctx: ExpressionContext): AnyRef = {
val res = super.visitExpression(ctx)
println("yyyyyy")
println(res.getClass)
res match {
// case ua: UnresolvedAttribute =>
// println(ua.name)
// println(ua)
case e: ResolvedStar =>
println(e.expressions)
case _ =>
}
res
}

}

object TableHolder {
Expand Down

0 comments on commit 6fc689a

Please sign in to comment.