Skip to content

created first version of api. ffi calls and example missing CU-861n1bn07#13

Open
CallMeMSL wants to merge 30 commits into
train_with_val_datafrom
new_api
Open

created first version of api. ffi calls and example missing CU-861n1bn07#13
CallMeMSL wants to merge 30 commits into
train_with_val_datafrom
new_api

Conversation

@CallMeMSL
Copy link
Copy Markdown

I created the first version of the new API and would like some feedback. The booster is built with the TypeState pattern, so logic errors when building the booster are caught by the compiler.
I also added add_pramas as a builder method, so that you can either have different params or datasets after duplicating a builder.

ffi code and tests are still missing.
You can also ignore the changes in old_booster.rs and old_dataset.rs, refactoring got a little messy.

@CallMeMSL CallMeMSL requested a review from leofidus May 18, 2023 20:24
Copy link
Copy Markdown

@leofidus leofidus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is definitely on the right way

Comment thread src/booster/builder.rs
Comment on lines +20 to +29
/// Builder for the Booster.
///
/// Uses TypeState Pattern to make sure that Training Data is added
/// so that Validation can be synced properly and params are present for training.
#[derive(Default, Clone)]
pub struct BoosterBuilder<T: Clone, P: Clone> {
train_data: T,
val_data: Vec<DataSet>,
params: P, // after #3 should this be a struct
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a neat pattern, and it looks like Rustdoc doesn't have issues with it either.

Comment thread src/booster/builder.rs
Comment on lines +61 to +66
impl<P: Clone> BoosterBuilder<TrainDataAdded, P> {
pub fn add_val_data(mut self, val: DataSet) -> Self {
self.val_data.push(val);
self
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

People tend to add validation data after training data, but is there a reason this isn't just implemented on BoosterBuilder<T,P>?

I guess it helps with validation to restrict it

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reason is that I'm not sure if loading the datasets in fit() is the best way to approach it. Restricting it like that always keeps the possibility open to load the dataset at the add_val_data call directly. This would make duplicate() also a lot more efficient.

Comment thread src/booster/mod.rs
Comment on lines +14 to +18
pub struct Booster {
handle: lightgbm_sys::BoosterHandle,
train_data: DataSet,
validation_data: Vec<DataSet>,
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these meant to be LoadedDatasets? I would assume fit loads the data

What about Boosters that were trained in advance and are loaded from file? What would their train_data and validation_data be? (and does it even make sense to hold onto these potentially huge datasets?)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right. This is a refactoring artifact and should just be the Dataset pointers. Now that you say it, it probably doesn't make sense to add them to the booster, if we build them with fit() anyway.

Comment thread src/dataset/mod.rs
Comment on lines +24 to +37
pub enum DataFormat {
File {
path: String,
},
Vecs {
x: InputMatrix,
y: OutputVec,
},
#[cfg(feature = "dataframe")]
DataFrame {
df: DataFrame,
y_column: String,
},
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is to make Datasets clonable?

It feels like it makes datasets and error handling a bit more complicated, compared to just directly loading them. It would also prevent future load_* functions that only take a reference to properly laid-out data (maybe loading nalgebra arrays,if the support that?).

I think (suspect) you can implement clone on Dataset by calling LGBM_DatasetCreateByReference(h_old, rows, &mut h_new) followed by LGBM_DatasetAddFeaturesFrom(h_new, h_old). But maybe that's completely wrong, the documentation is incredibly vague.

Copy link
Copy Markdown

@matthiasvedder matthiasvedder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very promising overall.

Comment thread src/booster/builder.rs Outdated
#[derive(Clone)]
pub struct TrainDataAdded(DataSet); // this should not implement default, so it can safely be used for construction
#[derive(Default, Clone)]
pub struct TrainDataNotAdded;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about the resulting API, I was wondering whether these structs could have names where the important part stands out more? Like WithTrainData and NoTrainData. The difference is the very first word, instead of (not) having a Not added in the middle of a fairly long type name.

Same for Params.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with the TypeState Pattern is, that you can't set error messages. However, structs appear in the error message when you try to call a function from a different implementation. Example:

25  | pub struct BoosterBuilder<T: Clone, P: Clone> {
    | --------------------------------------------- method `fit` not found for this struct
...
109 |         let builder = Booster::builder().fit();
    |                                          ^^^ method not found in `BoosterBuilder<TrainDataNotAdded, ParamsNotAdded>`
    |
    = note: the method was found for
            - `BoosterBuilder<TrainDataAdded, ParamsAdded>`

Since this is the only point where the user actually encounters the structs, I named them so that the Error message sounds natural.

But your suggestion would work as well.

Comment thread src/booster/builder.rs
/// Returns the Builder and a clone from it. Useful if you want to train 2 models with
/// only a couple differences
pub fn duplicate(self) -> (Self, Self) {
(self.clone(), self)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate why this function returns two instances of Self and why self is the second one?

Calling this like let (other, me) = me.duplicate(); feels a bit weird at first glance.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(self, self.clone()) would work as well.

I added this function, so that you can call it after you defined everything that 2 boosters have in common and then add the differences, like this:

        let (bst_low_lr, bst_high_lr) = Booster::builder()
            .add_train_data(dataset)
            .add_val_data(another_dataset)
            .add_val_data(also_a_dataset)
            .duplicate();
        let bst_low_lr = bst_low_lr.add_params(params_a).fit()?;
        let bst_high_lr = bst_high_lr.add_params(params_b).fit()?;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your example does look clean.

It would restrict us to 2 boosters. I don't know if comparing 3 or more boosters does make any sense.

Eventually, examples like these should be part of the docs, they help understand the API better.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want more than 2 boosters, you'd probably use clone again, for example if you have a Vec of params you want to test you could do

let src_bst = Booster::builder()
  .add_train_data(dataset)
  .add_val_data(another_dataset)
  .add_val_data(also_a_dataset);
let boosters = params.map(|p| src_bst.clone().add_params(p).fit())
                     .filter_map(|booster| booster.ok());

duplicate() is a bit of a special case, but I think it's nice to have.

@CallMeMSL
Copy link
Copy Markdown
Author

I think the rewrite is so far done, that we can accept this pr. Any feedback?

Comment thread src/booster/mod.rs
validation_data: Vec<LoadedDataSet>,
}

// exchange params method as well? does this make sense?
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover comment from the development stage?

Comment thread src/booster/mod.rs
/// # Ok(())}
/// ```
pub fn predict(&self, x: &Matrixf64) -> Result<Matrixf64, LgbmError> {
let prediction_params = ""; // do we need this?
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we?

Comment thread src/booster/mod.rs
.collect())
}

/// this should take &mut self, because it changes the model
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really a doc comment.

Comment thread src/booster/mod.rs
/// This should not reset the already existing submodels.
/// Pass an empty array as validation data, if you don't want to validate the train results.
/// TODO validate this after implemented
pub fn finetune(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should happen with this code? Delete it?

Comment thread src/dataset/mod.rs
/// The DatasetHandle is returned by the lightgbm ffi.
pub struct LoadedDataSet {
pub(crate) handle: DatasetHandle,
dataset: DataSet, // this can maybe be removed
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since clippy warns about it, it should be removed.

@CallMeMSL CallMeMSL changed the title created first version of api. ffi calls and example missing created first version of api. ffi calls and example missing CU-861n1bn07 Jul 24, 2023
@leofidus
Copy link
Copy Markdown

Task linked: CU-861n1bn07 LightGBM API Rewrite

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants