Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -59,7 +59,7 @@ model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
|
| 59 |
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
| 60 |
model_id_or_path,
|
| 61 |
revision="fp16",
|
| 62 |
-
torch_dtype=torch.
|
| 63 |
use_auth_token=auth_token
|
| 64 |
)
|
| 65 |
#self.register_buffer('n_', ...)
|
|
@@ -87,7 +87,7 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
| 87 |
elif(radio == "type what to keep"):
|
| 88 |
img = transform(dict["image"]).squeeze(0)
|
| 89 |
word_masks = [word_mask]
|
| 90 |
-
with torch.no_grad():
|
| 91 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
| 92 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
| 93 |
filename = f"{uuid.uuid4()}.png"
|
|
@@ -101,7 +101,7 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
| 101 |
else:
|
| 102 |
img = transform(dict["image"]).unsqueeze(0)
|
| 103 |
word_masks = [word_mask]
|
| 104 |
-
with torch.no_grad():
|
| 105 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
| 106 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
| 107 |
filename = f"{uuid.uuid4()}.png"
|
|
|
|
| 59 |
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
| 60 |
model_id_or_path,
|
| 61 |
revision="fp16",
|
| 62 |
+
torch_dtype=torch.float16, #float16
|
| 63 |
use_auth_token=auth_token
|
| 64 |
)
|
| 65 |
#self.register_buffer('n_', ...)
|
|
|
|
| 87 |
elif(radio == "type what to keep"):
|
| 88 |
img = transform(dict["image"]).squeeze(0)
|
| 89 |
word_masks = [word_mask]
|
| 90 |
+
with torch.cuda.amp.autocast(): #with torch.no_grad():
|
| 91 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
| 92 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
| 93 |
filename = f"{uuid.uuid4()}.png"
|
|
|
|
| 101 |
else:
|
| 102 |
img = transform(dict["image"]).unsqueeze(0)
|
| 103 |
word_masks = [word_mask]
|
| 104 |
+
with torch.cuda.amp.autocast(): #with torch.no_grad():
|
| 105 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
| 106 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
| 107 |
filename = f"{uuid.uuid4()}.png"
|