Skip to content

Commit d7c2eb4

Browse files
committed
Added some load functions
1 parent 8501b42 commit d7c2eb4

8 files changed

Lines changed: 121 additions & 75 deletions

File tree

src/dsff/VERSION.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.2.2
1+
1.2.3

src/dsff/formats/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
__all__ = ["DSFF"]
2525
for k in list(globals().keys()):
26-
if k.startswith("is_"):
26+
if k.startswith("is_") or k.startswith("load_"):
2727
__all__.append(k)
2828

2929
_FORMAT_TEXT_ALIAS = {'arff': "ARFF", 'csv': "CSV", 'db': "SQL", 'orc': "ORC"}

src/dsff/formats/arff.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@
22
from .__common__ import *
33

44

5-
__all__ = ["from_arff", "is_arff", "to_arff"]
5+
__all__ = ["from_arff", "is_arff", "load_arff", "to_arff"]
66

77

8-
def _parse(text_or_fh, target=TARGET_NAME, missing=MISSING_TOKEN):
8+
def _parse(text_or_fh, target=TARGET_NAME, missing=MISSING_TOKEN, **kw):
99
d, features, metadata, title = [], {}, {}, ""
1010
relation, attributes, data = False, [False, False], False
1111
for n, l in enumerate(t.splitlines() if isinstance(t := text_or_fh, str) else t):
1212
l, pf = l.strip(), f"Line {n}: "
1313
# the file shall start with "@RELATION"
1414
if not relation:
15-
if l.startswith("@RELATION "):
15+
if l.upper().startswith("@RELATION "):
1616
relation = True
1717
try:
18-
title = re.match(r"@RELATION\s+('[^']*'|\"[^\"]*\")$", l).group(1).strip("'\"")
18+
title = re.match(r"@RELATION\s+('[^']*'|\"[^\"]*\")$", l, re.I).group(1).strip("'\"")
1919
continue
2020
except Exception as e:
2121
raise BadInputData(f"{pf}failed on @RELATION ({e})")
@@ -34,9 +34,8 @@ def _parse(text_or_fh, target=TARGET_NAME, missing=MISSING_TOKEN):
3434
if attributes[0] and not attributes[1]:
3535
# close the atributes block
3636
attributes[1] = True
37-
n_cols = len(d[0])
3837
continue
39-
if l.startswith("@ATTRIBUTE "):
38+
if l.upper().startswith("@ATTRIBUTE "):
4039
if not attributes[0]:
4140
attributes[0] = True
4241
if len(d) == 0:
@@ -45,21 +44,23 @@ def _parse(text_or_fh, target=TARGET_NAME, missing=MISSING_TOKEN):
4544
if attributes[1]:
4645
raise BadInputData(f"{pf}found @ATTRIBUTE out of the attributes block)")
4746
try:
48-
header = re.match(r"@ATTRIBUTE\s+([^\s]+)\s+[A-Z]+$", l).group(1)
47+
header = re.match(r"@ATTRIBUTE\s+([^\s]+)\s+(?:[a-zA-Z]+|\{.*?\})$", l, re.I).group(1).strip("'\"")
4948
if header == "class":
5049
header = target
50+
else:
51+
features.setdefault(header, "")
5152
d[0].append(header)
5253
continue
5354
except AttributeError:
5455
raise BadInputData(f"{pf}failed on @ATTRIBUTE (bad type)")
5556
if not data:
56-
if l == "@DATA":
57+
if l.upper() == "@DATA":
5758
data = True
59+
n_cols = len(d[0])
5860
continue
5961
else:
6062
raise BadInputData(f"{pf}did not find @DATA where expected")
61-
row = list(map(lambda x: x.strip("'\""), re.split(r",\s+", l)))
62-
if len(row) != n_cols:
63+
if len(row := list(map(lambda x: x.strip("'\""), re.split(r",\s*", l)))) != n_cols:
6364
raise BadInputData(f"{pf}this row does not match the number of columns")
6465
d.append(row)
6566
for i in range(n_cols):
@@ -78,7 +79,7 @@ def _parse(text_or_fh, target=TARGET_NAME, missing=MISSING_TOKEN):
7879
return d, features, metadata, title
7980

8081

81-
def from_arff(dsff, path=None, target=TARGET_NAME, missing=MISSING_TOKEN):
82+
def from_arff(dsff, path=None, target=TARGET_NAME, missing=MISSING_TOKEN, **kw):
8283
""" Populate the DSFF file from an ARFF file. """
8384
with open(path) as f:
8485
d, ft, md, t = _parse(f, target, missing)
@@ -87,7 +88,7 @@ def from_arff(dsff, path=None, target=TARGET_NAME, missing=MISSING_TOKEN):
8788

8889

8990
@text_or_path
90-
def is_arff(text):
91+
def is_arff(text, target=TARGET_NAME, missing=MISSING_TOKEN, **kw):
9192
""" Check if the input text or path is a valid ARFF. """
9293
try:
9394
_parse(ensure_str(text))
@@ -96,7 +97,14 @@ def is_arff(text):
9697
return False
9798

9899

99-
def to_arff(dsff, path=None, target=TARGET_NAME, exclude=DEFAULT_EXCL, missing=MISSING_TOKEN, text=False):
100+
def load_arff(path, target=TARGET_NAME, missing=MISSING_TOKEN, **kw):
101+
""" Load an ARFF file as a dictionary with data, features and metadata. """
102+
with open(path) as f:
103+
d, ft, md, _ = _parse(f, target, missing)
104+
return {'data': d, 'features': ft, 'metadata': md}
105+
106+
107+
def to_arff(dsff, path=None, target=TARGET_NAME, exclude=DEFAULT_EXCL, missing=MISSING_TOKEN, text=False, **kw):
100108
""" Output the dataset in ARFF format, suitable for use with the Weka framework, saved as a file or output as a
101109
string. """
102110
name = splitext(basename(path))[0]

src/dsff/formats/csv.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,16 @@
22
from .__common__ import *
33

44

5-
__all__ = ["from_csv", "is_csv", "to_csv"]
5+
__all__ = ["from_csv", "is_csv", "load_csv", "to_csv"]
66

77

8-
def from_csv(dsff, path=None, exclude=DEFAULT_EXCL):
8+
def from_csv(dsff, path=None, exclude=DEFAULT_EXCL, **kw):
99
""" Populate the DSFF file from a CSV file. """
10-
dsff.write(path)
11-
features = {}
12-
for headers in dsff['data'].rows:
13-
for header in headers:
14-
if header.value in exclude:
15-
continue
16-
features[header.value] = ""
17-
break
18-
dsff.write(features=features)
10+
dsff.write(**load_csv(path))
1911

2012

2113
@text_or_path
22-
def is_csv(text):
14+
def is_csv(text, **kw):
2315
""" Check if the input text or path is a valid CSV. """
2416
try:
2517
dialect = csvmod.Sniffer().sniff(text := ensure_str(text))
@@ -29,7 +21,16 @@ def is_csv(text):
2921
return False
3022

3123

32-
def to_csv(dsff, path=None, text=False):
24+
def load_csv(path, exclude=DEFAULT_EXCL, **kw):
25+
""" Load a CSV file as a dictionary with data, features and metadata. """
26+
data = {'metadata': {}}
27+
with open(expanduser(path)) as f:
28+
data['data'] = [r for r in csvmod.reader(f, delimiter=CSV_DELIMITER)]
29+
data['features'] = {h: "" for h in data['data'][0] if h not in exclude}
30+
return data
31+
32+
33+
def to_csv(dsff, path=None, text=False, **kw):
3334
""" Create a CSV from the data worksheet, saved as a file or output as a string. """
3435
with (StringIO() if text else open(path, 'w+')) as f:
3536
writer = csvmod.writer(f, delimiter=";")

