Skip to content

Commit 17e81c1

Browse files
Hhhhhhaosvekarscarljparker
authored
Added USB tutorial script (#2676)
* update USB tutorial script * Update usb_semisup_learn.py * Update usb_semisup_learn.py Fix typos * fix usb PyTorch and add more intro on PyTorch * add USB tutorial to index --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com> Co-authored-by: Carl Parker <carljparker@meta.com>
1 parent 37462ab commit 17e81c1

File tree

3 files changed

+233
-0
lines changed

3 files changed

+233
-0
lines changed

advanced_source/usb_semisup_learn.py

+224
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
"""
2+
Semi-Supervised Learning using USB built upon PyTorch
3+
=============================
4+
5+
6+
**Author**: `Hao Chen <https://github.com/Hhhhhhao>`_
7+
8+
9+
Introduction
10+
------------
11+
12+
USB is a semi-supervised learning framework built upon PyTorch.
13+
Based on Datasets and Modules provided by PyTorch, USB becomes a flexible, modular, and easy-to-use framework for semi-supervised learning.
14+
It supports a variety of semi-supervised learning algorithms, including FixMatch, FreeMatch, DeFixMatch, SoftMatch, etc.
15+
It also supports a variety of imbalanced semi-supervised learning algorithms.
16+
The benchmark results across different datasets of computer vision, natural language processing, and speech processing are included in USB.
17+
18+
This tutorial will walk you through the basics of using the usb lighting package.
19+
Let's get started by training a FreeMatch/SoftMatch model on CIFAR-10 using pre-trained ViT!
20+
And we will show it is easy to change the semi-supervised algorithm and train on imbalanced datasets.
21+
22+
23+
.. figure:: /_static/img/usb_semisup_learn/code.png
24+
:alt: USB framework illustration
25+
"""
26+
27+
28+
######################################################################
29+
# Introduction to FreeMatch and SoftMatch in Semi-Supervised Learning
30+
# --------------------
31+
# Here we provide a brief introduction to FreeMatch and SoftMatch.
32+
# First we introduce a famous baseline for semi-supervised learning called FixMatch.
33+
# FixMatch is a very simple framework for semi-supervised learning, where it utilizes a strong augmentation to generate pseudo labels for unlabeled data.
34+
# It adopts a confidence thresholding strategy to filter out the low-confidence pseudo labels with a fixed threshold set.
35+
# FreeMatch and SoftMatch are two algorithms that improve upon FixMatch.
36+
# FreeMatch proposes adaptive thresholding strategy to replace the fixed thresholding strategy in FixMatch.
37+
# The adaptive thresholding progressively increases the threshold according to the learning status of the model on each class.
38+
# SoftMatch absorbs the idea of confidence thresholding as an weighting mechanism.
39+
# It proposes a Gaussian weighting mechanism to overcome the quantity-quality trade-off in pseudo-labels.
40+
# In this tutorial, we will use USB to train FreeMatch and SoftMatch.
41+
42+
43+
######################################################################
44+
# Use USB to Train FreeMatch/SoftMatch on CIFAR-10 with only 40 labels
45+
# --------------------
46+
# USB is a Pytorch-based Python package for Semi-Supervised Learning (SSL).
47+
# It is easy-to-use/extend, affordable to small groups, and comprehensive for developing and evaluating SSL algorithms.
48+
# USB provides the implementation of 14 SSL algorithms based on Consistency Regularization, and 15 tasks for evaluation from CV, NLP, and Audio domain.
49+
# It has a modular design that allows users to easily extend the package by adding new algorithms and tasks.
50+
# It also supports a python api for easier adaptation to different SSL algorithms on new data.
51+
#
52+
#
53+
# Now, let's use USB to train FreeMatch and SoftMatch on CIFAR-10.
54+
# First, we need to install USB package ``semilearn`` and import necessary api functions from USB.
55+
# Below is a list of functions we will use from ``semilearn``:
56+
# - ``get_dataset`` to load dataset, here we use CIFAR-10
57+
# - ``get_data_loader`` to create train (labeled and unlabeled) and test data loaders, the train unlabeled loaders will provide both strong and weak augmentation of unlabeled data
58+
# - ``get_net_builder`` to create a model, here we use pre-trained ViT
59+
# - ``get_algorithm`` to create the semi-supervised learning algorithm, here we use FreeMatch and SoftMatch
60+
# - ``get_config``: to get default configuration of the algorithm
61+
# - ``Trainer``: a Trainer class for training and evaluating the algorithm on dataset
62+
#
63+
import semilearn
64+
from semilearn import get_dataset, get_data_loader, get_net_builder, get_algorithm, get_config, Trainer
65+
66+
######################################################################
67+
# After importing necessary functions, we first set the hyper-parameters of the algorithm.
68+
#
69+
config = {
70+
'algorithm': 'freematch',
71+
'net': 'vit_tiny_patch2_32',
72+
'use_pretrain': True,
73+
'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',
74+
75+
# optimization configs
76+
'epoch': 1,
77+
'num_train_iter': 4000,
78+
'num_eval_iter': 500,
79+
'num_log_iter': 50,
80+
'optim': 'AdamW',
81+
'lr': 5e-4,
82+
'layer_decay': 0.5,
83+
'batch_size': 16,
84+
'eval_batch_size': 16,
85+
86+
87+
# dataset configs
88+
'dataset': 'cifar10',
89+
'num_labels': 40,
90+
'num_classes': 10,
91+
'img_size': 32,
92+
'crop_ratio': 0.875,
93+
'data_dir': './data',
94+
'ulb_samples_per_class': None,
95+
96+
# algorithm specific configs
97+
'hard_label': True,
98+
'T': 0.5,
99+
'ema_p': 0.999,
100+
'ent_loss_ratio': 0.001,
101+
'uratio': 2,
102+
'ulb_loss_ratio': 1.0,
103+
104+
# device configs
105+
'gpu': 0,
106+
'world_size': 1,
107+
'distributed': False,
108+
"num_workers": 4,
109+
}
110+
config = get_config(config)
111+
112+
113+
######################################################################
114+
# Then, we load the dataset and create data loaders for training and testing.
115+
# And we specify the model and algorithm to use.
116+
#
117+
dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)
118+
train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)
119+
train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))
120+
eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)
121+
algorithm = get_algorithm(config, get_net_builder(config.net, from_name=False), tb_log=None, logger=None)
122+
123+
124+
######################################################################
125+
# We can start Train the algorithms on CIFAR-10 with 40 labels now.
126+
# We train for 4000 iterations and evaluate every 500 iterations.
127+
#
128+
trainer = Trainer(config, algorithm)
129+
trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
130+
131+
132+
######################################################################
133+
# Finally, let's evaluate the trained model on validation set.
134+
# After training 4000 iterations with FreeMatch on only 40 labels of CIFAR-10, we obtain a classifier that achieves above 93 accuracy on validation set.
135+
trainer.evaluate(eval_loader)
136+
137+
138+
139+
######################################################################
140+
# Use USB to Train SoftMatch with specific imbalanced algorithm on imbalanced CIFAR-10
141+
# --------------------
142+
#
143+
# Now let's say we have imbalanced labeled set and unlabeled set of CIFAR-10, and we want to train a SoftMatch model on it.
144+
# We create an imbalanced labeled set and imbalanced unlabeled set of CIFAR-10, by setting the ``lb_imb_ratio`` and ``ulb_imb_ratio`` to 10.
145+
# Also we replace the ``algorithm`` with ``softmatch`` and set the ``imbalanced`` to ``True``.
146+
#
147+
config = {
148+
'algorithm': 'softmatch',
149+
'net': 'vit_tiny_patch2_32',
150+
'use_pretrain': True,
151+
'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',
152+
153+
# optimization configs
154+
'epoch': 1,
155+
'num_train_iter': 4000,
156+
'num_eval_iter': 500,
157+
'num_log_iter': 50,
158+
'optim': 'AdamW',
159+
'lr': 5e-4,
160+
'layer_decay': 0.5,
161+
'batch_size': 16,
162+
'eval_batch_size': 16,
163+
164+
165+
# dataset configs
166+
'dataset': 'cifar10',
167+
'num_labels': 1500,
168+
'num_classes': 10,
169+
'img_size': 32,
170+
'crop_ratio': 0.875,
171+
'data_dir': './data',
172+
'ulb_samples_per_class': None,
173+
'lb_imb_ratio': 10,
174+
'ulb_imb_ratio': 10,
175+
'ulb_num_labels': 3000,
176+
177+
# algorithm specific configs
178+
'hard_label': True,
179+
'T': 0.5,
180+
'ema_p': 0.999,
181+
'ent_loss_ratio': 0.001,
182+
'uratio': 2,
183+
'ulb_loss_ratio': 1.0,
184+
185+
# device configs
186+
'gpu': 0,
187+
'world_size': 1,
188+
'distributed': False,
189+
"num_workers": 4,
190+
}
191+
config = get_config(config)
192+
193+
######################################################################
194+
# Then, we re-load the dataset and create data loaders for training and testing.
195+
# And we specify the model and algorithm to use.
196+
#
197+
dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)
198+
train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)
199+
train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))
200+
eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)
201+
algorithm = get_algorithm(config, get_net_builder(config.net, from_name=False), tb_log=None, logger=None)
202+
203+
204+
######################################################################
205+
# We can start Train the algorithms on CIFAR-10 with 40 labels now.
206+
# We train for 4000 iterations and evaluate every 500 iterations.
207+
#
208+
trainer = Trainer(config, algorithm)
209+
trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
210+
211+
212+
######################################################################
213+
# Finally, let's evaluate the trained model on validation set.
214+
#
215+
trainer.evaluate(eval_loader)
216+
217+
218+
219+
######################################################################
220+
# References
221+
# [1] USB: https://github.com/microsoft/Semi-supervised-learning
222+
# [2] Kihyuk Sohn et al. FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
223+
# [3] Yidong Wang et al. FreeMatch: Self-adaptive Thresholding for Semi-supervised Learning
224+
# [4] Hao Chen et al. SoftMatch: Addressing the Quantity-Quality Trade-off in Semi-supervised Learning

index.rst

+7
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,13 @@ What's new in PyTorch tutorials?
142142
:link: intermediate/spatial_transformer_tutorial.html
143143
:tags: Image/Video
144144

145+
.. customcarditem::
146+
:header: Semi-Supervised Learning Tutorial Based on USB
147+
:card_description: Learn how to train semi-supervised learning algorithms (on custom data) using USB and PyTorch.
148+
:image: _static/img/usb_semisup_learn/code.png
149+
:link: advanced/usb_semisup_learn.html
150+
:tags: Image/Video
151+
145152
.. Audio
146153
147154
.. customcarditem::

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,6 @@ gymnasium[mujoco]==0.27.0
5959
timm
6060
iopath
6161
pygame==2.1.2
62+
semilearn==0.3.2
63+
6264

0 commit comments

Comments
 (0)