Download an OpenImages subset (single split) into an ImageFolder layout.
root/
train/
CLASS_ID/
image1.jpg
image2.jpg
...
val/
CLASS_ID/
...
Parameters
outdir : Path
Destination directory.
split : {"train", "validation"}
Which OpenImages split to download.
max_classes : int
How many classes to download. OpenImages is huge;
limiting classes makes a sane mid-size dataset.
num_workers : int
Parallel download workers.
Source code in src/ezpz/data/utils.py
| def download_openimages_subset(
outdir: str | Path,
split: str = "train",
max_classes: int = 50,
num_workers: int = 32,
) -> None:
"""
Download an OpenImages subset (single split) into an ImageFolder layout.
root/
train/
CLASS_ID/
image1.jpg
image2.jpg
...
val/
CLASS_ID/
...
Parameters
----------
outdir : Path
Destination directory.
split : {"train", "validation"}
Which OpenImages split to download.
max_classes : int
How many classes to download. OpenImages is huge;
limiting classes makes a sane mid-size dataset.
num_workers : int
Parallel download workers.
"""
outdir = Path(outdir)
split = split.lower()
assert split in {"train", "validation"}
# ---------------------------------------------------------------------
# STEP 1 β download class descriptions
# ---------------------------------------------------------------------
class_csv_path = outdir / "class-descriptions.csv"
if not class_csv_path.exists():
class_csv_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Downloading class descriptions β {class_csv_path}")
urlretrieve(ANNOTATIONS_URLS["class-descriptions"], class_csv_path)
class_map = {}
with open(class_csv_path, "r") as f:
reader = csv.reader(f)
for cid, cname in reader:
class_map[cid] = cname.replace(" ", "_")
# ---------------------------------------------------------------------
# STEP 2 β download annotation CSV for this split
# ---------------------------------------------------------------------
ann_path = outdir / f"{split}-annotations.csv"
if not ann_path.exists():
print(f"Downloading annotations for {split} β {ann_path}")
urlretrieve(ANNOTATIONS_URLS[split], ann_path)
# ---------------------------------------------------------------------
# STEP 3 β collect image β class assignments
# ---------------------------------------------------------------------
print(f"Parsing annotations: {ann_path}")
image_to_classes = defaultdict(list)
with ann_path.open("r") as f:
reader = csv.DictReader(f)
for row in reader:
if row["Confidence"] == "1":
cid = row["LabelName"]
if cid in class_map:
image_to_classes[row["ImageID"]].append(cid)
# Keep only a subset for a manageable dataset size
print("Collecting top classes...")
class_counts = defaultdict(int)
for classes in image_to_classes.values():
for cid in classes:
class_counts[cid] += 1
# Select top frequent classes
top_classes = sorted(class_counts, key=class_counts.get, reverse=True)[
:max_classes
]
top_classes = set(top_classes)
# Filter: keep only images that contain at least one top class
filtered = {
img: [cid for cid in classes if cid in top_classes]
for img, classes in image_to_classes.items()
if any(cid in top_classes for cid in classes)
}
print(
f"Selected {len(filtered)} images across {len(top_classes)} classes."
)
# ---------------------------------------------------------------------
# STEP 4 β Create output directories
# ---------------------------------------------------------------------
for cid in top_classes:
cname = class_map[cid]
(outdir / split / cname).mkdir(parents=True, exist_ok=True)
# ---------------------------------------------------------------------
# STEP 5 β define download worker
# ---------------------------------------------------------------------
def download_image(image_id: str, cids: list[str]):
# All OpenImages JPEGs follow this pattern
url = IMAGES_BASE + f"{split}/{image_id}.jpg"
# Save to first class directory only (multi-label ignored)
class_id = cids[0]
class_name = class_map[class_id]
dst = outdir / split / class_name / f"{image_id}.jpg"
if dst.exists():
return
try:
urlretrieve(url, dst)
except Exception as e:
print(f"Failed {image_id}: {e}")
# ---------------------------------------------------------------------
# STEP 6 β parallel download
# ---------------------------------------------------------------------
print("Downloading images...")
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as ex:
futures = [
ex.submit(download_image, img_id, cids)
for img_id, cids in filtered.items()
]
for i, f in enumerate(concurrent.futures.as_completed(futures), 1):
if i % 500 == 0:
print(f"Downloaded {i} images...")
print("Finished!")
|