Skip to content
This repository was archived by the owner on Mar 18, 2025. It is now read-only.

Commit 135cd4a

Browse files
committedSep 21, 2023
Fix attention aggregation.
- Fix attention aggregation (for visualization). - Removed some unused code.
1 parent 9cac7d3 commit 135cd4a

File tree

1 file changed

+0
-53
lines changed

1 file changed

+0
-53
lines changed
 

‎train.py

-53
Original file line numberDiff line numberDiff line change
@@ -1391,7 +1391,6 @@ def aggregate_attention(
13911391
):
13921392
out = []
13931393
attention_maps = self.get_average_attention()
1394-
attention_maps = self.controller.attention_store
13951394
num_pixels = res**2
13961395
for location in from_where:
13971396
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
@@ -1474,58 +1473,6 @@ def save_cross_attention_vis(self, prompt, attention_maps, path):
14741473
vis = ptp_utils.view_images(np.stack(images, axis=0))
14751474
vis.save(path)
14761475

1477-
def show_cross_attention(
1478-
self,
1479-
prompts,
1480-
attention_store: AttentionStore,
1481-
res: int,
1482-
from_where: List[str],
1483-
select: int = 0,
1484-
):
1485-
tokens = self.tokenizer.encode(prompts[select])
1486-
decoder = self.tokenizer.decode
1487-
attention_maps = self.aggregate_attention(
1488-
prompts, attention_store, res, from_where, True, select
1489-
)
1490-
images = []
1491-
for i in range(len(tokens)):
1492-
image = attention_maps[:, :, i]
1493-
image = 255 * image / image.max()
1494-
image = image.unsqueeze(-1).expand(*image.shape, 3)
1495-
image = image.numpy().astype(np.uint8)
1496-
image = np.array(Image.fromarray(image).resize((256, 256)))
1497-
image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
1498-
images.append(image)
1499-
return ptp_utils.view_images(np.stack(images, axis=0))
1500-
1501-
def show_self_attention_comp(
1502-
self,
1503-
attention_store: AttentionStore,
1504-
res: int,
1505-
from_where: List[str],
1506-
max_com=10,
1507-
select: int = 0,
1508-
):
1509-
attention_maps = (
1510-
self.aggregate_attention(attention_store, res, from_where, False, select)
1511-
.numpy()
1512-
.reshape((res**2, res**2))
1513-
)
1514-
u, s, vh = np.linalg.svd(
1515-
attention_maps - np.mean(attention_maps, axis=1, keepdims=True)
1516-
)
1517-
images = []
1518-
for i in range(max_com):
1519-
image = vh[i].reshape(res, res)
1520-
image = image - image.min()
1521-
image = 255 * image / image.max()
1522-
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
1523-
image = Image.fromarray(image).resize((256, 256))
1524-
image = np.array(image)
1525-
images.append(image)
1526-
ptp_utils.view_images(np.concatenate(images, axis=1))
1527-
1528-
15291476
class P2PCrossAttnProcessor:
15301477
def __init__(self, controller, place_in_unet):
15311478
super().__init__()

0 commit comments

Comments
 (0)
This repository has been archived.