manaswim commited on
Commit
48cbe57
·
1 Parent(s): 996b654

Add application file

Browse files
Files changed (1) hide show
  1. app.py +445 -0
app.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision.models import convnext_base, ConvNeXt_Base_Weights
4
+ from torchvision.models._api import WeightsEnum
5
+ from torch.hub import load_state_dict_from_url
6
+ from statistics import mean
7
+ import time, os
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from torchvision import datasets
11
+ from torchvision.transforms import ToTensor
12
+ import matplotlib.pyplot as plt
13
+ from torch.utils.data import DataLoader
14
+ import gradio as gr
15
+ from torchvision import transforms
16
+ import matplotlib.pyplot as plt
17
+
18
+ from typing import List, Tuple
19
+
20
+ from PIL import Image
21
+ from torch.utils.data import Subset
22
+ from torch import nn
23
+
24
+ from tqdm.auto import tqdm
25
+ from typing import Dict, List, Tuple
26
+
27
+ """
28
+ Contains functionality for creating PyTorch DataLoaders for
29
+ image classification data.
30
+ """
31
+ import os
32
+
33
+ from torchvision import datasets, transforms
34
+ from torch.utils.data import DataLoader
35
+
36
+ import torch
37
+
38
+ from tqdm.auto import tqdm
39
+ from typing import Dict, List, Tuple
40
+ import torch
41
+ import torchvision
42
+ from torchvision import transforms
43
+ import matplotlib.pyplot as plt
44
+
45
+ from typing import List, Tuple
46
+
47
+ from PIL import Image
48
+
49
+ NUM_WORKERS = os.cpu_count()
50
+
51
+ def create_dataloaders(
52
+ train_dir: str,
53
+ test_dir: str,
54
+ transform: transforms.Compose,
55
+ batch_size: int,
56
+ num_workers: int=NUM_WORKERS
57
+ ):
58
+ """Creates training and testing DataLoaders.
59
+
60
+ Takes in a training directory and testing directory path and turns
61
+ them into PyTorch Datasets and then into PyTorch DataLoaders.
62
+
63
+ Args:
64
+ train_dir: Path to training directory.
65
+ test_dir: Path to testing directory.
66
+ transform: torchvision transforms to perform on training and testing data.
67
+ batch_size: Number of samples per batch in each of the DataLoaders.
68
+ num_workers: An integer for number of workers per DataLoader.
69
+
70
+ Returns:
71
+ A tuple of (train_dataloader, test_dataloader, class_names).
72
+ Where class_names is a list of the target classes.
73
+ Example usage:
74
+ train_dataloader, test_dataloader, class_names = \
75
+ = create_dataloaders(train_dir=path/to/train_dir,
76
+ test_dir=path/to/test_dir,
77
+ transform=some_transform,
78
+ batch_size=32,
79
+ num_workers=4)
80
+ """
81
+ # Use ImageFolder to create dataset(s)
82
+ train_data = datasets.ImageFolder(train_dir, transform=transform)
83
+ test_data = datasets.ImageFolder(test_dir, transform=transform)
84
+
85
+ # Get class names
86
+ class_names = train_data.classes
87
+
88
+ # Turn images into data loaders
89
+ train_dataloader = DataLoader(
90
+ train_data,
91
+ batch_size=batch_size,
92
+ shuffle=True,
93
+ num_workers=num_workers,
94
+ pin_memory=True,
95
+ )
96
+ test_dataloader = DataLoader(
97
+ test_data,
98
+ batch_size=batch_size,
99
+ shuffle=False,
100
+ num_workers=num_workers,
101
+ pin_memory=True,
102
+ )
103
+
104
+ return train_dataloader, test_dataloader, class_names
105
+
106
+ """
107
+ Contains functions for training and testing a PyTorch model.
108
+ """
109
+
110
+ def train_step(model: torch.nn.Module,
111
+ dataloader: torch.utils.data.DataLoader,
112
+ loss_fn: torch.nn.Module,
113
+ optimizer: torch.optim.Optimizer,
114
+ device: torch.device) -> Tuple[float, float]:
115
+ """Trains a PyTorch model for a single epoch.
116
+
117
+ Turns a target PyTorch model to training mode and then
118
+ runs through all of the required training steps (forward
119
+ pass, loss calculation, optimizer step).
120
+
121
+ Args:
122
+ model: A PyTorch model to be trained.
123
+ dataloader: A DataLoader instance for the model to be trained on.
124
+ loss_fn: A PyTorch loss function to minimize.
125
+ optimizer: A PyTorch optimizer to help minimize the loss function.
126
+ device: A target device to compute on (e.g. "cuda" or "cpu").
127
+
128
+ Returns:
129
+ A tuple of training loss and training accuracy metrics.
130
+ In the form (train_loss, train_accuracy). For example:
131
+
132
+ (0.1112, 0.8743)
133
+ """
134
+ # Put model in train mode
135
+ model.train()
136
+
137
+ # Setup train loss and train accuracy values
138
+ train_loss, train_acc = 0, 0
139
+
140
+ # Loop through data loader data batches
141
+ for batch, (X, y) in enumerate(dataloader):
142
+ # Send data to target device
143
+ X, y = X.to(device), y.to(device)
144
+
145
+ # 1. Forward pass
146
+ y_pred = model(X)
147
+
148
+ # 2. Calculate and accumulate loss
149
+ loss = loss_fn(y_pred, y)
150
+ train_loss += loss.item()
151
+
152
+ # 3. Optimizer zero grad
153
+ optimizer.zero_grad()
154
+
155
+ # 4. Loss backward
156
+ loss.backward()
157
+
158
+ # 5. Optimizer step
159
+ optimizer.step()
160
+
161
+ # Calculate and accumulate accuracy metric across all batches
162
+ y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
163
+ train_acc += (y_pred_class == y).sum().item()/len(y_pred)
164
+
165
+ # Adjust metrics to get average loss and accuracy per batch
166
+ train_loss = train_loss / len(dataloader)
167
+ train_acc = train_acc / len(dataloader)
168
+ return train_loss, train_acc
169
+
170
+ def test_step(model: torch.nn.Module,
171
+ dataloader: torch.utils.data.DataLoader,
172
+ loss_fn: torch.nn.Module,
173
+ device: torch.device) -> Tuple[float, float]:
174
+ """Tests a PyTorch model for a single epoch.
175
+
176
+ Turns a target PyTorch model to "eval" mode and then performs
177
+ a forward pass on a testing dataset.
178
+
179
+ Args:
180
+ model: A PyTorch model to be tested.
181
+ dataloader: A DataLoader instance for the model to be tested on.
182
+ loss_fn: A PyTorch loss function to calculate loss on the test data.
183
+ device: A target device to compute on (e.g. "cuda" or "cpu").
184
+
185
+ Returns:
186
+ A tuple of testing loss and testing accuracy metrics.
187
+ In the form (test_loss, test_accuracy). For example:
188
+
189
+ (0.0223, 0.8985)
190
+ """
191
+ # Put model in eval mode
192
+ model.eval()
193
+
194
+ # Setup test loss and test accuracy values
195
+ test_loss, test_acc = 0, 0
196
+
197
+ # Turn on inference context manager
198
+ with torch.inference_mode():
199
+ # Loop through DataLoader batches
200
+ for batch, (X, y) in enumerate(dataloader):
201
+ # Send data to target device
202
+ X, y = X.to(device), y.to(device)
203
+
204
+ # 1. Forward pass
205
+ test_pred_logits = model(X)
206
+
207
+ # 2. Calculate and accumulate loss
208
+ loss = loss_fn(test_pred_logits, y)
209
+ test_loss += loss.item()
210
+
211
+ # Calculate and accumulate accuracy
212
+ test_pred_labels = test_pred_logits.argmax(dim=1)
213
+ test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))
214
+
215
+ # Adjust metrics to get average loss and accuracy per batch
216
+ test_loss = test_loss / len(dataloader)
217
+ test_acc = test_acc / len(dataloader)
218
+ return test_loss, test_acc
219
+
220
+ def train(model: torch.nn.Module,
221
+ train_dataloader: torch.utils.data.DataLoader,
222
+ test_dataloader: torch.utils.data.DataLoader,
223
+ optimizer: torch.optim.Optimizer,
224
+ loss_fn: torch.nn.Module,
225
+ epochs: int,
226
+ device: torch.device) -> Dict[str, List]:
227
+ """Trains and tests a PyTorch model.
228
+
229
+ Passes a target PyTorch models through train_step() and test_step()
230
+ functions for a number of epochs, training and testing the model
231
+ in the same epoch loop.
232
+
233
+ Calculates, prints and stores evaluation metrics throughout.
234
+
235
+ Args:
236
+ model: A PyTorch model to be trained and tested.
237
+ train_dataloader: A DataLoader instance for the model to be trained on.
238
+ test_dataloader: A DataLoader instance for the model to be tested on.
239
+ optimizer: A PyTorch optimizer to help minimize the loss function.
240
+ loss_fn: A PyTorch loss function to calculate loss on both datasets.
241
+ epochs: An integer indicating how many epochs to train for.
242
+ device: A target device to compute on (e.g. "cuda" or "cpu").
243
+
244
+ Returns:
245
+ A dictionary of training and testing loss as well as training and
246
+ testing accuracy metrics. Each metric has a value in a list for
247
+ each epoch.
248
+ In the form: {train_loss: [...],
249
+ train_acc: [...],
250
+ test_loss: [...],
251
+ test_acc: [...]}
252
+ For example if training for epochs=2:
253
+ {train_loss: [2.0616, 1.0537],
254
+ train_acc: [0.3945, 0.3945],
255
+ test_loss: [1.2641, 1.5706],
256
+ test_acc: [0.3400, 0.2973]}
257
+ """
258
+ # Create empty results dictionary
259
+ results = {"train_loss": [],
260
+ "train_acc": [],
261
+ "test_loss": [],
262
+ "test_acc": []
263
+ }
264
+
265
+ # Make sure model on target device
266
+ model.to(device)
267
+
268
+ # Loop through training and testing steps for a number of epochs
269
+ for epoch in tqdm(range(epochs)):
270
+ train_loss, train_acc = train_step(model=model,
271
+ dataloader=train_dataloader,
272
+ loss_fn=loss_fn,
273
+ optimizer=optimizer,
274
+ device=device)
275
+ test_loss, test_acc = test_step(model=model,
276
+ dataloader=test_dataloader,
277
+ loss_fn=loss_fn,
278
+ device=device)
279
+
280
+ # Print out what's happening
281
+ print(
282
+ f"Epoch: {epoch+1} | "
283
+ f"train_loss: {train_loss:.4f} | "
284
+ f"train_acc: {train_acc:.4f} | "
285
+ f"test_loss: {test_loss:.4f} | "
286
+ f"test_acc: {test_acc:.4f}"
287
+ )
288
+
289
+ # Update results dictionary
290
+ results["train_loss"].append(train_loss)
291
+ results["train_acc"].append(train_acc)
292
+ results["test_loss"].append(test_loss)
293
+ results["test_acc"].append(test_acc)
294
+
295
+ # Return the filled results at the end of the epochs
296
+ return results
297
+
298
+ """
299
+ Utility functions to make predictions.
300
+
301
+ Main reference for code creation: https://www.learnpytorch.io/06_pytorch_transfer_learning/#6-make-predictions-on-images-from-the-test-set
302
+ """
303
+
304
+ # Set device
305
+ device = "cuda" if torch.cuda.is_available() else "cpu"
306
+
307
+ # Predict on a target image with a target model
308
+ # Function created in: https://www.learnpytorch.io/06_pytorch_transfer_learning/#6-make-predictions-on-images-from-the-test-set
309
+ def pred_and_plot_image(
310
+ model: torch.nn.Module,
311
+ class_names: List[str],
312
+ image_path: str,
313
+ image_size: Tuple[int, int] = (224, 224),
314
+ transform: torchvision.transforms = None,
315
+ device: torch.device = device,
316
+ ):
317
+ """Predicts on a target image with a target model.
318
+
319
+ Args:
320
+ model (torch.nn.Module): A trained (or untrained) PyTorch model to predict on an image.
321
+ class_names (List[str]): A list of target classes to map predictions to.
322
+ image_path (str): Filepath to target image to predict on.
323
+ image_size (Tuple[int, int], optional): Size to transform target image to. Defaults to (224, 224).
324
+ transform (torchvision.transforms, optional): Transform to perform on image. Defaults to None which uses ImageNet normalization.
325
+ device (torch.device, optional): Target device to perform prediction on. Defaults to device.
326
+ """
327
+
328
+ # Open image
329
+ img = Image.open(image_path)
330
+
331
+ # Create transformation for image (if one doesn't exist)
332
+ if transform is not None:
333
+ image_transform = transform
334
+ else:
335
+ image_transform = transforms.Compose(
336
+ [
337
+ transforms.Resize(image_size),
338
+ transforms.ToTensor(),
339
+ transforms.Normalize(
340
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
341
+ ),
342
+ ]
343
+ )
344
+
345
+ ### Predict on image ###
346
+
347
+ # Make sure the model is on the target device
348
+ model.to(device)
349
+
350
+ # Turn on model evaluation mode and inference mode
351
+ model.eval()
352
+ with torch.inference_mode():
353
+ # Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
354
+ transformed_image = image_transform(img).unsqueeze(dim=0)
355
+
356
+ # Make a prediction on image with an extra dimension and send it to the target device
357
+ target_image_pred = model(transformed_image.to(device))
358
+
359
+ # Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
360
+ target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
361
+
362
+ # Convert prediction probabilities -> prediction labels
363
+ target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
364
+
365
+ # Plot image with predicted label and probability
366
+ plt.figure()
367
+ plt.imshow(img)
368
+ plt.title(
369
+ f"Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}"
370
+ )
371
+ plt.axis(False)
372
+
373
+ BATCH_SIZE = 32
374
+
375
+ device = "cuda" if torch.cuda.is_available() else "cpu"
376
+
377
+ training_datab, test_datab = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=ConvNeXt_Base_Weights.DEFAULT.transforms()), torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=ConvNeXt_Base_Weights.DEFAULT.transforms())
378
+ subset_train, subset_test = Subset(training_datab, indices=range(len(training_datab) // 1000)), Subset(test_datab, indices=range(len(test_datab) // 1000)) # delete here IMPORTANT!!!!!!!!!!!!!!!!!!!!!!!
379
+
380
+ def get_state_dict(self, *args, **kwargs):
381
+ kwargs.pop("check_hash")
382
+ return load_state_dict_from_url(self.url, *args, **kwargs)
383
+ WeightsEnum.get_state_dict = get_state_dict
384
+
385
+ modeld = convnext_base(ConvNeXt_Base_Weights.DEFAULT)
386
+
387
+ modeld.classifier = nn.Sequential(
388
+ nn.LayerNorm((1024, 1, 1), eps=1e-06, elementwise_affine=True),
389
+ nn.Flatten(start_dim=1, end_dim=-1),
390
+ nn.Linear(in_features=1024, out_features=10, bias=True)
391
+ )
392
+
393
+ optimizerd = torch.optim.Adam(modeld.parameters(), 0.001)
394
+
395
+ loss_fn = nn.CrossEntropyLoss()
396
+ epochs = 5
397
+
398
+ train_dataloaderd, test_dataloaderd = DataLoader(subset_train, batch_size=BATCH_SIZE, shuffle=True), DataLoader(subset_test, batch_size=BATCH_SIZE, shuffle=False) # change data here IMPORTANT!!!!!!!!!!!!!!!!!!!!!!!
399
+
400
+ # engine.train(modeld, train_dataloaderd, test_dataloaderd, optimizerd, loss_fn, epochs, device)
401
+
402
+ def pred_image(image_path: str, model: torch.nn.Module = modeld, class_names: List[str] = training_datab.classes, image_size: Tuple[int, int] = (224, 224), transform: torchvision.transforms = ConvNeXt_Base_Weights.DEFAULT.transforms(), device: torch.device = device):
403
+ # Open image
404
+ img = Image.open(image_path)
405
+
406
+ # Create transformation for image (if one doesn't exist)
407
+ if transform is not None:
408
+ image_transform = transform
409
+ else:
410
+ image_transform = transforms.Compose(
411
+ [
412
+ transforms.Resize(image_size),
413
+ transforms.ToTensor(),
414
+ transforms.Normalize(
415
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
416
+ ),
417
+ ]
418
+ )
419
+
420
+ ### Predict on image ###
421
+
422
+ # Make sure the model is on the target device
423
+ model.to(device)
424
+
425
+ # Turn on model evaluation mode and inference mode
426
+ model.eval()
427
+ with torch.inference_mode():
428
+ # Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
429
+ transformed_image = image_transform(img).unsqueeze(dim=0)
430
+
431
+ # Make a prediction on image with an extra dimension and send it to the target device
432
+ target_image_pred = model(transformed_image.to(device))
433
+
434
+ # Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
435
+ target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
436
+
437
+ # Convert prediction probabilities -> prediction labels
438
+ target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
439
+
440
+ return class_names[target_image_pred_label], target_image_pred_probs.max()
441
+
442
+
443
+ demo = gr.Interface(fn=pred_image, inputs=gr.Image(type="filepath"), outputs=[gr.Textbox(label="label"), gr.Textbox(label="probability")], examples=["apple.jpg","bird.jpg","car.jpg","ocean.jpg"])
444
+
445
+ demo.launch(share=True)