Skip to content

Commit d840489

Browse files
committed
Program to convert [h]json to tsv.
1 parent e875bf5 commit d840489

File tree

3 files changed

+111
-1
lines changed

3 files changed

+111
-1
lines changed

ragability/ragability_2tsv.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Module for the CLI to convert data files from json/hjson to tsv formt.
5+
6+
This will create a tsv file where for each array in the json/hjson file, the tsv row contains
7+
as many columns as the maximum number of elements in any of the arrays. The columns are named
8+
with the array index as a suffix, starting at 0. If the array has fewer elements than the maximum, the
9+
missing columns are filled with empty strings. The first row of the tsv file contains the column names.
10+
Nested fields in the json/hjson file are represented by a column name that is the concatenation of the
11+
field names separated by a dot.
12+
"""
13+
14+
import os, sys
15+
from typing import List, Dict, Optional
16+
import json
17+
import argparse
18+
import pandas as pd
19+
from pandas.core.frame import DataFrame
20+
import hjson
21+
import sklearn as sk
22+
from collections import defaultdict, Counter
23+
from logging import DEBUG
24+
from ragability.logging import logger, set_logging_level, add_logging_file
25+
from ragability.data import read_input_file
26+
from ragability.utils import pp_config
27+
from ragability.checks import CHECKS
28+
29+
30+
31+
def get_args():
32+
"""
33+
Get the command line arguments
34+
"""
35+
parser = argparse.ArgumentParser(description='Convert json/hjson to tsv')
36+
parser.add_argument('--input', '-i', type=str, help='Input json/hjson file', required=True)
37+
parser.add_argument('--output', '-o', type=str, help='Output tsv file (same as input but with tsv extension)', required=False)
38+
args_tmp = parser.parse_args()
39+
args = {}
40+
args.update(vars(args_tmp))
41+
return args
42+
43+
44+
45+
def run(config: dict):
46+
indata = read_input_file(config["input"])
47+
logger.info(f"Read {len(indata)} records from {config['input']}")
48+
# indata is a list of nested dictionaries where each of the values in the dictionaries could
49+
# be a scalar value, a nested dictionary, a scalar value or a list of scalar values or
50+
# nested dictionaries. We need to flatten this structure into a list of flat dictionaries where
51+
# each dictionary contains only scalar values. We will use a recursive function to do this.
52+
# The names use dots to separate nested fields and underscores to separate array indices.
53+
# Example:
54+
# indata = [ { "a": 1, "b": [2, 3], "c": { "d": 4, "e": 5 }, "e": [{ "f": 6 },{ "f": 7 }] } ]
55+
# outdata = [ { "a": 1, "b_0": 2, "b_1": 3, "c.d": 4, "c.e": 5, "e_0.f": 6, "e_1.f": 7 } ]
56+
# First we analyse all the nested dictionaries in the list to find all the field names
57+
# and the maximum number of elements in any array.
58+
# We also need to make sure that none of the text fields contain any new lines or tabs before
59+
# we write the tsv file.
60+
def analyse(indata):
61+
fieldnames = set()
62+
maxarraysize = 0
63+
for item in indata:
64+
for k, v in item.items():
65+
fieldnames.add(k)
66+
if isinstance(v, list):
67+
maxarraysize = max(maxarraysize, len(v))
68+
return fieldnames, maxarraysize
69+
# now we actually convert the list of nested dictionaries into a list of flat dictionaries
70+
def flatten(indata):
71+
fieldnames, maxarraysize = analyse(indata)
72+
flatdata = []
73+
for item in indata:
74+
flatitem = {}
75+
for k in fieldnames:
76+
v = item.get(k)
77+
if isinstance(v, list):
78+
for i, vi in enumerate(v):
79+
flatitem[f"{k}_{i}"] = vi
80+
elif isinstance(v, dict):
81+
for k1, v1 in v.items():
82+
flatitem[f"{k}.{k1}"] = v1
83+
else:
84+
flatitem[k] = v
85+
flatdata.append(flatitem)
86+
return flatdata
87+
flatdata = flatten(indata)
88+
# make sure there are no new lines or tabs in the text fields
89+
for item in flatdata:
90+
for k, v in item.items():
91+
if isinstance(v, str):
92+
item[k] = v.replace("\n", " ").replace("\t", " ")
93+
df = pd.DataFrame(flatdata)
94+
logger.info(f"Converted to dataframe with {df.shape[0]} rows and {df.shape[1]} columns")
95+
# Now we have the dataframe, we can write it to the output file
96+
outputfile = config["output"]
97+
if not outputfile:
98+
outputfile = os.path.splitext(config["input"])[0] + ".tsv"
99+
df.to_csv(outputfile, sep="\t", index=False)
100+
logger.info(f"Output written to {outputfile}")
101+
102+
103+
def main():
104+
args = get_args()
105+
run(args)
106+
107+
108+
if __name__ == '__main__':
109+
main()

ragability/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
import importlib.metadata
2-
__version__ = "0.7.2"
2+
__version__ = "0.7.3"

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def versionfromfile(*filepath):
5757
"ragability_hjson_cat=ragability.ragability_hjson_cat:main",
5858
"ragability_hjson_info=ragability.ragability_hjson_info:main",
5959
"ragability_test_llms=ragability.ragability_test_llms:main",
60+
"ragability_2tsv=ragability.ragability_2tsv:main",
6061
]},
6162
classifiers=[
6263
# "Development Status :: 6 - Mature",

0 commit comments

Comments
 (0)