Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions PyRDF/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,17 @@ def use(backend_name, conf={}):
necessary configuration parameters. Its default value is an empty
dictionary {}.
"""
future_backends = [
"dask"
]

global current_backend

if backend_name in future_backends:
msg = "This backend environment will be considered in the future !"
raise NotImplementedError(msg)
elif backend_name == "local":
if backend_name == "local":
current_backend = Local(conf)
elif backend_name == "spark":
from PyRDF.backend.Spark import Spark
current_backend = Spark(conf)
elif backend_name == "dask":
from PyRDF.backend.Dask import Dask
current_backend = Dask(conf)
else:
msg = "Incorrect backend environment \"{}\"".format(backend_name)
raise Exception(msg)
Expand Down
93 changes: 93 additions & 0 deletions PyRDF/backend/Dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import print_function

import logging
from pprint import pformat

from PyRDF.backend.Dist import Dist

import dask
from dask.distributed import Client

logger = logging.getLogger(__name__)


class Dask(Dist):
"""Dask backend for PyRDF."""

MIN_NPARTITIONS = 2

def __init__(self, config={}):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs docs for what parameters can be passed as configuration

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I'll make more tests and finalize the PR with docs in the end 👍

"""Init function."""
super(Dask, self).__init__(config)

self.config = config
self.client = None
self.npartitions = self._get_partitions()

logger.debug("Creating {} instance with {} partitions".format(
type(self), self.npartitions))
logger.debug("Dask configuration:\n{}".format(
pformat(dask.config.config)))

def _get_partitions(self):
"""Estimate partitions of the dataset."""
npartitions = (self.npartitions or Dask.MIN_NPARTITIONS)
return int(npartitions)

def ProcessAndMerge(self, mapper, reducer):
"""
Performs map-reduce using Dask framework.

Args:
mapper (function): A function that runs the computational graph
and returns a list of values.

reducer (function): A function that merges two lists that were
returned by the mapper.

Returns:
list: A list representing the values of action nodes returned
after computation (Map-Reduce).
"""

ranges = self.build_ranges() # Get range pairs

# The Dask client has to be initialized inside some context and not on
# global scope since it's using Python Multiprocessing and each process
# fork needs independent environment (e.g. otherwise each process would
# try recreating a connection to the Dask client).
if self.client is None:
logger.debug("Connecting to Dask client.")
if self.config.get("scheduler_address"):
self.client = Client(self.config["scheduler_address"])
else:
# TODO: Investigate the case where processes=True
# On my laptop multiprocessing triggers some segfault
self.client = Client(processes=False)
logger.debug(
"Succesfully connected to client {}".format(self.client))

dmapper = dask.delayed(mapper)
dreducer = dask.delayed(reducer)

mergeables_lists = [dmapper(range) for range in ranges]

while len(mergeables_lists) > 1:
mergeables_lists.append(
dreducer(mergeables_lists.pop(0), mergeables_lists.pop(0)))

if self.config.get("visualize_dask_graph"):
dask.visualize(mergeables_lists[0])

return mergeables_lists.pop().compute()

def distribute_files(self, includes_list):
"""
TODO: Implement file distribution to Dask workers.

Args:
includes_list (list): A list consisting of all necessary C++
files as strings, created one of the `include` functions of
the PyRDF API.
"""
pass
81 changes: 81 additions & 0 deletions tests/integration/dask/test_histo_write_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import unittest
from array import array

import PyRDF

import ROOT


class DaskHistoWriteTest(unittest.TestCase):
"""
Integration tests to check writing histograms to a `TFile` distributedly.
"""

@classmethod
def setUpClass(cls):
"""
Parameter initialization for the histogram.
"""
cls.nentries = 10000 # Number of fills
cls.gaus_mean = 10 # Mean of the gaussian distribution
cls.gaus_stdev = 1 # Standard deviation of the gaussian distribution
cls.delta_equal = 0.01 # Delta to check for float equality

def create_tree_with_data(self):
"""Creates a .root file with some data"""
f = ROOT.TFile("tree_gaus.root", "recreate")
T = ROOT.TTree("Events", "Gaus(10,1)")

x = array("f", [0])
T.Branch("x", x, "x/F")

r = ROOT.TRandom()
# The parent will have a gaussian distribution with mean 10 and
# standard deviation 1
for _ in range(self.nentries):
x[0] = r.Gaus(self.gaus_mean, self.gaus_stdev)
T.Fill()

f.Write()
f.Close()

def test_write_histo(self):
"""
Tests that an histogram is correctly written to a .root file created
before the execution of the event loop.
"""
self.create_tree_with_data()

# Create a new file where the histogram will be written
outfile = ROOT.TFile("out_file.root", "recreate")

# Create a PyRDF RDataFrame with the parent and the friend trees
PyRDF.use("dask")
df = PyRDF.RDataFrame("Events", "tree_gaus.root")

# Create histogram
histo = df.Histo1D(("x", "x", 100, 0, 20), "x")

# Write histogram to out_file.root and close the file
histo.Write()
outfile.Close()

# Reopen file to check that histogram was correctly stored
reopen_file = ROOT.TFile("out_file.root", "read")
reopen_histo = reopen_file.Get("x")

# Check histogram statistics
self.assertEqual(reopen_histo.GetEntries(), self.nentries)
self.assertAlmostEqual(reopen_histo.GetMean(), self.gaus_mean,
delta=self.delta_equal)
self.assertAlmostEqual(reopen_histo.GetStdDev(), self.gaus_stdev,
delta=self.delta_equal)

# Remove unnecessary .root files
os.remove("tree_gaus.root")
os.remove("out_file.root")


if __name__ == "__main__":
unittest.main()