In [1]:
import os
import time
import shutil
import numpy as np
import rasterio
from rasterio.windows import from_bounds
import cv2
import ee
import geemap
from google.colab import drive
from sklearn.model_selection import train_test_split

os.system('pip install -q geedim geemap')

drive.mount('/content/drive', force_remount=True)

# CONFIGURATION
MY_PROJECT_ID = '[REDACTED_FOR_SECURITY]'
ASSET_ID = '[REDACTED_FOR_SECURITY]'
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_1/'

if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

try:
    ee.Initialize(project=MY_PROJECT_ID)
    print(f"Earth Engine Initialized with Project: {MY_PROJECT_ID}")
except Exception as e:
    ee.Authenticate()
    ee.Initialize(project=MY_PROJECT_ID)
    print(f"Earth Engine Authenticated & Initialized.")

PATCH_SIZE = 224
S2_SCALE = 5000.0
TIME_WINDOWS = [
    ('2024-10-15', '2024-11-15'),
    ('2025-01-01', '2025-01-31'),
    ('2025-02-15', '2025-03-15')
]

def generate_satmae_scratch_data():
    print("Starting SatMAE Scratch Data Generation...")

    # 1. Mask
    mask_img = ee.Image(ASSET_ID)
    roi_geom = mask_img.geometry()
    mask_file = 'local_mask_satmae.tif'
    if not os.path.exists(mask_file):
        geemap.download_ee_image(mask_img, mask_file, region=roi_geom, scale=10, crs='EPSG:4326', overwrite=True)

    with rasterio.open(mask_file) as src:
        b = src.bounds
        cx, cy = (b.left + b.right)/2, (b.bottom + b.top)/2
        offset = 0.04
        window = from_bounds(cx-offset, cy-offset, cx+offset, cy+offset, src.transform)
        mask = src.read(1, window=window)
        mask = np.where(mask > 0, 1.0, 0.0).astype(np.float32)
        target_h, target_w = mask.shape
        small_roi = ee.Geometry.BBox(cx-offset, cy-offset, cx+offset, cy+offset)

    # 2. Imagery
    stack = []
    for i, (start, end) in enumerate(TIME_WINDOWS):
        fname = f'satmae_time_{i}.tif'
        attempts = 0
        while not os.path.exists(fname) and attempts < 3:
            try:
                print(f"Downloading T{i+1}: {start} to {end}...")
                img = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED').filterBounds(small_roi).filterDate(start, end).median().select(['B2','B3','B4','B8','B11','B12'])
                geemap.download_ee_image(img, fname, region=small_roi, scale=10, crs='EPSG:4326', overwrite=True)
            except:
                attempts += 1
                time.sleep(2)

        if not os.path.exists(fname):
            with rasterio.open(mask_file) as src:
                 profile = src.profile
                 profile.update(count=6, dtype=rasterio.float32)
                 with rasterio.open(fname, 'w', **profile) as dst:
                     dst.write(np.zeros((6, target_h, target_w), dtype=np.float32))

        with rasterio.open(fname) as src:
            arr = src.read()
            arr = np.transpose(arr, (1, 2, 0))
            if arr.shape[:2] != (target_h, target_w):
                arr = cv2.resize(arr, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
            arr = np.clip(arr / S2_SCALE, 0, 1).astype(np.float32)
            stack.append(arr)

    # 3. Tiling
    full_cube = np.stack(stack, axis=2)
    x_out, y_out = [], []
    stride = PATCH_SIZE

    for y in range(0, target_h, stride):
        for x in range(0, target_w, stride):
            img_p = full_cube[y:y+stride, x:x+stride]
            mask_p = mask[y:y+stride, x:x+stride]
            if img_p.shape[0] != PATCH_SIZE or img_p.shape[1] != PATCH_SIZE: continue
            if np.mean(mask_p) < 0.01 or np.isnan(img_p).any(): continue
            x_out.append(img_p)
            y_out.append(mask_p)

    X = np.array(x_out, dtype=np.float32).transpose(0, 4, 3, 1, 2)
    y = np.array(y_out, dtype=np.float32)[:, None, :, :]

    print(f"Data Generated: {X.shape}")
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

    np.save(os.path.join(SAVE_DIR, 'train_x.npy'), X_train)
    np.save(os.path.join(SAVE_DIR, 'train_y.npy'), y_train)
    np.save(os.path.join(SAVE_DIR, 'val_x.npy'), X_val)
    np.save(os.path.join(SAVE_DIR, 'val_y.npy'), y_val)
    print("Done.")

generate_satmae_scratch_data()
Mounted at /content/drive
Earth Engine Authenticated & Initialized.
Starting SatMAE Scratch Data Generation...
/usr/local/lib/python3.12/dist-packages/geemap/common.py:12471: FutureWarning: 'BaseImage' is deprecated and will be removed in a future release.  Please use the 'ee.Image.gd' accessor instead.
  img = gd.download.BaseImage(image)
...tmae-2026/assets/Punjab_Mask_2024_NEW:   0%|          |0/585 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:googleapiclient.http:Sleeping 0.67 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/local-dialect-484618-b9/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 0.50 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/local-dialect-484618-b9/thumbnails?fields=name&alt=json, after 429
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:googleapiclient.http:Sleeping 1.86 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/local-dialect-484618-b9/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 1.30 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/local-dialect-484618-b9/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 0.29 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/local-dialect-484618-b9/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 1.50 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/local-dialect-484618-b9/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 1.06 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/local-dialect-484618-b9/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 1.27 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/local-dialect-484618-b9/thumbnails?fields=name&alt=json, after 429
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:googleapiclient.http:Sleeping 1.36 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/local-dialect-484618-b9/thumbnails?fields=name&alt=json, after 429
WARNING:googleapiclient.http:Sleeping 1.51 seconds before retry 1 of 5 for request: POST https://earthengine.googleapis.com/v1/projects/local-dialect-484618-b9/thumbnails?fields=name&alt=json, after 429
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
/usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'projects/satmae-2026/assets/Punjab_Mask_2024_NEW'.
  return STACClient().get(self.id)
Downloading T1: 2024-10-15 to 2024-11-15...
  0%|          |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
/usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'None'.
  return STACClient().get(self.id)
Downloading T2: 2025-01-01 to 2025-01-31...
  0%|          |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Downloading T3: 2025-02-15 to 2025-03-15...
  0%|          |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Data Generated: (9, 6, 3, 224, 224)
Done.
In [2]:
import torch
import torch.nn as nn

class SatMAEPatchEmbed(nn.Module):
    def __init__(self, in_chans=6, embed_dim=768, patch_size=16):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, T, H, W = x.shape
        x = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = x.reshape(B, T, -1, x.shape[-1])
        return x

class SatMAEBackbone(nn.Module):
    def __init__(self, num_frames=3, in_chans=6, embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        self.patch_embed = SatMAEPatchEmbed(in_chans=in_chans, embed_dim=embed_dim)
        num_patches = (224 // 16) ** 2

        self.pos_embed = nn.Parameter(torch.randn(1, 1, num_patches + 1, embed_dim) * 0.02)
        self.time_embed = nn.Parameter(torch.randn(1, num_frames, 1, embed_dim) * 0.02)
        self.cls_token = nn.Parameter(torch.randn(1, 1, 1, embed_dim) * 0.02)

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4, activation="gelu", batch_first=True, norm_first=True)
        self.blocks = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.patch_embed(x)
        B, T, N, D = x.shape
        x = x + self.time_embed
        x = x.reshape(B, T*N, D)
        spatial_pos = self.pos_embed[:, :, 1:, :].expand(B, T, -1, -1).reshape(B, T*N, D)
        x = x + spatial_pos
        cls_token = self.cls_token.expand(B, -1, -1, -1).reshape(B, 1, D) + self.pos_embed[:, :, 0, :].expand(B, 1, D)
        x = torch.cat((cls_token, x), dim=1)
        x = self.blocks(x)
        x = self.norm(x)
        return x

class SatMAESegmentation(nn.Module):
    def __init__(self, num_frames=3, in_chans=6, embed_dim=768):
        super().__init__()
        self.num_frames = num_frames
        self.backbone = SatMAEBackbone(num_frames=num_frames, in_chans=in_chans, embed_dim=embed_dim)
        self.temporal_agg = nn.Conv2d(embed_dim * num_frames, embed_dim, kernel_size=1)

        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2), nn.Conv2d(embed_dim, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(256, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(64, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.GELU(),
            nn.Conv2d(32, 1, 1)
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.backbone(x)[:, 1:, :]
        B, L, D = features.shape
        H_p = 14
        features = features.view(B, self.num_frames, H_p, H_p, D)
        features = features.permute(0, 1, 4, 2, 3).reshape(B, self.num_frames * D, H_p, H_p)
        features = self.temporal_agg(features)
        return self.decoder(features)
In [4]:
import torch.optim as optim
import torch.optim.swa_utils as swa_utils
from torch.utils.data import TensorDataset, DataLoader, random_split
from scipy.ndimage import distance_transform_edt as distance
import glob

def apply_augmentation(x, y):
    if np.random.rand() > 0.5:
        x = torch.flip(x, [4]); y = torch.flip(y, [3])
    if np.random.rand() > 0.5:
        x = torch.flip(x, [3]); y = torch.flip(y, [2])
    k = np.random.randint(0, 4)
    x = torch.rot90(x, k, [3, 4]); y = torch.rot90(y, k, [2, 3])
    return x, y

def manage_rolling_checkpoints(save_dir, keep_k=5):
    files = sorted(glob.glob(os.path.join(save_dir, "epoch_*.pth")), key=os.path.getmtime)
    if len(files) > keep_k:
        for f in files[:-keep_k]:
            os.remove(f)

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__(); self.smooth = smooth
    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs).reshape(-1)
        targets = targets.reshape(-1)
        inter = (inputs * targets).sum()
        return 1 - (2. * inter + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)

class HausdorffDTLoss(nn.Module):
    def __init__(self, alpha=2.0):
        super().__init__(); self.alpha = alpha
    def forward(self, pred, gt):
        with torch.no_grad():
            gt_np = gt.cpu().numpy()
            dist_map = np.zeros_like(gt_np)
            for i in range(len(gt_np)):
                mask = (gt_np[i, 0] > 0.5).astype(np.uint8)
                if mask.sum() == 0: continue
                d_in = distance(mask); d_out = distance(1 - mask)
                dist_map[i, 0] = (d_out - d_in)
            dist_map = torch.tensor(dist_map, device=pred.device, dtype=torch.float32)
        probs = torch.sigmoid(pred)
        return torch.mean((probs - gt) ** 2 * (1 + self.alpha * torch.abs(dist_map)))

class CompoundLoss(nn.Module):
    def __init__(self):
        super().__init__(); self.dice = DiceLoss(); self.boundary = HausdorffDTLoss(alpha=2.0)
    def forward(self, p, t): return 0.7*self.dice(p, t) + 0.3*self.boundary(p, t)

SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_1/'
X_train = np.load(os.path.join(SAVE_DIR, 'train_x.npy'), mmap_mode='r')
y_train = np.load(os.path.join(SAVE_DIR, 'train_y.npy'), mmap_mode='r')
X_val = np.load(os.path.join(SAVE_DIR, 'val_x.npy'), mmap_mode='r')
y_val = np.load(os.path.join(SAVE_DIR, 'val_y.npy'), mmap_mode='r')

train_ds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
val_ds = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))
/tmp/ipython-input-4243033762.py:58: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:206.)
  train_ds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
