Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
mohammadhprp committed Aug 13, 2024
1 parent 36c5593 commit a251afc
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
models/09_pretrained_effnetb2_feature_extractor_food101_20_percent.pth filter=lfs diff=lfs merge=lfs -text
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.DS_Store
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
title: Lens
emoji: 🐢
emoji: 📷
colorFrom: yellow
colorTo: indigo
sdk: gradio
Expand Down
80 changes: 80 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch

from model import create_effnetb2_model
from timeit import default_timer as timer
from typing import Tuple, Dict

# Setup class names
with open("class_names.txt", "r") as f: # reading them in from class_names.txt
class_names = [food_name.strip() for food_name in f.readlines()]

### 2. Model and transforms preparation ###

# Create model
effnetb2, effnetb2_transforms = create_effnetb2_model(
num_classes=len(class_names),
)

# Load saved weights
effnetb2.load_state_dict(
torch.load(
f="models/09_pretrained_effnetb2_feature_extractor_food101_20_percent.pth",
map_location=torch.device("cpu"), # load to CPU
)
)

### 3. Predict function ###

# Create predict function
def predict(img) -> Tuple[Dict, float]:
"""Transforms and performs a prediction on img and returns prediction and time taken.
"""
# Start the timer
start_time = timer()

# Transform the target image and add a batch dimension
img = effnetb2_transforms(img).unsqueeze(0)

# Put model into evaluation mode and turn on inference mode
effnetb2.eval()
with torch.inference_mode():
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
pred_probs = torch.softmax(effnetb2(img), dim=1)

# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}

# Calculate the prediction time
pred_time = round(timer() - start_time, 5)

# Return the prediction dictionary and prediction time
return pred_labels_and_probs, pred_time

### 4. Gradio app ###

# Create title, description and article strings
title = "Lens 📷"
description = "An EfficientNetB2 feature extractor computer vision model to classify images of food into 101 different classes"

# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]

# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=5, label="Predictions"),
gr.Number(label="Prediction time (s)"),
],
examples=example_list,
title=title,
description=description,
article=article,
)

# Launch the app!
demo.launch()
101 changes: 101 additions & 0 deletions class_names.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
apple_pie
baby_back_ribs
baklava
beef_carpaccio
beef_tartare
beet_salad
beignets
bibimbap
bread_pudding
breakfast_burrito
bruschetta
caesar_salad
cannoli
caprese_salad
carrot_cake
ceviche
cheese_plate
cheesecake
chicken_curry
chicken_quesadilla
chicken_wings
chocolate_cake
chocolate_mousse
churros
clam_chowder
club_sandwich
crab_cakes
creme_brulee
croque_madame
cup_cakes
deviled_eggs
donuts
dumplings
edamame
eggs_benedict
escargots
falafel
filet_mignon
fish_and_chips
foie_gras
french_fries
french_onion_soup
french_toast
fried_calamari
fried_rice
frozen_yogurt
garlic_bread
gnocchi
greek_salad
grilled_cheese_sandwich
grilled_salmon
guacamole
gyoza
hamburger
hot_and_sour_soup
hot_dog
huevos_rancheros
hummus
ice_cream
lasagna
lobster_bisque
lobster_roll_sandwich
macaroni_and_cheese
macarons
miso_soup
mussels
nachos
omelette
onion_rings
oysters
pad_thai
paella
pancakes
panna_cotta
peking_duck
pho
pizza
pork_chop
poutine
prime_rib
pulled_pork_sandwich
ramen
ravioli
red_velvet_cake
risotto
samosa
sashimi
scallops
seaweed_salad
shrimp_and_grits
spaghetti_bolognese
spaghetti_carbonara
spring_rolls
steak
strawberry_shortcake
sushi
tacos
takoyaki
tiramisu
tuna_tartare
waffles
Binary file added examples/04-pizza.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 36 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torchvision

from torch import nn


def create_effnetb2_model(num_classes:int=3,
seed:int=42):
"""Creates an EfficientNetB2 feature extractor model and transforms.
Args:
num_classes (int, optional): number of classes in the classifier head.
Defaults to 3.
seed (int, optional): random seed value. Defaults to 42.
Returns:
model (torch.nn.Module): EffNetB2 feature extractor model.
transforms (torchvision.transforms): EffNetB2 image transforms.
"""
# Create EffNetB2 pretrained weights, transforms and model
weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
transforms = weights.transforms()
model = torchvision.models.efficientnet_b2(weights=weights)

# Freeze all layers in base model
for param in model.parameters():
param.requires_grad = False

# Change classifier head with random seed for reproducibility
torch.manual_seed(seed)
model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408, out_features=num_classes),
)

return model, transforms
Git LFS file not shown
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch==2.2.1
torchvision==0.17.1
gradio==4.22.0

0 comments on commit a251afc

Please sign in to comment.