Skip to content

Instantly share code, notes, and snippets.

@chenxiex
Last active April 28, 2026 15:07
Show Gist options
  • Select an option

  • Save chenxiex/6fc1713e5d0d5eb7b5b63b46da596126 to your computer and use it in GitHub Desktop.

Select an option

Save chenxiex/6fc1713e5d0d5eb7b5b63b46da596126 to your computer and use it in GitHub Desktop.
Eval DTU on VGGT
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..22548d7
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "evaluation/DTUeval-python"]
+ path = evaluation/DTUeval-python
+ url = https://gh-proxy.org/https://github.com/chenxiex/DTUeval-python.git
diff --git a/evaluation/DTUeval-python b/evaluation/DTUeval-python
new file mode 160000
index 0000000..b07b411
--- /dev/null
+++ b/evaluation/DTUeval-python
@@ -0,0 +1 @@
+Subproject commit b07b41157716982e29fda4f367a68772a168bb20
diff --git a/evaluation/README.md b/evaluation/README.md
new file mode 100644
index 0000000..05e4639
--- /dev/null
+++ b/evaluation/README.md
@@ -0,0 +1,148 @@
+# VGGT Evaluation
+
+This repository contains code to reproduce the evaluation results presented in the VGGT paper.
+
+## Table of Contents
+
+- [Camera Pose Estimation on Co3D](#camera-pose-estimation-on-co3d)
+ - [Model Weights](#model-weights)
+ - [Setup](#setup)
+ - [Dataset Preparation](#dataset-preparation)
+ - [Running the Evaluation](#running-the-evaluation)
+ - [Expected Results](#expected-results)
+- [Checklist](#checklist)
+
+## Camera Pose Estimation on Co3D
+
+### Model Weights
+
+We have addressed a minor bug in the publicly released checkpoint related to the TrackHead configuration. Specifically, the `pos_embed` flag was incorrectly set to `False`. The following checkpoint incorporates this fix by fine-tuning the tracker head with `pos_embed` as `True` while preserving all other parameters. This fix will be merged into the main branch in a future update.
+
+```bash
+wget https://huggingface.co/facebook/VGGT_tracker_fixed/resolve/main/model_tracker_fixed_e20.pt
+```
+
+Note: The default checkpoint remains functional, though you may observe a slight performance decrease (approximately 0.3% in AUC@30) when using Bundle Adjustment (BA). If using the default checkpoint, ensure you set `pos_embed` to `False` for the TrackHead. This modification only affects tracking-based evaluations and has no impact on feed-forward estimation performance, as tracking is not utilized in the feed-forward approach.
+
+### Setup
+
+Install the required dependencies:
+
+```bash
+# Install VGGT as a package
+pip install -e .
+
+# Install evaluation dependencies
+pip install pycolmap==3.10.0 pyceres==2.3
+
+# Install LightGlue for keypoint detection
+git clone https://github.com/cvg/LightGlue.git
+cd LightGlue
+python -m pip install -e .
+cd ..
+```
+
+### Dataset Preparation
+
+1. Download the Co3D dataset from the [official repository](https://github.com/facebookresearch/co3d)
+
+2. Preprocess the dataset (approximately 5 minutes):
+```bash
+python preprocess_co3d.py --category all \
+ --co3d_v2_dir /YOUR/CO3D/PATH \
+ --output_dir /YOUR/CO3D/ANNO/PATH
+```
+
+ Replace `/YOUR/CO3D/PATH` with the path to your downloaded Co3D dataset, and `/YOUR/CO3D/ANNO/PATH` with the desired output directory for the processed annotations. Note that the processed data here uses the PyTorch3D camera convention, while the annotation files we provided for training on Hugging Face have already been converted to the OpenCV convention.
+
+
+
+### Running the Evaluation
+
+Choose one of these evaluation modes:
+
+```bash
+# Standard VGGT evaluation
+python test_co3d.py \
+ --model_path /YOUR/MODEL/PATH \
+ --co3d_dir /YOUR/CO3D/PATH \
+ --co3d_anno_dir /YOUR/CO3D/ANNO/PATH \
+ --seed 0
+
+# VGGT with Bundle Adjustment
+python test_co3d.py \
+ --model_path /YOUR/MODEL/PATH \
+ --co3d_dir /YOUR/CO3D/PATH \
+ --co3d_anno_dir /YOUR/CO3D/ANNO/PATH \
+ --seed 0 \
+ --use_ba
+```
+
+
+
+
+### Expected Results
+
+#### Quick Evaluation
+Full evaluation on Co3D can take a long time. For faster trials, you can run with ```--fast_eval```. This does exactly the same but limiting to evaluate over at most 10 sequence per category.
+
+Use `--fast_eval` to test on a subset of data (max 10 sequences per category):
+
+- Feed-forward estimation:
+ - AUC@30: 89.98
+ - AUC@15: 83.89
+ - AUC@5: 67.45
+ - AUC@3: 56.65
+
+- With Bundle Adjustment (`--use_ba`):
+ - AUC@30: 90.52
+ - AUC@15: 85.08
+ - AUC@5: 70.69
+ - AUC@3: 61.32
+
+#### Full Evaluation
+
+- Feedforward estimation achieves a Mean AUC@30 of 89.5% (slightly higher than the 88.2% reported in the paper due to implementation differences)
+- With Bundle Adjustment, you can expect a Mean AUC@30 between 90.5% and 92.5%
+
+> **Note:** For simplicity, this script did not optimize the inference speed, so timing results may differ from those reported in the paper. For example, when using ba, keypoint extractor models are re-initialized for each sequence rather than being loaded once.
+
+
+## Dense MVS Estimation on the DTU Dataset.
+
+### Setup
+
+Install the required dependencies:
+
+```bash
+pip install -e .[evaluation]
+```
+
+### Dataset Preparation
+
+You can use the provided `download_dtu.py` script to download all the needed data:
+
+```bash
+python download_dtu.py --output "${DATASETS_PATH}" --cache "${CACHE_PATH}"
+```
+
+`curl` is needed for this script. All datasets will be downloaded to `${DATASETS_PATH}`, and what downloaded to `${CACHE_PATH}` will be automatically cleaned when finished.
+
+### Running the Evaluation
+
+Currently no ba support. Run the following:
+
+```bash
+python test_dtu.py \
+ --dtu_test_1200_path "${DATASETS_PATH}/dtu-test-1200" \
+ --dtu_depths_path "${DATASETS_PATH}/dtu-depths-raw" \
+ --results_path "${RESULTS_PATH}" \
+ --model_path "${CKPT_FILE_PATH}" \
+```
+
+CKPT file will be downloaded automatically if not already exists. 3d ply points cloud will be saved to `${RESULTS_PATH}/XXX.ply`. You can then use the official matlab evaluation code under `${DATASETS_PATH}/dtu-sample`. I also found a python implementation [DTUeval-python](https://github.com/chenxiex/DTUeval-python). You can use it like this:
+
+```bash
+cd DTUeval-python
+python eval.py --scans true --dataset_dir "${DATASETS_PATH}/dtu-sample/SampleSet/MVS Data/" --input_dir "${RESULTS_PATH}" --mode pcd --vis_out_dir "${out_dir_for_visualization}" --result_file "${result_file_for_score}"
+```
\ No newline at end of file
diff --git a/evaluation/download_dtu.py b/evaluation/download_dtu.py
new file mode 100644
index 0000000..965cbfd
--- /dev/null
+++ b/evaluation/download_dtu.py
@@ -0,0 +1,230 @@
+#!/usr/bin/env python3
+
+import argparse
+import shutil
+import subprocess
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+from zipfile import ZipFile
+
+
+@dataclass
+class DatasetSpec:
+ extract_dir: str # subdirectory name under output
+ zip_name: str # local zip filename (non-MS)
+ url: str # direct download URL (non-MS)
+ ms_repo: str # ModelScope repository (owner/name)
+ ms_file: str # filename within the ModelScope repository
+
+
+DATASETS: list[DatasetSpec] = [
+ DatasetSpec(
+ extract_dir="dtu-test-1200",
+ zip_name="dtu-test-1200.zip",
+ url="https://www.kaggle.com/api/v1/datasets/download/chenxiex/dtu-test-1200",
+ ms_repo="anlorsp/dtu-test-1200",
+ ms_file="dtu-test-1200.zip",
+ ),
+ DatasetSpec(
+ extract_dir="dtu-depths-raw",
+ zip_name="dtu-depths-raw.zip",
+ url="https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip",
+ ms_repo="anlorsp/dtu-depths-raw",
+ ms_file="Depths_raw.zip",
+ ),
+ DatasetSpec(
+ extract_dir="dtu-sample",
+ zip_name="dtu-sample.zip",
+ url="http://roboimagedata2.compute.dtu.dk/data/MVS/SampleSet.zip",
+ ms_repo="anlorsp/dtu-sample",
+ ms_file="dtu-sample.zip",
+ ),
+]
+
+# Only needed for non-MS path: Points are already merged in the MS sample dataset
+_POINTS_SPEC = DatasetSpec(
+ extract_dir="dtu-points",
+ zip_name="dtu-points.zip",
+ url="http://roboimagedata2.compute.dtu.dk/data/MVS/Points.zip",
+ ms_repo="",
+ ms_file="",
+)
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Download and prepare DTU evaluation data")
+ parser.add_argument("--output", default="data", help="Output directory for extracted files")
+ parser.add_argument(
+ "--cache",
+ default=None,
+ help="Cache directory for temporary zip files (default: same as --output)",
+ )
+ parser.add_argument(
+ "--ms",
+ action="store_true",
+ help="Download from ModelScope mirror instead of original sources",
+ )
+ return parser.parse_args()
+
+
+def get_remote_content_length(url: str) -> int | None:
+ """Fetch Content-Length header from remote URL."""
+ try:
+ result = subprocess.run(
+ ["curl", "-sI", "-L", url],
+ capture_output=True,
+ text=True,
+ check=False,
+ timeout=10,
+ )
+ for line in result.stdout.splitlines():
+ if line.lower().startswith("content-length:"):
+ return int(line.split(":")[1].strip())
+ except Exception as e:
+ print(f"Warning: Could not fetch remote file size: {e}", file=sys.stderr)
+ return None
+
+
+def is_download_complete(url: str, local_path: Path) -> bool:
+ """Check if local file size matches the remote Content-Length."""
+ if not local_path.exists():
+ return False
+ remote_size = get_remote_content_length(url)
+ if remote_size is None:
+ return False
+ return local_path.stat().st_size == remote_size
+
+
+def download_with_resume(url: str, output_path: Path) -> None:
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ if output_path.exists():
+ if is_download_complete(url, output_path):
+ print(f"File already complete: {output_path}")
+ return
+ print(f"Resuming download {url} to {output_path}...")
+ else:
+ print(f"Downloading {url} to {output_path}...")
+ subprocess.run(["curl", "-L", "-C", "-", "-o", str(output_path), url], check=True)
+
+
+def download_from_modelscope(ms_repo: str, ms_file: str, cache: Path) -> Path:
+ """Download a file from ModelScope to cache and return its local path."""
+ local_zip = cache / Path(ms_file).name
+ if local_zip.exists():
+ print(f"File already in cache: {local_zip}")
+ return local_zip
+ print(f"Downloading {ms_repo}/{ms_file} from ModelScope...")
+ subprocess.run(
+ ["modelscope", "download", "--dataset", ms_repo, "--include", ms_file, "--local_dir", str(cache)],
+ check=True,
+ )
+ return local_zip
+
+
+def get_zip_path(spec: DatasetSpec, cache: Path, use_ms: bool) -> Path:
+ """Return the expected local zip path for a dataset spec."""
+ if use_ms:
+ return cache / Path(spec.ms_file).name
+ return cache / spec.zip_name
+
+
+def acquire_zip(spec: DatasetSpec, cache: Path, use_ms: bool) -> Path:
+ """Download zip file if needed and return its local path."""
+ if use_ms:
+ return download_from_modelscope(spec.ms_repo, spec.ms_file, cache)
+ zip_path = cache / spec.zip_name
+ download_with_resume(spec.url, zip_path)
+ return zip_path
+
+
+def flatten_nested_directory(extract_dir: Path, zip_path: Path) -> None:
+ nested_dir = extract_dir / zip_path.stem
+ if not nested_dir.is_dir():
+ return
+ for item in nested_dir.iterdir():
+ target = extract_dir / item.name
+ if target.exists():
+ if target.is_dir():
+ shutil.rmtree(target)
+ else:
+ target.unlink()
+ shutil.move(str(item), str(target))
+ try:
+ nested_dir.rmdir()
+ except OSError:
+ pass
+
+
+def extract_zip(zip_path: Path, extract_dir: Path) -> None:
+ """Extract zip and flatten top-level nested directory if present."""
+ extract_dir.mkdir(parents=True, exist_ok=True)
+ print(f"Extracting {zip_path.name}...")
+ with ZipFile(zip_path, "r") as zf:
+ zf.extractall(extract_dir)
+ flatten_nested_directory(extract_dir, zip_path)
+ print(f"Extraction complete: {extract_dir}")
+
+
+def process_dataset(spec: DatasetSpec, output: Path, cache: Path, use_ms: bool) -> None:
+ """Download and extract a dataset, skipping if already present."""
+ extract_dir = output / spec.extract_dir
+ if extract_dir.is_dir():
+ print(f"Directory {extract_dir} already exists. Skipping download and extraction.")
+ return
+ zip_path = acquire_zip(spec, cache, use_ms)
+ extract_zip(zip_path, extract_dir)
+ print(f"Deleting {zip_path.name}...")
+ zip_path.unlink(missing_ok=True)
+ print(f"Dataset {spec.extract_dir} processing complete.")
+
+
+def _merge_points_into_sample(sample_dir: Path, cache: Path) -> None:
+ """Download Points.zip, copy stl files into dtu-sample, then clean up."""
+ points_dir = cache / _POINTS_SPEC.extract_dir
+ points_zip = acquire_zip(_POINTS_SPEC, cache, use_ms=False)
+ extract_zip(points_zip, points_dir)
+ print(f"Merging Points data into sample directory...")
+ source_stl = points_dir / "Points" / "stl"
+ target_stl = sample_dir / "SampleSet" / "MVS Data" / "Points" / "stl"
+ target_stl.mkdir(parents=True, exist_ok=True)
+ for item in source_stl.iterdir():
+ shutil.copy2(item, target_stl / item.name)
+ shutil.rmtree(points_dir, ignore_errors=True)
+ print(f"Deleting {points_zip.name}...")
+ points_zip.unlink(missing_ok=True)
+ print(f"Points merge complete.")
+
+
+def main() -> int:
+ args = parse_args()
+ output = Path(args.output)
+ cache = Path(args.cache) if args.cache else output
+ cache.mkdir(parents=True, exist_ok=True)
+
+ *regular_specs, sample_spec = DATASETS
+ for spec in regular_specs:
+ process_dataset(spec, output, cache, use_ms=args.ms)
+
+ # dtu-sample: non-MS mode requires merging a separate Points.zip into the sample dir
+ sample_dir = output / sample_spec.extract_dir
+ if sample_dir.is_dir():
+ print(f"Directory {sample_dir} already exists. Skipping download and extraction.")
+ else:
+ sample_zip = acquire_zip(sample_spec, cache, use_ms=args.ms)
+ extract_zip(sample_zip, sample_dir)
+ print(f"Deleting {sample_zip.name}...")
+ sample_zip.unlink(missing_ok=True)
+ if not args.ms:
+ _merge_points_into_sample(sample_dir, cache)
+
+ print("All datasets processed successfully!")
+ return 0
+
+
+if __name__ == "__main__":
+ try:
+ raise SystemExit(main())
+ except subprocess.CalledProcessError as exc:
+ print(f"Command failed with exit code {exc.returncode}: {' '.join(exc.cmd)}", file=sys.stderr)
+ raise SystemExit(exc.returncode)
diff --git a/evaluation/test_dtu.py b/evaluation/test_dtu.py
new file mode 100644
index 0000000..7246f47
--- /dev/null
+++ b/evaluation/test_dtu.py
@@ -0,0 +1,257 @@
+import torch
+from vggt.models.vggt import VGGT
+import random
+from pathlib import Path
+import os
+import numpy as np
+from PIL import Image
+import argparse
+import logging
+import sys
+
+from utils import load_model, predict, read_pfm, upsample_images, write_ply, open3d_filter
+
+logger = logging.getLogger(__name__)
+
+
+def configure_logging(level: int = logging.INFO):
+ logging.basicConfig(
+ level=level,
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ stream=sys.stdout,
+ force=True,
+ )
+
+
+def save_predictions(results_path: Path, scene_name: str, predictions, sample_no):
+ save_dict = {
+ "predictions": predictions,
+ "sample_no": sample_no
+ }
+
+ if not results_path.exists():
+ os.makedirs(results_path, exist_ok=True)
+
+ torch.save(save_dict, results_path/f"{scene_name}.pt")
+
+
+def load_predictions(results_path: Path, scene_name: str):
+ results = torch.load(
+ results_path/f"{scene_name}.pt", map_location=torch.device("cpu"))
+ sample_no = results["sample_no"]
+ predictions = results["predictions"]
+ return predictions, sample_no
+
+
+def load_gt_depth(gt_depths_path: Path, sample_no: list[int]):
+ '''
+ Args:
+ gt_depths_path: 真实深度图所在的文件夹路径。gt_depths_path/f"depth_map_{i:04}.pfm"
+ sample_no: 需要加载的真实深度图对应的编号列表
+ Returns:
+ gt_depth: 形状为 (batch_size, H, W),
+ '''
+ sampled_gt_depth_paths = [gt_depths_path /
+ f"depth_map_{i:04}.pfm" for i in sample_no]
+
+ gt_depth = []
+
+ for gt_depth_path in sampled_gt_depth_paths:
+ data, scale = read_pfm(gt_depth_path)
+ gt_depth.append(data*scale)
+
+ gt_depth = torch.from_numpy(np.stack(gt_depth, axis=0)).float()
+ return gt_depth
+
+
+def align_pred_to_gt(
+ pred_depth: np.ndarray,
+ gt_depth: np.ndarray,
+ valid_mask: np.ndarray,
+ min_valid_pixels: int = 100,
+):
+ """
+ Aligns a predicted depth map to a ground truth depth map using scale and shift.
+ The alignment is: gt_aligned_to_pred ≈ scale * pred_depth + shift.
+
+ Args:
+ pred_depth (np.ndarray): The HxW predicted depth map.
+ gt_depth (np.ndarray): The HxW ground truth depth map.
+ valid_mask: (np.ndarray): A boolean mask of the valid pixels in the depth maps.
+ min_valid_pixels (int): The minimum number of valid pixels required for alignment.
+
+ Returns:
+ tuple[float, float, np.ndarray]:
+ - scale (float): The calculated scale factor. (NaN if alignment failed)
+ - shift (float): The calculated shift offset. (NaN if alignment failed)
+ """
+ if pred_depth.shape != gt_depth.shape:
+ raise ValueError(
+ f"Predicted depth shape {pred_depth.shape} must match GT depth shape {gt_depth.shape}"
+ )
+
+ # Extract valid depth values
+ gt_masked = gt_depth[valid_mask]
+ pred_masked = pred_depth[valid_mask]
+
+ if len(gt_masked) < min_valid_pixels:
+ logger.warning(
+ f"Warning: Not enough valid pixels ({len(gt_masked)} < {min_valid_pixels}) to align. "
+ "Using all pixels."
+ )
+ gt_masked = gt_depth.reshape(-1)
+ pred_masked = pred_depth.reshape(-1)
+
+ # Handle case where pred_masked has no variance (e.g., all zeros or a constant value)
+ if np.std(pred_masked) < 1e-6: # Small epsilon to check for near-constant values
+ logger.warning(
+ "Warning: Predicted depth values in the valid mask have near-zero variance. "
+ "Scale is ill-defined. Setting scale=1 and solving for shift only."
+ )
+ scale = 1.0
+ # or np.median(gt_masked) - np.median(pred_masked)
+ shift = np.mean(gt_masked) - np.mean(pred_masked)
+ else:
+ A = np.vstack([pred_masked, np.ones_like(pred_masked)]).T
+ try:
+ x, residuals, rank, s_values = np.linalg.lstsq(
+ A, gt_masked, rcond=None)
+ scale, shift = x[0], x[1]
+ except np.linalg.LinAlgError as e:
+ logger.warning(
+ f"Warning: Least squares alignment failed ({e}). Returning original prediction.")
+ return np.nan, np.nan
+
+ return scale, shift
+
+
+def parse_cam(cam_file: Path):
+ cam_txt = open(cam_file).readlines()
+ def f(xs): return list(map(lambda x: list(map(float, x.strip().split())), xs))
+
+ extr_mat = f(cam_txt[1:5])
+ intr_mat = f(cam_txt[7:10])
+
+ extr_mat = np.array(extr_mat, np.float32)
+ intr_mat = np.array(intr_mat, np.float32)
+
+ return extr_mat, intr_mat
+
+
+def load_data(dtu_test_1200_path: Path, scene_name: str, sample_no: list[int]):
+
+ projs = []
+ rgbs = []
+
+ for view in sample_no:
+ img_file = dtu_test_1200_path / \
+ f"Rectified/{scene_name}/rect_{view+1:03d}_3_r5000.png"
+ cam_file = dtu_test_1200_path/f"Cameras/{view:08}_cam.txt"
+
+ extr_mat, intr_mat = parse_cam(cam_file)
+ proj_mat = np.eye(4)
+ proj_mat[:3, :4] = intr_mat[:3, :3] @ extr_mat[:3, :4]
+ projs.append(torch.from_numpy(proj_mat))
+
+ rgb = np.array(Image.open(img_file))
+ rgbs.append(rgb)
+
+ projs = torch.stack(projs).float()
+
+ # 归一化,维度从[H, W, C]调整为[C, H, W]
+ rgb_tensors = [torch.from_numpy(img.astype(np.float32) / 255.).permute(2, 0, 1)
+ for img in rgbs]
+ rgbs = torch.stack(rgb_tensors) # (B,3,H,W)
+
+ return projs, rgbs
+
+
+if __name__ == "__main__":
+ configure_logging()
+ logger.setLevel(logging.INFO)
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dtu_test_1200_path", type=Path,
+ required=True, help="Path to the DTU testing dataset")
+ parser.add_argument("--dtu_depths_path", type=Path,
+ required=True, help="Path to the DTU raw depth maps")
+ parser.add_argument("--results_path", type=Path, required=True,
+ help="Path to save the DTU testing results")
+ parser.add_argument("--model_path", type=Path,
+ required=False, help="Path to the trained VGGT model")
+ parser.add_argument("--sample_size", type=int, default=49,
+ help="Sample size for prediction")
+ parser.add_argument("--no_pred", action="store_true",
+ help="If set, skip prediction and only load existing predictions")
+ parser.add_argument("--scans", type=int, nargs='+', required=False,
+ help="Scene ID numbers to evaluate (e.g., 1 2 3)")
+ args = parser.parse_args()
+
+ if not args.no_pred and not args.model_path:
+ raise ValueError(
+ "Model path must be provided if not skipping prediction.")
+
+ if args.scans:
+ scene_names = [f"scan{i}" for i in args.scans]
+ else:
+ with open(args.dtu_test_1200_path/"scan_list_test.txt") as f:
+ scene_names = [line.strip() for line in f.readlines()]
+
+ for scene_name in scene_names:
+ # 推理
+ logger.info(f"Processing {scene_name}...")
+ if not args.no_pred:
+ logger.info("Predicting depth maps...")
+ model = load_model(args.model_path)
+ images_path = args.dtu_test_1200_path/"Rectified"/scene_name
+ sample_no = random.sample(range(0, 49), args.sample_size)
+ sampled_image_paths = [
+ images_path/f"rect_{i+1:03d}_3_r5000.png" for i in sample_no]
+ predictions = predict(sampled_image_paths, model)
+ save_predictions(args.results_path, scene_name,
+ predictions, sample_no)
+ del model
+ torch.cuda.empty_cache()
+ else:
+ logger.info("Loading predictions...")
+ predictions, sample_no = load_predictions(
+ args.results_path, scene_name)
+
+ # 对齐
+ logger.info("Aligning predicted depth maps to ground truth...")
+ gt_depths_path = args.dtu_depths_path/"Depths"/scene_name
+ gt_depth = load_gt_depth(gt_depths_path, sample_no)
+ gt_depth_w, gt_depth_h = gt_depth[0].shape[:2]
+ depths = predictions['depth'][0]
+ conf = predictions['depth_conf'][0]
+ upsampled_pred_depth = upsample_images(
+ depths, gt_depth_w, gt_depth_h)
+ upsampled_depth_conf = upsample_images(
+ conf, gt_depth_w, gt_depth_h)
+
+ valid_mask = (gt_depth > 1e-3) & (upsampled_depth_conf > 3)
+
+ align_mask = valid_mask.reshape(-1)
+ align_depth_map = upsampled_pred_depth.reshape(-1)
+ align_gt_depth = gt_depth.reshape(-1)
+
+ scale_val, shift_val = align_pred_to_gt(
+ align_depth_map.cpu().numpy(),
+ align_gt_depth.cpu().numpy(),
+ align_mask.cpu().numpy()
+ )
+
+ scale = torch.tensor(scale_val, dtype=torch.float32)
+ shift = torch.tensor(shift_val, dtype=torch.float32)
+
+ aligned_upsampled_depth = upsampled_pred_depth * scale + shift
+ depths = aligned_upsampled_depth * (upsampled_depth_conf > 3)
+ depths = depths.unsqueeze(1)
+
+ # 点云融合
+ logger.info("Fusing depth maps into point cloud and saving results...")
+ projs, rgbs = load_data(args.dtu_test_1200_path, scene_name, sample_no)
+ points = open3d_filter(depths, projs, rgbs,
+ dist_thresh=1.0, batch_size=20, num_consist=4)
+ write_ply(args.results_path /
+ f"{int(scene_name[4:]):03d}.ply", points)
+ logger.info(f"Finished processing {scene_name}, written to {int(scene_name[4:]):03d}.ply")
diff --git a/evaluation/utils.py b/evaluation/utils.py
new file mode 100644
index 0000000..578afe4
--- /dev/null
+++ b/evaluation/utils.py
@@ -0,0 +1,289 @@
+from vggt.models.vggt import VGGT
+import torch
+from typing import List
+from pathlib import Path
+from vggt.utils.load_fn import load_and_preprocess_images
+import re
+import numpy as np
+from PIL import Image
+import torch.nn.functional as F
+import open3d as o3d
+from urllib.request import urlretrieve
+import os
+import logging
+
+HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
+MODEL_URL = f"{HF_ENDPOINT}/facebook/VGGT-1B/resolve/main/model.pt"
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+if torch.cuda.is_available():
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
+else:
+ dtype = torch.float32
+
+logger = logging.getLogger(__name__)
+
+def load_model(model_path:Path) -> VGGT:
+ if not model_path.exists():
+ logger.info(f"Model doesn't exists. Downloading from {MODEL_URL}...")
+ model_path.parent.mkdir(parents=True, exist_ok=True)
+ urlretrieve(MODEL_URL, model_path)
+
+ model = VGGT()
+ model.load_state_dict(torch.load(model_path, map_location=device))
+ model.eval()
+ model = model.to(device)
+ return model
+
+
+def predict(images_path: List[Path], model: VGGT):
+ sampled_image_names = [str(p) for p in images_path]
+ images = load_and_preprocess_images(sampled_image_names).to(device)
+ with torch.no_grad():
+ with torch.amp.autocast('cuda', dtype=dtype): # pyright: ignore[reportPrivateImportUsage]
+ # Predict attributes including cameras, depth maps, and point maps
+ predictions = model(images)
+ return predictions
+
+
+def read_pfm(filename):
+ file = open(filename, 'rb') # 1. 以二进制读模式打开文件
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().decode('utf-8').rstrip() # 2. 读取第一行头部信息
+ if header == 'PF':
+ color = True # 彩色图像
+ elif header == 'Pf':
+ color = False # 灰度图像
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$',
+ file.readline().decode('utf-8')) # 3. 读取第二行,解析宽度和高度
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+
+ scale = float(file.readline().rstrip()) # 4. 读取第三行,解析缩放因子和字节序
+ if scale < 0: # little-endian 小端序
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian 大端序
+
+ # 5. 读取剩余的二进制数据
+ # np.fromfile 从文件中读取数据,并指定字节序和数据类型 ('f' 代表 float32)
+ data = np.fromfile(file, endian + 'f')
+ # 6. 根据是否为彩色图确定数据形状
+ shape = (height, width, 3) if color else (height, width)
+
+ # 7. 重塑数据并翻转行
+ # PFM 文件的行序是颠倒的 (从下到上存储),所以需要使用 np.flipud 进行翻转
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ file.close() # 8. 关闭文件
+ return data, scale # 9. 返回解析后的数据和缩放因子
+
+def upsample_image(image: torch.Tensor, target_w: int, target_h: int) -> torch.Tensor:
+ """
+ Args:
+ image: 源图像
+ target_w: 目标宽度
+ target_h: 目标高度
+ Returns:
+ 上采样后的图像
+ """
+
+ image = image.squeeze()
+
+ # 2. 转成 numpy(float32)
+ image_np = image.detach().cpu().numpy().astype(np.float32)
+
+ # 3. 用 PIL BICUBIC 上采样
+ img = Image.fromarray(image_np)
+ img_up_np = np.array(
+ img.resize((target_h, target_w), Image.Resampling.BICUBIC)
+ )
+
+ # 4. 转回 torch Tensor
+ img_up = torch.from_numpy(img_up_np).float()
+
+ return img_up
+
+
+def upsample_images(images: torch.Tensor, target_w: int, target_h: int) -> torch.Tensor:
+ upsampled_images = []
+
+ for d in images:
+ d = upsample_image(d, target_w, target_h)
+ upsampled_images.append(d)
+
+ upsampled_images = torch.stack(upsampled_images, dim=0)
+ return upsampled_images
+
+
+def generate_points_from_depth(depth, proj):
+ '''
+ :param depth: (B, 1, H, W)
+ :param proj: (B, 4, 4)
+ :return: point_cloud (B, 3, H, W)
+ '''
+ batch, height, width = depth.shape[0], depth.shape[2], depth.shape[3]
+ inv_proj = torch.inverse(proj)
+
+ rot = inv_proj[:, :3, :3] # [B,3,3]
+ trans = inv_proj[:, :3, 3:4] # [B,3,1]
+
+ y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth.device),
+ torch.arange(0, width, dtype=torch.float32, device=depth.device)])
+ y, x = y.contiguous(), x.contiguous()
+ y, x = y.view(height * width), x.view(height * width)
+ # [u,v,1]
+ xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
+ xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
+ # (RK)^{-1}*[u,v,1]
+ rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
+ # (RK)^{-1}*[u,v,1]*d
+ rot_depth_xyz = rot_xyz * depth.view(batch, 1, -1)
+ # (RK)^{-1}*[u,v,1]*d+t
+ proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1) # [B, 3, H*W]
+ proj_xyz = proj_xyz.view(batch, 3, height, width)
+
+ return proj_xyz
+
+
+def homo_warping(src_fea, src_proj, ref_proj, depth_values):
+ '''
+ 该函数将src_fea从src_proj投影到ref_proj。首先利用src_proj和ref_proj算出depth_values在src上的投影,然后根据这个投影的坐标对src_fea进行采样。
+ Args:
+ src_fea: (B, C, H, W)
+ src_proj: (B, 4, 4)
+ ref_proj: (B, 4, 4)
+ depth_values: (B, H, W)
+ Returns:
+ warped_src_fea: (B, C, H, W)
+ '''
+ batch, channels = src_fea.shape[0], src_fea.shape[1]
+ height, width = src_fea.shape[2], src_fea.shape[3]
+
+ with torch.no_grad():
+ proj = torch.matmul(src_proj, torch.inverse(ref_proj))
+ rot = proj[:, :3, :3] # [B,3,3]
+ trans = proj[:, :3, 3:4] # [B,3,1]
+
+ y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device),
+ torch.arange(0, width, dtype=torch.float32, device=src_fea.device)])
+ y, x = y.contiguous(), x.contiguous()
+ y, x = y.view(height * width), x.view(height * width)
+ xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
+ xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
+ rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
+
+ rot_depth_xyz = rot_xyz.unsqueeze(
+ 2) * depth_values.view(-1, 1, 1, height*width) # [B, 3, 1, H*W]
+
+ proj_xyz = rot_depth_xyz + \
+ trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W]
+ proj_xy = proj_xyz[:, :2, :, :] / \
+ proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
+ proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1
+ proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
+ # [B, Ndepth, H*W, 2]
+ proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3)
+ grid = proj_xy
+
+ warped_src_fea = F.grid_sample(src_fea, grid.view(batch, height, width, 2), mode='bilinear',
+ padding_mode='zeros')
+ warped_src_fea = warped_src_fea.view(batch, channels, height, width)
+
+ return warped_src_fea
+
+
+def filter_depth(ref_depth, src_depths, ref_proj, src_projs):
+ ref_pc = generate_points_from_depth(ref_depth, ref_proj)
+ src_pcs = generate_points_from_depth(src_depths, src_projs)
+
+ aligned_pcs = homo_warping(src_pcs, src_projs, ref_proj, ref_depth)
+
+ x_2 = (ref_pc[:, 0] - aligned_pcs[:, 0])**2
+ y_2 = (ref_pc[:, 1] - aligned_pcs[:, 1])**2
+ z_2 = (ref_pc[:, 2] - aligned_pcs[:, 2])**2
+ dist = torch.sqrt(x_2 + y_2 + z_2).unsqueeze(1)
+
+ return ref_pc, aligned_pcs, dist
+
+
+def write_ply(file: Path, points):
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points[:, :3])
+ pcd.colors = o3d.utility.Vector3dVector(points[:, 3:] / 255.)
+ o3d.io.write_point_cloud(file, pcd, write_ascii=False)
+
+
+def extract_points(pc, mask, rgb):
+ pc = pc.cpu()
+ mask = mask.cpu()
+ rgb = rgb.cpu()
+
+ pc = pc.numpy()
+ mask = mask.numpy()
+ rgb = rgb.numpy()
+
+ mask = np.reshape(mask, (-1,))
+ pc = np.reshape(pc, (-1, 3))
+ rgb = np.reshape(rgb, (-1, 3))
+
+ points = pc[np.where(mask)]
+ colors = rgb[np.where(mask)]
+
+ points_with_color = np.concatenate([points, colors], axis=1)
+
+ return points_with_color
+
+
+def open3d_filter(depths: torch.Tensor, projs: torch.Tensor, rgbs: torch.Tensor, dist_thresh: float = 1.0, batch_size: int = 20, num_consist: int = 4):
+ with torch.no_grad():
+ tot_frame = depths.shape[0]
+ height, width = depths.shape[2], depths.shape[3]
+ points = []
+
+ for i in range(tot_frame):
+ pc_buff = torch.zeros((3, height, width),
+ device=depths.device, dtype=depths.dtype)
+ val_cnt = torch.zeros((1, height, width),
+ device=depths.device, dtype=depths.dtype)
+ j = 0
+
+ while True:
+ ref_pc, pcs, dist = filter_depth(
+ ref_depth=depths[i:i+1],
+ src_depths=depths[j:min(j+batch_size, tot_frame)],
+ ref_proj=projs[i:i+1],
+ src_projs=projs[j:min(j+batch_size, tot_frame)]
+ )
+
+ depth_mask = (dist < dist_thresh).float()
+
+ masks = depth_mask
+
+ masked_pc = pcs * masks
+ pc_buff += masked_pc.sum(dim=0, keepdim=False)
+ val_cnt += masks.sum(dim=0, keepdim=False)
+
+ j += batch_size
+ if j >= tot_frame:
+ break
+
+ final_mask = (val_cnt >= num_consist).squeeze(0)
+ avg_points = torch.div(pc_buff, val_cnt).permute(1, 2, 0)
+
+ final_pc = extract_points(avg_points, final_mask, rgbs[i])
+ points.append(final_pc)
+
+ points = np.concatenate(points, axis=0)
+ return points
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 81d4f1d..2db8d85 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,6 +26,15 @@ demo = [
"trimesh",
"matplotlib",
]
+evaluation = [
+ "torch",
+ "torchvision",
+ "tqdm",
+ "open3d",
+ "scipy",
+ "scikit-learn",
+ "plyfile"
+]
# Using setuptools as the build backend
[build-system]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment