Nowadays, Microsoft, Google, Facebook, and OpenAI are sharing lots of state-of-the-art models in the field of Natural Language Processing. However, fewer materials exist how to use these models from R. In this post, we will show how R users can access and benefit from these models as well.
The Transformers repository from “Hugging Face” contains a lot of ready to use, state-of-the-art models, which are straightforward to download and fine-tune with Tensorflow & Keras.
For this purpose the users usually need to get:
In this post, we will work on a classic binary classification task and train our dataset on 3 models:
However, readers should know that one can work with transformers on a variety of down-stream tasks, such as:
Our first job is to install the transformers package via reticulate
.
reticulate::py_install('transformers', pip = TRUE)
Then, as usual, load standard ‘Keras’, ‘TensorFlow’ >= 2.0 and some classic libraries from R.
library(keras)
library(tensorflow)
library(dplyr)
library(tfdatasets)
transformer = reticulate::import('transformers')
Note that if running TensorFlow on GPU one could specify the following parameters in order to avoid memory issues.
physical_devices = tf$config$list_physical_devices('GPU')
tf$config$experimental$set_memory_growth(physical_devices[[1]],TRUE)
tf$keras$backend$set_floatx('float32')
We already mentioned that to train a data on the specific model, users should download the model, its tokenizer object and weights. For example, to get a RoBERTa model one has to do the following:
# get Tokenizer
transformer$RobertaTokenizer$from_pretrained('roberta-base', do_lower_case=TRUE)
# get Model with weights
transformer$TFRobertaModel$from_pretrained('roberta-base')
A dataset for binary classification is provided in text2vec package. Let’s load the dataset and take a sample for fast model training.
Split our data into 2 parts:
idx_train = sample.int(nrow(df)*0.8)
train = df[idx_train,]
test = df[!idx_train,]
Until now, we’ve just covered data import and train-test split. To feed input to the network we have to turn our raw text into indices via the imported tokenizer. And then adapt the model to do binary classification by adding a dense layer with a single unit at the end.
However, we want to train our data for 3 models GPT-2, RoBERTa, and Electra. We need to write a loop for that.
Note: one model in general requires 500-700 MB
# list of 3 models
ai_m = list(
c('TFGPT2Model', 'GPT2Tokenizer', 'gpt2'),
c('TFRobertaModel', 'RobertaTokenizer', 'roberta-base'),
c('TFElectraModel', 'ElectraTokenizer', 'google/electra-small-generator')
)
# parameters
max_len = 50L
epochs = 2
batch_size = 10
# create a list for model results
gather_history = list()
for (i in 1:length(ai_m)) {
# tokenizer
tokenizer = glue::glue("transformer${ai_m[[i]][2]}$from_pretrained('{ai_m[[i]][3]}',
do_lower_case=TRUE)") %>%
rlang::parse_expr() %>% eval()
# model
model_ = glue::glue("transformer${ai_m[[i]][1]}$from_pretrained('{ai_m[[i]][3]}')") %>%
rlang::parse_expr() %>% eval()
# inputs
text = list()
# outputs
label = list()
data_prep = function(data) {
for (i in 1:nrow(data)) {
txt = tokenizer$encode(data[['comment_text']][i],max_length = max_len,
truncation=T) %>%
t() %>%
as.matrix() %>% list()
lbl = data[['target']][i] %>% t()
text = text %>% append(txt)
label = label %>% append(lbl)
}
list(do.call(plyr::rbind.fill.matrix,text), do.call(plyr::rbind.fill.matrix,label))
}
train_ = data_prep(train)
test_ = data_prep(test)
# slice dataset
tf_train = tensor_slices_dataset(list(train_[[1]],train_[[2]])) %>%
dataset_batch(batch_size = batch_size, drop_remainder = TRUE) %>%
dataset_shuffle(128) %>% dataset_repeat(epochs) %>%
dataset_prefetch(tf$data$experimental$AUTOTUNE)
tf_test = tensor_slices_dataset(list(test_[[1]],test_[[2]])) %>%
dataset_batch(batch_size = batch_size)
# create an input layer
input = layer_input(shape=c(max_len), dtype='int32')
hidden_mean = tf$reduce_mean(model_(input)[[1]], axis=1L) %>%
layer_dense(64,activation = 'relu')
# create an output layer for binary classification
output = hidden_mean %>% layer_dense(units=1, activation='sigmoid')
model = keras_model(inputs=input, outputs = output)
# compile with AUC score
model %>% compile(optimizer= tf$keras$optimizers$Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0),
loss = tf$losses$BinaryCrossentropy(from_logits=F),
metrics = tf$metrics$AUC())
print(glue::glue('{ai_m[[i]][1]}'))
# train the model
history = model %>% keras::fit(tf_train, epochs=epochs, #steps_per_epoch=len/batch_size,
validation_data=tf_test)
gather_history[[i]]<- history
names(gather_history)[i] = ai_m[[i]][1]
}
Extract results to see the benchmarks:
res = sapply(1:3, function(x) {
do.call(rbind,gather_history[[x]][["metrics"]]) %>%
as.data.frame() %>%
tibble::rownames_to_column() %>%
mutate(model_names = names(gather_history[x]))
}, simplify = F) %>% do.call(plyr::rbind.fill,.) %>%
mutate(rowname = stringr::str_extract(rowname, 'loss|val_loss|auc|val_auc')) %>%
rename(epoch_1 = V1, epoch_2 = V2)
Both the RoBERTa and Electra models show some additional improvements after 2 epochs of training, which cannot be said of GPT-2. In this case, it is clear that it can be enough to train a state-of-the-art model even for a single epoch.
In this post, we showed how to use state-of-the-art NLP models from R. To understand how to apply them to more complex tasks, it is highly recommended to review the transformers tutorial.
We encourage readers to try out these models and share their results below in the comments section!
If you see mistakes or want to suggest changes, please create an issue on the source repository.
Text and figures are licensed under Creative Commons Attribution CC BY 4.0. Source code is available at https://github.com/henry090/transformers, unless otherwise noted. The figures that have been reused from other sources don't fall under this license and can be recognized by a note in their caption: "Figure from ...".
For attribution, please cite this work as
Abdullayev (2020, July 30). Posit AI Blog: State-of-the-art NLP models from R. Retrieved from https://blogs.rstudio.com/tensorflow/posts/2020-07-30-state-of-the-art-nlp-models-from-r/
BibTeX citation
@misc{abdullayev2020state-of-the-art, author = {Abdullayev, Turgut}, title = {Posit AI Blog: State-of-the-art NLP models from R}, url = {https://blogs.rstudio.com/tensorflow/posts/2020-07-30-state-of-the-art-nlp-models-from-r/}, year = {2020} }