Demo: Fine-Tuning NM Results Management Language Model with a Custom Dataset

In this notebook, we will be using a sample of 10 radiology reports to show how we can preprocess the data, load the NM Results Management language model checkpoints, and use them for fine-tuning on your in-house data.

Load the Data

First, the data is loaded. In this example, we will train the model to perform a three-class classification problem, determining whether a report contains lung, adrenal, or no findings.

[1]:
import os
import joblib
from IPython.display import display, HTML

# Define the path to the data
base_path = os.path.dirname("__file__")
data_path = os.path.abspath(os.path.join(base_path, "..", "demo_data.gz"))

# Import data
modeling_df = joblib.load(data_path)

display(HTML(modeling_df.head(3).to_html()))
rpt_num note selected_finding selected_proc selected_label new_note
0 1 PROCEDURE: CT CHEST WO CONTRAST. HISTORY: Wheezing TECHNIQUE: Non-contrast helical thoracic CT was performed. COMPARISON: There is no prior chest CT for comparison. FINDINGS: Support Devices: None. Heart/Pericardium/Great Vessels: Cardiac size is normal. There is no calcific coronary artery atherosclerosis. There is no pericardial effusion. The aorta is normal in diameter. The main pulmonary artery is normal in diameter. Pleural Spaces: Few small pleural calcifications are present in the right pleura for example on 2/62 and 3/76. The pleural spaces are otherwise clear. Mediastinum/Hila: There is no mediastinal or hilar lymph node enlargement. Subcentimeter minimally calcified paratracheal lymph nodes are likely related to prior granulomas infection. Neck Base/Chest Wall/Diaphragm/Upper Abdomen: There is no supraclavicular or axillary lymph node enlargement. Limited, non-contrast imaging through the upper abdomen is within normal limits. Mild degenerative change is present in the spine. Lungs/Central Airways: There is a 15 mm nodular density in the nondependent aspect of the bronchus intermedius on 2/52. The trachea and central airways are otherwise clear. There is mild diffuse bronchial wall thickening. There is a calcified granuloma in the posterior right upper lobe. The lungs are otherwise clear. CONCLUSIONS: 1. There is mild diffuse bronchial wall thickening suggesting small airways disease such as asthma or bronchitis in the appropriate clinical setting. 2. A 3 mm nodular soft tissue attenuation in the nondependent aspect of the right bronchus intermedius is nonspecific, which could be mucus or abnormal soft tissue. A follow-up CT in 6 months might be considered to evaluate the growth. 3. Stigmata of old granulomatous disease is present.   FINAL REPORT Attending Radiologist: Lung Findings CT Chest A 3 mm nodular soft tissue attenuation in the nondependent aspect of the right bronchus intermedius is nonspecific, which could be mucus or abnormal soft tissue. A follow-up CT in 6 months might be considered to evaluate the growth. support devices: none. heart/pericardium/great vessels: cardiac size is normal. there is no calcific coronary artery atherosclerosis. there is no pericardial effusion. the aorta is normal in diameter. the main pulmonary artery is normal in diameter. pleural spaces: few small pleural calcifications are present in the right pleura for example on 2/62 and 3/76. the pleural spaces are otherwise clear. mediastinum/hila: there is no mediastinal or hilar lymph node enlargement. subcentimeter minimally calcified paratracheal lymph nodes are likely related to prior granulomas infection. neck base/chest wall/diaphragm/upper abdomen: there is no supraclavicular or axillary lymph node enlargement. limited, non-contrast imaging through the upper abdomen is within normal limits. mild degenerative change is present in the spine. lungs/central airways: there is a 15 mm nodular density in the nondependent aspect of the bronchus intermedius on 2/52. the trachea and central airways are otherwise clear. there is mild diffuse bronchial wall thickening. there is a calcified granuloma in the posterior right upper lobe. the lungs are otherwise clear. conclusions: 1. there is mild diffuse bronchial wall thickening suggesting small airways disease such as asthma or bronchitis in the appropriate clinical setting. 2. a 3 mm nodular soft tissue attenuation in the nondependent aspect of the right bronchus intermedius is nonspecific, which could be mucus or abnormal soft tissue. a follow-up ct in 6 months might be considered to evaluate the growth. 3. stigmata of old granulomatous disease is present.
1 2 PROCEDURE: CT ABDOMEN PELVIS W CONTRAST COMPARISON: date INDICATIONS: Lower abdominal/flank pain on the right TECHNIQUE: After obtaining the patients consent, CT images were created with intravenous iodinated contrast. FINDINGS: LIVER: The liver is normal in size. No suspicious liver lesion is seen. The portal and hepatic veins are patent. BILIARY: No biliary duct dilation. The biliary system is otherwise unremarkable. PANCREAS: No focal pancreatic lesion. No pancreatic duct dilation. SPLEEN: No suspicious splenic lesion is seen. The spleen is normal in size. KIDNEYS: No suspicious renal lesion is seen. No hydronephrosis. ADRENALS: No adrenal gland nodule or thickening. AORTA/VASCULAR: No aneurysm. RETROPERITONEUM: No lymphadenopathy. BOWEL/MESENTERY: The appendix is normal. No bowel wall thickening or bowel dilation. ABDOMINAL WALL: No hernia. URINARY BLADDER: Incomplete bladder distension limits evaluation, but no focal wall thickening or calculus is seen. PELVIC NODES: No lymphadenopathy. PELVIC ORGANS: Status post hysterectomy. No pelvic mass. BONES: No acute fracture or suspicious osseous lesion. LUNG BASES: No pleural effusion or consolidation. OTHER: Small hiatal hernia. CONCLUSION: 1. No acute process is detected. 2. Small hiatal hernia   FINAL REPORT Attending Radiologist: No Findings NaN No label liver: the liver is normal in size. no suspicious liver lesion is seen. the portal and hepatic veins are patent. biliary: no biliary duct dilation. the biliary system is otherwise unremarkable. pancreas: no focal pancreatic lesion. no pancreatic duct dilation. spleen: no suspicious splenic lesion is seen. the spleen is normal in size. kidneys: no suspicious renal lesion is seen. no hydronephrosis. adrenals: no adrenal gland nodule or thickening. aorta/vascular: no aneurysm. retroperitoneum: no lymphadenopathy. bowel/mesentery: the appendix is normal. no bowel wall thickening or bowel dilation. abdominal wall: no hernia. urinary bladder: incomplete bladder distension limits evaluation, but no focal wall thickening or calculus is seen. pelvic nodes: no lymphadenopathy. pelvic organs: status post hysterectomy. no pelvic mass. bones: no acute fracture or suspicious osseous lesion. lung bases: no pleural effusion or consolidation. other: small hiatal hernia. conclusion: 1. no acute process is detected. 2. small hiatal hernia
2 3 EXAM: MRI ABDOMEN W WO CONTRAST CLINICAL INDICATION: Cirrhosis of liver without ascites, unspecified hepatic cirrhosis type (CMS-HCC) TECHNIQUE: MRI of the abdomen was performed with and without contrast. Multiplanar imaging was performed. 8.5 cc of Gadavist was administered. COMPARISON: DATE and priors FINDINGS: On limited views of the lung bases, no acute abnormality is noted. There may be mild distal esophageal wall thickening. On the out of phase series, there is suggestion of some signal gain within the hepatic parenchyma. This is stable. A tiny cystic nonenhancing focus is seen anteriorly in the right hepatic lobe (9/10), unchanged. A subtly micronodular hepatic periphery is noted. There are few subtle hypervascular lesions in the right hepatic lobe, without significant washout. The portal vein is patent. Some splenorenal shunting is redemonstrated, similar to the comparison exam. The spleen measures 12.4 cm in length. No focal splenic lesion is appreciated. There are several small renal lesions again seen, many of which again demonstrate T1 shortening. On the postcontrast subtraction series, no obvious enhancement is noted. The adrenal glands and pancreas are intact. There is mild cholelithiasis, without gallbladder wall thickening or pericholecystic fluid. No free abdominal fluid is visualized. IMPRESSION: 1. Stable cirrhotic appearance of the liver. Few subtly hypervascular hepatic lesions do not demonstrate washout, and probably relate to perfusion variants. No particularly suspicious hepatic mass is seen. 2. Mild splenomegaly to 12.4 cm redemonstrated. Splenorenal shunting is again seen. 3. Scattered simple and complex renal cystic lesions, nonenhancing, stable from March 2040. 4. Incidentally, there is evidence of signal gain in the liver on the out of phase series. This occasionally may represent iron overload.   FINAL REPORT Attending Radiologist: No Findings NaN No label on limited views of the lung bases, no acute abnormality is noted. there may be mild distal esophageal wall thickening. on the out of phase series, there is suggestion of some signal gain within the hepatic parenchyma. this is stable. a tiny cystic nonenhancing focus is seen anteriorly in the right hepatic lobe (9/10), unchanged. a subtly micronodular hepatic periphery is noted. there are few subtle hypervascular lesions in the right hepatic lobe, without significant washout. the portal vein is patent. some splenorenal shunting is redemonstrated, similar to the comparison exam. the spleen measures 12.4 cm in length. no focal splenic lesion is appreciated. there are several small renal lesions again seen, many of which again demonstrate t1 shortening. on the postcontrast subtraction series, no obvious enhancement is noted. the adrenal glands and pancreas are intact. there is mild cholelithiasis, without gallbladder wall thickening or pericholecystic fluid. no free abdominal fluid is visualized. impression: 1. stable cirrhotic appearance of the liver. few subtly hypervascular hepatic lesions do not demonstrate washout, and probably relate to perfusion variants. no particularly suspicious hepatic mass is seen. 2. mild splenomegaly to 12.4 cm redemonstrated. splenorenal shunting is again seen. 3. scattered simple and complex renal cystic lesions, nonenhancing, stable from march 2040. 4. incidentally, there is evidence of signal gain in the liver on the out of phase series. this occasionally may represent iron overload.

