Spaces:
Running
Running
zhang-ziang
commited on
Commit
·
74503df
1
Parent(s):
6965bae
infer aug
Browse files
app.py
CHANGED
|
@@ -8,6 +8,7 @@ import os
|
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
import io
|
| 10 |
from PIL import Image
|
|
|
|
| 11 |
import rembg
|
| 12 |
from typing import Any
|
| 13 |
import torch.nn.functional as F
|
|
@@ -97,6 +98,37 @@ def remove_background(image: Image,
|
|
| 97 |
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
| 98 |
return image
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
def get_3angle(image):
|
| 101 |
|
| 102 |
# image = Image.open(image_path).convert('RGB')
|
|
@@ -108,7 +140,7 @@ def get_3angle(image):
|
|
| 108 |
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
| 109 |
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
| 110 |
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
| 111 |
-
confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0]
|
| 112 |
angles = torch.zeros(4)
|
| 113 |
angles[0] = gaus_ax_pred
|
| 114 |
angles[1] = gaus_pl_pred - 90
|
|
@@ -116,18 +148,86 @@ def get_3angle(image):
|
|
| 116 |
angles[3] = confidence
|
| 117 |
return angles
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
def get_3angle_infer_aug(image):
|
| 120 |
|
| 121 |
# image = Image.open(image_path).convert('RGB')
|
|
|
|
| 122 |
image_inputs = val_preprocess(images = image)
|
| 123 |
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
|
| 124 |
with torch.no_grad():
|
| 125 |
dino_pred = dino(image_inputs)
|
| 126 |
|
| 127 |
-
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
| 128 |
-
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
| 129 |
-
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
angles = torch.zeros(4)
|
| 132 |
angles[0] = gaus_ax_pred
|
| 133 |
angles[1] = gaus_pl_pred - 90
|
|
@@ -221,7 +321,7 @@ def infer_func(img, do_rm_bkg, do_infer_aug):
|
|
| 221 |
|
| 222 |
res_img = figure_to_img(fig)
|
| 223 |
# axis_model = "axis.obj"
|
| 224 |
-
return [res_img, float(angles[0]), float(angles[1]), float(angles[2]), float(angles[3])]
|
| 225 |
|
| 226 |
server = gr.Interface(
|
| 227 |
flagging_mode='never',
|
|
|
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
import io
|
| 10 |
from PIL import Image
|
| 11 |
+
import random
|
| 12 |
import rembg
|
| 13 |
from typing import Any
|
| 14 |
import torch.nn.functional as F
|
|
|
|
| 98 |
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
| 99 |
return image
|
| 100 |
|
| 101 |
+
def random_crop(image, crop_scale=(0.8, 0.95)):
|
| 102 |
+
"""
|
| 103 |
+
随机裁切图片
|
| 104 |
+
image (numpy.ndarray): (H, W, C)。
|
| 105 |
+
crop_scale (tuple): (min_scale, max_scale)。
|
| 106 |
+
"""
|
| 107 |
+
assert isinstance(image, Image.Image), "iput must be PIL.Image.Image"
|
| 108 |
+
assert len(crop_scale) == 2 and 0 < crop_scale[0] <= crop_scale[1] <= 1
|
| 109 |
+
|
| 110 |
+
width, height = image.size
|
| 111 |
+
|
| 112 |
+
# 计算裁切的高度和宽度
|
| 113 |
+
crop_width = random.randint(int(width * crop_scale[0]), int(width * crop_scale[1]))
|
| 114 |
+
crop_height = random.randint(int(height * crop_scale[0]), int(height * crop_scale[1]))
|
| 115 |
+
|
| 116 |
+
# 随机选择裁切的起始点
|
| 117 |
+
left = random.randint(0, width - crop_width)
|
| 118 |
+
top = random.randint(0, height - crop_height)
|
| 119 |
+
|
| 120 |
+
# 裁切图片
|
| 121 |
+
cropped_image = image.crop((left, top, left + crop_width, top + crop_height))
|
| 122 |
+
|
| 123 |
+
return cropped_image
|
| 124 |
+
|
| 125 |
+
def get_crop_images(img, num=3):
|
| 126 |
+
cropped_images = []
|
| 127 |
+
for i in range(num):
|
| 128 |
+
cropped_images.append(random_crop(img))
|
| 129 |
+
return cropped_images
|
| 130 |
+
|
| 131 |
+
|
| 132 |
def get_3angle(image):
|
| 133 |
|
| 134 |
# image = Image.open(image_path).convert('RGB')
|
|
|
|
| 140 |
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
| 141 |
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
| 142 |
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
| 143 |
+
confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0][0]
|
| 144 |
angles = torch.zeros(4)
|
| 145 |
angles[0] = gaus_ax_pred
|
| 146 |
angles[1] = gaus_pl_pred - 90
|
|
|
|
| 148 |
angles[3] = confidence
|
| 149 |
return angles
|
| 150 |
|
| 151 |
+
def remove_outliers_and_average(tensor, threshold=1.5):
|
| 152 |
+
assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
|
| 153 |
+
|
| 154 |
+
q1 = torch.quantile(tensor, 0.25)
|
| 155 |
+
q3 = torch.quantile(tensor, 0.75)
|
| 156 |
+
iqr = q3 - q1
|
| 157 |
+
|
| 158 |
+
lower_bound = q1 - threshold * iqr
|
| 159 |
+
upper_bound = q3 + threshold * iqr
|
| 160 |
+
|
| 161 |
+
non_outliers = tensor[(tensor >= lower_bound) & (tensor <= upper_bound)]
|
| 162 |
+
|
| 163 |
+
if len(non_outliers) == 0:
|
| 164 |
+
return tensor.mean().item()
|
| 165 |
+
|
| 166 |
+
return non_outliers.mean().item()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def remove_outliers_and_average_circular(tensor, threshold=1.5):
|
| 170 |
+
assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
|
| 171 |
+
|
| 172 |
+
# 将角度转换为二维平面上的点
|
| 173 |
+
radians = tensor * torch.pi / 180.0
|
| 174 |
+
x_coords = torch.cos(radians)
|
| 175 |
+
y_coords = torch.sin(radians)
|
| 176 |
+
|
| 177 |
+
# 计算平均向量
|
| 178 |
+
mean_x = torch.mean(x_coords)
|
| 179 |
+
mean_y = torch.mean(y_coords)
|
| 180 |
+
|
| 181 |
+
differences = torch.sqrt((x_coords - mean_x) * (x_coords - mean_x) + (y_coords - mean_y) * (y_coords - mean_y))
|
| 182 |
+
|
| 183 |
+
# 计算四分位数和 IQR
|
| 184 |
+
q1 = torch.quantile(differences, 0.25)
|
| 185 |
+
q3 = torch.quantile(differences, 0.75)
|
| 186 |
+
iqr = q3 - q1
|
| 187 |
+
|
| 188 |
+
# 计算上下限
|
| 189 |
+
lower_bound = q1 - threshold * iqr
|
| 190 |
+
upper_bound = q3 + threshold * iqr
|
| 191 |
+
|
| 192 |
+
# 筛选非离群点
|
| 193 |
+
non_outliers = tensor[(differences >= lower_bound) & (differences <= upper_bound)]
|
| 194 |
+
|
| 195 |
+
if len(non_outliers) == 0:
|
| 196 |
+
mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
|
| 197 |
+
mean_angle = (mean_angle + 360) % 360
|
| 198 |
+
return mean_angle # 如果没有非离群点,返回 None
|
| 199 |
+
|
| 200 |
+
# 对非离群点再次计算平均向量
|
| 201 |
+
radians = non_outliers * torch.pi / 180.0
|
| 202 |
+
x_coords = torch.cos(radians)
|
| 203 |
+
y_coords = torch.sin(radians)
|
| 204 |
+
|
| 205 |
+
mean_x = torch.mean(x_coords)
|
| 206 |
+
mean_y = torch.mean(y_coords)
|
| 207 |
+
|
| 208 |
+
mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
|
| 209 |
+
mean_angle = (mean_angle + 360) % 360
|
| 210 |
+
|
| 211 |
+
return mean_angle
|
| 212 |
+
|
| 213 |
def get_3angle_infer_aug(image):
|
| 214 |
|
| 215 |
# image = Image.open(image_path).convert('RGB')
|
| 216 |
+
image = get_crop_images(image, num=6)
|
| 217 |
image_inputs = val_preprocess(images = image)
|
| 218 |
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
|
| 219 |
with torch.no_grad():
|
| 220 |
dino_pred = dino(image_inputs)
|
| 221 |
|
| 222 |
+
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1).to(torch.float32)
|
| 223 |
+
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1).to(torch.float32)
|
| 224 |
+
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1).to(torch.float32)
|
| 225 |
+
|
| 226 |
+
gaus_ax_pred = remove_outliers_and_average_circular(gaus_ax_pred)
|
| 227 |
+
gaus_pl_pred = remove_outliers_and_average(gaus_pl_pred)
|
| 228 |
+
gaus_ro_pred = remove_outliers_and_average(gaus_ro_pred)
|
| 229 |
+
|
| 230 |
+
confidence = torch.mean(F.softmax(dino_pred[:, -2:], dim=-1), dim=0)[0]
|
| 231 |
angles = torch.zeros(4)
|
| 232 |
angles[0] = gaus_ax_pred
|
| 233 |
angles[1] = gaus_pl_pred - 90
|
|
|
|
| 321 |
|
| 322 |
res_img = figure_to_img(fig)
|
| 323 |
# axis_model = "axis.obj"
|
| 324 |
+
return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]
|
| 325 |
|
| 326 |
server = gr.Interface(
|
| 327 |
flagging_mode='never',
|