src/dsff/formats/dataset.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
# -*- coding: UTF-8 -*-
22
from .__common__ import *
3+
from .csv import load_csv
34

45

5-
__all__ = ["from_dataset", "is_dataset", "to_dataset"]
6+
__all__ = ["from_dataset", "is_dataset", "load_dataset", "to_dataset"]
67

78

89
def _parse(path):
9-
if not isdir(path):
10+
if not isdir(expanduser(path)):
1011
raise BadInputData("Not a folder")
11-
else:
12-
if len(missing := [f for f in ["data.csv", "features.json", "metadata.json"] if not isfile(join(path, f))]) > 0:
13-
raise BadInputData(f"Not a valid dataset folder (missing: {', '.join(missing)})")
12+
if len(missing := [f for f in ["data.csv", "features.json", "metadata.json"] if not isfile(join(path, f))]) > 0:
13+
raise BadInputData(f"Not a valid dataset folder (missing: {', '.join(missing)})")
1414

1515

16-
def from_dataset(dsff, path=None):
16+
def from_dataset(dsff, path=None, **kw):
1717
""" Populate the DSFF file from a Dataset structure. """
1818
_parse(path)
1919
dsff.write(path)
2020

2121

22-
def is_dataset(path):
22+
def is_dataset(path, **kw):
2323
""" Check if the input path is a valid Dataset. """
2424
try:
2525
_parse(path)
@@ -28,7 +28,21 @@ def is_dataset(path):
2828
return False
2929

3030

