Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions meegkit/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@ class ASR:
component-based artifact removal method for removing transient or
large-amplitude artifacts in multi-channel EEG recordings [1]_.

The key parameter of the method is ``cutoff``.

Parameters
----------
sfreq : float
Sampling rate of the data, in Hz.

The following are optional parameters (the key parameter of the method is
the ``cutoff``):

cutoff: float
Standard deviation cutoff for rejection. X portions whose variance
is larger than this threshold relative to the calibration data are
Expand Down Expand Up @@ -58,16 +56,16 @@ class ASR:
method : {'riemann', 'euclid'}
Method to use. If riemann, use the riemannian-modified version of
ASR [2]_.
memory : float
Memory size (s), regulates the number of covariance matrices to store.
estimator : str in {'scm', 'lwf', 'oas', 'mcd'}
memory : float | None
Memory size (samples), regulates the number of covariance matrices to
store.
If None (default), will use twice the sampling frequency.
estimator : {'scm', 'lwf', 'oas', 'mcd'}
Covariance estimator (default: 'scm' which computes the sample
covariance). Use 'lwf' if you need regularization (requires pyriemann).

Attributes
----------
``state_`` : dict
Initial state of the ASR filter.
``zi_``: array, shape=(n_channels, filter_order)
Filter initial conditions.
``ab_``: 2-tuple
Expand Down Expand Up @@ -98,9 +96,9 @@ class ASR:

"""

def __init__(self, sfreq=250, cutoff=5, blocksize=100, win_len=0.5,
def __init__(self, *, sfreq=250, cutoff=5, blocksize=100, win_len=0.5,
win_overlap=0.66, max_dropout_fraction=0.1,
min_clean_fraction=0.25, name="asrfilter", method="euclid",
min_clean_fraction=0.25, method="euclid", memory=None,
estimator="scm", **kwargs):

if pyriemann is None and method == "riemann":
Expand All @@ -115,7 +113,10 @@ def __init__(self, sfreq=250, cutoff=5, blocksize=100, win_len=0.5,
self.min_clean_fraction = min_clean_fraction
self.max_bad_chans = 0.3
self.method = method
self.memory = int(2 * sfreq) # smoothing window for covariances
if memory is None:
self.memory = int(2 * sfreq) # smoothing window for covariances
else:
self.memory = memory
self.sample_weight = np.geomspace(0.05, 1, num=self.memory + 1)
self.sfreq = sfreq
self.estimator = estimator
Expand All @@ -141,10 +142,10 @@ def fit(self, X, y=None, **kwargs):
"""Calibration for the Artifact Subspace Reconstruction method.

The input to this data is a multi-channel time series of calibration
data. In typical uses the calibration data is clean resting EEG data of
data if the fraction of artifact content is below the breakdown point
data. In typical uses the calibration data is clean resting EEG data.
The fraction of artifact content should be below the breakdown point
of the robust statistics used for estimation (50% theoretical, ~30%
practical). If the data has a proportion of more than 30-50% artifacts
practical). If the data has a proportion of more than 30-50% artifacts,
then bad time windows should be removed beforehand. This data is used
to estimate the thresholds that are used by the ASR processing function
to identify and remove artifact components.
Expand All @@ -164,6 +165,12 @@ def fit(self, X, y=None, **kwargs):
reasonably clean not less than 30 seconds (this method is typically
used with 1 minute or more).

Returns
-------
clean : array, shape=(n_channels, n_samples)
Dataset with bad time periods removed.
sample_mask : boolean array, shape=(1, n_samples)
Mask of retained samples (logical array).
"""
if X.ndim == 3:
X = X.squeeze()
Expand Down Expand Up @@ -468,6 +475,9 @@ def asr_calibrate(X, sfreq, cutoff=5, blocksize=100, win_len=0.5,
estimation (default=0.25).
method : {'euclid', 'riemann'}
Metric to compute the covariance matrix average.
estimator : {'scm', 'lwf', 'oas', 'mcd'}
Covariance estimator (default: 'scm' which computes the sample
covariance). Use 'lwf' if you need regularization (requires pyriemann).

Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion tests/test_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_asr_class(method, reref, show=False):
blah = ASR(method=method, estimator="scm")
blah.fit(raw2[:, train_idx])

asr = ASR(method=method, estimator="lwf")
asr = ASR(method=method, estimator="lwf", memory=int(2 * sfreq))
asr.fit(raw2[:, train_idx])
else:
asr = ASR(method=method, estimator="scm")
Expand Down