This repository was archived by the owner on Aug 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 138
/
Copy pathcustom_sqlite.py
105 lines (92 loc) · 3.63 KB
/
custom_sqlite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import aiosqlite
from collections import OrderedDict
from typing import AsyncIterator
from dffml import config, Record, BaseSource, BaseSourceContext
@config
class CustomSQLiteSourceConfig:
filename: str
class CustomSQLiteSourceContext(BaseSourceContext):
async def update(self, record: Record):
db = self.parent.db
# Store feature data
feature_cols = self.parent.FEATURE_COLS
feature_data = OrderedDict.fromkeys(feature_cols)
feature_data.update(record.features(feature_cols))
await db.execute(
"INSERT OR REPLACE INTO features (key, "
+ ", ".join(feature_cols)
+ ") "
"VALUES(?, " + ", ".join("?" * len(feature_cols)) + ")",
[record.key] + list(feature_data.values()),
)
# Store prediction
try:
prediction = record.prediction("target_name")
prediction_cols = self.parent.PREDICTION_COLS
prediction_data = OrderedDict.fromkeys(prediction_cols)
prediction_data.update(prediction.dict())
await db.execute(
"INSERT OR REPLACE INTO prediction (key, "
+ ", ".join(prediction_cols)
+ ") "
"VALUES(?, " + ", ".join("?" * len(prediction_cols)) + ")",
[record.key] + list(prediction_data.values()),
)
except KeyError:
pass
async def records(self) -> AsyncIterator[Record]:
# NOTE This logic probably isn't what you want. Only for demo purposes.
keys = await self.parent.db.execute("SELECT key FROM features")
for row in await keys.fetchall():
yield await self.record(row["key"])
async def record(self, key: str):
db = self.parent.db
record = Record(key)
# Get features
features = await db.execute(
"SELECT " + ", ".join(self.parent.FEATURE_COLS) + " "
"FROM features WHERE key=?",
(record.key,),
)
features = await features.fetchone()
if features is not None:
record.evaluated(features)
# Get prediction
prediction = await db.execute(
"SELECT * FROM prediction WHERE " "key=?", (record.key,)
)
prediction = await prediction.fetchone()
if prediction is not None:
record.predicted(
"target_name", prediction["value"], prediction["confidence"]
)
return record
async def __aexit__(self, exc_type, exc_value, traceback):
await self.parent.db.commit()
class CustomSQLiteSource(BaseSource):
CONFIG = CustomSQLiteSourceConfig
CONTEXT = CustomSQLiteSourceContext
FEATURE_COLS = ["PetalLength", "PetalWidth", "SepalLength", "SepalWidth"]
PREDICTION_COLS = ["value", "confidence"]
async def __aenter__(self) -> "BaseSourceContext":
self.__db = aiosqlite.connect(self.config.filename)
self.db = await self.__db.__aenter__()
self.db.row_factory = aiosqlite.Row
# Create table for feature data
await self.db.execute(
"CREATE TABLE IF NOT EXISTS features ("
"key TEXT PRIMARY KEY NOT NULL, "
+ (" REAL, ".join(self.FEATURE_COLS))
+ " REAL"
")"
)
# Create table for predictions
await self.db.execute(
"CREATE TABLE IF NOT EXISTS prediction ("
"key TEXT PRIMARY KEY, " + "value TEXT, "
"confidence REAL"
")"
)
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.__db.__aexit__(exc_type, exc_value, traceback)