

Introduction
While working with applications related to text generation, due to various problems (which we’ll explain in this blog), we concluded that we should train GPT-2 from scratch on our dataset. In this blog, we will understand GPT-2, its applications, and when & how to train a language model from scratch. We’ll also understand the challenges and solutions associated with training the GPT-2 model with Hugging Face (HF) from scratch. Prerequisites for this blog are a basic understanding of transformers and transformers-based language models.
What is GPT
GPTs are transformer-based autoregressive language models, which predict a word based on the prior sequence of words (prompt) and include the current predicted word in the prompt to expect the next word. In this way, it generates text given a small prompt of text.
In sequence-to-sequence tasks like translation, summary & question-answering, we need two sets of transformers, one to understand the given text, i.e., to encode the given text; for this, we use encoder models like BERT, and once the text is understood, we need another set of transformers to generate new text, i.e., to decode the encoded text in different formats like summarised text, translated text or answers to the questions. For this decoding, we can use GPTs which are decoder-based transformers.
Many pre-trained decoder-based models like GPT, GPT-2, GPT-3, chatGPT, and Transformer XL exist. We can use or finetune these models for downstream tasks like translation, summarisation, and question-answering. These decoder-based transformers are trained with a semi-supervised task of next-word prediction given a sequence of words. A considerable amount of data is required to train it. In this blog, We will focus on GPT-2, as the basic structure and idea of all GPTs are the same. The only difference is in architecture & size.
Need for Training (From Scratch)
This is a critical question: If we have multiple pre-trained models, training a new model is costly and increases carbon footprint, then what is the need to train a new tokenizer and model from scratch?
To answer this question, one should have an understanding of tokenisation. In tokenisation, we break the words into tokens/sub-words present in the tokenizer’s vocab so that all parts of the sentence can be associated with token ids.
Let’s look at a few examples.
- “YUBI is a fintech company founded by Mr. Gaurav Kumar”. Before sending this sentence into the model we tokenize it. The tokenized form of this sentence is : [‘Y’, ‘_U’, ‘_BI’, ‘is’, ‘a’, ‘fintech’, ‘company’, ‘found’, ‘_ed’, ‘by’, ‘Mr.’, ‘Gau’, ‘_rav’, ‘Kum’, ‘_ar’] because words ‘YUBI’, ‘founded’, ‘Gaurav’ and ‘Kumar’ are not present in pre-trained tokenizer vocab, so it gets split into best possible chunks of the vocab.
- “Azithromycin is an antibiotic”.’ Tokenised form of this will be [ ‘Az’, ‘_ithr’, ‘_rom’, ‘_y’, ‘_cin’,’ is’, ‘an’, ‘antibiotic’]
“Меня зовут Сан Прит ”. Tokenized form of this Russian sentence will be [‘<unknown>’, ‘<unknown>’,‘<unknown>’, ‘<unknown>’, ‘<unknown>’, ‘<unknown>’, ‘<unknown>’, ‘<unknown>’, ‘<unknown>’] because no such symbols or characters are present in the current/pre-trained vocab.
Through these three examples, we can understand that if our dataset is from a different domain, different language, different characters, or different style, then the tokenization of the sentences may be full of <unknown>, or it may be split into tiny tokens, even into characters increasing the length of tokens.
With this kind of tokenization, there are two problems:
- Word may lose sense, as we have seen in the case of ‘Azithromycin’
- There is a limit to the length of tokens we can pass into GPT; we have to trim it after a certain length. If the number of tokens becomes very high just due to bad tokenization, and we are trimming the sentence, then sentences may become incomplete and illogical.
Since we have a new kind of dataset and tokenizer now, and it is not advisable to use pre-trained models in this scenario, as its weights have yet to see this new kind of data, we decided to train our tokenizer and model from scratch. This applies not only to GPT-2 but to any transformer-based or pre-trained model.
You can train your tokenizer and model if you are working with any transformer-based model with new and specific kinds of data domains (Health Care, EdTech, FinTech, Sports, etc.).
What are Huggingface and Fairseq?
Hugging Face is a community and data science platform that provides the following:
- Tools that enable users to build, train and deploy ML models based on open source (OS) code and technologies.
- A place where a broad community of data scientists, researchers, and ML engineers can come together and share ideas, get support, and contribute to open-source projects.
Fairseq is an open-source sequence modelling toolkit that allows researchers and developers to train custom models for translation, summarisation, language modelling, and other text generation tasks. The toolkit is based on PyTorch and supports distributed training across multiple GPUs and machines.
Exploring LLM Training With Hugging Face
Once we are convinced that we have to train our new tokenizer and model, we will focus on training GPT-2 with Hugging Face. Training any machine learning or deep learning model begins with data collection, preprocessing, and/or labeling. In this case, we will also need text data of that particular language/domain for which we are training our new model.
We have trained GPT-2 in the fintech domain for English and 13 other Indian languages, e.g., Hindi, Telugu, Marathi, Tamil, Gujarati, Bengali, Assamese, etc.
For this, we have extracted financial/business news of all the mentioned languages and whole Wikipedia data in English to teach the model general English. Thus far, we have collected 240 GB of text data across different languages (the major was English, and the most minor data was Assamese). Note that initially, GPT-2 was trained on 40GB of data of Reddit posts that had received at least three upvotes (source: https://en.wikipedia.org/wiki/GPT-2).
Once we gathered the data, we preprocessed a bit. In preprocessing, we have removed small and illogical sentences as in generative models, and we need to teach the model only contextual sentences. Also, we removed the headers and footers of the web pages to keep the news articles.
Training
We can easily train a tokenizer on Hugging Face with our custom data:from transformers import AutoTokenizer old_tokenizer = AutoTokenizer.from_pretrained(“gpt2”) new_tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 52000 tokenizer.save_pretrained(“tokenizer_folder/new_tokenizer”)
from transformers import AutoTokenizer old_tokenizer = AutoTokenizer.from_pretrained("gpt2") new_tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 52000) tokenizer.save_pretrained("tokenizer_folder/new_tokenizer") ## How to load tokenizer = AutoTokenizer.from_pretrained("tokenizer_folder/new_tokenizer")
It will train and save a new tokenizer similar to GPT-2 based tokenizer. In our case, we had trained our tokenizer with Fairseq as we needed this tokenizer while training various transformer models, not just with Hugging Face but also with Fairseq.
But Hugging Face doesn’t support the fairseq tokenizer(refer to this ) directly; it blunders while tokenizing the fairseq script. With fairseq, the HF tokenizer doesn’t care whether a particular token has come from the start or middle of the word. So, we had to make the Hugging Face tokenizer compatible with the fairseq using this script. This script takes care of whether a token has come from the start of the word or the middle before splitting into tokens
Now coming to actual model training, HF has provided scripts to train transformer models and GPT-2 as well. But we have found an issue with the script. Understand this issue with an example. Let’s consider these lines of text in our training dataset.
- The government strives to train 10 lakh youth per annum through apprenticeship training and to fulfill this mission
- Such disruptions include higher-than-expected inflation, sudden spikes in interest rates to contain surging prices or a pandemic resurgence.
- यस बैंक समेत कई बैंकों ने FD पर ब्याज दरों में बढ़ोतरी की है। बता दें कि पिछले साल रिजर्व बैंक ऑफ इंडिया (RBI) ने रेपो रेट में बढ़ोतरी की थी।
- সেঞ্চুরি করার পর ইনিংসের বিরতিতে কোহলি বলেছিলেন, তিনি নাকি এই সিরিজের আগে যে বিশ্রাম পেয়েছেন, সেটাই ভালো পারফরম্যান্স করতে কাজে লেগেছে।
- राज्यातील सहकारी पतसंस्थांना आता बँकांमध्ये केलेल्या गुंतवणुकीवर मिळणाऱ्या व्याजावर यापुढे प्राप्तिकर द्यावा लागणार नाही. प्राप्तिकर अपिलीय प्राधिकरणाच्या पुणे खंडपीठाने दिलेल्या या निर्णयामुळे सहकारी पतसंस्थांना दिलासा मिळाला आहे
- நமக்கு கிடைக்கும் சிறிய வாய்ப்பை கூட பயன்படுத்திக் கொள்ள வேண்டும்.அப்போதுதான் கிரிக்கெட்டில் வெற்றி கிடைக்கும். உங்களுக்கு வெற்றி கிடைத்தால் அதில் அனைத்தும் சரியாக அமையாது.
Here we can see that each line is independent of the other( in fact if we randomly mix the dataset then each next line may be from a different language). And so each line should be treated independently. But what the HF training script does, after tokenization, it stacks all tokens in sequence and as we know that there should be a fixed input size (512 in our case), it forms groups of 512 tokens in sequence. Sometimes, in one group of 512 tokens, there are tokens from more than one line ( so may be of different languages). In this case, the model will try to learn context from that input which I think is not a good practice.
My suggestion here is that we should not mix the tokens from different lines unless all lines are not dependent on each other (like a story). In our case, all lines are from different contexts, so with a group of 512 tokens, we have not taken tokens from different lines. If the number of tokens in a line is not a multiple of 512 tokens, then we have padded that group of tokens with PAD_TOKEN_ID. That’s why we have modified the script accordingly.
However in their TensorFlow script, they have asked for an argument line_by_line whether to consider each line independently or not, but they have not used this argument while grouping the tokens.
Code Changes
We can achieve the above logic with a slight change in the Hugging Face code.
#old code def group_texts(tokenized_dataset): concatenated_tokens = {k: list(chain(*tokenized_dataset[k])) for k in tokenized_dataset.keys()} total_length = len(concatenated_tokens[list(tokenized_dataset.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_tokens.items() } result["labels"] = result["input_ids"].copy() return result
#new code def group_texts(tokenized_dataset): # We will divide longer (more than 512 tokens) into chunks of 512 tokens and will add padding id to last chunk #if total no of sentences will not be multiple of 512 #grouping_in_chunks has been defined below tokenized_dataset = grouping_in_chunks(tokenized_dataset,block_size) concatenated_tokens = {k: list(chain(*tokenized_dataset[k])) for k in examples.keys()} total_length = len(concatenated_tokens[list(tokenized_dataset.keys())[0]]) if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_tokens.items() } result["labels"] = result["input_ids"].copy() return result ##defining grouping_in_chunks function def grouping_in_chunks(examples,block_size): examples_2={"input_ids":[],"attention_mask":[]} for key in examples.keys(): for item in examples[key]: total_length = len(item) if total_length < block_size: #if total length of tokens in a sentence is less than 512 then pad it with #pad token id i.e. 0 to make the length of tokens to 512 examples_2[key].append(item.copy()) examples_2[key][-1].extend(0 for _ in range(block_size-total_length)) else: no_of_chunks = total_length//block_size if (total_length%block_size)!=0: no_of_chunks += 1 remaning_length = (no_of_chunks*block_size) - total_length final_list=item.copy() final_list.extend(0 for _ in range(remaning_length)) for i in range(no_of_chunks): examples_2[key].append(final_list[i*block_size:(i+1)*block_size]) return examples_2
#new code def group_texts(tokenized_dataset): # We will divide longer (more than 512 tokens) into chunks of 512 tokens and will add padding id to last chunk #if total no of sentences will not be multiple of 512 #grouping_in_chunks has been defined below tokenized_dataset = grouping_in_chunks(tokenized_dataset,block_size) concatenated_tokens = {k: list(chain(*tokenized_dataset[k])) for k in examples.keys()} total_length = len(concatenated_tokens[list(tokenized_dataset.keys())[0]]) if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_tokens.items() } result["labels"] = result["input_ids"].copy() return result ##defining grouping_in_chunks function def grouping_in_chunks(examples,block_size): examples_2={"input_ids":[],"attention_mask":[]} for key in examples.keys(): for item in examples[key]: total_length = len(item) if total_length < block_size: #if total length of tokens in a sentence is less than 512 then pad it with #pad token id i.e. 0 to make the length of tokens to 512 examples_2[key].append(item.copy()) examples_2[key][-1].extend(0 for _ in range(block_size-total_length)) else: no_of_chunks = total_length//block_size if (total_length%block_size)!=0: no_of_chunks += 1 remaning_length = (no_of_chunks*block_size) - total_length final_list=item.copy() final_list.extend(0 for _ in range(remaning_length)) for i in range(no_of_chunks): examples_2[key].append(final_list[i*block_size:(i+1)*block_size]) return examples_2
Also, if we have multi-domain/language sentences, we decided to randomise the training dataset because if we pass the data of a new language/domain after all data of a specific language/domain, there will be a sudden shift in weights, and this may impact learning. Better mix sentences of all domains/languages randomly so that it will see all types of data together; no significant jump in weights will come in the middle of the training.
After data preprocessing and building the tokenizer in Hugging Face format, we are ready to run the training script provided by Hugging Face i.e run_clm.py. You can learn about training arguments from here. After making changes (as mentioned above) in the run_clm script, we have renamed it as run_clm_modified.py(used this name later)
In this training, the dataset type is the ‘Hugging Face Dict Dataset’. We will not advise creating a pandas data frame or any other data frame in the entire journey because the Hugging Face dataset is designed to run in batches. To run any other dataset type, there can be an extra load to write a script to run that in batches. And if you will not train the model in batches, then it is almost impossible to train it. It will fill the entire memory of the instance, especially when training with a large dataset, which is necessary to train any transformer model.
During the training, a significant focus should be on using multiple GPUs. With the HugginFace script, what we have observed is that each GPU is not getting utilised 100% continuously. There was a very high fluctuation in GPU memory use. To use all GPUs 100% always, instead of this command:
python3 transformers/examples/pytorch/language-modeling/run_clm_modified.py ...
We have run this python command:
python3 -m torch.distributed.launch --nproc_per_node 8 transformers/examples/pytorch/language-modeling/run_clm_modified.py ...
Here 8 is the number of available GPUs. It can be 4,16 etc i.e how many GPUs are available in your instance.
Challenges, Solutions, and Suggestions
- A large volume of data is required to train any big architecture model, and so is the case with GPT-2. You will have to gather a vast amount of relevant text data.
- A cache folder is created during the tokenisation and grouping stage in training. The size of the cache folder will be enormous; for training text files of around 20MB, the final cache folder size will be about 140GB. Even more interesting is that in the middle of tokenisation and grouping, the cache file size will go to about 500GB. After creating this much-size cache, it deletes some unnecessary cache, bringing the size to 140 GB. But there should be around 500GB disk to train even on 20MB data because midway it is required. With this observation, you can know how much disk size you will need to train your own GPT-2 from scratch on your data.
- If disk size or memory is less than required, you will get a ‘broken pipe error’.
- There are arguments (–per_device_train_batch_size and –per_device_eval_batch_size )to set train and test batch size, but when we tried to set batch sizes using these arguments, we got an error related to batch size. Then we tried one more argument ( –auto_find_batch_size True ) along with these two then it set batch sizes as per the earlier two arguments only. So, whatever batch size you want to set as per your GPU memory, set those with the earlier two commands and put this one argument along with those two.
- When we tried to use the maximum threads of the CPU ( for tokenisation and grouping of tokens), we got a broken pipe error. After multiple attempts, we could use only around 60% of all available threads.
- Another challenge with Hugging Face is training time. With HF training time is very high in comparison to Faieseq. With 20 MB of data, one iteration will take around 8 days.
- Another challenge is Hugging Face documentation; there needs to be more documentation related to Hugging Face Training. You have only one option i.e., to play with scripts. If something is not working, you will not get a solution quickly. The GitHub issues section could be more responsive.
- You are suggested not to keep large validation data size and high validation frequency (i.e., validation steps) because it will take extra time.
- If you are using any other tokenizer, ensure it is tokenising the words correctly and mapping with the correct token ids.
- After running the training script, ensure all GPUs are utilised equally and up to their full potential.
- The most critical aspect to train any transformer model from scratch is COST. Earlier, we decided to train on 4 GPUs with 24GB of memory each. The approximate cost for this instance is $150/day; on Lambda Labs, it was $108/day. Lambda Labs GPUs are faster. For 238 GB of data, It would take 97 days on AWS and 36 days on Lambda Labs for 1 epoch. In a nutshell, to train a model on 238 GB data for 1 epoch, it will cost ~ $15,000 on AWS and ~4,000 on Lambda Labs.
- After so much cost, getting trained takes a very long time. Then we finally decided to train it on 8 GPUs with 40GB memory each. It costs $15,000/month+ taxes, and with this instance, it will take around 10 days to get trained. With this, it would cost about $10,000 for just 1 epoch.
Conclusion
Considering this experience, you can gauge when to train a new transformer model from scratch as it requires enormous data, brings huge costs and affects the environment. So, if you want to train it from scratch, you should be very specific. And for that, I suggest rereading the challenges and solution section to be prepared for the show, or else it will take a lot of time to fix these issues and may break your momentum. When there is very little documentation of Hugging Face, we hope with the help of this article, you can train GPT-2 from scratch with fewer issues in hand.
Reference
- https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling
- https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
- https://huggingface.co/blog/how-to-train
- https://github.com/huggingface/tokenizers/blob/main/bindings/python/scripts/sentencepiece_extractor.py
- https://github.com/Yubi2Community/YubiAI/tree/master/nlp/tokenizer
Want to learn more?
- https://www.go-yubi.com/blog/yubibert-a-tiny-fintech-language-model/
- https://github.com/Yubi2Community/YubiAI/tree/master/yubiai