31-
def to_dataset(dsff, path=None):
31+
def load_dataset(path, **kw):
32+
""" Load a dataset folder as a dictionary with data, features and metadata. """
33+
if not isdir(d := expanduser(str(path))):
34+
raise BadInputData("Not a folder")
35+
dp, fp, mp = join(d, "data.csv"), join(d, "features.json"), join(d, "metadata.json")
36+
data = {}
37+
data['data'] = load_csv(dp)['data']
38+
with open(fp) as f:
39+
data['features'] = json.load(f)
40+
with open(mp) as f:
41+
data['metadata'] = json.load(f)
42+
return data
43+
44+
45+
def to_dataset(dsff, path=None, **kw):
3246
""" Create a dataset folder according to the following structure ;
3347
name
3448
+-- data.csv

src/dsff/formats/db.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,16 @@
22
from .__common__ import *
33

44

5-
__all__ = ["from_db", "is_db", "to_db"]
5+
__all__ = ["from_db", "is_db", "load_db", "to_db"]
66

77

8-
def from_db(dsff, path=None, exclude=DEFAULT_EXCL):
8+
def from_db(dsff, path=None, **kw):
99
""" Populate the DSFF file from a SQLDB file. """
10-
from json import loads
11-
from sqlite3 import connect
12-
conn = connect(path)
13-
cursor = conn.cursor()
14-
# list tables
15-
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
16-
tables = [table[0] for table in cursor.fetchall()]
17-
if not all(t in tables for t in ["data", "features", "metadata"]): # pragma: no cover
18-
raise BadInputData("The target SQLDB does not have the right format")
19-
# import data
20-
cursor.execute("PRAGMA table_info('data')")
21-
headers = [[col[1] for col in cursor.fetchall()]]
22-
cursor.execute("SELECT * FROM data;")
23-
dsff.write(headers + [r for r in cursor.fetchall()])
24-
# import feature definitions
25-
cursor.execute("SELECT name,description FROM features;")
26-
dsff.write(features={r[0]: r[1] for r in cursor.fetchall()})
27-
# import metadata
28-
cursor.execute("SELECT key,value FROM metadata;")
29-
dsff.write(metadata={r[0]: loads(r[1]) if isinstance(r[1], str) else r[1] for r in cursor.fetchall()})
30-
conn.close()
10+
dsff.write(**load_db(path))
3111

3212

3313
@text_or_path
34-
def is_db(data):
14+
def is_db(data, **kw):
3515
""" Check if the input data or path is a valid SQL database. """
3616
from sqlite3 import connect, Error
3717
from sys import version_info
@@ -60,7 +40,34 @@ def is_db(data):
6040
return False
6141

6242

63-
def to_db(dsff, path=None, text=False, primary_index=0):
43+
def load_db(path, **kw):
44+
""" Load a SQLDB file as a dictionary with data, features and metadata. """
45+
from json import loads
46+
from os.path import basename, splitext
47+
from sqlite3 import connect
48+
conn = connect(path)
49+
cursor, data = conn.cursor(), {}
50+
# list tables
51+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
52+
tables = [table[0] for table in cursor.fetchall()]
53+
if not all(t in tables for t in ["data", "features", "metadata"]): # pragma: no cover
54+
raise BadInputData("The target SQLDB does not have the right format")
55+
# import data
56+
cursor.execute("PRAGMA table_info('data')")
57+
headers = [[col[1] for col in cursor.fetchall()]]
58+
cursor.execute("SELECT * FROM data;")
59+
data['data'] = headers + [r for r in cursor.fetchall()]
60+
# import feature definitions
61+
cursor.execute("SELECT name,description FROM features;")
62+
data['features'] = {r[0]: r[1] for r in cursor.fetchall()}
63+
# import metadata
64+
cursor.execute("SELECT key,value FROM metadata;")
65+
data['metadata'] = {r[0]: loads(r[1]) if isinstance(r[1], str) else r[1] for r in cursor.fetchall()}
66+
conn.close()
67+
return data
68+
69+
70+
def to_db(dsff, path=None, text=False, primary_index=0, **kw):
6471
""" Create a SQLDB from the data worksheet, saved as a file or output as a string. """
6572
from json import dumps
6673
from sqlite3 import connect

src/dsff/formats/pa.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,38 @@ def _nowrite(m):
99
raise NotImplementedError(f"none of {m}.write_table and {m}.write_{m} is implemented")
1010

1111

12+
def _parse(ds):
13+
return {
14+
'data': [ds.schema.names] + [list(r.values()) for r in ds.to_pylist()],
15+
'features': {k.decode(): v.decode() for k, v in ds.schema.metadata.items()},
16+
'metadata': literal_eval(ds.schema.metadata.pop(b'__metadata__', b"{}").decode()),
17+
}
18+
19+
1220
for module in ["feather", "orc", "parquet"]:
13-
__all__ += [f"from_{module}", f"is_{module}", f"to_{module}"]
21+
__all__ += [f"from_{module}", f"is_{module}", f"load_{module}", f"to_{module}"]
1422
def gen_func(m):
15-
def from_(dsff, path=None, exclude=DEFAULT_EXCL):
16-
dataset = globals()[m].read_table(path)
17-
dsff.write(data=[dataset.schema.names] + [list(r.values()) for r in dataset.to_pylist()],
18-
metadata=literal_eval(dataset.schema.metadata.pop(b'__metadata__', b"{}").decode()),
19-
features={k.decode(): v.decode() for k, v in dataset.schema.metadata.items()})
23+
def from_(dsff, path=None, **kw):
24+
from os.path import basename, splitext
25+
dsff.write(**_parse(globals()[m].read_table(path)))
2026
from_.__name__ = f"from_{m}"
21-
def is_(data):
27+
def is_(data, **kw):
2228
try:
2329
globals()[m].read_table(pyarrow.BufferReader(data))
2430
return True
2531
except Exception:
2632
return False
2733
is_.__name__ = f"is_{m}"
28-
def to_(dsff, path=None, text=False):
34+
def load_(path, **kw):
35+
return _parse(globals()[m].read_table(path))
36+
load_.__name__ = f"load_{m}"
37+
def to_(dsff, path=None, text=False, **kw):
2938
with (BytesIO() if text else open(path, 'wb+')) as f:
3039
getattr(globals()[m], "write_table", getattr(globals()[m], f"write_{m}", _nowrite))(dsff._to_table(), f)
3140
if text:
3241
return f.getvalue()
3342
to_.__name__ = f"to_{m}"
34-
return from_, text_or_path(is_), to_
35-
globals()[f'from_{module}'], globals()[f'is_{module}'], globals()[f'to_{module}'] = gen_func(module)
43+
return from_, text_or_path(is_), load_, to_
44+
globals()[f'from_{module}'], globals()[f'is_{module}'], globals()[f'load_{module}'], globals()[f'to_{module}'] = \
45+
gen_func(module)
3646

tests/test_dsff.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def test_conversion_arff(self):
8989
with open(arff := f"{TEST_BASENAME}.arff", 'w') as f:
9090
f.write(TEST_ARFF)
9191
self.assertTrue(is_arff(arff))
92+
self.assertIsInstance(load_arff(arff), dict)
9293
with DSFF() as f:
9394
f.from_arff(TEST_BASENAME)
9495
# test for multiple error scenarios
@@ -130,7 +131,8 @@ def test_conversion_csv(self):
130131
with DSFF(TEST) as f:
131132
self.assertIsNotNone(f.to_csv(text=True))
132133
self.assertIsNone(f.to_csv())
133-
self.assertTrue(is_csv(f"{TEST}.csv"))
134+
self.assertTrue(is_csv(csv := f"{TEST_BASENAME}.csv"))
135+
self.assertIsInstance(load_csv(csv), dict)
134136
# CSV to DSFF
135137
with DSFF() as f:
136138
f.from_csv(TEST_BASENAME)
@@ -143,7 +145,9 @@ def test_conversion_dataset(self):
143145
# FilelessDataset to DSFF
144146
with DSFF() as f:
145147
f.from_dataset(TEST_BASENAME)
146-
self.assertTrue(is_dataset(f"{TEST_BASENAME}"))
148+
self.assertTrue(is_dataset(TEST_BASENAME))
149+
self.assertRaises(BadInputData, load_dataset, 0)
150+
self.assertIsInstance(load_dataset(TEST_BASENAME), dict)
147151
# FilelessDataset to DSFF (bad input dataset)
148152
os.remove(os.path.join(TEST_BASENAME, "metadata.json"))
149153
with DSFF() as f:
@@ -167,7 +171,8 @@ def test_conversion_db(self):
167171
with DSFF(TEST) as f:
168172
self.assertIsNotNone(f.to_db(text=True))
169173
self.assertIsNone(f.to_db())
170-
self.assertTrue(is_db(f"{TEST_BASENAME}.db"))
174+
self.assertTrue(is_db(db := f"{TEST_BASENAME}.db"))
175+
self.assertIsInstance(load_db(db), dict)
171176
# SQL database to DSFF
172177
with DSFF() as f:
173178
f.from_db(TEST_BASENAME)
@@ -183,8 +188,9 @@ def test_conversion_pyarrow_formats(self):
183188
self.assertIsNone(getattr(f, f"to_{fmt}")(TEST_BASENAME))
184189
self.assertIsNotNone(getattr(f, f"to_{fmt}")(text=True))
185190
is_ = globals()[f'is_{fmt}']
186-
self.assertTrue(is_(f"{TEST_BASENAME}.{fmt}"))
191+
self.assertTrue(is_(fn := f"{TEST_BASENAME}.{fmt}"))
187192
self.assertFalse(is_(b"PK\x03\x04\x14\x00\x00\x00\x08\x00P\xb3T\\F\xc7MH"))
193+
self.assertIsInstance(globals()[f'load_{fmt}'](fn), dict)
188194
with DSFF(INMEMORY) as f:
189195
self.assertIsNone(getattr(f, f"from_{fmt}")(TEST_BASENAME))
190196
f.to_dataset(path=TEST_BASENAME)

0 commit comments

Comments
 (0)