Skip to content

Commit

Permalink
Merge pull request byzer-org#236 from allwefantasy/mlsql
Browse files Browse the repository at this point in the history
添加加载图片类
  • Loading branch information
allwefantasy authored May 29, 2018
2 parents 62f6c63 + 3ecbff6 commit a3cc66b
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 2 deletions.
12 changes: 11 additions & 1 deletion docs/mlsql-data-processing-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,17 @@ OpenCVImage模块主要是对图像做处理。第一版仅仅能够做resize动
select crawler_request_image("https://tpc.googlesyndication.com/simgad/10310202961328364833") as imagePath
as images;

--
-- 或者你可能因为训练的原因,需要加载一个图片数据集 该表只有一个字段image,但是image是一个复杂字段,其中origin 带有路径信息。
load image.`/Users/allwefantasy/CSDNWorkSpace/streamingpro/images`
options
-- 是不是需要递归查找图片
recursive="false"
-- 是不是丢弃掉解析失败的图片
and dropImageFailures="false"
-- 采样比例
and sampleRatio="1.0"
as images;

train images as OpenCVImage.`/tmp/word2vecinplace`
where inputCol="imagePath"
-- 宽度和高度重新设置为100
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ class SQLOpenCVImage extends SQLAlg with SQlBaseFunc {
ImageSchema.decode("", a).getOrElse(ImageSchema.invalidImageRow(""))
}
val imageRdd = df.rdd.map { f =>
val image = decodeImage(f.getAs[Array[Byte]](inputCol)).getStruct(0)
val index = f.schema.fieldNames.indexOf(inputCol)
val image = if (f.schema(index).dataType.getClass.getSimpleName == "StructType") {
f.getStruct(index)
} else {
decodeImage(f.getAs[Array[Byte]](inputCol)).getStruct(0)
}

var cvImage: IplImage = null
var targetImage: IplImage = null
var data: Array[Byte] = Array()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package streaming.dsl.mmlib.algs.processing.image

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, SaveMode, _}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType

/**
* Created by allwefantasy on 29/5/2018.
*/
class DefaultSource extends RelationProvider with CreatableRelationProvider with DataSourceRegister {
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
ImageRelation(parameters, None)(sqlContext)
}

override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
null
}

override def shortName(): String = "Image"
}

case class ImageRelation(
parameters: Map[String, String],
userSpecifiedschema: Option[StructType]
)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with Logging {
override def schema: StructType = ImageSchema.imageDFSchema

override def buildScan(): RDD[Row] = {
val path = parameters("path")
val recursive = parameters.getOrElse("recursive", "false").toBoolean
val dropImageFailures = parameters.getOrElse("dropImageFailures", "false").toBoolean
val sampleRatio = parameters.getOrElse("sampleRatio", "1.0").toDouble
val spark = sqlContext.sparkSession
ImageSchema.readImages(path = path, sparkSession = spark, recursive = recursive, sampleRatio = sampleRatio, dropImageFailures = dropImageFailures).rdd
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class BatchLoadAdaptor(scriptSQLExecListener: ScriptSQLExecListener,
table = reader.format("org.apache.spark.sql.execution.datasources.hbase").load()
case "crawlersql" =>
table = reader.option("path", cleanStr(path)).format("org.apache.spark.sql.execution.datasources.crawlersql").load()
case "image" =>
table = reader.option("path", cleanStr(path)).format("streaming.dsl.mmlib.algs.processing.image").load()
case _ =>
val owner = option.get("owner")
table = reader.format(format).load(withPathPrefix(scriptSQLExecListener.pathPrefix(owner), cleanStr(path)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,4 +442,17 @@ class AutoMLSpec extends BasicSparkOperation with SpecFunctions with BasicMLSQLC
spark.sql("select vec_image(jack(crawler_request_image(imagePath))) as image from orginal_text_corpus").show(false)
}
}

"image-read-path" should "work fine" in {
withBatchContext(setupBatchContext(batchParams, "classpath:///test/empty.json")) { runtime: SparkRuntime =>
//执行sql
implicit val spark = runtime.sparkSession
val sq = createSSEL
ScriptSQLExec.parse("load image.`/Users/allwefantasy/CSDNWorkSpace/streamingpro/images` as images;", sq)
val df = spark.sql("select * from images");
val newDF = new SQLOpenCVImage().interval_train(df, "/tmp/image", Map("inputCol" -> "image", "shape" -> "100,100,4"))
newDF.show()

}
}
}

0 comments on commit a3cc66b

Please sign in to comment.