Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ImageReward/ImageReward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from PIL import Image
from .models.BLIP.blip_pretrain import BLIP_Pretrain
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers import BertTokenizer

try:
from torchvision.transforms import InterpolationMode
Expand Down Expand Up @@ -69,11 +70,11 @@ def forward(self, input):


class ImageReward(nn.Module):
def __init__(self, med_config, device='cpu'):
def __init__(self, med_config, device='cpu', tokenizer=None):
super().__init__()
self.device = device

self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, tokenizer=tokenizer)
self.preprocess = _transform(224)
self.mlp = MLP(768)

Expand Down
5 changes: 3 additions & 2 deletions ImageReward/models/BLIP/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from .vit import VisionTransformer, interpolate_pos_embed


def init_tokenizer():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def init_tokenizer(tokenizer: BertTokenizer = None):
if tokenizer is None:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
Expand Down
5 changes: 3 additions & 2 deletions ImageReward/models/BLIP/blip_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from .blip import create_vit, init_tokenizer

class BLIP_Pretrain(nn.Module):
def __init__(self,
def __init__(self,
tokenizer,
med_config = "med_config.json",
image_size = 224,
vit = 'base',
Expand All @@ -31,7 +32,7 @@ def __init__(self,

self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)

self.tokenizer = init_tokenizer()
self.tokenizer = init_tokenizer(tokenizer)
encoder_config = BertConfig.from_json_file(med_config)
encoder_config.encoder_width = vision_width
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from setuptools import setup, find_packages
import os
import pkg_resources
from pathlib import Path

long_description = (Path(__file__).parent / "README.md").read_text()
Expand Down