Preprocess the Data

First, the impression (i.e., the findings / conclusions section) of the report is extracted, any doctor signatures are removed, and the report lowercased. This preprocessing section may need to be modified to accommodate your healthcare system’s reports, formatting, etc. The preprocess_note function is modified from nmrezman.utils.preprocess_input.

[2]:
def keyword_split(x, keywords, return_idx: int=2):
    """
    Extract portion of string given a list of possible delimiters (keywords) via partition method
    """
    for keyword in keywords:
        if x.partition(keyword)[2] !='':
            return x.partition(keyword)[return_idx]
    return x

def preprocess_note(note):
    """
    Get the impression from the note, remove doctor signature, and lowercase
    """
    impression_keywords = [
            "impression:",
            "conclusion(s):",
            "conclusions:",
            "conclusion:",
            "finding:",
            "findings:",
    ]
    signature_keywords = [
        "&#x20",
        "final report attending radiologist:",
    ]
    impressions = keyword_split(str(note).lower(), impression_keywords)
    impressions = keyword_split(impressions, signature_keywords, return_idx=0)
    return impressions

# Preprocess the note
modeling_df["impression"] = modeling_df["note"].apply(preprocess_note)
modeling_df = modeling_df[modeling_df["impression"].notnull()]
modeling_df["impression"] = modeling_df["impression"].apply(lambda x: str(x.encode('utf-8')) +"\n"+"\n")

Here we encode the findings label into integer labels for the model to interpret.

[3]:
from sklearn import preprocessing

# Encode the Lung, Adrenal, and No Finding into integer labels
le = preprocessing.LabelEncoder()
le.fit(modeling_df["selected_finding"])
modeling_df["int_labels"] = le.transform(modeling_df["selected_finding"])

The data is split into train and test sets (as lists so that it is formatted for the Dataset).

[4]:
from sklearn.model_selection import train_test_split

# Split the data into train and test
train_df, test_df = train_test_split(modeling_df, test_size=0.3, stratify=modeling_df["selected_finding"], random_state=37)
train_note = list(train_df["impression"])
train_label = list(train_df["int_labels"])
test_note = list(test_df["impression"])
test_label = list(test_df["int_labels"])

Tokenize and Define the Datasets

First, we define a tokenizer to mask words or word fragments to tokens. Here, we are using 🤗’s pretrained RoBERTa base model’s checkpoint. Padding is done on the left side since NM radiology reports generally have the findings at the end of the report. Note that you can change out the tokenizer and model to start from a different RoBERTa checkpoint (e.g., roberta-large).

[5]:
from transformers import AutoTokenizer

# Define the tokenizer (from a pre-trained checkpoint) and tokenize the notes
tokenizer = AutoTokenizer.from_pretrained("roberta-base", use_fast=True, padding_side="left")
train_encodings = tokenizer(train_note, truncation=True, padding=True)
val_encodings = tokenizer(test_note, truncation=True, padding=True)

Next, we define a custom Pytorch Dataset class. This will return the tokenized report text and integer label for a given index. 🤗 can easily use custom Pytorch Datasets for training data.

[6]:
import torch

class Reports_Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings: dict, labels: list) -> None:
        self.encodings = encodings
        self.labels = labels
        return

    def __getitem__(self, idx: int) -> dict:
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self) -> int:
        return len(self.labels)

# Define the trainign dataset with tokenized notes and labels
train_dataset = Reports_Dataset(train_encodings, train_label)
val_dataset = Reports_Dataset(val_encodings, test_label)

Fine-Tune the Model

First, we load the pretrained model (similar to the one that was pretrained via the notebook Demo: Phase 02 Pretraining the NM Results Management Language Model with Custom Corpus only trained on thousands of reports). This model will be fine-tuned to a specific task, which, in this case, is a multi-class classification problem that determines if a report has Lung Findings, Adrenal Findings, or No Findings.

[7]:
from transformers import AutoModelForSequenceClassification

# TODO: point to the pretrained model trained as part of the pretraining process
# Here, we are using a pretrained checkpoint trained on thousands of reports (vs the pretrained model wieghts generated via the notebook ``demo_pretrain``)
# To use the only directly trained by the notebook, use "/path/to/results/phase02/demo/checkpoint-4"
model_pretrained_path = "/path/to/results/phase02/demo/checkpoint-14500"

# Fine-tune the model from the pre-trained checkpoint
model = AutoModelForSequenceClassification.from_pretrained(model_pretrained_path, num_labels=3)

Here we begin training using the 🤗 Trainer, which will train according to the parameters specified in the 🤗 TrainingArguments. 🤗 will take care of all the training for us! When done, the last checkpoint will be used as for classifying reports for Lung, Adrenal, or No Findings.

[8]:
from transformers import Trainer, TrainingArguments

# Define the training parameters and 🤗 Trainer
training_args = TrainingArguments(
                    output_dir="/path/to/results/phase02/demo/findings",    # output directory
                    num_train_epochs=40,                                    # total number of training epochs
                    per_device_train_batch_size=16,                         # batch size per device during training
                    per_device_eval_batch_size=8,                           # batch size per device during evaluation
                    warmup_steps=100,                                       # number of warmup steps for learning rate scheduler
                    weight_decay=0.015,                                     # strength of weight decay
                    fp16=True,                                              # mixed precision training
                    do_predict=True,                                        # run predictions on test set
                    load_best_model_at_end=True,                            # load best model at end so we can run confusion matrix
                    logging_steps=2,                                        # remaining args are related to logging
                    save_total_limit=2,
                    evaluation_strategy="epoch",
                    save_strategy="epoch",
                    report_to="none",
)
trainer = Trainer(
                    model=model,                                            # the instantiated 🤗 Transformers model to be trained
                    args=training_args,                                     # training arguments, defined above
                    train_dataset=train_dataset,                            # training dataset
                    eval_dataset=val_dataset,                               # test (evaluation) dataset: save and eval strategy to match
)

# Train!
trainer.train()
Using amp half precision backend
/usr/local/lib/python3.8/dist-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
***** Running training *****
  Num examples = 7
  Num Epochs = 40
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 40
[40/40 01:11, Epoch 40/40]
Epoch Training Loss Validation Loss
1 No log 1.456469
2 0.004000 1.459345
3 0.004000 1.463199
4 0.003300 1.468944
5 0.003300 1.477593
6 0.004200 1.487210
7 0.004200 1.499711
8 0.003200 1.513178
9 0.003200 1.528578
10 0.003100 1.545932
11 0.003100 1.564255
12 0.002400 1.582583
13 0.002400 1.600918
14 0.001900 1.622189
15 0.001900 1.643487
16 0.002200 1.663810
17 0.002200 1.685127
18 0.001600 1.706464
19 0.001600 1.725871
20 0.001400 1.744312
21 0.001400 1.758865
22 0.001400 1.771481
23 0.001400 1.783134
24 0.001100 1.794791
25 0.001100 1.806458
26 0.001100 1.817151
27 0.001100 1.826876
28 0.000900 1.835629
29 0.000900 1.844386
30 0.000800 1.851195
31 0.000800 1.858008
32 0.000800 1.864821
33 0.000800 1.871639
34 0.000800 1.878457
35 0.000800 1.886249
36 0.000700 1.895020
37 0.000700 1.903792
38 0.000600 1.912562
39 0.000600 1.923288
40 0.000500 1.934011

***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-1
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-1/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-1/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-2
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-2/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-2/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-3
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-3/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-3/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-2] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-4
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-4/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-4/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-3] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-5
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-5/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-5/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-4] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-6
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-6/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-6/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-5] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-7
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-7/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-7/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-6] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-8
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-8/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-8/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-7] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-9
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-9/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-9/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-8] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-10
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-10/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-10/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-9] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-11
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-11/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-11/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-10] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-12
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-12/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-12/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-11] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-13
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-13/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-13/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-12] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-14
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-14/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-14/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-13] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-15
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-15/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-15/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-14] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-16
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-16/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-16/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-15] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-17
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-17/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-17/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-16] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-18
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-18/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-18/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-17] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-19
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-19/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-19/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-18] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-20
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-20/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-20/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-19] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-21
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-21/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-21/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-20] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-22
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-22/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-22/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-21] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-23
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-23/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-23/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-22] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-24
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-24/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-24/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-23] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-25
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-25/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-25/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-24] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-26
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-26/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-26/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-25] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-27
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-27/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-27/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-26] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-28
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-28/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-28/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-27] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-29
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-29/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-29/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-28] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-30
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-30/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-30/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-29] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-31
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-31/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-31/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-30] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-32
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-32/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-32/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-31] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-33
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-33/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-33/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-32] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-34
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-34/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-34/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-33] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-35
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-35/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-35/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-34] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-36
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-36/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-36/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-35] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-37
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-37/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-37/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-36] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-38
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-38/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-38/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-37] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-39
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-39/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-39/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-38] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4
  Batch size = 8
Saving model checkpoint to /path/to/results/phase02/demo/findings/checkpoint-40
Configuration saved in /path/to/results/phase02/demo/findings/checkpoint-40/config.json
Model weights saved in /path/to/results/phase02/demo/findings/checkpoint-40/pytorch_model.bin
Deleting older checkpoint [/path/to/results/phase02/demo/findings/checkpoint-39] due to args.save_total_limit


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from /path/to/results/phase02/demo/findings/checkpoint-1 (score: 1.4564694166183472).
[8]:
TrainOutput(global_step=40, training_loss=0.0017979430325794966, metrics={'train_runtime': 71.8828, 'train_samples_per_second': 3.895, 'train_steps_per_second': 0.556, 'total_flos': 17314211733840.0, 'train_loss': 0.0017979430325794966, 'epoch': 40.0})

Evaluate the Results

Using sklearn’s classification_report and confusion_matrix, we can evaluate how well the model performs on the test dataset.

[9]:
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

# Perform confusion matrix and print the results
y_pred = trainer.predict(val_dataset)
y_pred = np.argmax(y_pred.predictions, axis=1)
report = classification_report(test_label, y_pred)
matrix = confusion_matrix(test_label, y_pred)
print(report)
print(matrix)
***** Running Prediction *****
  Num examples = 4
  Batch size = 8
[1/1 : < :]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00         1
           1       0.50      1.00      0.67         1
           2       1.00      0.50      0.67         2

    accuracy                           0.75         4
   macro avg       0.83      0.83      0.78         4
weighted avg       0.88      0.75      0.75         4

[[1 0 0]
 [0 1 0]
 [0 1 1]]
[ ]: