luz 0.3.0

Torch Packages/Releases R

luz version 0.3.0 is now on CRAN. luz is a high-level interface for torch.

Daniel Falbel (RStudio)https://www.rstudio.com/
2022-08-24

We are happy to announce that luz version 0.3.0 is now on CRAN. This release brings a few improvements to the learning rate finder first contributed by Chris McMaster. As we didn’t have a 0.2.0 release post, we will also highlight a few improvements that date back to that version.

What’s luz?

Since it is relatively new package, we are starting this blog post with a quick recap of how luz works. If you already know what luz is, feel free to move on to the next section.

luz is a high-level API for torch that aims to encapsulate the training loop into a set of reusable pieces of code. It reduces the boilerplate required to train a model with torch, avoids the error-prone zero_grad() - backward() - step() sequence of calls, and also simplifies the process of moving data and models between CPUs and GPUs.

With luz you can take your torch nn_module(), for example the two-layer perceptron defined below:

modnn <- nn_module(
  initialize = function(input_size) {
    self$hidden <- nn_linear(input_size, 50)
    self$activation <- nn_relu()
    self$dropout <- nn_dropout(0.4)
    self$output <- nn_linear(50, 1)
  },
  forward = function(x) {
    x %>% 
      self$hidden() %>% 
      self$activation() %>% 
      self$dropout() %>% 
      self$output()
  }
)

and fit it to a specified dataset like so:

fitted <- modnn %>% 
  setup(
    loss = nn_mse_loss(),
    optimizer = optim_rmsprop,
    metrics = list(luz_metric_mae())
  ) %>% 
  set_hparams(input_size = 50) %>% 
  fit(
    data = list(x_train, y_train),
    valid_data = list(x_valid, y_valid),
    epochs = 20
  )

luz will automatically train your model on the GPU if it’s available, display a nice progress bar during training, and handle logging of metrics, all while making sure evaluation on validation data is performed in the correct way (e.g., disabling dropout).

luz can be extended in many different layers of abstraction, so you can improve your knowledge gradually, as you need more advanced features in your project. For example, you can implement custom metrics, callbacks, or even customize the internal training loop.

To learn about luz, read the getting started section on the website, and browse the examples gallery.

What’s new in luz?

Learning rate finder

In deep learning, finding a good learning rate is essential to be able to fit your model. If it’s too low, you will need too many iterations for your loss to converge, and that might be impractical if your model takes too long to run. If it’s too high, the loss can explode and you might never be able to arrive at a minimum.

The lr_finder() function implements the algorithm detailed in Cyclical Learning Rates for Training Neural Networks (Smith 2015) popularized in the FastAI framework (Howard and Gugger 2020). It takes an nn_module() and some data to produce a data frame with the losses and the learning rate at each step.

model <- net %>% setup(
  loss = torch::nn_cross_entropy_loss(),
  optimizer = torch::optim_adam
)

records <- lr_finder(
  object = model, 
  data = train_ds, 
  verbose = FALSE,
  dataloader_options = list(batch_size = 32),
  start_lr = 1e-6, # the smallest value that will be tried
  end_lr = 1 # the largest value to be experimented with
)

str(records)
#> Classes 'lr_records' and 'data.frame':   100 obs. of  2 variables:
#>  $ lr  : num  1.15e-06 1.32e-06 1.51e-06 1.74e-06 2.00e-06 ...
#>  $ loss: num  2.31 2.3 2.29 2.3 2.31 ...

You can use the built-in plot method to display the exact results, along with an exponentially smoothed value of the loss.

plot(records) +
  ggplot2::coord_cartesian(ylim = c(NA, 5))
Plot displaying the results of the lr_finder()

If you want to learn how to interpret the results of this plot and learn more about the methodology read the learning rate finder article on the luz website.

Data handling

In the first release of luz, the only kind of object that was allowed to be used as input data to fit was a torch dataloader(). As of version 0.2.0, luz also support’s R matrices/arrays (or nested lists of them) as input data, as well as torch dataset()s.

Supporting low level abstractions like dataloader() as input data is important, as with them the user has full control over how input data is loaded. For example, you can create parallel dataloaders, change how shuffling is done, and more. However, having to manually define the dataloader seems unnecessarily tedious when you don’t need to customize any of this.

Another small improvement from version 0.2.0, inspired by Keras, is that you can pass a value between 0 and 1 to fit’s valid_data parameter, and luz will take a random sample of that proportion from the training set, to be used for validation data.

Read more about this in the documentation of the fit() function.

New callbacks

In recent releases, new built-in callbacks were added to luz:

Final remarks

You can see the full changelog available here.

In this post we would also like to thank:

Thank you!

Photo by Dil on Unsplash

Howard, Jeremy, and Sylvain Gugger. 2020. “Fastai: A Layered API for Deep Learning.” Information 11 (2): 108. https://doi.org/10.3390/info11020108.
Smith, Leslie N. 2015. “Cyclical Learning Rates for Training Neural Networks.” https://doi.org/10.48550/ARXIV.1506.01186.
Zhang, Hongyi, Moustapha Cisse, Yann N. Dauphin, and David Lopez-Paz. 2017. “Mixup: Beyond Empirical Risk Minimization.” https://doi.org/10.48550/ARXIV.1710.09412.

References

Reuse

Text and figures are licensed under Creative Commons Attribution CC BY 4.0. The figures that have been reused from other sources don't fall under this license and can be recognized by a note in their caption: "Figure from ...".

Citation

For attribution, please cite this work as

Falbel (2022, Aug. 24). Posit AI Blog: luz 0.3.0. Retrieved from https://blogs.rstudio.com/tensorflow/posts/2022-08-24-luz-0-3/

BibTeX citation

@misc{luz-0-3-0,
  author = {Falbel, Daniel},
  title = {Posit AI Blog: luz 0.3.0},
  url = {https://blogs.rstudio.com/tensorflow/posts/2022-08-24-luz-0-3/},
  year = {2022}
}