@@ -1391,7 +1391,6 @@ def aggregate_attention(
1391
1391
):
1392
1392
out = []
1393
1393
attention_maps = self .get_average_attention ()
1394
- attention_maps = self .controller .attention_store
1395
1394
num_pixels = res ** 2
1396
1395
for location in from_where :
1397
1396
for item in attention_maps [f"{ location } _{ 'cross' if is_cross else 'self' } " ]:
@@ -1474,58 +1473,6 @@ def save_cross_attention_vis(self, prompt, attention_maps, path):
1474
1473
vis = ptp_utils .view_images (np .stack (images , axis = 0 ))
1475
1474
vis .save (path )
1476
1475
1477
- def show_cross_attention (
1478
- self ,
1479
- prompts ,
1480
- attention_store : AttentionStore ,
1481
- res : int ,
1482
- from_where : List [str ],
1483
- select : int = 0 ,
1484
- ):
1485
- tokens = self .tokenizer .encode (prompts [select ])
1486
- decoder = self .tokenizer .decode
1487
- attention_maps = self .aggregate_attention (
1488
- prompts , attention_store , res , from_where , True , select
1489
- )
1490
- images = []
1491
- for i in range (len (tokens )):
1492
- image = attention_maps [:, :, i ]
1493
- image = 255 * image / image .max ()
1494
- image = image .unsqueeze (- 1 ).expand (* image .shape , 3 )
1495
- image = image .numpy ().astype (np .uint8 )
1496
- image = np .array (Image .fromarray (image ).resize ((256 , 256 )))
1497
- image = ptp_utils .text_under_image (image , decoder (int (tokens [i ])))
1498
- images .append (image )
1499
- return ptp_utils .view_images (np .stack (images , axis = 0 ))
1500
-
1501
- def show_self_attention_comp (
1502
- self ,
1503
- attention_store : AttentionStore ,
1504
- res : int ,
1505
- from_where : List [str ],
1506
- max_com = 10 ,
1507
- select : int = 0 ,
1508
- ):
1509
- attention_maps = (
1510
- self .aggregate_attention (attention_store , res , from_where , False , select )
1511
- .numpy ()
1512
- .reshape ((res ** 2 , res ** 2 ))
1513
- )
1514
- u , s , vh = np .linalg .svd (
1515
- attention_maps - np .mean (attention_maps , axis = 1 , keepdims = True )
1516
- )
1517
- images = []
1518
- for i in range (max_com ):
1519
- image = vh [i ].reshape (res , res )
1520
- image = image - image .min ()
1521
- image = 255 * image / image .max ()
1522
- image = np .repeat (np .expand_dims (image , axis = 2 ), 3 , axis = 2 ).astype (np .uint8 )
1523
- image = Image .fromarray (image ).resize ((256 , 256 ))
1524
- image = np .array (image )
1525
- images .append (image )
1526
- ptp_utils .view_images (np .concatenate (images , axis = 1 ))
1527
-
1528
-
1529
1476
class P2PCrossAttnProcessor :
1530
1477
def __init__ (self , controller , place_in_unet ):
1531
1478
super ().__init__ ()
0 commit comments