forked from RL4VLM/RL4VLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathalf_utils.py
More file actions
145 lines (129 loc) · 5.81 KB
/
alf_utils.py
File metadata and controls
145 lines (129 loc) · 5.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import yaml
import torchvision.transforms as T
from alfworld.agents.environment.alfred_thor_env import AlfredThorEnv
import gymnasium as gym
from gymnasium import spaces
import alfworld.agents.environment as environment
from typing import Optional
import numpy as np
import torch
import random
ALF_ACTION_LIST=["pass", "goto", "pick", "put", "open", "close", "toggle", "heat", "clean", "cool", "slice", "inventory", "examine", "look"]
# ALF_ITEM_LIST =
def load_config_file(path):
assert os.path.exists(path), "Invalid config file"
with open(path) as reader:
config = yaml.safe_load(reader)
return config
def get_obs_image(env):
transform = T.Compose([T.ToTensor()])
current_frames = env.get_frames()
image_tensors = [transform(i).cuda() for i in current_frames]
for i in range(len(image_tensors)):
image_tensors[i] = image_tensors[i].permute(1, 2, 0)
image_tensors[i]*= 255
image_tensors[i] = image_tensors[i].int()
image_tensors[i] = image_tensors[i][:,:,[2,1,0]]
image_tensors = torch.stack(image_tensors, dim=0)
return image_tensors
class AlfEnv(gym.Env):
def __init__(self, config_file):
config = load_config_file(config_file)
env_type = config['env']['type']
env = getattr(environment, env_type)(config, train_eval='train')
self.env = env.init_env(batch_size=1)
self.action_space = spaces.Discrete(len(ALF_ACTION_LIST))
self.observation_space = spaces.Box(low=0, high=255, shape=(300, 300, 3), dtype=np.uint8)
# Add the previous admissible commands for step
self.prev_admissible_commands = None
self.num_envs = 1
def step(self, action):
## SZ.3.4: sanity checking legal action as rewards
action, legal_action = process_action(self.env, action, self.prev_admissible_commands)
obs, scores, dones, infos = self.env.step(action)
infos['observation_text'] = obs
reward = compute_reward(infos, legal_action)
self.prev_admissible_commands = list(infos['admissible_commands'])[0]
return self._get_obs(), reward, dones, infos
def reset(
self,
seed=42,
):
self.env.seed(seed)
obs, infos = self.env.reset()
infos['observation_text'] = obs
self.prev_admissible_commands = list(infos['admissible_commands'])[0]
return self._get_obs(), infos
def _get_obs(self):
image = get_obs_image(self.env)
return image
def process_action(env, action=None, action_list=None):
"""
An function to process the action
env: the environment should be of type AlfredThorEnv
action: the list of action to be processeed, it is a list of strings.
"""
if type(env) != AlfEnv and type(env) != AlfredThorEnv:
pass
else:
legal_action = False
for i in range(len(action)):
action[i] = action[i].lower()
# TODO: need to figure this out
if len(action[i]) == 0:
print("Action is empty!!!!")
# randomly choose an action from the action list if illegal
action[i] = action_list[random.randint(0, len(action_list)-1)]
else:
try:
action_index = action[i].find('"action":')
# string has the following format '"action": "look"\n}'
if action_index == -1:
# if we cannot find "action":, then we pick the last 30 characters
string = action[i][-30:]
else:
string = action[i][action_index:]
# post processing by removing the first and last part of the string
for act in action_list:
if act in string:
action[i] = act
# if found legal action, set legal_action = True
legal_action = True
break
except:
# randomly choose an action from the action list if illegal
action[i] = action_list[random.randint(0, len(action_list)-1)]
return action, legal_action
def compute_reward(infos, legal_action):
# A function to compute the shaped reward for the alfworld environment
# infos: the info returned by the environment
# legal_action: a boolean value to indicate if the action is legal
## Tentative rewards: r = success_reward * 10 + goal_conditioned_r - 1*illegal_action
reward = 50*float(infos['won'][0]) + float(infos['goal_condition_success_rate'][0])
if not legal_action:
# adding a reward penalty to illegal actions
reward -= 1
reward = [reward]
return torch.tensor(reward)
def get_encoded_text(observation_text, tokenizer, model):
encoded_input = tokenizer(observation_text, return_tensors='pt')
outputs = model(**encoded_input)
cls_embeddings = outputs.last_hidden_state[:,0,:]
return cls_embeddings
def get_concat(obs, infos, tokenizer, model, device):
assert 'observation_text' in infos.keys(), 'observation_text not in infos!'
obs_text = infos['observation_text']
obs_text_encode = get_encoded_text(obs_text, tokenizer, model)
obs_text_encode = obs_text_encode.to(device)
obs_cat = torch.cat((obs.flatten(start_dim=1), obs_text_encode), dim=1)
return obs_cat
def get_cards_concat(obs, infos, tokenizer, model, device):
## Need to move these codes to a CNN utils or something
assert 'Formula' in infos[0].keys(), 'Formula not in infos!'
infos = infos[0]
formula_list = infos['Formula']
formula = "".join([str("".join([str(x) for x in formula_list]))])
obs_text_encode = get_encoded_text(formula, tokenizer, model).to(device)
obs_cat = torch.cat((obs.flatten(start_dim=1), obs_text_encode), dim=1)
return obs_cat