Skip to content

Commit

Permalink
Fix: 1. remove RandomForestClassifier dependency by AmazonScanHarvester
Browse files Browse the repository at this point in the history
  • Loading branch information
platonai committed Nov 16, 2023
1 parent ae9f0d4 commit c9f324e
Showing 1 changed file with 18 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package ai.platon.exotic.examples.ml.supervised

//import ai.platon.exotic.ml.RandomForestClassifier
import ai.platon.exotic.crawl.common.AmazonAsinUrlNormalizer
import ai.platon.exotic.ml.RandomForestClassifier
import ai.platon.pulsar.common.AppPaths
import ai.platon.pulsar.dom.nodes.node.ext.isText
import ai.platon.pulsar.persist.gora.generated.GWebPage
import ai.platon.scent.common.clearMLLabels
import ai.platon.scent.common.mlLabels
import ai.platon.scent.dom.nodes.node.ext.nthScreen
import ai.platon.scent.ml.EncodeOptions
import ai.platon.scent.tools.VerboseCrawler
import java.lang.management.ManagementFactory
Expand Down Expand Up @@ -35,9 +33,9 @@ class AmazonScanHarvester {
private val datasetPath = AppPaths.getTmp("amazon.dataset.libsvm.txt")

val labels = listOf("stars", "stars_text", "ratings", "qas", "price_text", "brand")
private val classifier = RandomForestClassifier(labels.size, datasetPath)

// private val classifier = RandomForestClassifier(labels.size, datasetPath)

init {
session.context.urlNormalizer.add(AmazonAsinUrlNormalizer())
}
Expand Down Expand Up @@ -86,18 +84,18 @@ class AmazonScanHarvester {
// labeling & encoding
}

fun train() {
classifier.train()
}
fun predict(url: String) {
val document = session.loadDocument(url)
val encodeOptions = EncodeOptions(labels = labels)
val df = session.encodeNodes(document, encodeOptions) { it.isText && it.nthScreen <= 2 }
df.points.forEach { point ->
classifier.predict(point.dataRef)
}
}
// fun train() {
// classifier.train()
// }
//
// fun predict(url: String) {
// val document = session.loadDocument(url)
// val encodeOptions = EncodeOptions(labels = labels)
// val df = session.encodeNodes(document, encodeOptions) { it.isText && it.nthScreen <= 2 }
// df.points.forEach { point ->
// classifier.predict(point.dataRef)
// }
// }
}

fun main(args: Array<String>) {
Expand All @@ -124,8 +122,8 @@ fun main(args: Array<String>) {
"check" -> harvester.check(start, limit)
"harvest" -> harvester.harvest(start, limit, args2)
"encode" -> harvester.encode(start, limit, args2)
"train" -> harvester.train()
"predict" -> harvester.predict(url)
// "train" -> harvester.train()
// "predict" -> harvester.predict(url)
"clearAnnotations" -> harvester.clearAnnotations(start, limit)
}
}

0 comments on commit c9f324e

Please sign in to comment.