In [ ]:
# Configuration
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_1/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_satmae_scratch.pth")
BATCH_SIZE = 8
EPOCHS = 5000
PATIENCE_TRIGGER = 150
SWA_DURATION = 50

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

model = SatMAESegmentation(num_frames=3, in_chans=6).to(device)
criterion = CompoundLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)

swa_model = swa_utils.AveragedModel(model)
swa_scheduler = swa_utils.SWALR(optimizer, swa_lr=5e-5)

# Loaders
train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, BATCH_SIZE, shuffle=False)

# Resume Logic
start_epoch = 0
best_loss = float('inf')
patience_counter = 0
swa_active = False
swa_epoch_counter = 0

if os.path.exists(CHECKPOINT_PATH):
    print("Resuming from Checkpoint...")
    ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    start_epoch = ckpt['epoch'] + 1
    best_loss = ckpt['best_loss']
    patience_counter = ckpt.get('patience_counter', 0)
    swa_active = ckpt.get('swa_active', False)
    swa_epoch_counter = ckpt.get('swa_epoch_counter', 0)

# Training Loop
print(f"Starting Training. Max Epochs: {EPOCHS}")

for ep in range(start_epoch, EPOCHS):
    model.train()
    train_loss = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        x, y = apply_augmentation(x, y)

        optimizer.zero_grad()
        preds = model(x)
        loss = criterion(preds, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            preds = model(x)
            val_loss += criterion(preds, y).item()

    avg_t = train_loss / len(train_loader)
    avg_v = val_loss / len(val_loader)
    status_msg = ""

    # --- SWA LOGIC ---
    if swa_active:
        # We are in the final 50 epochs of SWA
        swa_model.update_parameters(model)
        swa_scheduler.step()
        swa_epoch_counter += 1
        status_msg = f"SWA Mode ({swa_epoch_counter}/{SWA_DURATION})"

        if swa_epoch_counter >= SWA_DURATION:
            print("SWA Duration Complete. Saving Final Model and Stopping.")
            swa_utils.update_bn(train_loader, swa_model, device=device)
            torch.save(swa_model.state_dict(), os.path.join(SAVE_DIR, "satmae_scratch_swa_final.pth"))
            break  # <--- STOPS TRAINING HERE
    else:
        # Normal Training
        scheduler.step()

        if avg_v < best_loss:
            best_loss = avg_v
            patience_counter = 0
            torch.save(model.state_dict(), os.path.join(SAVE_DIR, "best_model.pth"))
            status_msg = "New Best Model!"
        else:
            patience_counter += 1
            status_msg = f"No Improvement ({patience_counter}/{PATIENCE_TRIGGER})"

        # Trigger Check
        if patience_counter >= PATIENCE_TRIGGER:
            print(f"Patience limit ({PATIENCE_TRIGGER}) reached. Triggering SWA for {SWA_DURATION} epochs.")
            swa_active = True
            swa_epoch_counter = 0

    print(f"Epoch {ep+1} | Train: {avg_t:.4f} | Val: {avg_v:.4f} | {status_msg}")

    # Save Progress
    epoch_ckpt = os.path.join(SAVE_DIR, f"epoch_{ep+1}.pth")
    torch.save(model.state_dict(), epoch_ckpt)
    manage_rolling_checkpoints(SAVE_DIR, keep_k=5)

    torch.save({
        'epoch': ep,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_loss': best_loss,
        'patience_counter': patience_counter,
        'swa_active': swa_active,
        'swa_epoch_counter': swa_epoch_counter
    }, CHECKPOINT_PATH)
Device: cuda
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True
  warnings.warn(
Starting Training. Max Epochs: 5000
Epoch 1 | Train: 1.0452 | Val: 1.7441 | New Best Model!
Epoch 2 | Train: 1.0197 | Val: 1.3315 | New Best Model!
Epoch 3 | Train: 0.9954 | Val: 1.2687 | New Best Model!
Epoch 4 | Train: 0.9730 | Val: 1.1613 | New Best Model!
Epoch 5 | Train: 0.9534 | Val: 1.1272 | New Best Model!
Epoch 6 | Train: 0.9176 | Val: 0.9812 | New Best Model!
Epoch 7 | Train: 0.8920 | Val: 0.8549 | New Best Model!
Epoch 8 | Train: 0.8766 | Val: 0.8088 | New Best Model!
Epoch 9 | Train: 0.8470 | Val: 0.8388 | No Improvement (1/150)
Epoch 10 | Train: 0.8069 | Val: 0.8481 | No Improvement (2/150)
Epoch 11 | Train: 0.7661 | Val: 0.8464 | No Improvement (3/150)
Epoch 12 | Train: 0.7129 | Val: 0.8223 | No Improvement (4/150)
Epoch 13 | Train: 0.6361 | Val: 0.7406 | New Best Model!
Epoch 14 | Train: 0.5759 | Val: 0.5354 | New Best Model!
Epoch 15 | Train: 0.5291 | Val: 0.4312 | New Best Model!
Epoch 16 | Train: 0.4932 | Val: 0.4486 | No Improvement (1/150)
Epoch 17 | Train: 0.4633 | Val: 0.5174 | No Improvement (2/150)
Epoch 18 | Train: 0.4441 | Val: 0.5416 | No Improvement (3/150)
Epoch 19 | Train: 0.4281 | Val: 0.4875 | No Improvement (4/150)
Epoch 20 | Train: 0.4180 | Val: 0.3987 | New Best Model!
Epoch 21 | Train: 0.4042 | Val: 0.3323 | New Best Model!
Epoch 22 | Train: 0.3939 | Val: 0.3038 | New Best Model!
Epoch 23 | Train: 0.3844 | Val: 0.3032 | New Best Model!
Epoch 24 | Train: 0.3841 | Val: 0.3082 | No Improvement (1/150)
Epoch 25 | Train: 0.3757 | Val: 0.3072 | No Improvement (2/150)
Epoch 26 | Train: 0.3702 | Val: 0.3034 | No Improvement (3/150)
Epoch 27 | Train: 0.3682 | Val: 0.3012 | New Best Model!
Epoch 28 | Train: 0.3594 | Val: 0.3024 | No Improvement (1/150)
Epoch 29 | Train: 0.3559 | Val: 0.3059 | No Improvement (2/150)
Epoch 30 | Train: 0.3582 | Val: 0.3086 | No Improvement (3/150)
Epoch 31 | Train: 0.3521 | Val: 0.3087 | No Improvement (4/150)
Epoch 32 | Train: 0.3547 | Val: 0.3070 | No Improvement (5/150)
Epoch 33 | Train: 0.3487 | Val: 0.3046 | No Improvement (6/150)
Epoch 34 | Train: 0.3465 | Val: 0.3032 | No Improvement (7/150)
Epoch 35 | Train: 0.3486 | Val: 0.3028 | No Improvement (8/150)
Epoch 36 | Train: 0.3472 | Val: 0.3034 | No Improvement (9/150)
Epoch 37 | Train: 0.3452 | Val: 0.3043 | No Improvement (10/150)
Epoch 38 | Train: 0.3465 | Val: 0.3046 | No Improvement (11/150)
Epoch 39 | Train: 0.3441 | Val: 0.3043 | No Improvement (12/150)
Epoch 40 | Train: 0.3446 | Val: 0.3035 | No Improvement (13/150)
Epoch 41 | Train: 0.3438 | Val: 0.3024 | No Improvement (14/150)
Epoch 42 | Train: 0.3439 | Val: 0.3011 | New Best Model!
Epoch 43 | Train: 0.3433 | Val: 0.2998 | New Best Model!
Epoch 44 | Train: 0.3418 | Val: 0.2988 | New Best Model!
Epoch 45 | Train: 0.3395 | Val: 0.2980 | New Best Model!
Epoch 46 | Train: 0.3425 | Val: 0.2974 | New Best Model!
Epoch 47 | Train: 0.3390 | Val: 0.2970 | New Best Model!
Epoch 48 | Train: 0.3399 | Val: 0.2969 | New Best Model!
Epoch 49 | Train: 0.3399 | Val: 0.2969 | No Improvement (1/150)
Epoch 50 | Train: 0.3436 | Val: 0.2971 | No Improvement (2/150)
Epoch 51 | Train: 0.3426 | Val: 0.2833 | New Best Model!
Epoch 52 | Train: 0.3380 | Val: 0.2877 | No Improvement (1/150)
Epoch 53 | Train: 0.3373 | Val: 0.2827 | New Best Model!
Epoch 54 | Train: 0.3341 | Val: 0.2828 | No Improvement (1/150)
Epoch 55 | Train: 0.3283 | Val: 0.3026 | No Improvement (2/150)
Epoch 56 | Train: 0.3269 | Val: 0.3021 | No Improvement (3/150)
Epoch 57 | Train: 0.3261 | Val: 0.2792 | New Best Model!
Epoch 58 | Train: 0.3215 | Val: 0.2710 | New Best Model!
Epoch 59 | Train: 0.3199 | Val: 0.2716 | No Improvement (1/150)
Epoch 60 | Train: 0.3198 | Val: 0.2679 | New Best Model!
Epoch 61 | Train: 0.3210 | Val: 0.2709 | No Improvement (1/150)
Epoch 62 | Train: 0.3118 | Val: 0.2849 | No Improvement (2/150)
Epoch 63 | Train: 0.3149 | Val: 0.2904 | No Improvement (3/150)
Epoch 64 | Train: 0.3107 | Val: 0.2786 | No Improvement (4/150)
Epoch 65 | Train: 0.3081 | Val: 0.2664 | New Best Model!
Epoch 66 | Train: 0.3120 | Val: 0.2630 | New Best Model!
Epoch 67 | Train: 0.3068 | Val: 0.2648 | No Improvement (1/150)
Epoch 68 | Train: 0.3094 | Val: 0.2741 | No Improvement (2/150)
Epoch 69 | Train: 0.3077 | Val: 0.2776 | No Improvement (3/150)
Epoch 70 | Train: 0.3032 | Val: 0.2704 | No Improvement (4/150)
Epoch 71 | Train: 0.3073 | Val: 0.2608 | New Best Model!
Epoch 72 | Train: 0.3016 | Val: 0.2566 | New Best Model!
Epoch 73 | Train: 0.2986 | Val: 0.2552 | New Best Model!
Epoch 74 | Train: 0.3008 | Val: 0.2557 | No Improvement (1/150)
Epoch 75 | Train: 0.2975 | Val: 0.2571 | No Improvement (2/150)
Epoch 76 | Train: 0.3024 | Val: 0.2578 | No Improvement (3/150)
Epoch 77 | Train: 0.2991 | Val: 0.2585 | No Improvement (4/150)
Epoch 78 | Train: 0.3020 | Val: 0.2575 | No Improvement (5/150)
Epoch 79 | Train: 0.3000 | Val: 0.2565 | No Improvement (6/150)
Epoch 80 | Train: 0.2971 | Val: 0.2562 | No Improvement (7/150)
Epoch 81 | Train: 0.2985 | Val: 0.2575 | No Improvement (8/150)
Epoch 82 | Train: 0.2935 | Val: 0.2578 | No Improvement (9/150)
Epoch 83 | Train: 0.2942 | Val: 0.2562 | No Improvement (10/150)
Epoch 84 | Train: 0.2965 | Val: 0.2543 | New Best Model!
Epoch 85 | Train: 0.2966 | Val: 0.2532 | New Best Model!
Epoch 86 | Train: 0.2974 | Val: 0.2540 | No Improvement (1/150)
Epoch 87 | Train: 0.2961 | Val: 0.2574 | No Improvement (2/150)
Epoch 88 | Train: 0.2946 | Val: 0.2606 | No Improvement (3/150)
Epoch 89 | Train: 0.2947 | Val: 0.2604 | No Improvement (4/150)
Epoch 90 | Train: 0.2961 | Val: 0.2591 | No Improvement (5/150)
Epoch 91 | Train: 0.2946 | Val: 0.2584 | No Improvement (6/150)
Epoch 92 | Train: 0.2898 | Val: 0.2589 | No Improvement (7/150)
Epoch 93 | Train: 0.2937 | Val: 0.2587 | No Improvement (8/150)
Epoch 94 | Train: 0.2897 | Val: 0.2582 | No Improvement (9/150)
Epoch 95 | Train: 0.2921 | Val: 0.2572 | No Improvement (10/150)
Epoch 96 | Train: 0.2938 | Val: 0.2561 | No Improvement (11/150)
Epoch 97 | Train: 0.2919 | Val: 0.2558 | No Improvement (12/150)
Epoch 98 | Train: 0.2923 | Val: 0.2552 | No Improvement (13/150)
Epoch 99 | Train: 0.2899 | Val: 0.2563 | No Improvement (14/150)
Epoch 100 | Train: 0.2899 | Val: 0.2582 | No Improvement (15/150)
Epoch 101 | Train: 0.2912 | Val: 0.2590 | No Improvement (16/150)
Epoch 102 | Train: 0.2902 | Val: 0.2585 | No Improvement (17/150)
Epoch 103 | Train: 0.2906 | Val: 0.2579 | No Improvement (18/150)
Epoch 104 | Train: 0.2896 | Val: 0.2577 | No Improvement (19/150)
Epoch 105 | Train: 0.2898 | Val: 0.2600 | No Improvement (20/150)
Epoch 106 | Train: 0.2894 | Val: 0.2628 | No Improvement (21/150)
Epoch 107 | Train: 0.2934 | Val: 0.2627 | No Improvement (22/150)
Epoch 108 | Train: 0.2904 | Val: 0.2615 | No Improvement (23/150)
Epoch 109 | Train: 0.2899 | Val: 0.2596 | No Improvement (24/150)
Epoch 110 | Train: 0.2890 | Val: 0.2570 | No Improvement (25/150)
Epoch 111 | Train: 0.2886 | Val: 0.2559 | No Improvement (26/150)
Epoch 112 | Train: 0.2860 | Val: 0.2560 | No Improvement (27/150)
Epoch 113 | Train: 0.2886 | Val: 0.2570 | No Improvement (28/150)
Epoch 114 | Train: 0.2860 | Val: 0.2587 | No Improvement (29/150)
Epoch 115 | Train: 0.2874 | Val: 0.2599 | No Improvement (30/150)
Epoch 116 | Train: 0.2880 | Val: 0.2599 | No Improvement (31/150)
Epoch 117 | Train: 0.2850 | Val: 0.2587 | No Improvement (32/150)
Epoch 118 | Train: 0.2868 | Val: 0.2578 | No Improvement (33/150)
Epoch 119 | Train: 0.2845 | Val: 0.2567 | No Improvement (34/150)
Epoch 120 | Train: 0.2843 | Val: 0.2564 | No Improvement (35/150)
Epoch 121 | Train: 0.2873 | Val: 0.2567 | No Improvement (36/150)
Epoch 122 | Train: 0.2866 | Val: 0.2577 | No Improvement (37/150)
Epoch 123 | Train: 0.2920 | Val: 0.2591 | No Improvement (38/150)
Epoch 124 | Train: 0.2877 | Val: 0.2600 | No Improvement (39/150)
Epoch 125 | Train: 0.2860 | Val: 0.2598 | No Improvement (40/150)
Epoch 126 | Train: 0.2863 | Val: 0.2592 | No Improvement (41/150)
Epoch 127 | Train: 0.2881 | Val: 0.2585 | No Improvement (42/150)
Epoch 128 | Train: 0.2862 | Val: 0.2572 | No Improvement (43/150)
Epoch 129 | Train: 0.2870 | Val: 0.2562 | No Improvement (44/150)
Epoch 130 | Train: 0.2868 | Val: 0.2552 | No Improvement (45/150)
Epoch 131 | Train: 0.2869 | Val: 0.2549 | No Improvement (46/150)
Epoch 132 | Train: 0.2858 | Val: 0.2552 | No Improvement (47/150)
Epoch 133 | Train: 0.2874 | Val: 0.2558 | No Improvement (48/150)
Epoch 134 | Train: 0.2853 | Val: 0.2564 | No Improvement (49/150)
Epoch 135 | Train: 0.2905 | Val: 0.2569 | No Improvement (50/150)
Epoch 136 | Train: 0.2882 | Val: 0.2576 | No Improvement (51/150)
Epoch 137 | Train: 0.2875 | Val: 0.2580 | No Improvement (52/150)
Epoch 138 | Train: 0.2855 | Val: 0.2583 | No Improvement (53/150)
Epoch 139 | Train: 0.2864 | Val: 0.2583 | No Improvement (54/150)
Epoch 140 | Train: 0.2864 | Val: 0.2582 | No Improvement (55/150)
Epoch 141 | Train: 0.2863 | Val: 0.2579 | No Improvement (56/150)
Epoch 142 | Train: 0.2856 | Val: 0.2576 | No Improvement (57/150)
Epoch 143 | Train: 0.2868 | Val: 0.2573 | No Improvement (58/150)
Epoch 144 | Train: 0.2870 | Val: 0.2569 | No Improvement (59/150)
Epoch 145 | Train: 0.2863 | Val: 0.2566 | No Improvement (60/150)
Epoch 146 | Train: 0.2880 | Val: 0.2564 | No Improvement (61/150)
Epoch 147 | Train: 0.2867 | Val: 0.2563 | No Improvement (62/150)
Epoch 148 | Train: 0.2864 | Val: 0.2561 | No Improvement (63/150)
Epoch 149 | Train: 0.2863 | Val: 0.2561 | No Improvement (64/150)
Epoch 150 | Train: 0.2894 | Val: 0.2560 | No Improvement (65/150)
Epoch 151 | Train: 0.2867 | Val: 0.2615 | No Improvement (66/150)
Epoch 152 | Train: 0.2934 | Val: 0.2572 | No Improvement (67/150)
Epoch 153 | Train: 0.2872 | Val: 0.3040 | No Improvement (68/150)
Epoch 154 | Train: 0.2921 | Val: 0.2545 | No Improvement (69/150)
Epoch 155 | Train: 0.2856 | Val: 0.2534 | No Improvement (70/150)
Epoch 156 | Train: 0.2882 | Val: 0.2519 | New Best Model!
Epoch 157 | Train: 0.2834 | Val: 0.2776 | No Improvement (1/150)
Epoch 158 | Train: 0.2882 | Val: 0.2684 | No Improvement (2/150)
Epoch 159 | Train: 0.2841 | Val: 0.2537 | No Improvement (3/150)
Epoch 160 | Train: 0.2860 | Val: 0.2535 | No Improvement (4/150)
Epoch 161 | Train: 0.2846 | Val: 0.2605 | No Improvement (5/150)
Epoch 162 | Train: 0.2833 | Val: 0.2796 | No Improvement (6/150)
Epoch 163 | Train: 0.2852 | Val: 0.2700 | No Improvement (7/150)
Epoch 164 | Train: 0.2822 | Val: 0.2515 | New Best Model!
Epoch 165 | Train: 0.2862 | Val: 0.2486 | New Best Model!
Epoch 166 | Train: 0.2832 | Val: 0.2519 | No Improvement (1/150)
Epoch 167 | Train: 0.2858 | Val: 0.2633 | No Improvement (2/150)
Epoch 168 | Train: 0.2798 | Val: 0.2611 | No Improvement (3/150)
Epoch 169 | Train: 0.2810 | Val: 0.2520 | No Improvement (4/150)
Epoch 170 | Train: 0.2802 | Val: 0.2506 | No Improvement (5/150)
Epoch 171 | Train: 0.2793 | Val: 0.2564 | No Improvement (6/150)
Epoch 172 | Train: 0.2836 | Val: 0.2681 | No Improvement (7/150)
Epoch 173 | Train: 0.2818 | Val: 0.2662 | No Improvement (8/150)
Epoch 174 | Train: 0.2812 | Val: 0.2603 | No Improvement (9/150)
Epoch 175 | Train: 0.2775 | Val: 0.2621 | No Improvement (10/150)
Epoch 176 | Train: 0.2853 | Val: 0.2785 | No Improvement (11/150)
Epoch 177 | Train: 0.2777 | Val: 0.2800 | No Improvement (12/150)
Epoch 178 | Train: 0.2777 | Val: 0.2605 | No Improvement (13/150)
Epoch 179 | Train: 0.2773 | Val: 0.2532 | No Improvement (14/150)
Epoch 180 | Train: 0.2770 | Val: 0.2544 | No Improvement (15/150)
Epoch 181 | Train: 0.2803 | Val: 0.2665 | No Improvement (16/150)
Epoch 182 | Train: 0.2795 | Val: 0.2675 | No Improvement (17/150)
Epoch 183 | Train: 0.2771 | Val: 0.2609 | No Improvement (18/150)
Epoch 184 | Train: 0.2798 | Val: 0.2598 | No Improvement (19/150)
Epoch 185 | Train: 0.2730 | Val: 0.2555 | No Improvement (20/150)
Epoch 186 | Train: 0.2722 | Val: 0.2581 | No Improvement (21/150)
Epoch 187 | Train: 0.2724 | Val: 0.2665 | No Improvement (22/150)
Epoch 188 | Train: 0.2758 | Val: 0.2728 | No Improvement (23/150)
Epoch 189 | Train: 0.2758 | Val: 0.2631 | No Improvement (24/150)
Epoch 190 | Train: 0.2749 | Val: 0.2524 | No Improvement (25/150)
Epoch 191 | Train: 0.2747 | Val: 0.2586 | No Improvement (26/150)
Epoch 192 | Train: 0.2697 | Val: 0.2663 | No Improvement (27/150)
Epoch 193 | Train: 0.2695 | Val: 0.2607 | No Improvement (28/150)
Epoch 194 | Train: 0.2662 | Val: 0.2506 | No Improvement (29/150)
Epoch 195 | Train: 0.2637 | Val: 0.2522 | No Improvement (30/150)
Epoch 196 | Train: 0.2585 | Val: 0.2460 | New Best Model!
Epoch 197 | Train: 0.2563 | Val: 0.2652 | No Improvement (1/150)
Epoch 198 | Train: 0.2504 | Val: 0.3090 | No Improvement (2/150)
Epoch 199 | Train: 0.2464 | Val: 0.2504 | No Improvement (3/150)
Epoch 200 | Train: 0.2460 | Val: 0.2536 | No Improvement (4/150)
Epoch 201 | Train: 0.2462 | Val: 0.2315 | New Best Model!
Epoch 202 | Train: 0.2492 | Val: 0.2714 | No Improvement (1/150)
Epoch 203 | Train: 0.2434 | Val: 0.2370 | No Improvement (2/150)
Epoch 204 | Train: 0.2389 | Val: 0.2096 | New Best Model!
Epoch 205 | Train: 0.2400 | Val: 0.2229 | No Improvement (1/150)
Epoch 206 | Train: 0.2323 | Val: 0.2334 | No Improvement (2/150)
Epoch 207 | Train: 0.2325 | Val: 0.2176 | No Improvement (3/150)
Epoch 208 | Train: 0.2392 | Val: 0.2127 | No Improvement (4/150)
Epoch 209 | Train: 0.2317 | Val: 0.2203 | No Improvement (5/150)
Epoch 210 | Train: 0.2258 | Val: 0.2128 | No Improvement (6/150)
Epoch 211 | Train: 0.2349 | Val: 0.2124 | No Improvement (7/150)
Epoch 212 | Train: 0.2247 | Val: 0.2098 | No Improvement (8/150)
Epoch 213 | Train: 0.2260 | Val: 0.2072 | New Best Model!
Epoch 214 | Train: 0.2179 | Val: 0.2050 | New Best Model!
Epoch 215 | Train: 0.2193 | Val: 0.2029 | New Best Model!
Epoch 216 | Train: 0.2287 | Val: 0.2141 | No Improvement (1/150)
Epoch 217 | Train: 0.2135 | Val: 0.2259 | No Improvement (2/150)
Epoch 218 | Train: 0.2235 | Val: 0.2033 | No Improvement (3/150)
Epoch 219 | Train: 0.2182 | Val: 0.1985 | New Best Model!
Epoch 220 | Train: 0.2176 | Val: 0.1969 | New Best Model!
Epoch 221 | Train: 0.2159 | Val: 0.2054 | No Improvement (1/150)
Epoch 222 | Train: 0.2170 | Val: 0.2039 | No Improvement (2/150)
Epoch 223 | Train: 0.2143 | Val: 0.1945 | New Best Model!
Epoch 224 | Train: 0.2069 | Val: 0.1963 | No Improvement (1/150)
Epoch 225 | Train: 0.2045 | Val: 0.1951 | No Improvement (2/150)
Epoch 226 | Train: 0.2076 | Val: 0.1951 | No Improvement (3/150)
Epoch 227 | Train: 0.2160 | Val: 0.1917 | New Best Model!
Epoch 228 | Train: 0.2114 | Val: 0.1952 | No Improvement (1/150)
Epoch 229 | Train: 0.2041 | Val: 0.1944 | No Improvement (2/150)
Epoch 230 | Train: 0.2071 | Val: 0.1945 | No Improvement (3/150)
Epoch 231 | Train: 0.2088 | Val: 0.1864 | New Best Model!
Epoch 232 | Train: 0.2056 | Val: 0.1869 | No Improvement (1/150)
Epoch 233 | Train: 0.2047 | Val: 0.1892 | No Improvement (2/150)
Epoch 234 | Train: 0.2051 | Val: 0.1888 | No Improvement (3/150)
Epoch 235 | Train: 0.2037 | Val: 0.1873 | No Improvement (4/150)
Epoch 236 | Train: 0.2044 | Val: 0.1871 | No Improvement (5/150)
Epoch 237 | Train: 0.2041 | Val: 0.1896 | No Improvement (6/150)
Epoch 238 | Train: 0.1974 | Val: 0.1926 | No Improvement (7/150)
Epoch 239 | Train: 0.2017 | Val: 0.1923 | No Improvement (8/150)
Epoch 240 | Train: 0.2015 | Val: 0.1870 | No Improvement (9/150)
Epoch 241 | Train: 0.1994 | Val: 0.1877 | No Improvement (10/150)
Epoch 242 | Train: 0.1981 | Val: 0.1961 | No Improvement (11/150)
Epoch 243 | Train: 0.2008 | Val: 0.1947 | No Improvement (12/150)
Epoch 244 | Train: 0.2020 | Val: 0.1898 | No Improvement (13/150)
Epoch 245 | Train: 0.2012 | Val: 0.1899 | No Improvement (14/150)
Epoch 246 | Train: 0.1990 | Val: 0.1925 | No Improvement (15/150)
Epoch 247 | Train: 0.1971 | Val: 0.1919 | No Improvement (16/150)
Epoch 248 | Train: 0.2031 | Val: 0.1849 | New Best Model!
Epoch 249 | Train: 0.1976 | Val: 0.1816 | New Best Model!
Epoch 250 | Train: 0.1987 | Val: 0.1847 | No Improvement (1/150)
Epoch 251 | Train: 0.1946 | Val: 0.1914 | No Improvement (2/150)
Epoch 252 | Train: 0.1952 | Val: 0.1911 | No Improvement (3/150)
Epoch 253 | Train: 0.1964 | Val: 0.1858 | No Improvement (4/150)
Epoch 254 | Train: 0.1964 | Val: 0.1825 | No Improvement (5/150)
Epoch 255 | Train: 0.1978 | Val: 0.1844 | No Improvement (6/150)
Epoch 256 | Train: 0.1936 | Val: 0.1902 | No Improvement (7/150)
Epoch 257 | Train: 0.1939 | Val: 0.1888 | No Improvement (8/150)
Epoch 258 | Train: 0.1946 | Val: 0.1842 | No Improvement (9/150)
Epoch 259 | Train: 0.1931 | Val: 0.1823 | No Improvement (10/150)
Epoch 260 | Train: 0.1944 | Val: 0.1822 | No Improvement (11/150)
Epoch 261 | Train: 0.1928 | Val: 0.1843 | No Improvement (12/150)
Epoch 262 | Train: 0.1929 | Val: 0.1856 | No Improvement (13/150)
Epoch 263 | Train: 0.1922 | Val: 0.1831 | No Improvement (14/150)
Epoch 264 | Train: 0.1915 | Val: 0.1798 | New Best Model!
Epoch 265 | Train: 0.1910 | Val: 0.1800 | No Improvement (1/150)
Epoch 266 | Train: 0.1901 | Val: 0.1817 | No Improvement (2/150)
Epoch 267 | Train: 0.1909 | Val: 0.1839 | No Improvement (3/150)
Epoch 268 | Train: 0.1933 | Val: 0.1836 | No Improvement (4/150)
Epoch 269 | Train: 0.1924 | Val: 0.1840 | No Improvement (5/150)
Epoch 270 | Train: 0.1877 | Val: 0.1835 | No Improvement (6/150)
Epoch 271 | Train: 0.1922 | Val: 0.1821 | No Improvement (7/150)
Epoch 272 | Train: 0.1906 | Val: 0.1801 | No Improvement (8/150)
Epoch 273 | Train: 0.1935 | Val: 0.1778 | New Best Model!
Epoch 274 | Train: 0.1953 | Val: 0.1781 | No Improvement (1/150)
Epoch 275 | Train: 0.1892 | Val: 0.1807 | No Improvement (2/150)
Epoch 276 | Train: 0.1912 | Val: 0.1844 | No Improvement (3/150)
Epoch 277 | Train: 0.1872 | Val: 0.1847 | No Improvement (4/150)
Epoch 278 | Train: 0.1908 | Val: 0.1816 | No Improvement (5/150)
Epoch 279 | Train: 0.1865 | Val: 0.1795 | No Improvement (6/150)
Epoch 280 | Train: 0.1884 | Val: 0.1789 | No Improvement (7/150)
Epoch 281 | Train: 0.1899 | Val: 0.1799 | No Improvement (8/150)
Epoch 282 | Train: 0.1917 | Val: 0.1811 | No Improvement (9/150)
Epoch 283 | Train: 0.1906 | Val: 0.1818 | No Improvement (10/150)
Epoch 284 | Train: 0.1875 | Val: 0.1830 | No Improvement (11/150)
Epoch 285 | Train: 0.1880 | Val: 0.1816 | No Improvement (12/150)
Epoch 286 | Train: 0.1862 | Val: 0.1810 | No Improvement (13/150)
Epoch 287 | Train: 0.1908 | Val: 0.1801 | No Improvement (14/150)
Epoch 288 | Train: 0.1881 | Val: 0.1801 | No Improvement (15/150)
Epoch 289 | Train: 0.1908 | Val: 0.1815 | No Improvement (16/150)
Epoch 290 | Train: 0.1839 | Val: 0.1845 | No Improvement (17/150)
Epoch 291 | Train: 0.1840 | Val: 0.1878 | No Improvement (18/150)
Epoch 292 | Train: 0.1873 | Val: 0.1871 | No Improvement (19/150)
Epoch 293 | Train: 0.1870 | Val: 0.1834 | No Improvement (20/150)
Epoch 294 | Train: 0.1847 | Val: 0.1802 | No Improvement (21/150)
Epoch 295 | Train: 0.1846 | Val: 0.1788 | No Improvement (22/150)
Epoch 296 | Train: 0.1852 | Val: 0.1794 | No Improvement (23/150)
Epoch 297 | Train: 0.1889 | Val: 0.1807 | No Improvement (24/150)
Epoch 298 | Train: 0.1857 | Val: 0.1832 | No Improvement (25/150)
Epoch 299 | Train: 0.1918 | Val: 0.1843 | No Improvement (26/150)
Epoch 300 | Train: 0.1839 | Val: 0.1847 | No Improvement (27/150)
Epoch 301 | Train: 0.1864 | Val: 0.1840 | No Improvement (28/150)
Epoch 302 | Train: 0.1906 | Val: 0.1831 | No Improvement (29/150)
Epoch 303 | Train: 0.1828 | Val: 0.1828 | No Improvement (30/150)
Epoch 304 | Train: 0.1872 | Val: 0.1823 | No Improvement (31/150)
Epoch 305 | Train: 0.1895 | Val: 0.1820 | No Improvement (32/150)
Epoch 306 | Train: 0.1891 | Val: 0.1821 | No Improvement (33/150)
Epoch 307 | Train: 0.1879 | Val: 0.1823 | No Improvement (34/150)
Epoch 308 | Train: 0.1888 | Val: 0.1823 | No Improvement (35/150)
Epoch 309 | Train: 0.1848 | Val: 0.1822 | No Improvement (36/150)
Epoch 310 | Train: 0.1872 | Val: 0.1820 | No Improvement (37/150)
Epoch 311 | Train: 0.1856 | Val: 0.1812 | No Improvement (38/150)
Epoch 312 | Train: 0.1843 | Val: 0.1803 | No Improvement (39/150)
Epoch 313 | Train: 0.1857 | Val: 0.1796 | No Improvement (40/150)
Epoch 314 | Train: 0.1924 | Val: 0.1785 | No Improvement (41/150)
Epoch 315 | Train: 0.1923 | Val: 0.1776 | New Best Model!
Epoch 316 | Train: 0.1837 | Val: 0.1775 | New Best Model!
Epoch 317 | Train: 0.1867 | Val: 0.1777 | No Improvement (1/150)
Epoch 318 | Train: 0.1879 | Val: 0.1780 | No Improvement (2/150)
Epoch 319 | Train: 0.1850 | Val: 0.1784 | No Improvement (3/150)
Epoch 320 | Train: 0.1830 | Val: 0.1793 | No Improvement (4/150)
Epoch 321 | Train: 0.1832 | Val: 0.1804 | No Improvement (5/150)
Epoch 322 | Train: 0.1855 | Val: 0.1812 | No Improvement (6/150)
Epoch 323 | Train: 0.1882 | Val: 0.1819 | No Improvement (7/150)
Epoch 324 | Train: 0.1875 | Val: 0.1823 | No Improvement (8/150)
Epoch 325 | Train: 0.1874 | Val: 0.1824 | No Improvement (9/150)
Epoch 326 | Train: 0.1892 | Val: 0.1825 | No Improvement (10/150)
Epoch 327 | Train: 0.1876 | Val: 0.1823 | No Improvement (11/150)
Epoch 328 | Train: 0.1886 | Val: 0.1821 | No Improvement (12/150)
Epoch 329 | Train: 0.1832 | Val: 0.1819 | No Improvement (13/150)
Epoch 330 | Train: 0.1876 | Val: 0.1818 | No Improvement (14/150)
Epoch 331 | Train: 0.1854 | Val: 0.1815 | No Improvement (15/150)
Epoch 332 | Train: 0.1840 | Val: 0.1813 | No Improvement (16/150)
Epoch 333 | Train: 0.1835 | Val: 0.1811 | No Improvement (17/150)
Epoch 334 | Train: 0.1866 | Val: 0.1810 | No Improvement (18/150)
Epoch 335 | Train: 0.1869 | Val: 0.1808 | No Improvement (19/150)
Epoch 336 | Train: 0.1858 | Val: 0.1806 | No Improvement (20/150)
Epoch 337 | Train: 0.1820 | Val: 0.1806 | No Improvement (21/150)
Epoch 338 | Train: 0.1839 | Val: 0.1806 | No Improvement (22/150)
Epoch 339 | Train: 0.1822 | Val: 0.1807 | No Improvement (23/150)
Epoch 340 | Train: 0.1863 | Val: 0.1807 | No Improvement (24/150)
Epoch 341 | Train: 0.1824 | Val: 0.1807 | No Improvement (25/150)
Epoch 342 | Train: 0.1853 | Val: 0.1807 | No Improvement (26/150)
Epoch 343 | Train: 0.1834 | Val: 0.1807 | No Improvement (27/150)
Epoch 344 | Train: 0.1861 | Val: 0.1807 | No Improvement (28/150)
Epoch 345 | Train: 0.1886 | Val: 0.1807 | No Improvement (29/150)
Epoch 346 | Train: 0.1819 | Val: 0.1806 | No Improvement (30/150)
Epoch 347 | Train: 0.1871 | Val: 0.1806 | No Improvement (31/150)
Epoch 348 | Train: 0.1887 | Val: 0.1806 | No Improvement (32/150)
Epoch 349 | Train: 0.1857 | Val: 0.1806 | No Improvement (33/150)
Epoch 350 | Train: 0.1834 | Val: 0.1806 | No Improvement (34/150)
Epoch 351 | Train: 0.1874 | Val: 0.1904 | No Improvement (35/150)
Epoch 352 | Train: 0.1914 | Val: 0.1972 | No Improvement (36/150)
Epoch 353 | Train: 0.1955 | Val: 0.1938 | No Improvement (37/150)
Epoch 354 | Train: 0.1866 | Val: 0.1889 | No Improvement (38/150)
Epoch 355 | Train: 0.1905 | Val: 0.1871 | No Improvement (39/150)
Epoch 356 | Train: 0.1929 | Val: 0.1832 | No Improvement (40/150)
Epoch 357 | Train: 0.1869 | Val: 0.1915 | No Improvement (41/150)
Epoch 358 | Train: 0.1900 | Val: 0.2037 | No Improvement (42/150)
Epoch 359 | Train: 0.1937 | Val: 0.1814 | No Improvement (43/150)
Epoch 360 | Train: 0.1921 | Val: 0.1774 | New Best Model!
Epoch 361 | Train: 0.1876 | Val: 0.1825 | No Improvement (1/150)
Epoch 362 | Train: 0.1867 | Val: 0.1796 | No Improvement (2/150)
Epoch 363 | Train: 0.1915 | Val: 0.1808 | No Improvement (3/150)
Epoch 364 | Train: 0.1831 | Val: 0.1924 | No Improvement (4/150)
Epoch 365 | Train: 0.1900 | Val: 0.1793 | No Improvement (5/150)
Epoch 366 | Train: 0.1889 | Val: 0.1934 | No Improvement (6/150)
Epoch 367 | Train: 0.1898 | Val: 0.1790 | No Improvement (7/150)
Epoch 368 | Train: 0.1832 | Val: 0.1826 | No Improvement (8/150)
Epoch 369 | Train: 0.1861 | Val: 0.1876 | No Improvement (9/150)
Epoch 370 | Train: 0.1850 | Val: 0.1750 | New Best Model!
Epoch 371 | Train: 0.1875 | Val: 0.1858 | No Improvement (1/150)
Epoch 372 | Train: 0.1817 | Val: 0.2006 | No Improvement (2/150)
Epoch 373 | Train: 0.1847 | Val: 0.1831 | No Improvement (3/150)
Epoch 374 | Train: 0.1813 | Val: 0.1843 | No Improvement (4/150)
Epoch 375 | Train: 0.1832 | Val: 0.1838 | No Improvement (5/150)
Epoch 376 | Train: 0.1829 | Val: 0.1901 | No Improvement (6/150)
Epoch 377 | Train: 0.1822 | Val: 0.1784 | No Improvement (7/150)
Epoch 378 | Train: 0.1815 | Val: 0.1765 | No Improvement (8/150)
Epoch 379 | Train: 0.1769 | Val: 0.1897 | No Improvement (9/150)
Epoch 380 | Train: 0.1838 | Val: 0.1745 | New Best Model!
Epoch 381 | Train: 0.1849 | Val: 0.1739 | New Best Model!
Epoch 382 | Train: 0.1777 | Val: 0.1791 | No Improvement (1/150)
Epoch 383 | Train: 0.1749 | Val: 0.1730 | New Best Model!
Epoch 384 | Train: 0.1745 | Val: 0.1830 | No Improvement (1/150)
Epoch 385 | Train: 0.1818 | Val: 0.1879 | No Improvement (2/150)
Epoch 386 | Train: 0.1789 | Val: 0.1808 | No Improvement (3/150)
Epoch 387 | Train: 0.1723 | Val: 0.1811 | No Improvement (4/150)
Epoch 388 | Train: 0.1721 | Val: 0.1852 | No Improvement (5/150)
Epoch 389 | Train: 0.1863 | Val: 0.1810 | No Improvement (6/150)
Epoch 390 | Train: 0.1822 | Val: 0.1848 | No Improvement (7/150)
Epoch 391 | Train: 0.1798 | Val: 0.1935 | No Improvement (8/150)
Epoch 392 | Train: 0.1789 | Val: 0.1863 | No Improvement (9/150)
Epoch 393 | Train: 0.1794 | Val: 0.1804 | No Improvement (10/150)
Epoch 394 | Train: 0.1779 | Val: 0.1771 | No Improvement (11/150)
Epoch 395 | Train: 0.1741 | Val: 0.1773 | No Improvement (12/150)
Epoch 396 | Train: 0.1679 | Val: 0.1861 | No Improvement (13/150)
Epoch 397 | Train: 0.1739 | Val: 0.1745 | No Improvement (14/150)
Epoch 398 | Train: 0.1778 | Val: 0.1886 | No Improvement (15/150)
Epoch 399 | Train: 0.1810 | Val: 0.1812 | No Improvement (16/150)
Epoch 400 | Train: 0.1803 | Val: 0.1734 | No Improvement (17/150)
Epoch 401 | Train: 0.1772 | Val: 0.1790 | No Improvement (18/150)
Epoch 402 | Train: 0.1802 | Val: 0.1785 | No Improvement (19/150)
Epoch 403 | Train: 0.1749 | Val: 0.1692 | New Best Model!
Epoch 404 | Train: 0.1675 | Val: 0.1715 | No Improvement (1/150)
Epoch 405 | Train: 0.1709 | Val: 0.1721 | No Improvement (2/150)
Epoch 406 | Train: 0.1742 | Val: 0.1854 | No Improvement (3/150)
Epoch 407 | Train: 0.1666 | Val: 0.1715 | No Improvement (4/150)
Epoch 408 | Train: 0.1771 | Val: 0.1852 | No Improvement (5/150)
Epoch 409 | Train: 0.1741 | Val: 0.2025 | No Improvement (6/150)
Epoch 410 | Train: 0.1730 | Val: 0.1801 | No Improvement (7/150)
Epoch 411 | Train: 0.1782 | Val: 0.1688 | New Best Model!
Epoch 412 | Train: 0.1752 | Val: 0.1720 | No Improvement (1/150)
Epoch 413 | Train: 0.1743 | Val: 0.1795 | No Improvement (2/150)
Epoch 414 | Train: 0.1786 | Val: 0.1719 | No Improvement (3/150)
Epoch 415 | Train: 0.1768 | Val: 0.1731 | No Improvement (4/150)
Epoch 416 | Train: 0.1694 | Val: 0.1746 | No Improvement (5/150)
Epoch 417 | Train: 0.1743 | Val: 0.1782 | No Improvement (6/150)
Epoch 418 | Train: 0.1693 | Val: 0.1734 | No Improvement (7/150)
Epoch 419 | Train: 0.1718 | Val: 0.1972 | No Improvement (8/150)
Epoch 420 | Train: 0.1687 | Val: 0.1692 | No Improvement (9/150)
Epoch 421 | Train: 0.1674 | Val: 0.1620 | New Best Model!
Epoch 422 | Train: 0.1636 | Val: 0.1705 | No Improvement (1/150)
Epoch 423 | Train: 0.1691 | Val: 0.1859 | No Improvement (2/150)
In [3]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, jaccard_score
from torch.utils.data import DataLoader, Dataset

# CONFIGURATION
# Ensure this matches where you saved your scratch model
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_1/'
BATCH_SIZE = 8
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. DATA LOADER
class MmapDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.data = np.load(x_path, mmap_mode='r')
        self.target = np.load(y_path, mmap_mode='r')
    def __len__(self): return len(self.data)
    def __getitem__(self, index):
        return torch.from_numpy(self.data[index].copy()), torch.from_numpy(self.target[index].copy())

val_ds = MmapDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

# 2. MODEL ARCHITECTURE (Must match training exactly)
class SatMAEPatchEmbed(nn.Module):
    def __init__(self, in_chans=6, embed_dim=768, patch_size=16):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        B, C, T, H, W = x.shape
        x = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = x.reshape(B, T, -1, x.shape[-1])
        return x

class SatMAEBackbone(nn.Module):
    def __init__(self, num_frames=3, in_chans=6, embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        self.patch_embed = SatMAEPatchEmbed(in_chans=in_chans, embed_dim=embed_dim)
        num_patches = (224 // 16) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, 1, num_patches + 1, embed_dim))
        self.time_embed = nn.Parameter(torch.zeros(1, num_frames, 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4, activation="gelu", batch_first=True, norm_first=True)
        self.blocks = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)
    def forward(self, x):
        x = self.patch_embed(x)
        B, T, N, D = x.shape
        x = x + self.time_embed
        x = x.reshape(B, T*N, D)
        spatial_pos = self.pos_embed[:, :, 1:, :].expand(B, T, -1, -1).reshape(B, T*N, D)
        x = x + spatial_pos
        cls_token = self.cls_token.expand(B, -1, -1, -1).reshape(B, 1, D) + self.pos_embed[:, :, 0, :].expand(B, 1, D)
        x = torch.cat((cls_token, x), dim=1)
        x = self.blocks(x)
        x = self.norm(x)
        return x

class SatMAESegmentation(nn.Module):
    def __init__(self, num_frames=3, in_chans=6, embed_dim=768):
        super().__init__()
        self.num_frames = num_frames
        self.backbone = SatMAEBackbone(num_frames=num_frames, in_chans=in_chans, embed_dim=embed_dim)
        self.temporal_agg = nn.Conv2d(embed_dim * num_frames, embed_dim, kernel_size=1)
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2), nn.Conv2d(embed_dim, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(256, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(64, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.GELU(),
            nn.Conv2d(32, 1, 1)
        )
    def forward(self, x):
        features = self.backbone(x)[:, 1:, :]
        B, L, D = features.shape
        features = features.view(B, self.num_frames, 14, 14, D)
        features = features.permute(0, 1, 4, 2, 3).reshape(B, self.num_frames * D, 14, 14)
        features = self.temporal_agg(features)
        return self.decoder(features)

# 3. LOAD WEIGHTS
model = SatMAESegmentation(num_frames=3, in_chans=6).to(DEVICE)
weights_path = os.path.join(SAVE_DIR, "best_model.pth")

if not os.path.exists(weights_path):
    print(f"Warning: best_model.pth not found in {SAVE_DIR}. Looking for SWA model...")
    weights_path = os.path.join(SAVE_DIR, "satmae_scratch_swa_final.pth")

try:
    model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
    print(f"Loaded weights from: {weights_path}")
except Exception as e:
    print(f"Error loading weights: {e}")
    exit()

# 4. METRICS CALCULATION
model.eval()
all_preds = []
all_targets = []

print("Running Inference on Validation Set...")
with torch.no_grad():
    for x, y in val_loader:
        x = x.to(DEVICE)
        preds = torch.sigmoid(model(x))
        preds_bin = (preds > 0.5).cpu().numpy().astype(np.uint8).flatten()
        targets_bin = y.numpy().astype(np.uint8).flatten()
        all_preds.append(preds_bin)
        all_targets.append(targets_bin)

y_pred = np.concatenate(all_preds)
y_true = np.concatenate(all_targets)

print("\n" + "="*40)
print("       MODEL PERFORMANCE METRICS       ")
print("="*40)
print(f"Pixel Accuracy:    {accuracy_score(y_true, y_pred):.4f}")
print(f"IoU (Jaccard):     {jaccard_score(y_true, y_pred):.4f}")
print(f"F1-Score (Dice):   {f1_score(y_true, y_pred):.4f}")
print(f"Precision:         {precision_score(y_true, y_pred):.4f}")
print(f"Recall:            {recall_score(y_true, y_pred):.4f}")
print("="*40 + "\n")

# 5. VISUALIZATION
def visualize_samples(model, loader, num_samples=3):
    model.eval()
    data_iter = iter(loader)
    x, y = next(data_iter)
    x = x.to(DEVICE)
    with torch.no_grad():
        preds = torch.sigmoid(model(x)).cpu().numpy()

    x = x.cpu().numpy()
    y = y.cpu().numpy()

    plt.figure(figsize=(15, 5 * num_samples))
    for i in range(num_samples):
        # Input RGB (Time Step 0)
        # Channels: 0=Blue, 1=Green, 2=Red -> RGB = [2, 1, 0]
        rgb = x[i, [2, 1, 0], 0, :, :].transpose(1, 2, 0)
        # Normalize for display
        rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)

        # Ground Truth
        gt = y[i, 0, :, :]

        # Prediction
        pred = preds[i, 0, :, :] > 0.5

        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(rgb)
        plt.title(f"Sample {i+1}: Input (T0 RGB)")
        plt.axis('off')

        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(gt, cmap='gray')
        plt.title(f"Sample {i+1}: Ground Truth")
        plt.axis('off')

        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(pred, cmap='gray')
        plt.title(f"Sample {i+1}: Prediction (IoU: {jaccard_score(gt.flatten(), pred.flatten()):.2f})")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

print("Generating Visualizations...")
visualize_samples(model, val_loader, num_samples=3)
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/tmp/ipython-input-4193441673.py in <cell line: 0>()
     22         return torch.from_numpy(self.data[index].copy()), torch.from_numpy(self.target[index].copy())
     23 
---> 24 val_ds = MmapDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))
     25 val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
     26 

/tmp/ipython-input-4193441673.py in __init__(self, x_path, y_path)
     16 class MmapDataset(Dataset):
     17     def __init__(self, x_path, y_path):
---> 18         self.data = np.load(x_path, mmap_mode='r')
     19         self.target = np.load(y_path, mmap_mode='r')
     20     def __len__(self): return len(self.data)

/usr/local/lib/python3.12/dist-packages/numpy/lib/_npyio_impl.py in load(file, mmap_mode, allow_pickle, fix_imports, encoding, max_header_size)
    453             own_fid = False
    454         else:
--> 455             fid = stack.enter_context(open(os.fspath(file), "rb"))
    456             own_fid = True
    457 

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/SatMAE_Scratch_Results_1/val_x.npy'
In [4]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, jaccard_score
from torch.utils.data import DataLoader, Dataset
from google.colab import drive

# 1. REMOUNT DRIVE (Fixes the Disconnect Error)
print("Mounting Drive...")
drive.mount('/content/drive', force_remount=True)

# CONFIGURATION
# This is where your MODEL (best_model.pth) is saved
MODEL_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_1/'
BATCH_SIZE = 8
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. SMART DATA PATH FINDER
# If data isn't in the model folder, it checks the original folder
if os.path.exists(os.path.join(MODEL_DIR, 'val_x.npy')):
    DATA_DIR = MODEL_DIR
else:
    # Try removing the "_1" to find the original data folder
    DATA_DIR = MODEL_DIR.replace('_1', '')

print(f"Model Path: {MODEL_DIR}")
print(f"Data Path:  {DATA_DIR}")

if not os.path.exists(os.path.join(DATA_DIR, 'val_x.npy')):
    raise FileNotFoundError(f"CRITICAL: Could not find 'val_x.npy' in {MODEL_DIR} or {DATA_DIR}. Please check your folder path.")

# 3. DATA LOADER
class MmapDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.data = np.load(x_path, mmap_mode='r')
        self.target = np.load(y_path, mmap_mode='r')
    def __len__(self): return len(self.data)
    def __getitem__(self, index):
        return torch.from_numpy(self.data[index].copy()), torch.from_numpy(self.target[index].copy())

val_ds = MmapDataset(os.path.join(DATA_DIR, 'val_x.npy'), os.path.join(DATA_DIR, 'val_y.npy'))
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

# 4. MODEL ARCHITECTURE
class SatMAEPatchEmbed(nn.Module):
    def __init__(self, in_chans=6, embed_dim=768, patch_size=16):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        B, C, T, H, W = x.shape
        x = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = x.reshape(B, T, -1, x.shape[-1])
        return x

class SatMAEBackbone(nn.Module):
    def __init__(self, num_frames=3, in_chans=6, embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        self.patch_embed = SatMAEPatchEmbed(in_chans=in_chans, embed_dim=embed_dim)
        num_patches = (224 // 16) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, 1, num_patches + 1, embed_dim))
        self.time_embed = nn.Parameter(torch.zeros(1, num_frames, 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4, activation="gelu", batch_first=True, norm_first=True)
        self.blocks = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)
    def forward(self, x):
        x = self.patch_embed(x)
        B, T, N, D = x.shape
        x = x + self.time_embed
        x = x.reshape(B, T*N, D)
        spatial_pos = self.pos_embed[:, :, 1:, :].expand(B, T, -1, -1).reshape(B, T*N, D)
        x = x + spatial_pos
        cls_token = self.cls_token.expand(B, -1, -1, -1).reshape(B, 1, D) + self.pos_embed[:, :, 0, :].expand(B, 1, D)
        x = torch.cat((cls_token, x), dim=1)
        x = self.blocks(x)
        x = self.norm(x)
        return x

class SatMAESegmentation(nn.Module):
    def __init__(self, num_frames=3, in_chans=6, embed_dim=768):
        super().__init__()
        self.num_frames = num_frames
        self.backbone = SatMAEBackbone(num_frames=num_frames, in_chans=in_chans, embed_dim=embed_dim)
        self.temporal_agg = nn.Conv2d(embed_dim * num_frames, embed_dim, kernel_size=1)
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2), nn.Conv2d(embed_dim, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(256, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(64, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.GELU(),
            nn.Conv2d(32, 1, 1)
        )
    def forward(self, x):
        features = self.backbone(x)[:, 1:, :]
        B, L, D = features.shape
        features = features.view(B, self.num_frames, 14, 14, D)
        features = features.permute(0, 1, 4, 2, 3).reshape(B, self.num_frames * D, 14, 14)
        features = self.temporal_agg(features)
        return self.decoder(features)

# 5. LOAD WEIGHTS & EVALUATE
model = SatMAESegmentation(num_frames=3, in_chans=6).to(DEVICE)
weights_path = os.path.join(MODEL_DIR, "best_model.pth")

if not os.path.exists(weights_path):
    print(f"Warning: best_model.pth not found in {MODEL_DIR}. Looking for SWA model...")
    weights_path = os.path.join(MODEL_DIR, "satmae_scratch_swa_final.pth")

if os.path.exists(weights_path):
    print(f"Loading weights from: {weights_path}")
    model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
else:
    print("CRITICAL ERROR: No trained model file found. Please check MODEL_DIR path.")
    # Attempt to use checkpoint if model not found
    ckpt_path = os.path.join(MODEL_DIR, "checkpoint_satmae_scratch.pth")
    if os.path.exists(ckpt_path):
        print(f"Fallback: Loading from Checkpoint {ckpt_path}")
        ckpt = torch.load(ckpt_path, map_location=DEVICE)
        model.load_state_dict(ckpt['model_state_dict'])

model.eval()
all_preds = []
all_targets = []

print("Running Inference on Validation Set...")
with torch.no_grad():
    for x, y in val_loader:
        x = x.to(DEVICE)
        preds = torch.sigmoid(model(x))
        preds_bin = (preds > 0.5).cpu().numpy().astype(np.uint8).flatten()
        targets_bin = y.numpy().astype(np.uint8).flatten()
        all_preds.append(preds_bin)
        all_targets.append(targets_bin)

y_pred = np.concatenate(all_preds)
y_true = np.concatenate(all_targets)

print("\n" + "="*40)
print("       SCRATCH MODEL RESULTS       ")
print("="*40)
print(f"Pixel Accuracy:    {accuracy_score(y_true, y_pred):.4f}")
print(f"IoU (Jaccard):     {jaccard_score(y_true, y_pred):.4f}")
print(f"F1-Score (Dice):   {f1_score(y_true, y_pred):.4f}")
print(f"Precision:         {precision_score(y_true, y_pred):.4f}")
print(f"Recall:            {recall_score(y_true, y_pred):.4f}")
print("="*40 + "\n")

# 6. VISUALIZATION
def visualize_samples(model, loader, num_samples=3):
    model.eval()
    data_iter = iter(loader)
    x, y = next(data_iter)
    x = x.to(DEVICE)
    with torch.no_grad():
        preds = torch.sigmoid(model(x)).cpu().numpy()

    x = x.cpu().numpy()
    y = y.cpu().numpy()

    plt.figure(figsize=(15, 5 * num_samples))
    for i in range(num_samples):
        # Input RGB (Time Step 0) [2,1,0] for RGB from [B, G, R, NIR...]
        rgb = x[i, [2, 1, 0], 0, :, :].transpose(1, 2, 0)
        rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)

        gt = y[i, 0, :, :]
        pred = preds[i, 0, :, :] > 0.5

        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(rgb)
        plt.title(f"Sample {i+1}: Input (RGB)")
        plt.axis('off')

        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(gt, cmap='gray')
        plt.title(f"Sample {i+1}: Ground Truth")
        plt.axis('off')

        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(pred, cmap='gray')
        plt.title(f"Sample {i+1}: Prediction")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

print("Generating Visualizations...")
visualize_samples(model, val_loader, num_samples=3)
Mounting Drive...
Mounted at /content/drive
Model Path: /content/drive/MyDrive/SatMAE_Scratch_Results_1/
Data Path:  /content/drive/MyDrive/SatMAE_Scratch_Results_1/
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True
  warnings.warn(
Loading weights from: /content/drive/MyDrive/SatMAE_Scratch_Results_1/best_model.pth
Running Inference on Validation Set...

========================================
       SCRATCH MODEL RESULTS       
========================================
Pixel Accuracy:    0.9000
IoU (Jaccard):     0.8704
F1-Score (Dice):   0.9307
Precision:         0.9140
Recall:            0.9480
========================================

Generating Visualizations...
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipython-input-1416508178.py in <cell line: 0>()
    188 
    189 print("Generating Visualizations...")
--> 190 visualize_samples(model, val_loader, num_samples=3)

/tmp/ipython-input-1416508178.py in visualize_samples(model, loader, num_samples)
    163     for i in range(num_samples):
    164         # Input RGB (Time Step 0) [2,1,0] for RGB from [B, G, R, NIR...]
--> 165         rgb = x[i, [2, 1, 0], 0, :, :].transpose(1, 2, 0)
    166         rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)
    167 

IndexError: index 2 is out of bounds for axis 0 with size 2
No description has been provided for this image
In [5]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import os
import time
from scipy.spatial.distance import directed_hausdorff
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, jaccard_score
from torch.utils.data import DataLoader, Dataset
from google.colab import drive
import cv2

# 1. REMOUNT DRIVE
print("Mounting Drive...")
drive.mount('/content/drive', force_remount=True)

# CONFIGURATION
MODEL_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_1/'
BATCH_SIZE = 8
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. SMART PATH FINDER
if os.path.exists(os.path.join(MODEL_DIR, 'val_x.npy')):
    DATA_DIR = MODEL_DIR
else:
    DATA_DIR = MODEL_DIR.replace('_1', '')

print(f"Model Path: {MODEL_DIR}")
print(f"Data Path:  {DATA_DIR}")

if not os.path.exists(os.path.join(DATA_DIR, 'val_x.npy')):
    raise FileNotFoundError(f"CRITICAL: Could not find 'val_x.npy'. Check paths.")

# 3. DATA LOADER
class MmapDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.data = np.load(x_path, mmap_mode='r')
        self.target = np.load(y_path, mmap_mode='r')
    def __len__(self): return len(self.data)
    def __getitem__(self, index):
        return torch.from_numpy(self.data[index].copy()), torch.from_numpy(self.target[index].copy())

val_ds = MmapDataset(os.path.join(DATA_DIR, 'val_x.npy'), os.path.join(DATA_DIR, 'val_y.npy'))
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

# 4. MODEL ARCHITECTURE (SatMAE Scratch)
class SatMAEPatchEmbed(nn.Module):
    def __init__(self, in_chans=6, embed_dim=768, patch_size=16):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        B, C, T, H, W = x.shape
        x = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = x.reshape(B, T, -1, x.shape[-1])
        return x

class SatMAEBackbone(nn.Module):
    def __init__(self, num_frames=3, in_chans=6, embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        self.patch_embed = SatMAEPatchEmbed(in_chans=in_chans, embed_dim=embed_dim)
        num_patches = (224 // 16) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, 1, num_patches + 1, embed_dim))
        self.time_embed = nn.Parameter(torch.zeros(1, num_frames, 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4, activation="gelu", batch_first=True, norm_first=True)
        self.blocks = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)
    def forward(self, x):
        x = self.patch_embed(x)
        B, T, N, D = x.shape
        x = x + self.time_embed
        x = x.reshape(B, T*N, D)
        spatial_pos = self.pos_embed[:, :, 1:, :].expand(B, T, -1, -1).reshape(B, T*N, D)
        x = x + spatial_pos
        cls_token = self.cls_token.expand(B, -1, -1, -1).reshape(B, 1, D) + self.pos_embed[:, :, 0, :].expand(B, 1, D)
        x = torch.cat((cls_token, x), dim=1)
        x = self.blocks(x)
        x = self.norm(x)
        return x

class SatMAESegmentation(nn.Module):
    def __init__(self, num_frames=3, in_chans=6, embed_dim=768):
        super().__init__()
        self.num_frames = num_frames
        self.backbone = SatMAEBackbone(num_frames=num_frames, in_chans=in_chans, embed_dim=embed_dim)
        self.temporal_agg = nn.Conv2d(embed_dim * num_frames, embed_dim, kernel_size=1)
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2), nn.Conv2d(embed_dim, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(256, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.GELU(),
            nn.Upsample(scale_factor=2), nn.Conv2d(64, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.GELU(),
            nn.Conv2d(32, 1, 1)
        )
    def forward(self, x):
        features = self.backbone(x)[:, 1:, :]
        B, L, D = features.shape
        features = features.view(B, self.num_frames, 14, 14, D)
        features = features.permute(0, 1, 4, 2, 3).reshape(B, self.num_frames * D, 14, 14)
        features = self.temporal_agg(features)
        return self.decoder(features)

# 5. LOAD MODEL
model = SatMAESegmentation(num_frames=3, in_chans=6).to(DEVICE)
weights_path = os.path.join(MODEL_DIR, "satmae_scratch_swa_final.pth")
if not os.path.exists(weights_path):
    weights_path = os.path.join(MODEL_DIR, "best_model.pth")

print(f"Loading weights from: {weights_path}")
try:
    model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
except Exception as e:
    print(f"Error loading weights: {e}")
    print("Trying fallback to checkpoint...")
    ckpt_path = os.path.join(MODEL_DIR, "checkpoint_satmae_scratch.pth")
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path, map_location=DEVICE)
        model.load_state_dict(ckpt['model_state_dict'])

# 6. HELPER FUNCTIONS
def calculate_boundary_iou(gt_mask, pred_mask):
    gt_edges = cv2.Canny(gt_mask.astype(np.uint8)*255, 100, 200) > 0
    pred_edges = cv2.Canny(pred_mask.astype(np.uint8)*255, 100, 200) > 0
    intersection = np.logical_and(gt_edges, pred_edges).sum()
    union = np.logical_or(gt_edges, pred_edges).sum()
    if union == 0: return 0.0
    return intersection / union

def calculate_hausdorff(gt_mask, pred_mask):
    gt_coords = np.argwhere(gt_mask)
    pred_coords = np.argwhere(pred_mask)
    if len(gt_coords) == 0 or len(pred_coords) == 0: return 0.0
    d1 = directed_hausdorff(gt_coords, pred_coords)[0]
    d2 = directed_hausdorff(pred_coords, gt_coords)[0]
    return max(d1, d2)

# 7. INFERENCE & METRICS
model.eval()
all_preds, all_targets = [], []
boundary_ious, hausdorff_dists = [], []
total_frames = 0
start_time = time.time()

print("Running Inference...")
with torch.no_grad():
    for x, y in val_loader:
        x = x.to(DEVICE)
        preds = torch.sigmoid(model(x))
        preds_bin = (preds > 0.5).cpu().numpy().astype(np.uint8)
        targets_bin = y.numpy().astype(np.uint8)
        total_frames += x.size(0)

        for i in range(preds_bin.shape[0]):
            p, t = preds_bin[i, 0], targets_bin[i, 0]
            if t.sum() > 0 and p.sum() > 0:
                boundary_ious.append(calculate_boundary_iou(t, p))
                hausdorff_dists.append(calculate_hausdorff(t, p))
            else:
                boundary_ious.append(0.0)

        all_preds.append(preds_bin.flatten())
        all_targets.append(targets_bin.flatten())

inference_time = time.time() - start_time
fps = total_frames / inference_time
y_pred_flat = np.concatenate(all_preds)
y_true_flat = np.concatenate(all_targets)

# 8. PRINT REPORT
print("\n" + "="*50)
print("       SATMAE SCRATCH - FINAL METRICS       ")
print("="*50)
print(f"Pixel Accuracy:      {accuracy_score(y_true_flat, y_pred_flat):.4f}")
print(f"IoU (Jaccard):       {jaccard_score(y_true_flat, y_pred_flat):.4f}")
print(f"F1-Score (Dice):     {f1_score(y_true_flat, y_pred_flat):.4f}")
print(f"Precision:           {precision_score(y_true_flat, y_pred_flat):.4f}")
print(f"Recall:              {recall_score(y_true_flat, y_pred_flat):.4f}")
print("-" * 30)
print(f"Boundary IoU:        {np.mean(boundary_ious):.4f}")
print(f"Hausdorff Dist (px): {np.mean(hausdorff_dists):.2f}")
print(f"Inference Speed:     {fps:.2f} FPS")
print("="*50 + "\n")

# 9. ROBUST VISUALIZATION (Fixes IndexError)
def visualize_samples(model, loader, num_samples=3):
    model.eval()
    data_iter = iter(loader)
    x, y = next(data_iter)
    x = x.to(DEVICE)
    with torch.no_grad():
        preds = torch.sigmoid(model(x)).cpu().numpy()

    x = x.cpu().numpy()
    y = y.cpu().numpy()

    # Debug: Print dimensions to understand the error
    print(f"DEBUG: Input Tensor Shape: {x.shape}")
    num_channels = x.shape[1]

    plt.figure(figsize=(15, 5 * num_samples))
    for i in range(num_samples):

        # Safe RGB Extraction
        if num_channels >= 3:
            # We have enough channels for RGB
            rgb = x[i, [2, 1, 0], 0, :, :].transpose(1, 2, 0)
        else:
            # Not enough channels! Fallback to single channel grayscale
            print(f"Warning: Only {num_channels} channels found. Displaying Channel 0.")
            single_band = x[i, 0, 0, :, :]
            rgb = np.stack([single_band, single_band, single_band], axis=2) # replicate to 3 dims

        # Normalize
        rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)

        gt = y[i, 0, :, :]
        pred = preds[i, 0, :, :] > 0.5

        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(rgb)
        plt.title(f"Sample {i+1}: Input")
        plt.axis('off')

        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(gt, cmap='gray')
        plt.title(f"Sample {i+1}: Ground Truth")
        plt.axis('off')

        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(pred, cmap='gray')
        plt.title(f"Sample {i+1}: Prediction")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

print("Generating Visualizations...")
visualize_samples(model, val_loader, num_samples=3)
Mounting Drive...
Mounted at /content/drive
Model Path: /content/drive/MyDrive/SatMAE_Scratch_Results_1/
Data Path:  /content/drive/MyDrive/SatMAE_Scratch_Results_1/
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True
  warnings.warn(
Loading weights from: /content/drive/MyDrive/SatMAE_Scratch_Results_1/best_model.pth
Running Inference...

==================================================
       SATMAE SCRATCH - FINAL METRICS       
==================================================
Pixel Accuracy:      0.9000
IoU (Jaccard):       0.8704
F1-Score (Dice):     0.9307
Precision:           0.9140
Recall:              0.9480
------------------------------
Boundary IoU:        0.0884
Hausdorff Dist (px): 13.66
Inference Speed:     0.28 FPS
==================================================

Generating Visualizations...
DEBUG: Input Tensor Shape: (2, 6, 3, 224, 224)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipython-input-3798441792.py in <cell line: 0>()
    236 
    237 print("Generating Visualizations...")
--> 238 visualize_samples(model, val_loader, num_samples=3)

/tmp/ipython-input-3798441792.py in visualize_samples(model, loader, num_samples)
    204         if num_channels >= 3:
    205             # We have enough channels for RGB
--> 206             rgb = x[i, [2, 1, 0], 0, :, :].transpose(1, 2, 0)
    207         else:
    208             # Not enough channels! Fallback to single channel grayscale

IndexError: index 2 is out of bounds for axis 0 with size 2
No description has been provided for this image
In [6]:
# 6. VISUALIZATION (CORRECTED)
def visualize_samples(model, loader, num_samples=3):
    model.eval()
    data_iter = iter(loader)

    try:
        x, y = next(data_iter)
    except StopIteration:
        print(" Validation loader is empty!")
        return

    # FIX: Check actual batch size vs requested samples
    current_batch_size = x.shape[0]
    actual_samples = min(num_samples, current_batch_size)

    print(f" Visualizing {actual_samples} samples (Available in batch: {current_batch_size})")

    x = x.to(DEVICE)
    with torch.no_grad():
        preds = torch.sigmoid(model(x)).cpu().numpy()

    x = x.cpu().numpy()
    y = y.cpu().numpy()

    if actual_samples == 0:
        return

    plt.figure(figsize=(15, 5 * actual_samples))

    for i in range(actual_samples):
        # Safety check for channels
        if x.shape[1] >= 3:
            # Input RGB (Time Step 0) [2,1,0] for RGB from [B, G, R, NIR...]
            rgb = x[i, [2, 1, 0], 0, :, :].transpose(1, 2, 0)
        else:
            # Fallback for grayscale/single channel
            gray = x[i, 0, 0, :, :]
            rgb = np.stack([gray, gray, gray], axis=2)

        # Normalize 0-1 for display
        rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)

        gt = y[i, 0, :, :]
        pred = preds[i, 0, :, :] > 0.5

        # Plot RGB
        plt.subplot(actual_samples, 3, i*3 + 1)
        plt.imshow(rgb)
        plt.title(f"Sample {i+1}: Input (RGB)")
        plt.axis('off')

        # Plot Ground Truth
        plt.subplot(actual_samples, 3, i*3 + 2)
        plt.imshow(gt, cmap='gray')
        plt.title(f"Sample {i+1}: Ground Truth")
        plt.axis('off')

        # Plot Prediction
        plt.subplot(actual_samples, 3, i*3 + 3)
        plt.imshow(pred, cmap='gray')
        plt.title(f"Sample {i+1}: Prediction")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# Run the visualization
print("Generating Visualizations...")
visualize_samples(model, val_loader, num_samples=3)
Generating Visualizations...
ℹ️ Visualizing 2 samples (Available in batch: 2)
No description has been provided for this image