-
Notifications
You must be signed in to change notification settings - Fork 334
Implemented Coca architecture #2371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
VarunS1997
wants to merge
13
commits into
master
Choose a base branch
from
model-impl/CoCa
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
20cdf41
Implemented Coca architecture
VarunS1997 b8c0ba4
Minor clean-up
VarunS1997 bbe17c4
Fixed depth of decoders
VarunS1997 202526f
Updated config to match args
VarunS1997 367dd39
Moved layer definitions to build and added build calls for each layer
VarunS1997 80ea7d3
Unabbreviated 'contrastive' and 'captioning'
VarunS1997 3feacb6
Improved documentation and added output sizing to call(), also built …
VarunS1997 f15408f
Lowercased coca model directory and added to kokoro build
VarunS1997 960873f
Addressed comments by Matt; reformatted as well
VarunS1997 33cff54
Addressed comments related to attn pooling size, attn pooling name
VarunS1997 145d7b5
Wrote a test for coca saving and loading, which prompted some model c…
VarunS1997 e8623a9
Updated to functional model
VarunS1997 c9e1ec1
added size inputs for functional model
VarunS1997 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from keras import layers | ||
|
||
|
||
class AttentionPooling(layers.Layer): | ||
"""Implements the Pooled Attention Layer used in "CoCa": Contrastive Captioners are Image-Text Foundation Models" | ||
(https://arxiv.org/pdf/2205.01917.pdf), consisting of a Multiheaded Attention followed by Layer Normalization. | ||
VarunS1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
:param proj_dim: The dimensions of the attention heads | ||
:param num_heads: The number of attention heads in the multi-headed attention layer | ||
""" | ||
def __init__(self, | ||
VarunS1997 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
proj_dim, | ||
num_heads, | ||
**kwargs): | ||
super().__init__(self, **kwargs) | ||
|
||
self.proj_dim = proj_dim | ||
self.num_heads = num_heads | ||
|
||
def build(self, input_shape): | ||
self.multi_head_attn = layers.MultiHeadAttention( | ||
VarunS1997 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
self.num_heads, | ||
self.proj_dim | ||
) | ||
|
||
self.layer_norm = layers.LayerNormalization() | ||
VarunS1997 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
def call(self, query, value): | ||
x = self.multi_head_attn(query, value) | ||
return self.layer_norm(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright 2024 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import numpy as np | ||
from keras import Sequential | ||
from keras_cv.api_export import keras_cv_export | ||
from keras_nlp.layers import RotaryEmbedding, TransformerDecoder | ||
from keras_cv.layers import TransformerEncoder as CVTransformerEncoder | ||
from keras_cv.models.task import Task | ||
from keras_cv.layers.attention_pooling import AttentionPooling | ||
from keras_cv.layers.vit_layers import PatchingAndEmbedding | ||
|
||
|
||
@keras_cv_export(["keras_cv.models.CoCa"]) | ||
class CoCa(Task): | ||
VarunS1997 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
def __init__(self, | ||
img_query_dim, | ||
text_proj_dim, | ||
img_patch_size=18, | ||
encoder_depth=40, | ||
encoder_heads=16, | ||
encoder_intermediate_dim=6144, | ||
encoder_width=1408, | ||
unimodal_decoder_depth=18, | ||
multimodal_decoder_depth=18, | ||
decoder_intermediate_dim=5632, | ||
unimodal_decoder_heads=16, | ||
multimodal_decoder_heads=16, | ||
con_queries=1, | ||
VarunS1997 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
cap_queries=256, | ||
con_heads=16, | ||
cap_heads=16, | ||
cap_loss_weight=0.5, | ||
con_loss_weight=0.5, | ||
**kwargs): | ||
super().__init__(**kwargs) | ||
|
||
self.img_patch_size = img_patch_size | ||
self.img_query_dim = img_query_dim | ||
|
||
self.encoder_depth = encoder_depth | ||
self.encoder_heads = encoder_heads | ||
self.encoder_width = encoder_width | ||
self.encoder_intermediate_dim = encoder_intermediate_dim | ||
|
||
self.text_proj_dim = text_proj_dim | ||
self.unimodal_decoder_depth = unimodal_decoder_depth | ||
self.multimodal_decoder_depth = multimodal_decoder_depth | ||
self.decoder_intermediate_dim = decoder_intermediate_dim | ||
self.unimodal_decoder_heads = unimodal_decoder_heads | ||
self.multimodal_decoder_heads = multimodal_decoder_heads | ||
|
||
self.con_queries = con_queries | ||
self.con_heads = con_heads | ||
self.con_loss_weight = con_loss_weight | ||
|
||
self.cap_queries = cap_queries | ||
self.cap_heads = cap_heads | ||
self.cap_loss_weight = cap_loss_weight | ||
|
||
def build(self, input_shape): | ||
super().build(input_shape) | ||
|
||
self.image_patching = PatchingAndEmbedding(self.encoder_width, self.img_patch_size) | ||
VarunS1997 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
self.image_encoder = Sequential([ | ||
CVTransformerEncoder(self.img_query_dim, self.encoder_heads, self.encoder_intermediate_dim) | ||
for _ in range(self.encoder_depth) | ||
VarunS1997 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
]) | ||
|
||
self.cls_token = self.add_weight(shape=[1, 1, self.text_proj_dim], name="cls_token", trainable=True) | ||
|
||
self.text_embedding = RotaryEmbedding() | ||
self.unimodal_text_decoder = Sequential([ | ||
TransformerDecoder(self.decoder_intermediate_dim, self.unimodal_decoder_heads) | ||
for _ in range(self.unimodal_decoder_depth) | ||
]) | ||
self.multimodal_text_decoder = Sequential([ | ||
TransformerDecoder(self.decoder_intermediate_dim, self.multimodal_decoder_heads) | ||
for _ in range(self.multimodal_decoder_depth) | ||
]) | ||
|
||
self.con_query = self.add_weight(shape=[1, 1, self.con_queries], trainable=True) | ||
self.cap_query = self.add_weight(shape=[1, 1, self.cap_queries], trainable=True) | ||
|
||
self.con_attn_pooling = AttentionPooling(self.img_query_dim, self.con_heads) | ||
self.cap_attn_pooling = AttentionPooling(self.img_query_dim, self.cap_heads) | ||
|
||
def call(self, images, texts): | ||
""" | ||
Forward pass of the Coca Model | ||
|
||
:param images: [batch_size, height, width, channels] representing images | ||
:param texts: Tensor, typically represented as [batch_size, sequence_length, feature_length] or | ||
[batch_size, sequence_length, num_heads, feature_length]. The sequence_length and/or feature_length | ||
VarunS1997 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
are required. | ||
:return: output of the captioning Transformer Decoder with captioning cross-attention | ||
""" | ||
img_encoding = self.image_patching(images) | ||
img_encoding = self.image_encoder(img_encoding) # [batch, patches_len+1, img_query_dim] | ||
|
||
# This is only needed for loss calculations | ||
# con_feature = self.con_attn_pooling(self.con_query, img_encoding) | ||
cap_feature = self.cap_attn_pooling(self.cap_query, img_encoding) | ||
|
||
text_tokens = np.concatenate(texts, self.cls_token) | ||
mask = np.concatenate((np.ones_like(texts), np.zeros_like(self.cls_token))) | ||
|
||
embed_text = self.text_embedding(text_tokens) | ||
unimodal_out = self.unimodal_text_decoder(embed_text, attention_mask=mask) | ||
multimodal_out = self.multimodal_text_decoder(unimodal_out[:, :-1, :], | ||
encoder_sequence=cap_feature, | ||
decoder_attention_mask=mask) | ||
|
||
return multimodal_out | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"img_patch_size": self.img_patch_size, | ||
"img_query_dim": self.img_query_dim, | ||
"encoder_depth": self.encoder_depth, | ||
"encoder_heads": self.encoder_heads, | ||
"encoder_width": self.encoder_width, | ||
"encoder_intermediate_dim": self.encoder_intermediate_dim, | ||
"text_proj_dim": self.text_proj_dim, | ||
"unimodal_decoder_depth": self.unimodal_decoder_depth, | ||
"multimodal_decoder_depth": self.multimodal_decoder_depth, | ||
"decoder_intermediate_dim": self.decoder_intermediate_dim, | ||
"unimodal_decoder_heads": self.unimodal_decoder_heads, | ||
"multimodal_decoder_heads": self.multimodal_decoder_heads, | ||
"con_queries": self.con_queries, | ||
"con_heads": self.con_heads, | ||
"con_loss_weight": self.con_loss_weight, | ||
"cap_queries": self.cap_queries, | ||
"cap_heads": self.cap_heads, | ||
"cap_loss_weight": self.cap_loss_weight, | ||
} | ||
) | ||
return config |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.