-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
36c5593
commit a251afc
Showing
9 changed files
with
226 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
--- | ||
title: Lens | ||
emoji: 🐢 | ||
emoji: 📷 | ||
colorFrom: yellow | ||
colorTo: indigo | ||
sdk: gradio | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
### 1. Imports and class names setup ### | ||
import gradio as gr | ||
import os | ||
import torch | ||
|
||
from model import create_effnetb2_model | ||
from timeit import default_timer as timer | ||
from typing import Tuple, Dict | ||
|
||
# Setup class names | ||
with open("class_names.txt", "r") as f: # reading them in from class_names.txt | ||
class_names = [food_name.strip() for food_name in f.readlines()] | ||
|
||
### 2. Model and transforms preparation ### | ||
|
||
# Create model | ||
effnetb2, effnetb2_transforms = create_effnetb2_model( | ||
num_classes=len(class_names), | ||
) | ||
|
||
# Load saved weights | ||
effnetb2.load_state_dict( | ||
torch.load( | ||
f="models/09_pretrained_effnetb2_feature_extractor_food101_20_percent.pth", | ||
map_location=torch.device("cpu"), # load to CPU | ||
) | ||
) | ||
|
||
### 3. Predict function ### | ||
|
||
# Create predict function | ||
def predict(img) -> Tuple[Dict, float]: | ||
"""Transforms and performs a prediction on img and returns prediction and time taken. | ||
""" | ||
# Start the timer | ||
start_time = timer() | ||
|
||
# Transform the target image and add a batch dimension | ||
img = effnetb2_transforms(img).unsqueeze(0) | ||
|
||
# Put model into evaluation mode and turn on inference mode | ||
effnetb2.eval() | ||
with torch.inference_mode(): | ||
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities | ||
pred_probs = torch.softmax(effnetb2(img), dim=1) | ||
|
||
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter) | ||
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))} | ||
|
||
# Calculate the prediction time | ||
pred_time = round(timer() - start_time, 5) | ||
|
||
# Return the prediction dictionary and prediction time | ||
return pred_labels_and_probs, pred_time | ||
|
||
### 4. Gradio app ### | ||
|
||
# Create title, description and article strings | ||
title = "Lens 📷" | ||
description = "An EfficientNetB2 feature extractor computer vision model to classify images of food into 101 different classes" | ||
|
||
# Create examples list from "examples/" directory | ||
example_list = [["examples/" + example] for example in os.listdir("examples")] | ||
|
||
# Create Gradio interface | ||
demo = gr.Interface( | ||
fn=predict, | ||
inputs=gr.Image(type="pil"), | ||
outputs=[ | ||
gr.Label(num_top_classes=5, label="Predictions"), | ||
gr.Number(label="Prediction time (s)"), | ||
], | ||
examples=example_list, | ||
title=title, | ||
description=description, | ||
article=article, | ||
) | ||
|
||
# Launch the app! | ||
demo.launch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
apple_pie | ||
baby_back_ribs | ||
baklava | ||
beef_carpaccio | ||
beef_tartare | ||
beet_salad | ||
beignets | ||
bibimbap | ||
bread_pudding | ||
breakfast_burrito | ||
bruschetta | ||
caesar_salad | ||
cannoli | ||
caprese_salad | ||
carrot_cake | ||
ceviche | ||
cheese_plate | ||
cheesecake | ||
chicken_curry | ||
chicken_quesadilla | ||
chicken_wings | ||
chocolate_cake | ||
chocolate_mousse | ||
churros | ||
clam_chowder | ||
club_sandwich | ||
crab_cakes | ||
creme_brulee | ||
croque_madame | ||
cup_cakes | ||
deviled_eggs | ||
donuts | ||
dumplings | ||
edamame | ||
eggs_benedict | ||
escargots | ||
falafel | ||
filet_mignon | ||
fish_and_chips | ||
foie_gras | ||
french_fries | ||
french_onion_soup | ||
french_toast | ||
fried_calamari | ||
fried_rice | ||
frozen_yogurt | ||
garlic_bread | ||
gnocchi | ||
greek_salad | ||
grilled_cheese_sandwich | ||
grilled_salmon | ||
guacamole | ||
gyoza | ||
hamburger | ||
hot_and_sour_soup | ||
hot_dog | ||
huevos_rancheros | ||
hummus | ||
ice_cream | ||
lasagna | ||
lobster_bisque | ||
lobster_roll_sandwich | ||
macaroni_and_cheese | ||
macarons | ||
miso_soup | ||
mussels | ||
nachos | ||
omelette | ||
onion_rings | ||
oysters | ||
pad_thai | ||
paella | ||
pancakes | ||
panna_cotta | ||
peking_duck | ||
pho | ||
pizza | ||
pork_chop | ||
poutine | ||
prime_rib | ||
pulled_pork_sandwich | ||
ramen | ||
ravioli | ||
red_velvet_cake | ||
risotto | ||
samosa | ||
sashimi | ||
scallops | ||
seaweed_salad | ||
shrimp_and_grits | ||
spaghetti_bolognese | ||
spaghetti_carbonara | ||
spring_rolls | ||
steak | ||
strawberry_shortcake | ||
sushi | ||
tacos | ||
takoyaki | ||
tiramisu | ||
tuna_tartare | ||
waffles |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
import torchvision | ||
|
||
from torch import nn | ||
|
||
|
||
def create_effnetb2_model(num_classes:int=3, | ||
seed:int=42): | ||
"""Creates an EfficientNetB2 feature extractor model and transforms. | ||
Args: | ||
num_classes (int, optional): number of classes in the classifier head. | ||
Defaults to 3. | ||
seed (int, optional): random seed value. Defaults to 42. | ||
Returns: | ||
model (torch.nn.Module): EffNetB2 feature extractor model. | ||
transforms (torchvision.transforms): EffNetB2 image transforms. | ||
""" | ||
# Create EffNetB2 pretrained weights, transforms and model | ||
weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT | ||
transforms = weights.transforms() | ||
model = torchvision.models.efficientnet_b2(weights=weights) | ||
|
||
# Freeze all layers in base model | ||
for param in model.parameters(): | ||
param.requires_grad = False | ||
|
||
# Change classifier head with random seed for reproducibility | ||
torch.manual_seed(seed) | ||
model.classifier = nn.Sequential( | ||
nn.Dropout(p=0.3, inplace=True), | ||
nn.Linear(in_features=1408, out_features=num_classes), | ||
) | ||
|
||
return model, transforms |
3 changes: 3 additions & 0 deletions
3
models/09_pretrained_effnetb2_feature_extractor_food101_20_percent.pth
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
torch==2.2.1 | ||
torchvision==0.17.1 | ||
gradio==4.22.0 |